Skip to main content

svod_device/
sync.rs

1//! Timeline synchronization primitives for parallel execution.
2//!
3//! This module provides device-agnostic synchronization using timeline signals,
4//! which are monotonically increasing counters that enable ordering of operations
5//! across devices.
6//!
7//! # Design
8//!
9//! Timeline signals abstract over:
10//! - CPU: `AtomicU64` with parking_lot condvar for waiting
11//! - CUDA: Event pools keyed by timeline value
12//! - Metal: `MTLSharedEvent` (future)
13//! - HIP: Similar to CUDA (future)
14//!
15//! # Example
16//!
17//! ```ignore
18//! let signal = CpuTimelineSignal::new();
19//!
20//! // Producer thread
21//! signal.set(1);  // Signal completion of operation 1
22//!
23//! // Consumer thread
24//! signal.wait(1, 1000)?;  // Wait for operation 1 to complete
25//! ```
26
27use std::any::Any;
28use std::sync::Arc;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::time::{Duration, Instant};
31
32use parking_lot::{Condvar, Mutex};
33
34use crate::error::{Result, RuntimeSnafu};
35use snafu::ensure;
36
37/// Monotonic timeline signal for synchronization.
38///
39/// Timeline signals provide a way to order operations across different execution
40/// contexts (threads, devices, queues). The signal value only increases, and
41/// waiters block until the signal reaches or exceeds the target value.
42///
43/// # Thread Safety
44///
45/// All implementations must be `Send + Sync` for cross-thread use.
46pub trait TimelineSignal: Send + Sync + std::fmt::Debug + Any {
47    /// Return this signal as `Any` for checked type-erased queue dispatch.
48    fn as_any(&self) -> &dyn Any;
49
50    /// Get the current signal value.
51    fn value(&self) -> u64;
52
53    /// Set the signal to a new value.
54    ///
55    /// # Panics
56    ///
57    /// May panic if `value` is less than the current value (implementation-defined).
58    fn set(&self, value: u64);
59
60    /// Wait for the signal to reach or exceed `value`.
61    ///
62    /// # Arguments
63    ///
64    /// * `value` - The target value to wait for
65    /// * `timeout_ms` - Maximum time to wait in milliseconds (0 = infinite)
66    ///
67    /// # Returns
68    ///
69    /// `Ok(())` if the signal reached the target value, or `Err` on timeout.
70    fn wait(&self, value: u64, timeout_ms: u64) -> Result<()>;
71
72    /// Check if the signal has reached `value` without blocking.
73    fn is_reached(&self, value: u64) -> bool {
74        self.value() >= value
75    }
76}
77
78/// CPU-based timeline signal using atomics and condvar.
79///
80/// Efficient for CPU-only workloads. Uses `AtomicU64` for the counter and
81/// `parking_lot::Condvar` for efficient waiting.
82#[derive(Debug, Clone)]
83pub struct CpuTimelineSignal {
84    inner: Arc<CpuTimelineSignalInner>,
85}
86
87#[derive(Debug)]
88struct CpuTimelineSignalInner {
89    /// Current timeline value (monotonically increasing).
90    value: AtomicU64,
91    /// Mutex for condvar waiting (protects nothing, just for condvar).
92    mutex: Mutex<()>,
93    /// Condvar for waiting threads.
94    condvar: Condvar,
95}
96
97impl Default for CpuTimelineSignal {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl CpuTimelineSignal {
104    /// Create a new CPU timeline signal starting at 0.
105    pub fn new() -> Self {
106        Self {
107            inner: Arc::new(CpuTimelineSignalInner {
108                value: AtomicU64::new(0),
109                mutex: Mutex::new(()),
110                condvar: Condvar::new(),
111            }),
112        }
113    }
114
115    /// Create a new CPU timeline signal with an initial value.
116    pub fn with_initial(initial: u64) -> Self {
117        Self {
118            inner: Arc::new(CpuTimelineSignalInner {
119                value: AtomicU64::new(initial),
120                mutex: Mutex::new(()),
121                condvar: Condvar::new(),
122            }),
123        }
124    }
125}
126
127impl TimelineSignal for CpuTimelineSignal {
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn value(&self) -> u64 {
133        self.inner.value.load(Ordering::Acquire)
134    }
135
136    fn set(&self, value: u64) {
137        let previous = self.inner.value.fetch_max(value, Ordering::AcqRel);
138        if value > previous {
139            self.inner.condvar.notify_all();
140        }
141    }
142
143    fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
144        // Fast path: already reached
145        if self.inner.value.load(Ordering::Acquire) >= target {
146            return Ok(());
147        }
148
149        let mut guard = self.inner.mutex.lock();
150
151        if timeout_ms == 0 {
152            // Infinite wait
153            while self.inner.value.load(Ordering::Acquire) < target {
154                self.inner.condvar.wait(&mut guard);
155            }
156            Ok(())
157        } else {
158            // Timed wait
159            let deadline = Instant::now() + Duration::from_millis(timeout_ms);
160
161            while self.inner.value.load(Ordering::Acquire) < target {
162                let remaining = deadline.saturating_duration_since(Instant::now());
163                if remaining.is_zero() {
164                    ensure!(
165                        self.inner.value.load(Ordering::Acquire) >= target,
166                        RuntimeSnafu {
167                            message: format!(
168                                "timeline signal timeout: waited {}ms for value {}, current {}",
169                                timeout_ms,
170                                target,
171                                self.inner.value.load(Ordering::Acquire)
172                            )
173                        }
174                    );
175                    return Ok(());
176                }
177
178                let result = self.inner.condvar.wait_for(&mut guard, remaining);
179                if result.timed_out() && self.inner.value.load(Ordering::Acquire) < target {
180                    return RuntimeSnafu {
181                        message: format!(
182                            "timeline signal timeout: waited {}ms for value {}, current {}",
183                            timeout_ms,
184                            target,
185                            self.inner.value.load(Ordering::Acquire)
186                        ),
187                    }
188                    .fail();
189                }
190            }
191            Ok(())
192        }
193    }
194}
195
196#[cfg(feature = "cuda")]
197pub mod cuda {
198    //! CUDA-specific timeline signal using event pools.
199
200    use std::any::Any;
201    use std::sync::Arc;
202    use std::sync::atomic::{AtomicU64, Ordering};
203
204    use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
205    use parking_lot::Mutex;
206
207    use super::TimelineSignal;
208    use crate::error::{CudaSnafu, Result};
209    use snafu::ResultExt;
210
211    /// Number of event slots in the ring. 64 is well above typical in-flight
212    /// depth and bounds the worst-case event memory.
213    const EVENT_RING_SIZE: usize = 64;
214
215    /// One occupied slot in the event ring.
216    #[derive(Debug)]
217    struct EventSlot {
218        timeline_value: u64,
219        event: Arc<CudaEvent>,
220    }
221
222    /// CUDA timeline signal using a fixed-size ring of event slots.
223    ///
224    /// Each `record(n)` lands in the next slot of the ring. Waiters look up
225    /// the smallest slot whose `timeline_value >= target` and synchronise on
226    /// its event. When the ring wraps and a slot is overwritten, the previous
227    /// `Arc<CudaEvent>` is released only after every waiter that fetched it
228    /// drops their clone (Arc lifetime semantics) — slots are never torn out
229    /// from under outstanding waiters.
230    #[derive(Debug)]
231    pub struct CudaTimelineSignal {
232        /// Current timeline value.
233        value: AtomicU64,
234        /// Ring of recorded (timeline_value, event) slots and a next-write cursor.
235        ring: Mutex<EventRing>,
236        /// CUDA context for creating events.
237        context: Arc<CudaContext>,
238        /// Stream for recording events.
239        stream: Arc<CudaStream>,
240    }
241
242    #[derive(Debug)]
243    struct EventRing {
244        slots: [Option<EventSlot>; EVENT_RING_SIZE],
245        next: usize,
246    }
247
248    impl EventRing {
249        fn new() -> Self {
250            Self { slots: std::array::from_fn(|_| None), next: 0 }
251        }
252    }
253
254    impl CudaTimelineSignal {
255        /// Create a new CUDA timeline signal.
256        pub fn new(context: Arc<CudaContext>, stream: Arc<CudaStream>) -> Self {
257            Self { value: AtomicU64::new(0), ring: Mutex::new(EventRing::new()), context, stream }
258        }
259
260        /// Record an event at the given timeline value.
261        ///
262        /// Called after submitting work to the stream. The new (value, event)
263        /// pair occupies the next ring slot, overwriting whatever was there.
264        /// Overwriting is safe: any waiter that looked up that slot already
265        /// holds an `Arc<CudaEvent>` clone, which keeps the event alive past
266        /// the slot's overwrite — i.e. no event is dropped while a waiter
267        /// still references it.
268        pub fn record(&self, value: u64) -> Result<()> {
269            let event = self.context.create_event(None).context(CudaSnafu)?;
270            self.stream.record(&event).context(CudaSnafu)?;
271
272            let mut ring = self.ring.lock();
273            let slot_idx = ring.next;
274            ring.slots[slot_idx] = Some(EventSlot { timeline_value: value, event: Arc::new(event) });
275            ring.next = (slot_idx + 1) % EVENT_RING_SIZE;
276            drop(ring);
277
278            // Update the timeline value. AcqRel keeps load-half non-Relaxed so concurrent
279            // record/set observe each other's monotonic updates.
280            self.value.fetch_max(value, Ordering::AcqRel);
281
282            Ok(())
283        }
284
285        /// Get the stream for this signal.
286        pub fn stream(&self) -> &Arc<CudaStream> {
287            &self.stream
288        }
289    }
290
291    impl TimelineSignal for CudaTimelineSignal {
292        fn as_any(&self) -> &dyn Any {
293            self
294        }
295
296        fn value(&self) -> u64 {
297            self.value.load(Ordering::Acquire)
298        }
299
300        fn set(&self, value: u64) {
301            // For CUDA, set() should be called via record() after stream work.
302            // This is a fallback that just updates the counter.
303            self.value.fetch_max(value, Ordering::AcqRel);
304        }
305
306        fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
307            // Fast path: already reached
308            if self.value.load(Ordering::Acquire) >= target {
309                return Ok(());
310            }
311
312            // Find the smallest event in the ring with timeline_value >= target.
313            // Cloning the Arc keeps the event alive even if the slot is later
314            // overwritten by a recycling record() — no torn waiters.
315            let event = {
316                let ring = self.ring.lock();
317                ring.slots
318                    .iter()
319                    .filter_map(|slot| slot.as_ref().filter(|s| s.timeline_value >= target))
320                    .min_by_key(|s| s.timeline_value)
321                    .map(|s| Arc::clone(&s.event))
322            };
323
324            if let Some(event) = event {
325                if timeout_ms == 0 {
326                    // Synchronous wait
327                    event.synchronize().context(CudaSnafu)?;
328                } else {
329                    // Polling wait with timeout
330                    let start = std::time::Instant::now();
331                    let timeout = std::time::Duration::from_millis(timeout_ms);
332
333                    while !event.is_ready() {
334                        if start.elapsed() > timeout {
335                            return crate::error::RuntimeSnafu {
336                                message: format!(
337                                    "CUDA timeline signal timeout: waited {}ms for value {}",
338                                    timeout_ms, target
339                                ),
340                            }
341                            .fail();
342                        }
343                        std::thread::sleep(std::time::Duration::from_micros(100));
344                    }
345                }
346            } else {
347                // No event matched. This can happen when (a) no record() has
348                // landed for `target` yet, or (b) we lost a race against
349                // sufficient record()s to wrap the ring and overwrite the
350                // slot satisfying `target` between the fast-path load above
351                // and this lookup. Re-check the timeline counter before
352                // entering the spin so a race-loser exits immediately rather
353                // than busy-yielding for an already-completed target.
354                if self.value.load(Ordering::Acquire) >= target {
355                    return Ok(());
356                }
357
358                let start = std::time::Instant::now();
359                let timeout = if timeout_ms == 0 {
360                    std::time::Duration::MAX
361                } else {
362                    std::time::Duration::from_millis(timeout_ms)
363                };
364
365                while self.value.load(Ordering::Acquire) < target {
366                    if start.elapsed() > timeout {
367                        return crate::error::RuntimeSnafu {
368                            message: format!(
369                                "CUDA timeline signal timeout: waited {}ms for value {}, current {}",
370                                timeout_ms,
371                                target,
372                                self.value.load(Ordering::Acquire)
373                            ),
374                        }
375                        .fail();
376                    }
377                    std::thread::yield_now();
378                }
379            }
380
381            Ok(())
382        }
383    }
384}
385
386#[cfg(test)]
387#[path = "test/unit/sync.rs"]
388mod tests;