oximedia_gpu/
async_compute.rs1use std::collections::HashMap;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum TaskState {
30 Pending,
32 Running,
34 Complete,
36 Failed(String),
38}
39
40#[derive(Debug)]
42struct TaskRecord {
43 state: TaskState,
44 data: Vec<u8>,
46}
47
48#[derive(Debug, Default)]
56pub struct AsyncComputeQueue {
57 tasks: HashMap<u64, TaskRecord>,
59 pub submission_count: u64,
61 pub completed_count: u64,
63}
64
65impl AsyncComputeQueue {
66 #[must_use]
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn submit(&mut self, task_id: u64, data: Vec<u8>) {
79 self.submission_count += 1;
80 self.tasks.insert(
83 task_id,
84 TaskRecord {
85 state: TaskState::Complete,
86 data,
87 },
88 );
89 }
90
91 pub fn poll(&mut self, task_id: u64) -> Option<Vec<u8>> {
98 if let Some(record) = self.tasks.get(&task_id) {
99 if record.state == TaskState::Complete {
100 let record = self.tasks.remove(&task_id)?;
102 self.completed_count += 1;
103 return Some(record.data);
104 }
105 }
106 None
107 }
108
109 #[must_use]
114 pub fn state(&self, task_id: u64) -> Option<&TaskState> {
115 self.tasks.get(&task_id).map(|r| &r.state)
116 }
117
118 pub fn cancel(&mut self, task_id: u64) -> bool {
122 self.tasks.remove(&task_id).is_some()
123 }
124
125 #[must_use]
127 pub fn active_count(&self) -> usize {
128 self.tasks.len()
129 }
130
131 #[must_use]
133 pub fn is_empty(&self) -> bool {
134 self.tasks.is_empty()
135 }
136
137 pub fn fail_task(&mut self, task_id: u64, error: String) {
140 if let Some(record) = self.tasks.get_mut(&task_id) {
141 record.state = TaskState::Failed(error);
142 }
143 }
144
145 #[must_use]
147 pub fn is_failed(&self, task_id: u64) -> bool {
148 matches!(
149 self.tasks.get(&task_id).map(|r| &r.state),
150 Some(TaskState::Failed(_))
151 )
152 }
153}
154
155#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_submit_and_poll_returns_data() {
163 let mut q = AsyncComputeQueue::new();
164 q.submit(1, vec![10, 20, 30]);
165 let result = q.poll(1);
166 assert_eq!(result, Some(vec![10, 20, 30]));
167 }
168
169 #[test]
170 fn test_poll_twice_returns_none_second_time() {
171 let mut q = AsyncComputeQueue::new();
172 q.submit(42, vec![1]);
173 assert!(q.poll(42).is_some());
174 assert!(q.poll(42).is_none());
175 }
176
177 #[test]
178 fn test_poll_unknown_task_returns_none() {
179 let mut q = AsyncComputeQueue::new();
180 assert!(q.poll(99).is_none());
181 }
182
183 #[test]
184 fn test_multiple_tasks_independent() {
185 let mut q = AsyncComputeQueue::new();
186 q.submit(1, vec![0xAA]);
187 q.submit(2, vec![0xBB]);
188 assert_eq!(q.poll(2), Some(vec![0xBB]));
189 assert_eq!(q.poll(1), Some(vec![0xAA]));
190 }
191
192 #[test]
193 fn test_cancel_removes_task() {
194 let mut q = AsyncComputeQueue::new();
195 q.submit(7, vec![0xFF]);
196 assert!(q.cancel(7));
197 assert!(q.poll(7).is_none());
198 }
199
200 #[test]
201 fn test_submission_count_increments() {
202 let mut q = AsyncComputeQueue::new();
203 q.submit(1, vec![]);
204 q.submit(2, vec![]);
205 assert_eq!(q.submission_count, 2);
206 }
207
208 #[test]
209 fn test_completed_count_increments_on_poll() {
210 let mut q = AsyncComputeQueue::new();
211 q.submit(1, vec![1]);
212 q.poll(1);
213 assert_eq!(q.completed_count, 1);
214 }
215
216 #[test]
217 fn test_state_complete_after_submit() {
218 let q = {
219 let mut q = AsyncComputeQueue::new();
220 q.submit(5, vec![5]);
221 q
222 };
223 assert_eq!(q.state(5), Some(&TaskState::Complete));
224 }
225
226 #[test]
227 fn test_active_count_decreases_on_poll() {
228 let mut q = AsyncComputeQueue::new();
229 q.submit(1, vec![]);
230 q.submit(2, vec![]);
231 assert_eq!(q.active_count(), 2);
232 q.poll(1);
233 assert_eq!(q.active_count(), 1);
234 }
235
236 #[test]
237 fn test_is_empty_after_all_polled() {
238 let mut q = AsyncComputeQueue::new();
239 q.submit(1, vec![1]);
240 q.poll(1);
241 assert!(q.is_empty());
242 }
243
244 #[test]
245 fn test_fail_task_marks_failed() {
246 let mut q = AsyncComputeQueue::new();
247 q.submit(3, vec![]);
248 q.fail_task(3, "shader compile error".into());
249 assert!(q.is_failed(3));
250 }
251
252 #[test]
253 fn test_resubmit_replaces_previous() {
254 let mut q = AsyncComputeQueue::new();
255 q.submit(1, vec![0x01]);
256 q.submit(1, vec![0x02]); assert_eq!(q.poll(1), Some(vec![0x02]));
258 }
259
260 #[test]
261 fn test_empty_payload_allowed() {
262 let mut q = AsyncComputeQueue::new();
263 q.submit(0, vec![]);
264 assert_eq!(q.poll(0), Some(vec![]));
265 }
266}