Skip to main content

cuda_rust_wasm/runtime/
stream.rs

1//! CUDA stream abstraction for asynchronous operations
2//!
3//! Streams provide ordered execution queues. On CPU backends all operations
4//! are synchronous, so "synchronize" and "is_complete" reflect wall-clock
5//! state tracked via an atomic counter.
6
7use crate::Result;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use super::Device;
11
12/// Stream for asynchronous GPU operations.
13///
14/// Tracks a monotonically increasing "pending operations" counter.
15/// Each operation increments the counter when submitted and decrements
16/// when complete. On CPU backends the counter is always zero after
17/// synchronous execution.
18pub struct Stream {
19    device: Arc<Device>,
20    /// Number of in-flight operations.
21    pending: AtomicU64,
22    /// Total operations submitted through this stream.
23    total_ops: AtomicU64,
24}
25
26impl Stream {
27    /// Create a new stream associated with `device`.
28    pub fn new(device: Arc<Device>) -> Result<Self> {
29        Ok(Self {
30            device,
31            pending: AtomicU64::new(0),
32            total_ops: AtomicU64::new(0),
33        })
34    }
35
36    /// Get the device associated with this stream.
37    pub fn device(&self) -> Arc<Device> {
38        self.device.clone()
39    }
40
41    /// Record a submitted operation (increment pending counter).
42    pub fn record_submit(&self) {
43        self.pending.fetch_add(1, Ordering::SeqCst);
44        self.total_ops.fetch_add(1, Ordering::SeqCst);
45    }
46
47    /// Record a completed operation (decrement pending counter).
48    pub fn record_complete(&self) {
49        self.pending.fetch_sub(1, Ordering::SeqCst);
50    }
51
52    /// Synchronize the stream — block until all pending operations complete.
53    ///
54    /// On CPU backends all operations are already synchronous, so this
55    /// simply verifies the pending counter is zero.
56    pub fn synchronize(&self) -> Result<()> {
57        // CPU backend: operations complete inline so counter should be 0.
58        // Spin briefly to handle any race on decrement.
59        let mut spins = 0u32;
60        while self.pending.load(Ordering::SeqCst) > 0 {
61            std::thread::yield_now();
62            spins += 1;
63            if spins > 10_000 {
64                return Err(crate::runtime_error!(
65                    "Stream synchronize timed out with {} pending operations",
66                    self.pending.load(Ordering::SeqCst)
67                ));
68            }
69        }
70        Ok(())
71    }
72
73    /// Check if all stream operations are complete.
74    pub fn is_complete(&self) -> bool {
75        self.pending.load(Ordering::SeqCst) == 0
76    }
77
78    /// Get the number of pending operations.
79    pub fn pending_ops(&self) -> u64 {
80        self.pending.load(Ordering::SeqCst)
81    }
82
83    /// Get the total number of operations submitted.
84    pub fn total_ops(&self) -> u64 {
85        self.total_ops.load(Ordering::SeqCst)
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::runtime::Device;
93
94    #[test]
95    fn test_stream_creation() {
96        let device = Device::get_default().unwrap();
97        let stream = Stream::new(device).unwrap();
98        assert!(stream.is_complete());
99        assert_eq!(stream.pending_ops(), 0);
100    }
101
102    #[test]
103    fn test_stream_operation_tracking() {
104        let device = Device::get_default().unwrap();
105        let stream = Stream::new(device).unwrap();
106
107        stream.record_submit();
108        assert!(!stream.is_complete());
109        assert_eq!(stream.pending_ops(), 1);
110
111        stream.record_complete();
112        assert!(stream.is_complete());
113        assert_eq!(stream.total_ops(), 1);
114    }
115
116    #[test]
117    fn test_stream_synchronize() {
118        let device = Device::get_default().unwrap();
119        let stream = Stream::new(device).unwrap();
120        assert!(stream.synchronize().is_ok());
121    }
122}