Skip to main content

oximedia_gpu/
async_compute.rs

1//! Async compute queue for overlapping compute and transfer operations.
2//!
3//! Provides a CPU-side task queue where compute jobs can be submitted by
4//! `task_id` and polled for completion.  On the CPU fallback backend, tasks
5//! are executed synchronously on submission; the queue abstraction
6//! future-proofs the API for true async GPU execution when a WGPU device is
7//! available.
8//!
9//! # Example
10//!
11//! ```rust
12//! use oximedia_gpu::async_compute::AsyncComputeQueue;
13//!
14//! let mut queue = AsyncComputeQueue::new();
15//! queue.submit(1, vec![0x01, 0x02]);
16//! let result = queue.poll(1);
17//! assert!(result.is_some());
18//! assert_eq!(result.unwrap(), vec![0x01, 0x02]);
19//! // Polling a second time returns None (result already consumed).
20//! assert!(queue.poll(1).is_none());
21//! ```
22
23use std::collections::HashMap;
24
25// ── Task state ────────────────────────────────────────────────────────────────
26
27/// Lifecycle state of a submitted compute task.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum TaskState {
30    /// Submitted, waiting for GPU execution to begin.
31    Pending,
32    /// GPU execution in progress (stub: immediately transitions to Complete).
33    Running,
34    /// Execution finished; result data is available.
35    Complete,
36    /// Execution failed; error message is stored.
37    Failed(String),
38}
39
40/// Internal record for a tracked compute task.
41#[derive(Debug)]
42struct TaskRecord {
43    state: TaskState,
44    /// Payload supplied at submit time (also used as the result on CPU path).
45    data: Vec<u8>,
46}
47
48// ── AsyncComputeQueue ─────────────────────────────────────────────────────────
49
50/// Lightweight async compute task queue.
51///
52/// In the CPU-stub backend, tasks complete synchronously; the API is
53/// designed to be drop-in replaceable with an actual GPU async queue once
54/// a WGPU device is available.
55#[derive(Debug, Default)]
56pub struct AsyncComputeQueue {
57    /// Active tasks, keyed by caller-defined `task_id`.
58    tasks: HashMap<u64, TaskRecord>,
59    /// Monotonically increasing submission counter.
60    pub submission_count: u64,
61    /// Number of tasks that have been polled and returned a result.
62    pub completed_count: u64,
63}
64
65impl AsyncComputeQueue {
66    /// Create a new, empty async compute queue.
67    #[must_use]
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Submit a compute task.
73    ///
74    /// * `task_id`  – Caller-defined identifier for this task.
75    /// * `data`     – Input payload (or pre-computed output on CPU path).
76    ///
77    /// If a task with the same `task_id` already exists it is replaced.
78    pub fn submit(&mut self, task_id: u64, data: Vec<u8>) {
79        self.submission_count += 1;
80        // On the CPU-stub path, execution is synchronous → mark as Complete
81        // immediately so `poll()` can return the result on the next call.
82        self.tasks.insert(
83            task_id,
84            TaskRecord {
85                state: TaskState::Complete,
86                data,
87            },
88        );
89    }
90
91    /// Poll for the result of a previously submitted task.
92    ///
93    /// Returns `Some(result)` if the task has completed, consuming the
94    /// result from the queue (subsequent polls for the same `task_id`
95    /// return `None`).  Returns `None` if the task is still pending/running
96    /// or has already been consumed.
97    pub fn poll(&mut self, task_id: u64) -> Option<Vec<u8>> {
98        if let Some(record) = self.tasks.get(&task_id) {
99            if record.state == TaskState::Complete {
100                // Remove and return.
101                let record = self.tasks.remove(&task_id)?;
102                self.completed_count += 1;
103                return Some(record.data);
104            }
105        }
106        None
107    }
108
109    /// Query the current state of a task without consuming the result.
110    ///
111    /// Returns `None` if no task with that `task_id` exists (either never
112    /// submitted or already consumed by [`Self::poll`]).
113    #[must_use]
114    pub fn state(&self, task_id: u64) -> Option<&TaskState> {
115        self.tasks.get(&task_id).map(|r| &r.state)
116    }
117
118    /// Cancel a pending or running task.
119    ///
120    /// Returns `true` if the task was found and removed.
121    pub fn cancel(&mut self, task_id: u64) -> bool {
122        self.tasks.remove(&task_id).is_some()
123    }
124
125    /// Number of tasks currently tracked (pending, running, or complete).
126    #[must_use]
127    pub fn active_count(&self) -> usize {
128        self.tasks.len()
129    }
130
131    /// `true` if no tasks are currently tracked.
132    #[must_use]
133    pub fn is_empty(&self) -> bool {
134        self.tasks.is_empty()
135    }
136
137    /// Mark a task as failed with an error message (useful for testing error
138    /// paths).
139    pub fn fail_task(&mut self, task_id: u64, error: String) {
140        if let Some(record) = self.tasks.get_mut(&task_id) {
141            record.state = TaskState::Failed(error);
142        }
143    }
144
145    /// Returns `true` if the task with `task_id` has failed.
146    #[must_use]
147    pub fn is_failed(&self, task_id: u64) -> bool {
148        matches!(
149            self.tasks.get(&task_id).map(|r| &r.state),
150            Some(TaskState::Failed(_))
151        )
152    }
153}
154
155// ── Tests ─────────────────────────────────────────────────────────────────────
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_submit_and_poll_returns_data() {
163        let mut q = AsyncComputeQueue::new();
164        q.submit(1, vec![10, 20, 30]);
165        let result = q.poll(1);
166        assert_eq!(result, Some(vec![10, 20, 30]));
167    }
168
169    #[test]
170    fn test_poll_twice_returns_none_second_time() {
171        let mut q = AsyncComputeQueue::new();
172        q.submit(42, vec![1]);
173        assert!(q.poll(42).is_some());
174        assert!(q.poll(42).is_none());
175    }
176
177    #[test]
178    fn test_poll_unknown_task_returns_none() {
179        let mut q = AsyncComputeQueue::new();
180        assert!(q.poll(99).is_none());
181    }
182
183    #[test]
184    fn test_multiple_tasks_independent() {
185        let mut q = AsyncComputeQueue::new();
186        q.submit(1, vec![0xAA]);
187        q.submit(2, vec![0xBB]);
188        assert_eq!(q.poll(2), Some(vec![0xBB]));
189        assert_eq!(q.poll(1), Some(vec![0xAA]));
190    }
191
192    #[test]
193    fn test_cancel_removes_task() {
194        let mut q = AsyncComputeQueue::new();
195        q.submit(7, vec![0xFF]);
196        assert!(q.cancel(7));
197        assert!(q.poll(7).is_none());
198    }
199
200    #[test]
201    fn test_submission_count_increments() {
202        let mut q = AsyncComputeQueue::new();
203        q.submit(1, vec![]);
204        q.submit(2, vec![]);
205        assert_eq!(q.submission_count, 2);
206    }
207
208    #[test]
209    fn test_completed_count_increments_on_poll() {
210        let mut q = AsyncComputeQueue::new();
211        q.submit(1, vec![1]);
212        q.poll(1);
213        assert_eq!(q.completed_count, 1);
214    }
215
216    #[test]
217    fn test_state_complete_after_submit() {
218        let q = {
219            let mut q = AsyncComputeQueue::new();
220            q.submit(5, vec![5]);
221            q
222        };
223        assert_eq!(q.state(5), Some(&TaskState::Complete));
224    }
225
226    #[test]
227    fn test_active_count_decreases_on_poll() {
228        let mut q = AsyncComputeQueue::new();
229        q.submit(1, vec![]);
230        q.submit(2, vec![]);
231        assert_eq!(q.active_count(), 2);
232        q.poll(1);
233        assert_eq!(q.active_count(), 1);
234    }
235
236    #[test]
237    fn test_is_empty_after_all_polled() {
238        let mut q = AsyncComputeQueue::new();
239        q.submit(1, vec![1]);
240        q.poll(1);
241        assert!(q.is_empty());
242    }
243
244    #[test]
245    fn test_fail_task_marks_failed() {
246        let mut q = AsyncComputeQueue::new();
247        q.submit(3, vec![]);
248        q.fail_task(3, "shader compile error".into());
249        assert!(q.is_failed(3));
250    }
251
252    #[test]
253    fn test_resubmit_replaces_previous() {
254        let mut q = AsyncComputeQueue::new();
255        q.submit(1, vec![0x01]);
256        q.submit(1, vec![0x02]); // replace
257        assert_eq!(q.poll(1), Some(vec![0x02]));
258    }
259
260    #[test]
261    fn test_empty_payload_allowed() {
262        let mut q = AsyncComputeQueue::new();
263        q.submit(0, vec![]);
264        assert_eq!(q.poll(0), Some(vec![]));
265    }
266}