use crate::device::DeviceType;
use crate::error::Result;
use std::collections::HashMap;
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct DeviceEvent {
device: DeviceType,
inner: Arc<EventInner>,
}
#[derive(Debug)]
struct EventInner {
state: Mutex<EventState>,
cond: Condvar,
recorded_time: Mutex<Option<Instant>>,
completed_time: Mutex<Option<Instant>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum EventState {
Created,
Recorded,
Completed,
}
impl DeviceEvent {
pub fn new(device: DeviceType) -> Result<Self> {
Ok(DeviceEvent {
device,
inner: Arc::new(EventInner {
state: Mutex::new(EventState::Created),
cond: Condvar::new(),
recorded_time: Mutex::new(None),
completed_time: Mutex::new(None),
}),
})
}
pub fn device(&self) -> DeviceType {
self.device
}
pub fn record(&self) -> Result<()> {
let mut state = self
.inner
.state
.lock()
.expect("lock should not be poisoned");
*state = EventState::Recorded;
*self
.inner
.recorded_time
.lock()
.expect("lock should not be poisoned") = Some(Instant::now());
self.complete_async();
Ok(())
}
pub fn wait(&self) -> Result<()> {
let mut state = self
.inner
.state
.lock()
.expect("lock should not be poisoned");
while *state != EventState::Completed {
state = self
.inner
.cond
.wait(state)
.expect("condvar wait should not be poisoned");
}
Ok(())
}
pub fn wait_timeout(&self, timeout: Duration) -> Result<bool> {
let mut state = self
.inner
.state
.lock()
.expect("lock should not be poisoned");
while *state != EventState::Completed {
let (new_state, timeout_result) = self
.inner
.cond
.wait_timeout(state, timeout)
.expect("condvar wait_timeout should not be poisoned");
state = new_state;
if timeout_result.timed_out() {
return Ok(false);
}
}
Ok(true)
}
pub fn query(&self) -> Result<bool> {
let state = self
.inner
.state
.lock()
.expect("lock should not be poisoned");
Ok(*state == EventState::Completed)
}
pub fn elapsed_time(&self) -> Option<Duration> {
let recorded = self
.inner
.recorded_time
.lock()
.expect("lock should not be poisoned");
let completed = self
.inner
.completed_time
.lock()
.expect("lock should not be poisoned");
match (*recorded, *completed) {
(Some(start), Some(end)) => Some(end.duration_since(start)),
_ => None,
}
}
pub fn reset(&self) -> Result<()> {
let mut state = self
.inner
.state
.lock()
.expect("lock should not be poisoned");
*state = EventState::Created;
*self
.inner
.recorded_time
.lock()
.expect("lock should not be poisoned") = None;
*self
.inner
.completed_time
.lock()
.expect("lock should not be poisoned") = None;
Ok(())
}
fn complete_async(&self) {
let inner = self.inner.clone();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(1));
let mut state = inner.state.lock().expect("lock should not be poisoned");
*state = EventState::Completed;
*inner
.completed_time
.lock()
.expect("lock should not be poisoned") = Some(Instant::now());
inner.cond.notify_all();
});
}
}
impl Clone for DeviceEvent {
fn clone(&self) -> Self {
DeviceEvent {
device: self.device,
inner: self.inner.clone(),
}
}
}
#[derive(Debug)]
pub struct DeviceStream {
device: DeviceType,
id: u64,
priority: StreamPriority,
inner: Arc<StreamInner>,
}
struct StreamInner {
operation_queue: Mutex<Vec<Box<dyn FnOnce() + Send + 'static>>>,
is_synchronizing: Mutex<bool>,
sync_cond: Condvar,
}
impl std::fmt::Debug for StreamInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamInner")
.field("operation_queue", &"<operation_queue>")
.field("is_synchronizing", &self.is_synchronizing)
.field("sync_cond", &"<condvar>")
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamPriority {
Low = 0,
Normal = 1,
High = 2,
}
impl DeviceStream {
pub fn new(device: DeviceType) -> Result<Self> {
Self::with_priority(device, StreamPriority::Normal)
}
pub fn with_priority(device: DeviceType, priority: StreamPriority) -> Result<Self> {
static STREAM_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
Ok(DeviceStream {
device,
id: STREAM_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
priority,
inner: Arc::new(StreamInner {
operation_queue: Mutex::new(Vec::new()),
is_synchronizing: Mutex::new(false),
sync_cond: Condvar::new(),
}),
})
}
pub fn device(&self) -> DeviceType {
self.device
}
pub fn id(&self) -> u64 {
self.id
}
pub fn priority(&self) -> StreamPriority {
self.priority
}
pub fn submit_operation<F>(&self, operation: F) -> Result<()>
where
F: FnOnce() + Send + 'static,
{
let mut queue = self
.inner
.operation_queue
.lock()
.expect("lock should not be poisoned");
queue.push(Box::new(operation));
self.process_operations_async();
Ok(())
}
pub fn synchronize(&self) -> Result<()> {
let mut is_sync = self
.inner
.is_synchronizing
.lock()
.expect("lock should not be poisoned");
while !self.is_empty() || *is_sync {
is_sync = self
.inner
.sync_cond
.wait(is_sync)
.expect("condvar wait should not be poisoned");
}
Ok(())
}
pub fn is_empty(&self) -> bool {
let queue = self
.inner
.operation_queue
.lock()
.expect("lock should not be poisoned");
queue.is_empty()
}
pub fn pending_operations(&self) -> usize {
let queue = self
.inner
.operation_queue
.lock()
.expect("lock should not be poisoned");
queue.len()
}
pub fn record_event(&self) -> Result<DeviceEvent> {
let event = DeviceEvent::new(self.device)?;
event.record()?;
Ok(event)
}
pub fn wait_event(&self, event: &DeviceEvent) -> Result<()> {
if event.device() != self.device {
return Err(crate::error::TorshError::InvalidArgument(
"Event device does not match stream device".to_string(),
));
}
event.wait()
}
fn process_operations_async(&self) {
let inner = self.inner.clone();
std::thread::spawn(move || {
{
let mut is_sync = inner
.is_synchronizing
.lock()
.expect("lock should not be poisoned");
*is_sync = true;
}
loop {
let operation = {
let mut queue = inner
.operation_queue
.lock()
.expect("lock should not be poisoned");
queue.pop()
};
match operation {
Some(op) => {
op(); }
None => break,
}
}
{
let mut is_sync = inner
.is_synchronizing
.lock()
.expect("lock should not be poisoned");
*is_sync = false;
inner.sync_cond.notify_all();
}
});
}
}
#[derive(Debug)]
pub struct DeviceBarrier {
devices: Vec<DeviceType>,
inner: Arc<BarrierInner>,
}
#[derive(Debug)]
struct BarrierInner {
count: Mutex<usize>,
total: usize,
generation: Mutex<usize>,
cond: Condvar,
arrived_devices: Mutex<Vec<DeviceType>>,
}
impl DeviceBarrier {
pub fn new(devices: Vec<DeviceType>) -> Result<Self> {
let total = devices.len();
if total == 0 {
return Err(crate::error::TorshError::InvalidArgument(
"Barrier must have at least one device".to_string(),
));
}
Ok(DeviceBarrier {
devices: devices.clone(),
inner: Arc::new(BarrierInner {
count: Mutex::new(0),
total,
generation: Mutex::new(0),
cond: Condvar::new(),
arrived_devices: Mutex::new(Vec::new()),
}),
})
}
pub fn wait(&self, device: DeviceType) -> Result<()> {
if !self.devices.contains(&device) {
return Err(crate::error::TorshError::InvalidArgument(format!(
"Device {:?} is not part of this barrier",
device
)));
}
let mut count = self
.inner
.count
.lock()
.expect("lock should not be poisoned");
let mut arrived = self
.inner
.arrived_devices
.lock()
.expect("lock should not be poisoned");
let generation = *self
.inner
.generation
.lock()
.expect("lock should not be poisoned");
if arrived.contains(&device) {
return Err(crate::error::TorshError::InvalidArgument(
"Device already waiting at barrier".to_string(),
));
}
arrived.push(device);
*count += 1;
if *count == self.inner.total {
*count = 0;
arrived.clear();
let mut gen = self
.inner
.generation
.lock()
.expect("lock should not be poisoned");
*gen += 1;
drop(gen);
self.inner.cond.notify_all();
Ok(())
} else {
while *self
.inner
.generation
.lock()
.expect("lock should not be poisoned")
== generation
{
count = self
.inner
.cond
.wait(count)
.expect("condvar wait should not be poisoned");
}
Ok(())
}
}
pub fn devices(&self) -> &[DeviceType] {
&self.devices
}
pub fn arrived_count(&self) -> usize {
let arrived = self
.inner
.arrived_devices
.lock()
.expect("lock should not be poisoned");
arrived.len()
}
pub fn is_complete(&self) -> bool {
self.arrived_count() == self.inner.total
}
}
pub struct DeviceAsync<T> {
#[allow(dead_code)] device: DeviceType,
_phantom: std::marker::PhantomData<T>,
}
impl<T> DeviceAsync<T>
where
T: Send + 'static,
{
pub fn new(device: DeviceType) -> Self {
Self {
device,
_phantom: std::marker::PhantomData,
}
}
pub fn execute<F>(self, operation: F) -> std::thread::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
{
std::thread::spawn(operation)
}
}
#[derive(Debug)]
pub struct DeviceMutex<T> {
device: DeviceType,
data: Arc<Mutex<T>>,
}
impl<T> DeviceMutex<T> {
pub fn new(device: DeviceType, data: T) -> Self {
Self {
device,
data: Arc::new(Mutex::new(data)),
}
}
pub fn device(&self) -> DeviceType {
self.device
}
pub fn lock(&self) -> Result<std::sync::MutexGuard<'_, T>> {
self.data.lock().map_err(|_| {
crate::error::TorshError::DeviceError("Failed to acquire device mutex".to_string())
})
}
pub fn try_lock(&self) -> Result<Option<std::sync::MutexGuard<'_, T>>> {
match self.data.try_lock() {
Ok(guard) => Ok(Some(guard)),
Err(std::sync::TryLockError::WouldBlock) => Ok(None),
Err(_) => Err(crate::error::TorshError::DeviceError(
"Device mutex is poisoned".to_string(),
)),
}
}
}
impl<T> Clone for DeviceMutex<T> {
fn clone(&self) -> Self {
Self {
device: self.device,
data: self.data.clone(),
}
}
}
#[derive(Debug)]
pub struct DeviceSyncManager {
streams: Mutex<HashMap<(DeviceType, u64), Arc<DeviceStream>>>,
events: Mutex<HashMap<(DeviceType, u64), Arc<DeviceEvent>>>,
barriers: Mutex<Vec<Arc<DeviceBarrier>>>,
}
impl DeviceSyncManager {
pub fn new() -> Self {
Self {
streams: Mutex::new(HashMap::new()),
events: Mutex::new(HashMap::new()),
barriers: Mutex::new(Vec::new()),
}
}
pub fn register_stream(&self, stream: Arc<DeviceStream>) {
let mut streams = self.streams.lock().expect("lock should not be poisoned");
streams.insert((stream.device(), stream.id()), stream);
}
pub fn get_stream(&self, device: DeviceType, id: u64) -> Option<Arc<DeviceStream>> {
let streams = self.streams.lock().expect("lock should not be poisoned");
streams.get(&(device, id)).cloned()
}
pub fn synchronize_device(&self, device: DeviceType) -> Result<()> {
let streams = self.streams.lock().expect("lock should not be poisoned");
let device_streams: Vec<_> = streams
.values()
.filter(|stream| stream.device() == device)
.cloned()
.collect();
drop(streams);
for stream in device_streams {
stream.synchronize()?;
}
Ok(())
}
pub fn create_barrier(&self, devices: Vec<DeviceType>) -> Result<Arc<DeviceBarrier>> {
let barrier = Arc::new(DeviceBarrier::new(devices)?);
let mut barriers = self.barriers.lock().expect("lock should not be poisoned");
barriers.push(barrier.clone());
Ok(barrier)
}
pub fn statistics(&self) -> SyncStatistics {
let streams = self.streams.lock().expect("lock should not be poisoned");
let events = self.events.lock().expect("lock should not be poisoned");
let barriers = self.barriers.lock().expect("lock should not be poisoned");
let total_pending_ops: usize = streams
.values()
.map(|stream| stream.pending_operations())
.sum();
SyncStatistics {
total_streams: streams.len(),
total_events: events.len(),
total_barriers: barriers.len(),
pending_operations: total_pending_ops,
}
}
}
impl Default for DeviceSyncManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SyncStatistics {
pub total_streams: usize,
pub total_events: usize,
pub total_barriers: usize,
pub pending_operations: usize,
}
impl std::fmt::Display for SyncStatistics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SyncStats(streams={}, events={}, barriers={}, pending_ops={})",
self.total_streams, self.total_events, self.total_barriers, self.pending_operations
)
}
}
pub mod utils {
use super::*;
pub fn create_global_barrier(devices: &[DeviceType]) -> Result<DeviceBarrier> {
DeviceBarrier::new(devices.to_vec())
}
pub fn synchronize_streams(streams: &[&DeviceStream]) -> Result<()> {
for stream in streams {
stream.synchronize()?;
}
Ok(())
}
pub fn wait_events(events: &[&DeviceEvent]) -> Result<()> {
for event in events {
event.wait()?;
}
Ok(())
}
pub fn all_events_complete(events: &[&DeviceEvent]) -> Result<bool> {
for event in events {
if !event.query()? {
return Ok(false);
}
}
Ok(true)
}
pub fn measure_event_times(events: &[&DeviceEvent]) -> Vec<Option<Duration>> {
events.iter().map(|event| event.elapsed_time()).collect()
}
pub fn create_event_chain(device: DeviceType, count: usize) -> Result<Vec<DeviceEvent>> {
let mut events = Vec::new();
for _ in 0..count {
let event = DeviceEvent::new(device)?;
events.push(event);
}
Ok(events)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_device_event_basic() {
let event = DeviceEvent::new(DeviceType::Cpu).expect("event creation should succeed");
assert_eq!(event.device(), DeviceType::Cpu);
assert!(!event.query().expect("query should succeed"));
event.record().expect("record should succeed");
event.wait().expect("wait should succeed");
assert!(event.query().expect("query should succeed"));
}
#[test]
fn test_device_event_timeout() {
let event = DeviceEvent::new(DeviceType::Cpu).expect("event creation should succeed");
event.record().expect("record should succeed");
let completed = event
.wait_timeout(Duration::from_millis(100))
.expect("wait_timeout should succeed");
assert!(completed); }
#[test]
fn test_device_stream() {
let stream = DeviceStream::new(DeviceType::Cpu).expect("stream creation should succeed");
assert_eq!(stream.device(), DeviceType::Cpu);
assert_eq!(stream.priority(), StreamPriority::Normal);
let executed = Arc::new(Mutex::new(false));
let executed_clone = executed.clone();
stream
.submit_operation(move || {
*executed_clone.lock().expect("lock should not be poisoned") = true;
})
.expect("submit_operation should succeed");
stream.synchronize().expect("synchronize should succeed");
assert!(*executed.lock().expect("lock should not be poisoned"));
}
#[test]
fn test_device_barrier() {
let devices = vec![DeviceType::Cpu, DeviceType::Cuda(0)];
let barrier = DeviceBarrier::new(devices.clone()).expect("barrier creation should succeed");
assert_eq!(barrier.devices(), &devices);
assert_eq!(barrier.arrived_count(), 0);
assert!(!barrier.is_complete());
}
#[test]
fn test_device_mutex() {
let mutex = DeviceMutex::new(DeviceType::Cpu, 42);
assert_eq!(mutex.device(), DeviceType::Cpu);
{
let guard = mutex.lock().expect("lock should not be poisoned");
assert_eq!(*guard, 42);
}
let try_guard = mutex.try_lock().expect("try_lock should succeed");
assert!(try_guard.is_some());
assert_eq!(*try_guard.expect("guard should be Some"), 42);
}
#[test]
fn test_sync_manager() {
let manager = DeviceSyncManager::new();
let stream =
Arc::new(DeviceStream::new(DeviceType::Cpu).expect("stream creation should succeed"));
let stream_id = stream.id();
manager.register_stream(stream.clone());
let retrieved = manager.get_stream(DeviceType::Cpu, stream_id);
assert!(retrieved.is_some());
assert_eq!(retrieved.expect("stream should be found").id(), stream_id);
let stats = manager.statistics();
assert_eq!(stats.total_streams, 1);
}
#[test]
fn test_stream_priorities() {
let high_stream = DeviceStream::with_priority(DeviceType::Cpu, StreamPriority::High)
.expect("stream creation should succeed");
let low_stream = DeviceStream::with_priority(DeviceType::Cpu, StreamPriority::Low)
.expect("stream creation should succeed");
assert_eq!(high_stream.priority(), StreamPriority::High);
assert_eq!(low_stream.priority(), StreamPriority::Low);
}
#[test]
fn test_event_reset() {
let event = DeviceEvent::new(DeviceType::Cpu).expect("event creation should succeed");
event.record().expect("record should succeed");
event.wait().expect("wait should succeed");
assert!(event.query().expect("query should succeed"));
event.reset().expect("reset should succeed");
assert!(!event.query().expect("query should succeed"));
}
#[test]
fn test_utils_functions() {
let event1 = DeviceEvent::new(DeviceType::Cpu).expect("event creation should succeed");
let event2 = DeviceEvent::new(DeviceType::Cpu).expect("event creation should succeed");
let events = vec![&event1, &event2];
event1.record().expect("record should succeed");
event2.record().expect("record should succeed");
utils::wait_events(&events).expect("wait_events should succeed");
assert!(utils::all_events_complete(&events).expect("all_events_complete should succeed"));
let times = utils::measure_event_times(&events);
assert_eq!(times.len(), 2);
}
#[tokio::test]
async fn test_device_async() {
let async_op = DeviceAsync::new(DeviceType::Cpu);
let result = async_op
.execute(|| 42)
.join()
.expect("async execute should succeed");
assert_eq!(result, 42);
}
}