use std::any::Any;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::{Condvar, Mutex};
use crate::error::{Result, RuntimeSnafu};
use snafu::ensure;
pub trait TimelineSignal: Send + Sync + std::fmt::Debug + Any {
fn as_any(&self) -> &dyn Any;
fn value(&self) -> u64;
fn set(&self, value: u64);
fn wait(&self, value: u64, timeout_ms: u64) -> Result<()>;
fn is_reached(&self, value: u64) -> bool {
self.value() >= value
}
}
#[derive(Debug, Clone)]
pub struct CpuTimelineSignal {
inner: Arc<CpuTimelineSignalInner>,
}
#[derive(Debug)]
struct CpuTimelineSignalInner {
value: AtomicU64,
mutex: Mutex<()>,
condvar: Condvar,
}
impl Default for CpuTimelineSignal {
fn default() -> Self {
Self::new()
}
}
impl CpuTimelineSignal {
pub fn new() -> Self {
Self {
inner: Arc::new(CpuTimelineSignalInner {
value: AtomicU64::new(0),
mutex: Mutex::new(()),
condvar: Condvar::new(),
}),
}
}
pub fn with_initial(initial: u64) -> Self {
Self {
inner: Arc::new(CpuTimelineSignalInner {
value: AtomicU64::new(initial),
mutex: Mutex::new(()),
condvar: Condvar::new(),
}),
}
}
}
impl TimelineSignal for CpuTimelineSignal {
fn as_any(&self) -> &dyn Any {
self
}
fn value(&self) -> u64 {
self.inner.value.load(Ordering::Acquire)
}
fn set(&self, value: u64) {
let previous = self.inner.value.fetch_max(value, Ordering::AcqRel);
if value > previous {
self.inner.condvar.notify_all();
}
}
fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
if self.inner.value.load(Ordering::Acquire) >= target {
return Ok(());
}
let mut guard = self.inner.mutex.lock();
if timeout_ms == 0 {
while self.inner.value.load(Ordering::Acquire) < target {
self.inner.condvar.wait(&mut guard);
}
Ok(())
} else {
let deadline = Instant::now() + Duration::from_millis(timeout_ms);
while self.inner.value.load(Ordering::Acquire) < target {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
ensure!(
self.inner.value.load(Ordering::Acquire) >= target,
RuntimeSnafu {
message: format!(
"timeline signal timeout: waited {}ms for value {}, current {}",
timeout_ms,
target,
self.inner.value.load(Ordering::Acquire)
)
}
);
return Ok(());
}
let result = self.inner.condvar.wait_for(&mut guard, remaining);
if result.timed_out() && self.inner.value.load(Ordering::Acquire) < target {
return RuntimeSnafu {
message: format!(
"timeline signal timeout: waited {}ms for value {}, current {}",
timeout_ms,
target,
self.inner.value.load(Ordering::Acquire)
),
}
.fail();
}
}
Ok(())
}
}
}
#[cfg(feature = "cuda")]
pub mod cuda {
use std::any::Any;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use parking_lot::Mutex;
use super::TimelineSignal;
use crate::error::{CudaSnafu, Result};
use snafu::ResultExt;
const EVENT_RING_SIZE: usize = 64;
#[derive(Debug)]
struct EventSlot {
timeline_value: u64,
event: Arc<CudaEvent>,
}
#[derive(Debug)]
pub struct CudaTimelineSignal {
value: AtomicU64,
ring: Mutex<EventRing>,
context: Arc<CudaContext>,
stream: Arc<CudaStream>,
}
#[derive(Debug)]
struct EventRing {
slots: [Option<EventSlot>; EVENT_RING_SIZE],
next: usize,
}
impl EventRing {
fn new() -> Self {
Self { slots: std::array::from_fn(|_| None), next: 0 }
}
}
impl CudaTimelineSignal {
pub fn new(context: Arc<CudaContext>, stream: Arc<CudaStream>) -> Self {
Self { value: AtomicU64::new(0), ring: Mutex::new(EventRing::new()), context, stream }
}
pub fn record(&self, value: u64) -> Result<()> {
let event = self.context.create_event(None).context(CudaSnafu)?;
self.stream.record(&event).context(CudaSnafu)?;
let mut ring = self.ring.lock();
let slot_idx = ring.next;
ring.slots[slot_idx] = Some(EventSlot { timeline_value: value, event: Arc::new(event) });
ring.next = (slot_idx + 1) % EVENT_RING_SIZE;
drop(ring);
self.value.fetch_max(value, Ordering::AcqRel);
Ok(())
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl TimelineSignal for CudaTimelineSignal {
fn as_any(&self) -> &dyn Any {
self
}
fn value(&self) -> u64 {
self.value.load(Ordering::Acquire)
}
fn set(&self, value: u64) {
self.value.fetch_max(value, Ordering::AcqRel);
}
fn wait(&self, target: u64, timeout_ms: u64) -> Result<()> {
if self.value.load(Ordering::Acquire) >= target {
return Ok(());
}
let event = {
let ring = self.ring.lock();
ring.slots
.iter()
.filter_map(|slot| slot.as_ref().filter(|s| s.timeline_value >= target))
.min_by_key(|s| s.timeline_value)
.map(|s| Arc::clone(&s.event))
};
if let Some(event) = event {
if timeout_ms == 0 {
event.synchronize().context(CudaSnafu)?;
} else {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
while !event.is_ready() {
if start.elapsed() > timeout {
return crate::error::RuntimeSnafu {
message: format!(
"CUDA timeline signal timeout: waited {}ms for value {}",
timeout_ms, target
),
}
.fail();
}
std::thread::sleep(std::time::Duration::from_micros(100));
}
}
} else {
if self.value.load(Ordering::Acquire) >= target {
return Ok(());
}
let start = std::time::Instant::now();
let timeout = if timeout_ms == 0 {
std::time::Duration::MAX
} else {
std::time::Duration::from_millis(timeout_ms)
};
while self.value.load(Ordering::Acquire) < target {
if start.elapsed() > timeout {
return crate::error::RuntimeSnafu {
message: format!(
"CUDA timeline signal timeout: waited {}ms for value {}, current {}",
timeout_ms,
target,
self.value.load(Ordering::Acquire)
),
}
.fail();
}
std::thread::yield_now();
}
}
Ok(())
}
}
}
#[cfg(test)]
#[path = "test/unit/sync.rs"]
mod tests;