1use std::sync::atomic::{AtomicU64, Ordering};
28use std::time::{Duration, Instant};
29
30use parking_lot::{Condvar, Mutex};
31
32use crate::error::{Result, RuntimeSnafu};
33use snafu::ensure;
34
35pub trait TimelineSignal: Send + Sync + std::fmt::Debug {
45 fn value(&self) -> u64;
47
48 fn set(&self, value: u64);
54
55 fn wait(&self, value: u64, timeout_ms: u64) -> Result<()>;
66
67 fn is_reached(&self, value: u64) -> bool {
69 self.value() >= value
70 }
71}
72
73#[derive(Debug)]
78pub struct CpuTimelineSignal {
79 value: AtomicU64,
81 mutex: Mutex<()>,
83 condvar: Condvar,
85}
86
87impl Default for CpuTimelineSignal {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl CpuTimelineSignal {
94 pub fn new() -> Self {
96 Self { value: AtomicU64::new(0), mutex: Mutex::new(()), condvar: Condvar::new() }
97 }
98
99 pub fn with_initial(initial: u64) -> Self {
101 Self { value: AtomicU64::new(initial), mutex: Mutex::new(()), condvar: Condvar::new() }
102 }
103}
104
105impl TimelineSignal for CpuTimelineSignal {
106 fn value(&self) -> u64 {
107 self.value.load(Ordering::Acquire)
108 }
109
110 fn set(&self, value: u64) {
111 self.value.store(value, Ordering::Release);
113
114 self.condvar.notify_all();
116 }
117
118 fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
119 if self.value.load(Ordering::Acquire) >= target {
121 return Ok(());
122 }
123
124 let mut guard = self.mutex.lock();
125
126 if timeout_ms == 0 {
127 while self.value.load(Ordering::Acquire) < target {
129 self.condvar.wait(&mut guard);
130 }
131 Ok(())
132 } else {
133 let deadline = Instant::now() + Duration::from_millis(timeout_ms);
135
136 while self.value.load(Ordering::Acquire) < target {
137 let remaining = deadline.saturating_duration_since(Instant::now());
138 if remaining.is_zero() {
139 ensure!(
140 self.value.load(Ordering::Acquire) >= target,
141 RuntimeSnafu {
142 message: format!(
143 "timeline signal timeout: waited {}ms for value {}, current {}",
144 timeout_ms,
145 target,
146 self.value.load(Ordering::Acquire)
147 )
148 }
149 );
150 return Ok(());
151 }
152
153 let result = self.condvar.wait_for(&mut guard, remaining);
154 if result.timed_out() && self.value.load(Ordering::Acquire) < target {
155 return RuntimeSnafu {
156 message: format!(
157 "timeline signal timeout: waited {}ms for value {}, current {}",
158 timeout_ms,
159 target,
160 self.value.load(Ordering::Acquire)
161 ),
162 }
163 .fail();
164 }
165 }
166 Ok(())
167 }
168 }
169}
170
171#[cfg(feature = "cuda")]
172pub mod cuda {
173 use std::collections::HashMap;
176 use std::sync::Arc;
177 use std::sync::atomic::{AtomicU64, Ordering};
178
179 use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
180 use parking_lot::Mutex;
181
182 use super::TimelineSignal;
183 use crate::error::{CudaSnafu, Result};
184 use snafu::ResultExt;
185
186 #[derive(Debug)]
192 pub struct CudaTimelineSignal {
193 value: AtomicU64,
195 events: Mutex<HashMap<u64, Arc<CudaEvent>>>,
197 context: Arc<CudaContext>,
199 stream: Arc<CudaStream>,
201 }
202
203 impl CudaTimelineSignal {
204 pub fn new(context: Arc<CudaContext>, stream: Arc<CudaStream>) -> Self {
206 Self { value: AtomicU64::new(0), events: Mutex::new(HashMap::new()), context, stream }
207 }
208
209 pub fn record(&self, value: u64) -> Result<()> {
213 let event = self.context.create_event(None).context(CudaSnafu)?;
214 self.stream.record(&event).context(CudaSnafu)?;
215
216 let mut events = self.events.lock();
217 events.insert(value, Arc::new(event));
218
219 self.value.fetch_max(value, Ordering::Release);
221
222 if events.len() > 32 {
224 let current = self.value.load(Ordering::Acquire);
225 events.retain(|&v, _| v > current.saturating_sub(16));
226 }
227
228 Ok(())
229 }
230
231 pub fn stream(&self) -> &Arc<CudaStream> {
233 &self.stream
234 }
235 }
236
237 impl TimelineSignal for CudaTimelineSignal {
238 fn value(&self) -> u64 {
239 self.value.load(Ordering::Acquire)
240 }
241
242 fn set(&self, value: u64) {
243 self.value.fetch_max(value, Ordering::Release);
246 }
247
248 fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
249 if self.value.load(Ordering::Acquire) >= target {
251 return Ok(());
252 }
253
254 let event = {
256 let events = self.events.lock();
257 events.iter().filter(|(&v, _)| v >= target).min_by_key(|(&v, _)| v).map(|(_, e)| Arc::clone(e))
259 };
260
261 if let Some(event) = event {
262 if timeout_ms == 0 {
263 event.synchronize().context(CudaSnafu)?;
265 } else {
266 let start = std::time::Instant::now();
268 let timeout = std::time::Duration::from_millis(timeout_ms);
269
270 while !event.is_ready() {
271 if start.elapsed() > timeout {
272 return crate::error::RuntimeSnafu {
273 message: format!(
274 "CUDA timeline signal timeout: waited {}ms for value {}",
275 timeout_ms, target
276 ),
277 }
278 .fail();
279 }
280 std::thread::sleep(std::time::Duration::from_micros(100));
281 }
282 }
283 } else {
284 let start = std::time::Instant::now();
286 let timeout = if timeout_ms == 0 {
287 std::time::Duration::MAX
288 } else {
289 std::time::Duration::from_millis(timeout_ms)
290 };
291
292 while self.value.load(Ordering::Acquire) < target {
293 if start.elapsed() > timeout {
294 return crate::error::RuntimeSnafu {
295 message: format!(
296 "CUDA timeline signal timeout: waited {}ms for value {}, current {}",
297 timeout_ms,
298 target,
299 self.value.load(Ordering::Acquire)
300 ),
301 }
302 .fail();
303 }
304 std::thread::yield_now();
305 }
306 }
307
308 Ok(())
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use std::sync::Arc;
317 use std::thread;
318
319 #[test]
320 fn test_cpu_signal_basic() {
321 let signal = CpuTimelineSignal::new();
322 assert_eq!(signal.value(), 0);
323
324 signal.set(5);
325 assert_eq!(signal.value(), 5);
326
327 assert!(signal.is_reached(5));
328 assert!(signal.is_reached(3));
329 assert!(!signal.is_reached(10));
330 }
331
332 #[test]
333 fn test_cpu_signal_wait_already_reached() {
334 let signal = CpuTimelineSignal::new();
335 signal.set(10);
336
337 signal.wait(5, 100).unwrap();
339 signal.wait(10, 100).unwrap();
340 }
341
342 #[test]
343 fn test_cpu_signal_wait_concurrent() {
344 let signal = Arc::new(CpuTimelineSignal::new());
345 let signal_clone = Arc::clone(&signal);
346
347 let waiter = thread::spawn(move || {
348 signal_clone.wait(5, 5000).unwrap();
349 signal_clone.value()
350 });
351
352 thread::sleep(std::time::Duration::from_millis(10));
354
355 signal.set(5);
357
358 let result = waiter.join().unwrap();
359 assert!(result >= 5);
360 }
361
362 #[test]
363 fn test_cpu_signal_timeout() {
364 let signal = CpuTimelineSignal::new();
365
366 let result = signal.wait(10, 50);
368 assert!(result.is_err());
369 }
370}