svod_device/sync.rs
1//! Timeline synchronization primitives for parallel execution.
2//!
3//! This module provides device-agnostic synchronization using timeline signals,
4//! which are monotonically increasing counters that enable ordering of operations
5//! across devices.
6//!
7//! # Design
8//!
9//! Timeline signals abstract over:
10//! - CPU: `AtomicU64` with parking_lot condvar for waiting
11//! - CUDA: Event pools keyed by timeline value
12//! - Metal: `MTLSharedEvent` (future)
13//! - HIP: Similar to CUDA (future)
14//!
15//! # Example
16//!
17//! ```ignore
18//! let signal = CpuTimelineSignal::new();
19//!
20//! // Producer thread
21//! signal.set(1); // Signal completion of operation 1
22//!
23//! // Consumer thread
24//! signal.wait(1, 1000)?; // Wait for operation 1 to complete
25//! ```
26
27use std::any::Any;
28use std::sync::Arc;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::time::{Duration, Instant};
31
32use parking_lot::{Condvar, Mutex};
33
34use crate::error::{Result, RuntimeSnafu};
35use snafu::ensure;
36
37/// Monotonic timeline signal for synchronization.
38///
39/// Timeline signals provide a way to order operations across different execution
40/// contexts (threads, devices, queues). The signal value only increases, and
41/// waiters block until the signal reaches or exceeds the target value.
42///
43/// # Thread Safety
44///
45/// All implementations must be `Send + Sync` for cross-thread use.
46pub trait TimelineSignal: Send + Sync + std::fmt::Debug + Any {
47 /// Return this signal as `Any` for checked type-erased queue dispatch.
48 fn as_any(&self) -> &dyn Any;
49
50 /// Get the current signal value.
51 fn value(&self) -> u64;
52
53 /// Set the signal to a new value.
54 ///
55 /// # Panics
56 ///
57 /// May panic if `value` is less than the current value (implementation-defined).
58 fn set(&self, value: u64);
59
60 /// Wait for the signal to reach or exceed `value`.
61 ///
62 /// # Arguments
63 ///
64 /// * `value` - The target value to wait for
65 /// * `timeout_ms` - Maximum time to wait in milliseconds (0 = infinite)
66 ///
67 /// # Returns
68 ///
69 /// `Ok(())` if the signal reached the target value, or `Err` on timeout.
70 fn wait(&self, value: u64, timeout_ms: u64) -> Result<()>;
71
72 /// Check if the signal has reached `value` without blocking.
73 fn is_reached(&self, value: u64) -> bool {
74 self.value() >= value
75 }
76}
77
78/// CPU-based timeline signal using atomics and condvar.
79///
80/// Efficient for CPU-only workloads. Uses `AtomicU64` for the counter and
81/// `parking_lot::Condvar` for efficient waiting.
82#[derive(Debug, Clone)]
83pub struct CpuTimelineSignal {
84 inner: Arc<CpuTimelineSignalInner>,
85}
86
87#[derive(Debug)]
88struct CpuTimelineSignalInner {
89 /// Current timeline value (monotonically increasing).
90 value: AtomicU64,
91 /// Mutex for condvar waiting (protects nothing, just for condvar).
92 mutex: Mutex<()>,
93 /// Condvar for waiting threads.
94 condvar: Condvar,
95}
96
97impl Default for CpuTimelineSignal {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl CpuTimelineSignal {
104 /// Create a new CPU timeline signal starting at 0.
105 pub fn new() -> Self {
106 Self {
107 inner: Arc::new(CpuTimelineSignalInner {
108 value: AtomicU64::new(0),
109 mutex: Mutex::new(()),
110 condvar: Condvar::new(),
111 }),
112 }
113 }
114
115 /// Create a new CPU timeline signal with an initial value.
116 pub fn with_initial(initial: u64) -> Self {
117 Self {
118 inner: Arc::new(CpuTimelineSignalInner {
119 value: AtomicU64::new(initial),
120 mutex: Mutex::new(()),
121 condvar: Condvar::new(),
122 }),
123 }
124 }
125}
126
127impl TimelineSignal for CpuTimelineSignal {
128 fn as_any(&self) -> &dyn Any {
129 self
130 }
131
132 fn value(&self) -> u64 {
133 self.inner.value.load(Ordering::Acquire)
134 }
135
136 fn set(&self, value: u64) {
137 let previous = self.inner.value.fetch_max(value, Ordering::AcqRel);
138 if value > previous {
139 self.inner.condvar.notify_all();
140 }
141 }
142
143 fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
144 // Fast path: already reached
145 if self.inner.value.load(Ordering::Acquire) >= target {
146 return Ok(());
147 }
148
149 let mut guard = self.inner.mutex.lock();
150
151 if timeout_ms == 0 {
152 // Infinite wait
153 while self.inner.value.load(Ordering::Acquire) < target {
154 self.inner.condvar.wait(&mut guard);
155 }
156 Ok(())
157 } else {
158 // Timed wait
159 let deadline = Instant::now() + Duration::from_millis(timeout_ms);
160
161 while self.inner.value.load(Ordering::Acquire) < target {
162 let remaining = deadline.saturating_duration_since(Instant::now());
163 if remaining.is_zero() {
164 ensure!(
165 self.inner.value.load(Ordering::Acquire) >= target,
166 RuntimeSnafu {
167 message: format!(
168 "timeline signal timeout: waited {}ms for value {}, current {}",
169 timeout_ms,
170 target,
171 self.inner.value.load(Ordering::Acquire)
172 )
173 }
174 );
175 return Ok(());
176 }
177
178 let result = self.inner.condvar.wait_for(&mut guard, remaining);
179 if result.timed_out() && self.inner.value.load(Ordering::Acquire) < target {
180 return RuntimeSnafu {
181 message: format!(
182 "timeline signal timeout: waited {}ms for value {}, current {}",
183 timeout_ms,
184 target,
185 self.inner.value.load(Ordering::Acquire)
186 ),
187 }
188 .fail();
189 }
190 }
191 Ok(())
192 }
193 }
194}
195
196#[cfg(feature = "cuda")]
197pub mod cuda {
198 //! CUDA-specific timeline signal using event pools.
199
200 use std::any::Any;
201 use std::sync::Arc;
202 use std::sync::atomic::{AtomicU64, Ordering};
203
204 use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
205 use parking_lot::Mutex;
206
207 use super::TimelineSignal;
208 use crate::error::{CudaSnafu, Result};
209 use snafu::ResultExt;
210
211 /// Number of event slots in the ring. 64 is well above typical in-flight
212 /// depth and bounds the worst-case event memory.
213 const EVENT_RING_SIZE: usize = 64;
214
215 /// One occupied slot in the event ring.
216 #[derive(Debug)]
217 struct EventSlot {
218 timeline_value: u64,
219 event: Arc<CudaEvent>,
220 }
221
222 /// CUDA timeline signal using a fixed-size ring of event slots.
223 ///
224 /// Each `record(n)` lands in the next slot of the ring. Waiters look up
225 /// the smallest slot whose `timeline_value >= target` and synchronise on
226 /// its event. When the ring wraps and a slot is overwritten, the previous
227 /// `Arc<CudaEvent>` is released only after every waiter that fetched it
228 /// drops their clone (Arc lifetime semantics) — slots are never torn out
229 /// from under outstanding waiters.
230 #[derive(Debug)]
231 pub struct CudaTimelineSignal {
232 /// Current timeline value.
233 value: AtomicU64,
234 /// Ring of recorded (timeline_value, event) slots and a next-write cursor.
235 ring: Mutex<EventRing>,
236 /// CUDA context for creating events.
237 context: Arc<CudaContext>,
238 /// Stream for recording events.
239 stream: Arc<CudaStream>,
240 }
241
242 #[derive(Debug)]
243 struct EventRing {
244 slots: [Option<EventSlot>; EVENT_RING_SIZE],
245 next: usize,
246 }
247
248 impl EventRing {
249 fn new() -> Self {
250 Self { slots: std::array::from_fn(|_| None), next: 0 }
251 }
252 }
253
254 impl CudaTimelineSignal {
255 /// Create a new CUDA timeline signal.
256 pub fn new(context: Arc<CudaContext>, stream: Arc<CudaStream>) -> Self {
257 Self { value: AtomicU64::new(0), ring: Mutex::new(EventRing::new()), context, stream }
258 }
259
260 /// Record an event at the given timeline value.
261 ///
262 /// Called after submitting work to the stream. The new (value, event)
263 /// pair occupies the next ring slot, overwriting whatever was there.
264 /// Overwriting is safe: any waiter that looked up that slot already
265 /// holds an `Arc<CudaEvent>` clone, which keeps the event alive past
266 /// the slot's overwrite — i.e. no event is dropped while a waiter
267 /// still references it.
268 pub fn record(&self, value: u64) -> Result<()> {
269 let event = self.context.create_event(None).context(CudaSnafu)?;
270 self.stream.record(&event).context(CudaSnafu)?;
271
272 let mut ring = self.ring.lock();
273 let slot_idx = ring.next;
274 ring.slots[slot_idx] = Some(EventSlot { timeline_value: value, event: Arc::new(event) });
275 ring.next = (slot_idx + 1) % EVENT_RING_SIZE;
276 drop(ring);
277
278 // Update the timeline value. AcqRel keeps load-half non-Relaxed so concurrent
279 // record/set observe each other's monotonic updates.
280 self.value.fetch_max(value, Ordering::AcqRel);
281
282 Ok(())
283 }
284
285 /// Get the stream for this signal.
286 pub fn stream(&self) -> &Arc<CudaStream> {
287 &self.stream
288 }
289 }
290
291 impl TimelineSignal for CudaTimelineSignal {
292 fn as_any(&self) -> &dyn Any {
293 self
294 }
295
296 fn value(&self) -> u64 {
297 self.value.load(Ordering::Acquire)
298 }
299
300 fn set(&self, value: u64) {
301 // For CUDA, set() should be called via record() after stream work.
302 // This is a fallback that just updates the counter.
303 self.value.fetch_max(value, Ordering::AcqRel);
304 }
305
306 fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
307 // Fast path: already reached
308 if self.value.load(Ordering::Acquire) >= target {
309 return Ok(());
310 }
311
312 // Find the smallest event in the ring with timeline_value >= target.
313 // Cloning the Arc keeps the event alive even if the slot is later
314 // overwritten by a recycling record() — no torn waiters.
315 let event = {
316 let ring = self.ring.lock();
317 ring.slots
318 .iter()
319 .filter_map(|slot| slot.as_ref().filter(|s| s.timeline_value >= target))
320 .min_by_key(|s| s.timeline_value)
321 .map(|s| Arc::clone(&s.event))
322 };
323
324 if let Some(event) = event {
325 if timeout_ms == 0 {
326 // Synchronous wait
327 event.synchronize().context(CudaSnafu)?;
328 } else {
329 // Polling wait with timeout
330 let start = std::time::Instant::now();
331 let timeout = std::time::Duration::from_millis(timeout_ms);
332
333 while !event.is_ready() {
334 if start.elapsed() > timeout {
335 return crate::error::RuntimeSnafu {
336 message: format!(
337 "CUDA timeline signal timeout: waited {}ms for value {}",
338 timeout_ms, target
339 ),
340 }
341 .fail();
342 }
343 std::thread::sleep(std::time::Duration::from_micros(100));
344 }
345 }
346 } else {
347 // No event matched. This can happen when (a) no record() has
348 // landed for `target` yet, or (b) we lost a race against
349 // sufficient record()s to wrap the ring and overwrite the
350 // slot satisfying `target` between the fast-path load above
351 // and this lookup. Re-check the timeline counter before
352 // entering the spin so a race-loser exits immediately rather
353 // than busy-yielding for an already-completed target.
354 if self.value.load(Ordering::Acquire) >= target {
355 return Ok(());
356 }
357
358 let start = std::time::Instant::now();
359 let timeout = if timeout_ms == 0 {
360 std::time::Duration::MAX
361 } else {
362 std::time::Duration::from_millis(timeout_ms)
363 };
364
365 while self.value.load(Ordering::Acquire) < target {
366 if start.elapsed() > timeout {
367 return crate::error::RuntimeSnafu {
368 message: format!(
369 "CUDA timeline signal timeout: waited {}ms for value {}, current {}",
370 timeout_ms,
371 target,
372 self.value.load(Ordering::Acquire)
373 ),
374 }
375 .fail();
376 }
377 std::thread::yield_now();
378 }
379 }
380
381 Ok(())
382 }
383 }
384}
385
386#[cfg(test)]
387#[path = "test/unit/sync.rs"]
388mod tests;