use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::{Condvar, Mutex};
use crate::error::{Result, RuntimeSnafu};
use snafu::ensure;
pub trait TimelineSignal: Send + Sync + std::fmt::Debug {
fn value(&self) -> u64;
fn set(&self, value: u64);
fn wait(&self, value: u64, timeout_ms: u64) -> Result<()>;
fn is_reached(&self, value: u64) -> bool {
self.value() >= value
}
}
#[derive(Debug)]
pub struct CpuTimelineSignal {
value: AtomicU64,
mutex: Mutex<()>,
condvar: Condvar,
}
impl Default for CpuTimelineSignal {
fn default() -> Self {
Self::new()
}
}
impl CpuTimelineSignal {
pub fn new() -> Self {
Self { value: AtomicU64::new(0), mutex: Mutex::new(()), condvar: Condvar::new() }
}
pub fn with_initial(initial: u64) -> Self {
Self { value: AtomicU64::new(initial), mutex: Mutex::new(()), condvar: Condvar::new() }
}
}
impl TimelineSignal for CpuTimelineSignal {
fn value(&self) -> u64 {
self.value.load(Ordering::Acquire)
}
fn set(&self, value: u64) {
self.value.store(value, Ordering::Release);
self.condvar.notify_all();
}
fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
if self.value.load(Ordering::Acquire) >= target {
return Ok(());
}
let mut guard = self.mutex.lock();
if timeout_ms == 0 {
while self.value.load(Ordering::Acquire) < target {
self.condvar.wait(&mut guard);
}
Ok(())
} else {
let deadline = Instant::now() + Duration::from_millis(timeout_ms);
while self.value.load(Ordering::Acquire) < target {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
ensure!(
self.value.load(Ordering::Acquire) >= target,
RuntimeSnafu {
message: format!(
"timeline signal timeout: waited {}ms for value {}, current {}",
timeout_ms,
target,
self.value.load(Ordering::Acquire)
)
}
);
return Ok(());
}
let result = self.condvar.wait_for(&mut guard, remaining);
if result.timed_out() && self.value.load(Ordering::Acquire) < target {
return RuntimeSnafu {
message: format!(
"timeline signal timeout: waited {}ms for value {}, current {}",
timeout_ms,
target,
self.value.load(Ordering::Acquire)
),
}
.fail();
}
}
Ok(())
}
}
}
#[cfg(feature = "cuda")]
pub mod cuda {
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use parking_lot::Mutex;
use super::TimelineSignal;
use crate::error::{CudaSnafu, Result};
use snafu::ResultExt;
#[derive(Debug)]
pub struct CudaTimelineSignal {
value: AtomicU64,
events: Mutex<HashMap<u64, Arc<CudaEvent>>>,
context: Arc<CudaContext>,
stream: Arc<CudaStream>,
}
impl CudaTimelineSignal {
pub fn new(context: Arc<CudaContext>, stream: Arc<CudaStream>) -> Self {
Self { value: AtomicU64::new(0), events: Mutex::new(HashMap::new()), context, stream }
}
pub fn record(&self, value: u64) -> Result<()> {
let event = self.context.create_event(None).context(CudaSnafu)?;
self.stream.record(&event).context(CudaSnafu)?;
let mut events = self.events.lock();
events.insert(value, Arc::new(event));
self.value.fetch_max(value, Ordering::Release);
if events.len() > 32 {
let current = self.value.load(Ordering::Acquire);
events.retain(|&v, _| v > current.saturating_sub(16));
}
Ok(())
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl TimelineSignal for CudaTimelineSignal {
fn value(&self) -> u64 {
self.value.load(Ordering::Acquire)
}
fn set(&self, value: u64) {
self.value.fetch_max(value, Ordering::Release);
}
fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
if self.value.load(Ordering::Acquire) >= target {
return Ok(());
}
let event = {
let events = self.events.lock();
events.iter().filter(|(&v, _)| v >= target).min_by_key(|(&v, _)| v).map(|(_, e)| Arc::clone(e))
};
if let Some(event) = event {
if timeout_ms == 0 {
event.synchronize().context(CudaSnafu)?;
} else {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while !event.is_ready() {
if start.elapsed() > timeout {
return crate::error::RuntimeSnafu {
message: format!(
"CUDA timeline signal timeout: waited {}ms for value {}",
timeout_ms, target
),
}
.fail();
}
std::thread::sleep(std::time::Duration::from_micros(100));
}
}
} else {
let start = std::time::Instant::now();
let timeout = if timeout_ms == 0 {
std::time::Duration::MAX
} else {
std::time::Duration::from_millis(timeout_ms)
};
while self.value.load(Ordering::Acquire) < target {
if start.elapsed() > timeout {
return crate::error::RuntimeSnafu {
message: format!(
"CUDA timeline signal timeout: waited {}ms for value {}, current {}",
timeout_ms,
target,
self.value.load(Ordering::Acquire)
),
}
.fail();
}
std::thread::yield_now();
}
}
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_cpu_signal_basic() {
let signal = CpuTimelineSignal::new();
assert_eq!(signal.value(), 0);
signal.set(5);
assert_eq!(signal.value(), 5);
assert!(signal.is_reached(5));
assert!(signal.is_reached(3));
assert!(!signal.is_reached(10));
}
#[test]
fn test_cpu_signal_wait_already_reached() {
let signal = CpuTimelineSignal::new();
signal.set(10);
signal.wait(5, 100).unwrap();
signal.wait(10, 100).unwrap();
}
#[test]
fn test_cpu_signal_wait_concurrent() {
let signal = Arc::new(CpuTimelineSignal::new());
let signal_clone = Arc::clone(&signal);
let waiter = thread::spawn(move || {
signal_clone.wait(5, 5000).unwrap();
signal_clone.value()
});
thread::sleep(std::time::Duration::from_millis(10));
signal.set(5);
let result = waiter.join().unwrap();
assert!(result >= 5);
}
#[test]
fn test_cpu_signal_timeout() {
let signal = CpuTimelineSignal::new();
let result = signal.wait(10, 50);
assert!(result.is_err());
}
}