Skip to main content

morok_device/
sync.rs

1//! Timeline synchronization primitives for parallel execution.
2//!
3//! This module provides device-agnostic synchronization using timeline signals,
4//! which are monotonically increasing counters that enable ordering of operations
5//! across devices.
6//!
7//! # Design
8//!
9//! Timeline signals abstract over:
10//! - CPU: `AtomicU64` with parking_lot condvar for waiting
11//! - CUDA: Event pools keyed by timeline value
12//! - Metal: `MTLSharedEvent` (future)
13//! - HIP: Similar to CUDA (future)
14//!
15//! # Example
16//!
17//! ```ignore
18//! let signal = CpuTimelineSignal::new();
19//!
20//! // Producer thread
21//! signal.set(1);  // Signal completion of operation 1
22//!
23//! // Consumer thread
24//! signal.wait(1, 1000)?;  // Wait for operation 1 to complete
25//! ```
26
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::time::{Duration, Instant};
29
30use parking_lot::{Condvar, Mutex};
31
32use crate::error::{Result, RuntimeSnafu};
33use snafu::ensure;
34
35/// Monotonic timeline signal for synchronization.
36///
37/// Timeline signals provide a way to order operations across different execution
38/// contexts (threads, devices, queues). The signal value only increases, and
39/// waiters block until the signal reaches or exceeds the target value.
40///
41/// # Thread Safety
42///
43/// All implementations must be `Send + Sync` for cross-thread use.
44pub trait TimelineSignal: Send + Sync + std::fmt::Debug {
45    /// Get the current signal value.
46    fn value(&self) -> u64;
47
48    /// Set the signal to a new value.
49    ///
50    /// # Panics
51    ///
52    /// May panic if `value` is less than the current value (implementation-defined).
53    fn set(&self, value: u64);
54
55    /// Wait for the signal to reach or exceed `value`.
56    ///
57    /// # Arguments
58    ///
59    /// * `value` - The target value to wait for
60    /// * `timeout_ms` - Maximum time to wait in milliseconds (0 = infinite)
61    ///
62    /// # Returns
63    ///
64    /// `Ok(())` if the signal reached the target value, or `Err` on timeout.
65    fn wait(&self, value: u64, timeout_ms: u64) -> Result<()>;
66
67    /// Check if the signal has reached `value` without blocking.
68    fn is_reached(&self, value: u64) -> bool {
69        self.value() >= value
70    }
71}
72
73/// CPU-based timeline signal using atomics and condvar.
74///
75/// Efficient for CPU-only workloads. Uses `AtomicU64` for the counter and
76/// `parking_lot::Condvar` for efficient waiting.
77#[derive(Debug)]
78pub struct CpuTimelineSignal {
79    /// Current timeline value (monotonically increasing).
80    value: AtomicU64,
81    /// Mutex for condvar waiting (protects nothing, just for condvar).
82    mutex: Mutex<()>,
83    /// Condvar for waiting threads.
84    condvar: Condvar,
85}
86
87impl Default for CpuTimelineSignal {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl CpuTimelineSignal {
94    /// Create a new CPU timeline signal starting at 0.
95    pub fn new() -> Self {
96        Self { value: AtomicU64::new(0), mutex: Mutex::new(()), condvar: Condvar::new() }
97    }
98
99    /// Create a new CPU timeline signal with an initial value.
100    pub fn with_initial(initial: u64) -> Self {
101        Self { value: AtomicU64::new(initial), mutex: Mutex::new(()), condvar: Condvar::new() }
102    }
103}
104
105impl TimelineSignal for CpuTimelineSignal {
106    fn value(&self) -> u64 {
107        self.value.load(Ordering::Acquire)
108    }
109
110    fn set(&self, value: u64) {
111        // Store the new value
112        self.value.store(value, Ordering::Release);
113
114        // Wake all waiters to check the new value
115        self.condvar.notify_all();
116    }
117
118    fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
119        // Fast path: already reached
120        if self.value.load(Ordering::Acquire) >= target {
121            return Ok(());
122        }
123
124        let mut guard = self.mutex.lock();
125
126        if timeout_ms == 0 {
127            // Infinite wait
128            while self.value.load(Ordering::Acquire) < target {
129                self.condvar.wait(&mut guard);
130            }
131            Ok(())
132        } else {
133            // Timed wait
134            let deadline = Instant::now() + Duration::from_millis(timeout_ms);
135
136            while self.value.load(Ordering::Acquire) < target {
137                let remaining = deadline.saturating_duration_since(Instant::now());
138                if remaining.is_zero() {
139                    ensure!(
140                        self.value.load(Ordering::Acquire) >= target,
141                        RuntimeSnafu {
142                            message: format!(
143                                "timeline signal timeout: waited {}ms for value {}, current {}",
144                                timeout_ms,
145                                target,
146                                self.value.load(Ordering::Acquire)
147                            )
148                        }
149                    );
150                    return Ok(());
151                }
152
153                let result = self.condvar.wait_for(&mut guard, remaining);
154                if result.timed_out() && self.value.load(Ordering::Acquire) < target {
155                    return RuntimeSnafu {
156                        message: format!(
157                            "timeline signal timeout: waited {}ms for value {}, current {}",
158                            timeout_ms,
159                            target,
160                            self.value.load(Ordering::Acquire)
161                        ),
162                    }
163                    .fail();
164                }
165            }
166            Ok(())
167        }
168    }
169}
170
171#[cfg(feature = "cuda")]
172pub mod cuda {
173    //! CUDA-specific timeline signal using event pools.
174
175    use std::collections::HashMap;
176    use std::sync::Arc;
177    use std::sync::atomic::{AtomicU64, Ordering};
178
179    use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
180    use parking_lot::Mutex;
181
182    use super::TimelineSignal;
183    use crate::error::{CudaSnafu, Result};
184    use snafu::ResultExt;
185
186    /// CUDA timeline signal using event pools.
187    ///
188    /// Each timeline value maps to a CUDA event. When `set(n)` is called,
189    /// we record an event at value `n` on the stream. When `wait(n)` is called,
190    /// we synchronize on that event.
191    #[derive(Debug)]
192    pub struct CudaTimelineSignal {
193        /// Current timeline value.
194        value: AtomicU64,
195        /// Event pool: timeline value → recorded event.
196        events: Mutex<HashMap<u64, Arc<CudaEvent>>>,
197        /// CUDA context for creating events.
198        context: Arc<CudaContext>,
199        /// Stream for recording events.
200        stream: Arc<CudaStream>,
201    }
202
203    impl CudaTimelineSignal {
204        /// Create a new CUDA timeline signal.
205        pub fn new(context: Arc<CudaContext>, stream: Arc<CudaStream>) -> Self {
206            Self { value: AtomicU64::new(0), events: Mutex::new(HashMap::new()), context, stream }
207        }
208
209        /// Record an event at the given timeline value.
210        ///
211        /// This should be called after submitting work to the stream.
212        pub fn record(&self, value: u64) -> Result<()> {
213            let event = self.context.create_event(None).context(CudaSnafu)?;
214            self.stream.record(&event).context(CudaSnafu)?;
215
216            let mut events = self.events.lock();
217            events.insert(value, Arc::new(event));
218
219            // Update the timeline value
220            self.value.fetch_max(value, Ordering::Release);
221
222            // Clean up old events (keep last 16)
223            if events.len() > 32 {
224                let current = self.value.load(Ordering::Acquire);
225                events.retain(|&v, _| v > current.saturating_sub(16));
226            }
227
228            Ok(())
229        }
230
231        /// Get the stream for this signal.
232        pub fn stream(&self) -> &Arc<CudaStream> {
233            &self.stream
234        }
235    }
236
237    impl TimelineSignal for CudaTimelineSignal {
238        fn value(&self) -> u64 {
239            self.value.load(Ordering::Acquire)
240        }
241
242        fn set(&self, value: u64) {
243            // For CUDA, set() should be called via record() after stream work.
244            // This is a fallback that just updates the counter.
245            self.value.fetch_max(value, Ordering::Release);
246        }
247
248        fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
249            // Fast path: already reached
250            if self.value.load(Ordering::Acquire) >= target {
251                return Ok(());
252            }
253
254            // Find the event for this timeline value
255            let event = {
256                let events = self.events.lock();
257                // Find the smallest event >= target
258                events.iter().filter(|(&v, _)| v >= target).min_by_key(|(&v, _)| v).map(|(_, e)| Arc::clone(e))
259            };
260
261            if let Some(event) = event {
262                if timeout_ms == 0 {
263                    // Synchronous wait
264                    event.synchronize().context(CudaSnafu)?;
265                } else {
266                    // Polling wait with timeout
267                    let start = std::time::Instant::now();
268                    let timeout = std::time::Duration::from_millis(timeout_ms);
269
270                    while !event.is_ready() {
271                        if start.elapsed() > timeout {
272                            return crate::error::RuntimeSnafu {
273                                message: format!(
274                                    "CUDA timeline signal timeout: waited {}ms for value {}",
275                                    timeout_ms, target
276                                ),
277                            }
278                            .fail();
279                        }
280                        std::thread::sleep(std::time::Duration::from_micros(100));
281                    }
282                }
283            } else {
284                // No event recorded yet - spin wait
285                let start = std::time::Instant::now();
286                let timeout = if timeout_ms == 0 {
287                    std::time::Duration::MAX
288                } else {
289                    std::time::Duration::from_millis(timeout_ms)
290                };
291
292                while self.value.load(Ordering::Acquire) < target {
293                    if start.elapsed() > timeout {
294                        return crate::error::RuntimeSnafu {
295                            message: format!(
296                                "CUDA timeline signal timeout: waited {}ms for value {}, current {}",
297                                timeout_ms,
298                                target,
299                                self.value.load(Ordering::Acquire)
300                            ),
301                        }
302                        .fail();
303                    }
304                    std::thread::yield_now();
305                }
306            }
307
308            Ok(())
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use std::sync::Arc;
317    use std::thread;
318
319    #[test]
320    fn test_cpu_signal_basic() {
321        let signal = CpuTimelineSignal::new();
322        assert_eq!(signal.value(), 0);
323
324        signal.set(5);
325        assert_eq!(signal.value(), 5);
326
327        assert!(signal.is_reached(5));
328        assert!(signal.is_reached(3));
329        assert!(!signal.is_reached(10));
330    }
331
332    #[test]
333    fn test_cpu_signal_wait_already_reached() {
334        let signal = CpuTimelineSignal::new();
335        signal.set(10);
336
337        // Should return immediately
338        signal.wait(5, 100).unwrap();
339        signal.wait(10, 100).unwrap();
340    }
341
342    #[test]
343    fn test_cpu_signal_wait_concurrent() {
344        let signal = Arc::new(CpuTimelineSignal::new());
345        let signal_clone = Arc::clone(&signal);
346
347        let waiter = thread::spawn(move || {
348            signal_clone.wait(5, 5000).unwrap();
349            signal_clone.value()
350        });
351
352        // Give waiter time to block
353        thread::sleep(std::time::Duration::from_millis(10));
354
355        // Set the signal
356        signal.set(5);
357
358        let result = waiter.join().unwrap();
359        assert!(result >= 5);
360    }
361
362    #[test]
363    fn test_cpu_signal_timeout() {
364        let signal = CpuTimelineSignal::new();
365
366        // Should timeout waiting for value 10
367        let result = signal.wait(10, 50);
368        assert!(result.is_err());
369    }
370}