cuda_rust_wasm/runtime/
stream.rs1use crate::Result;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use super::Device;
11
12pub struct Stream {
19 device: Arc<Device>,
20 pending: AtomicU64,
22 total_ops: AtomicU64,
24}
25
26impl Stream {
27 pub fn new(device: Arc<Device>) -> Result<Self> {
29 Ok(Self {
30 device,
31 pending: AtomicU64::new(0),
32 total_ops: AtomicU64::new(0),
33 })
34 }
35
36 pub fn device(&self) -> Arc<Device> {
38 self.device.clone()
39 }
40
41 pub fn record_submit(&self) {
43 self.pending.fetch_add(1, Ordering::SeqCst);
44 self.total_ops.fetch_add(1, Ordering::SeqCst);
45 }
46
47 pub fn record_complete(&self) {
49 self.pending.fetch_sub(1, Ordering::SeqCst);
50 }
51
52 pub fn synchronize(&self) -> Result<()> {
57 let mut spins = 0u32;
60 while self.pending.load(Ordering::SeqCst) > 0 {
61 std::thread::yield_now();
62 spins += 1;
63 if spins > 10_000 {
64 return Err(crate::runtime_error!(
65 "Stream synchronize timed out with {} pending operations",
66 self.pending.load(Ordering::SeqCst)
67 ));
68 }
69 }
70 Ok(())
71 }
72
73 pub fn is_complete(&self) -> bool {
75 self.pending.load(Ordering::SeqCst) == 0
76 }
77
78 pub fn pending_ops(&self) -> u64 {
80 self.pending.load(Ordering::SeqCst)
81 }
82
83 pub fn total_ops(&self) -> u64 {
85 self.total_ops.load(Ordering::SeqCst)
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::runtime::Device;
93
94 #[test]
95 fn test_stream_creation() {
96 let device = Device::get_default().unwrap();
97 let stream = Stream::new(device).unwrap();
98 assert!(stream.is_complete());
99 assert_eq!(stream.pending_ops(), 0);
100 }
101
102 #[test]
103 fn test_stream_operation_tracking() {
104 let device = Device::get_default().unwrap();
105 let stream = Stream::new(device).unwrap();
106
107 stream.record_submit();
108 assert!(!stream.is_complete());
109 assert_eq!(stream.pending_ops(), 1);
110
111 stream.record_complete();
112 assert!(stream.is_complete());
113 assert_eq!(stream.total_ops(), 1);
114 }
115
116 #[test]
117 fn test_stream_synchronize() {
118 let device = Device::get_default().unwrap();
119 let stream = Stream::new(device).unwrap();
120 assert!(stream.synchronize().is_ok());
121 }
122}