1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//! CUDA stream abstraction for asynchronous operations
//!
//! Streams provide ordered execution queues. On CPU backends all operations
//! are synchronous, so "synchronize" and "is_complete" reflect wall-clock
//! state tracked via an atomic counter.
use crate::Result;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use super::Device;
/// Stream for asynchronous GPU operations.
///
/// Tracks a monotonically increasing "pending operations" counter.
/// Each operation increments the counter when submitted and decrements
/// when complete. On CPU backends the counter is always zero after
/// synchronous execution.
pub struct Stream {
device: Arc<Device>,
/// Number of in-flight operations.
pending: AtomicU64,
/// Total operations submitted through this stream.
total_ops: AtomicU64,
}
impl Stream {
/// Create a new stream associated with `device`.
pub fn new(device: Arc<Device>) -> Result<Self> {
Ok(Self {
device,
pending: AtomicU64::new(0),
total_ops: AtomicU64::new(0),
})
}
/// Get the device associated with this stream.
pub fn device(&self) -> Arc<Device> {
self.device.clone()
}
/// Record a submitted operation (increment pending counter).
pub fn record_submit(&self) {
self.pending.fetch_add(1, Ordering::SeqCst);
self.total_ops.fetch_add(1, Ordering::SeqCst);
}
/// Record a completed operation (decrement pending counter).
pub fn record_complete(&self) {
self.pending.fetch_sub(1, Ordering::SeqCst);
}
/// Synchronize the stream — block until all pending operations complete.
///
/// On CPU backends all operations are already synchronous, so this
/// simply verifies the pending counter is zero.
pub fn synchronize(&self) -> Result<()> {
// CPU backend: operations complete inline so counter should be 0.
// Spin briefly to handle any race on decrement.
let mut spins = 0u32;
while self.pending.load(Ordering::SeqCst) > 0 {
std::thread::yield_now();
spins += 1;
if spins > 10_000 {
return Err(crate::runtime_error!(
"Stream synchronize timed out with {} pending operations",
self.pending.load(Ordering::SeqCst)
));
}
}
Ok(())
}
/// Check if all stream operations are complete.
pub fn is_complete(&self) -> bool {
self.pending.load(Ordering::SeqCst) == 0
}
/// Get the number of pending operations.
pub fn pending_ops(&self) -> u64 {
self.pending.load(Ordering::SeqCst)
}
/// Get the total number of operations submitted.
pub fn total_ops(&self) -> u64 {
self.total_ops.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Device;
#[test]
fn test_stream_creation() {
let device = Device::get_default().unwrap();
let stream = Stream::new(device).unwrap();
assert!(stream.is_complete());
assert_eq!(stream.pending_ops(), 0);
}
#[test]
fn test_stream_operation_tracking() {
let device = Device::get_default().unwrap();
let stream = Stream::new(device).unwrap();
stream.record_submit();
assert!(!stream.is_complete());
assert_eq!(stream.pending_ops(), 1);
stream.record_complete();
assert!(stream.is_complete());
assert_eq!(stream.total_ops(), 1);
}
#[test]
fn test_stream_synchronize() {
let device = Device::get_default().unwrap();
let stream = Stream::new(device).unwrap();
assert!(stream.synchronize().is_ok());
}
}