#![allow(unused_variables)]
use crate::cuda::error::{CudaError, CudaResult};
use crate::cuda::{CudaEvent, CudaStream};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OperationType {
Kernel, MemoryTransfer, Synchronization, Reduction, Broadcast, AllReduce, Barrier, }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum EventPriority {
Critical, High, Normal, Low, Cleanup, }
#[derive(Debug, Clone)]
pub struct EventMetadata {
pub operation_type: OperationType,
pub priority: EventPriority,
pub stream_id: u64,
pub operation_id: u64,
pub creation_time: Instant,
pub dependencies: Vec<u64>, pub description: String,
}
#[derive(Debug)]
pub struct EventPool {
available_events: Mutex<VecDeque<Arc<CudaEvent>>>,
timing_events: Mutex<VecDeque<Arc<CudaEvent>>>,
in_use: Mutex<HashSet<*const CudaEvent>>,
pool_size: usize,
timing_pool_size: usize,
}
impl EventPool {
pub fn new(pool_size: usize, timing_pool_size: usize) -> CudaResult<Self> {
let mut available_events = VecDeque::with_capacity(pool_size);
let mut timing_events = VecDeque::with_capacity(timing_pool_size);
for _ in 0..pool_size {
available_events.push_back(Arc::new(CudaEvent::new()?));
}
for _ in 0..timing_pool_size {
timing_events.push_back(Arc::new(CudaEvent::new_with_timing()?));
}
Ok(Self {
available_events: Mutex::new(available_events),
timing_events: Mutex::new(timing_events),
in_use: Mutex::new(HashSet::new()),
pool_size,
timing_pool_size,
})
}
pub fn acquire_event(&self, with_timing: bool) -> CudaResult<Arc<CudaEvent>> {
let event = if with_timing {
let mut timing_events = self
.timing_events
.lock()
.expect("lock should not be poisoned");
timing_events.pop_front().unwrap_or_else(|| {
Arc::new(CudaEvent::new_with_timing().expect("Failed to create timing event"))
})
} else {
let mut available_events = self
.available_events
.lock()
.expect("lock should not be poisoned");
available_events
.pop_front()
.unwrap_or_else(|| Arc::new(CudaEvent::new().expect("Failed to create event")))
};
let mut in_use = self.in_use.lock().expect("lock should not be poisoned");
in_use.insert(Arc::as_ptr(&event));
Ok(event)
}
pub fn release_event(&self, event: Arc<CudaEvent>) {
let event_ptr = Arc::as_ptr(&event);
let mut in_use = self.in_use.lock().expect("lock should not be poisoned");
in_use.remove(&event_ptr);
drop(in_use);
if event.timing_enabled() {
let mut timing_events = self
.timing_events
.lock()
.expect("lock should not be poisoned");
if timing_events.len() < self.timing_pool_size {
timing_events.push_back(event);
}
} else {
let mut available_events = self
.available_events
.lock()
.expect("lock should not be poisoned");
if available_events.len() < self.pool_size {
available_events.push_back(event);
}
}
}
pub fn utilization(&self) -> (usize, usize, usize) {
let available = self
.available_events
.lock()
.expect("lock should not be poisoned")
.len();
let timing = self
.timing_events
.lock()
.expect("lock should not be poisoned")
.len();
let in_use = self
.in_use
.lock()
.expect("lock should not be poisoned")
.len();
(available, timing, in_use)
}
}
pub struct OperationCoordinator {
operations: RwLock<HashMap<u64, EventMetadata>>,
operation_events: RwLock<HashMap<u64, Arc<CudaEvent>>>,
dependency_graph: RwLock<HashMap<u64, Vec<u64>>>,
reverse_dependencies: RwLock<HashMap<u64, Vec<u64>>>,
completion_callbacks: Mutex<HashMap<u64, Vec<Box<dyn FnOnce() + Send + 'static>>>>,
next_operation_id: std::sync::atomic::AtomicU64,
event_pool: Arc<EventPool>,
coordination_metrics: Mutex<CoordinationMetrics>,
}
impl std::fmt::Debug for OperationCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OperationCoordinator")
.field("operations", &self.operations)
.field("operation_events", &self.operation_events)
.field("dependency_graph", &self.dependency_graph)
.field("reverse_dependencies", &self.reverse_dependencies)
.field("completion_callbacks", &"<completion callbacks>")
.field("next_operation_id", &self.next_operation_id)
.field("event_pool", &self.event_pool)
.field("coordination_metrics", &self.coordination_metrics)
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct CoordinationMetrics {
pub total_operations: usize,
pub completed_operations: usize,
pub blocked_operations: usize,
pub average_coordination_time: Duration,
pub deadlock_detections: usize,
pub priority_inversions: usize,
}
impl OperationCoordinator {
pub fn new(event_pool: Arc<EventPool>) -> Self {
Self {
operations: RwLock::new(HashMap::new()),
operation_events: RwLock::new(HashMap::new()),
dependency_graph: RwLock::new(HashMap::new()),
reverse_dependencies: RwLock::new(HashMap::new()),
completion_callbacks: Mutex::new(HashMap::new()),
next_operation_id: std::sync::atomic::AtomicU64::new(1),
event_pool,
coordination_metrics: Mutex::new(CoordinationMetrics::default()),
}
}
pub fn register_operation(
&self,
operation_type: OperationType,
priority: EventPriority,
stream: &CudaStream,
dependencies: Vec<u64>,
description: String,
) -> CudaResult<u64> {
let operation_id = self
.next_operation_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let metadata = EventMetadata {
operation_type,
priority,
stream_id: stream.id(),
operation_id,
creation_time: Instant::now(),
dependencies: dependencies.clone(),
description,
};
let use_timing = matches!(
operation_type,
OperationType::Kernel | OperationType::MemoryTransfer
);
let event = self.event_pool.acquire_event(use_timing)?;
{
let mut operations = self
.operations
.write()
.expect("lock should not be poisoned");
operations.insert(operation_id, metadata);
}
{
let mut operation_events = self
.operation_events
.write()
.expect("lock should not be poisoned");
operation_events.insert(operation_id, event);
}
if !dependencies.is_empty() {
let mut dep_graph = self
.dependency_graph
.write()
.expect("lock should not be poisoned");
dep_graph.insert(operation_id, dependencies.clone());
let mut reverse_deps = self
.reverse_dependencies
.write()
.expect("lock should not be poisoned");
for dep_id in dependencies {
reverse_deps
.entry(dep_id)
.or_insert_with(Vec::new)
.push(operation_id);
}
}
{
let mut metrics = self
.coordination_metrics
.lock()
.expect("lock should not be poisoned");
metrics.total_operations += 1;
}
Ok(operation_id)
}
pub fn begin_operation(&self, operation_id: u64, stream: &CudaStream) -> CudaResult<()> {
let event = {
let operation_events = self
.operation_events
.read()
.expect("lock should not be poisoned");
operation_events
.get(&operation_id)
.cloned()
.ok_or_else(|| CudaError::Context {
message: format!("Operation {} not found", operation_id),
})?
};
self.wait_for_dependencies(operation_id)?;
event.record_on_stream(stream)?;
Ok(())
}
pub fn complete_operation(&self, operation_id: u64) -> CudaResult<()> {
let event = {
let operation_events = self
.operation_events
.read()
.expect("lock should not be poisoned");
operation_events
.get(&operation_id)
.cloned()
.ok_or_else(|| CudaError::Context {
message: format!("Operation {} not found", operation_id),
})?
};
event.synchronize()?;
let callbacks = {
let mut completion_callbacks = self
.completion_callbacks
.lock()
.expect("lock should not be poisoned");
completion_callbacks
.remove(&operation_id)
.unwrap_or_default()
};
for callback in callbacks {
callback();
}
{
let mut metrics = self
.coordination_metrics
.lock()
.expect("lock should not be poisoned");
metrics.completed_operations += 1;
}
self.cleanup_operation(operation_id)?;
Ok(())
}
pub fn wait_for_dependencies(&self, operation_id: u64) -> CudaResult<()> {
let dependencies = {
let dep_graph = self
.dependency_graph
.read()
.expect("lock should not be poisoned");
dep_graph.get(&operation_id).cloned().unwrap_or_default()
};
for dep_id in dependencies {
if let Some(dep_event) = self
.operation_events
.read()
.expect("lock should not be poisoned")
.get(&dep_id)
{
dep_event.synchronize()?;
}
}
Ok(())
}
pub fn add_completion_callback<F>(&self, operation_id: u64, callback: F)
where
F: FnOnce() + Send + 'static,
{
let mut completion_callbacks = self
.completion_callbacks
.lock()
.expect("lock should not be poisoned");
completion_callbacks
.entry(operation_id)
.or_insert_with(Vec::new)
.push(Box::new(callback));
}
pub fn detect_deadlocks(&self) -> Vec<Vec<u64>> {
let dep_graph = self
.dependency_graph
.read()
.expect("lock should not be poisoned");
let mut deadlocks = Vec::new();
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut current_path = Vec::new();
for &operation_id in dep_graph.keys() {
if !visited.contains(&operation_id) {
if let Some(cycle) = self.detect_cycle(
operation_id,
&dep_graph,
&mut visited,
&mut rec_stack,
&mut current_path,
) {
deadlocks.push(cycle);
let mut metrics = self
.coordination_metrics
.lock()
.expect("lock should not be poisoned");
metrics.deadlock_detections += 1;
}
}
}
deadlocks
}
fn detect_cycle(
&self,
operation_id: u64,
dep_graph: &HashMap<u64, Vec<u64>>,
visited: &mut HashSet<u64>,
rec_stack: &mut HashSet<u64>,
current_path: &mut Vec<u64>,
) -> Option<Vec<u64>> {
visited.insert(operation_id);
rec_stack.insert(operation_id);
current_path.push(operation_id);
if let Some(dependencies) = dep_graph.get(&operation_id) {
for &dep_id in dependencies {
if !visited.contains(&dep_id) {
if let Some(cycle) =
self.detect_cycle(dep_id, dep_graph, visited, rec_stack, current_path)
{
return Some(cycle);
}
} else if rec_stack.contains(&dep_id) {
let cycle_start = current_path
.iter()
.position(|&id| id == dep_id)
.expect("dep_id should exist in current_path as it's in rec_stack");
return Some(current_path[cycle_start..].to_vec());
}
}
}
rec_stack.remove(&operation_id);
current_path.pop();
None
}
pub fn metrics(&self) -> CoordinationMetrics {
self.coordination_metrics
.lock()
.expect("lock should not be poisoned")
.clone()
}
fn cleanup_operation(&self, operation_id: u64) -> CudaResult<()> {
{
let mut operations = self
.operations
.write()
.expect("lock should not be poisoned");
operations.remove(&operation_id);
}
if let Some(event) = self
.operation_events
.write()
.expect("lock should not be poisoned")
.remove(&operation_id)
{
self.event_pool.release_event(event);
}
{
let mut dep_graph = self
.dependency_graph
.write()
.expect("lock should not be poisoned");
dep_graph.remove(&operation_id);
}
{
let mut reverse_deps = self
.reverse_dependencies
.write()
.expect("lock should not be poisoned");
reverse_deps.remove(&operation_id);
}
Ok(())
}
}
#[derive(Debug)]
pub struct CrossStreamBarrier {
participants: Vec<Arc<CudaStream>>,
barrier_events: Vec<Arc<CudaEvent>>,
completion_event: Arc<CudaEvent>,
event_pool: Arc<EventPool>,
}
impl CrossStreamBarrier {
pub fn new(streams: Vec<Arc<CudaStream>>, event_pool: Arc<EventPool>) -> CudaResult<Self> {
let mut barrier_events = Vec::with_capacity(streams.len());
for _ in 0..streams.len() {
barrier_events.push(event_pool.acquire_event(false)?);
}
let completion_event = event_pool.acquire_event(true)?;
Ok(Self {
participants: streams,
barrier_events,
completion_event,
event_pool,
})
}
pub fn synchronize(&self) -> CudaResult<Duration> {
let start_time = Instant::now();
for (stream, event) in self.participants.iter().zip(self.barrier_events.iter()) {
event.record_on_stream(stream)?;
}
for event in &self.barrier_events {
event.synchronize()?;
}
if !self.participants.is_empty() {
self.completion_event
.record_on_stream(&self.participants[0])?;
self.completion_event.synchronize()?;
}
Ok(start_time.elapsed())
}
pub fn wait_on_stream(&self, stream: &CudaStream) -> CudaResult<()> {
for event in &self.barrier_events {
stream.wait_event(event)?;
}
Ok(())
}
}
impl Drop for CrossStreamBarrier {
fn drop(&mut self) {
for event in self.barrier_events.drain(..) {
self.event_pool.release_event(event);
}
self.event_pool.release_event(self.completion_event.clone());
}
}
pub struct AsyncEventWaiter {
pending_events: Arc<Mutex<HashMap<u64, (Arc<CudaEvent>, Box<dyn FnOnce() + Send + 'static>)>>>,
worker_handle: Option<thread::JoinHandle<()>>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
next_wait_id: std::sync::atomic::AtomicU64,
}
impl std::fmt::Debug for AsyncEventWaiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncEventWaiter")
.field("pending_events", &"<pending events with callbacks>")
.field("worker_handle", &self.worker_handle.is_some())
.field("shutdown", &self.shutdown)
.field("next_wait_id", &self.next_wait_id)
.finish()
}
}
impl AsyncEventWaiter {
pub fn new() -> Self {
let pending_events: Arc<
Mutex<HashMap<u64, (Arc<CudaEvent>, Box<dyn FnOnce() + Send + 'static>)>>,
> = Arc::new(Mutex::new(HashMap::new()));
let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let worker_events = Arc::clone(&pending_events);
let worker_shutdown = Arc::clone(&shutdown);
let worker_handle = thread::spawn(move || {
while !worker_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
let ready_callbacks: Vec<Box<dyn FnOnce() + Send + 'static>> = {
let mut events = worker_events.lock().expect("lock should not be poisoned");
let mut ready_ids = Vec::new();
for (&wait_id, (event, _)) in events.iter() {
if event.is_ready().unwrap_or(false) {
ready_ids.push(wait_id);
}
}
ready_ids
.into_iter()
.filter_map(|wait_id| events.remove(&wait_id).map(|(_, cb)| cb))
.collect()
};
for callback in ready_callbacks {
callback();
}
thread::sleep(Duration::from_micros(100));
}
});
Self {
pending_events,
worker_handle: Some(worker_handle),
shutdown,
next_wait_id: std::sync::atomic::AtomicU64::new(1),
}
}
pub fn wait_async<F>(&self, event: Arc<CudaEvent>, callback: F) -> u64
where
F: FnOnce() + Send + 'static,
{
let wait_id = self
.next_wait_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let mut pending = self
.pending_events
.lock()
.expect("lock should not be poisoned");
pending.insert(wait_id, (event, Box::new(callback)));
wait_id
}
pub fn cancel_wait(&self, wait_id: u64) -> bool {
let mut pending = self
.pending_events
.lock()
.expect("lock should not be poisoned");
pending.remove(&wait_id).is_some()
}
}
impl Drop for AsyncEventWaiter {
fn drop(&mut self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::Relaxed);
if let Some(handle) = self.worker_handle.take() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "Requires CUDA hardware - run with --ignored flag"]
fn test_event_pool() {
if crate::cuda::is_available() {
let _device =
Arc::new(crate::cuda::device::CudaDevice::new(0).expect("Arc should succeed"));
let pool = EventPool::new(4, 2).expect("Event Pool should succeed");
let event1 = pool
.acquire_event(false)
.expect("event acquisition should succeed");
let event2 = pool
.acquire_event(false)
.expect("event acquisition should succeed");
let (available, timing, in_use) = pool.utilization();
assert_eq!(in_use, 2);
pool.release_event(event1);
pool.release_event(event2);
let (available, timing, in_use) = pool.utilization();
assert_eq!(in_use, 0);
}
}
#[test]
#[ignore = "Requires CUDA hardware - run with --ignored flag"]
fn test_operation_coordinator() {
if crate::cuda::is_available() {
let _device =
Arc::new(crate::cuda::device::CudaDevice::new(0).expect("Arc should succeed"));
let event_pool = Arc::new(EventPool::new(10, 5).expect("Arc should succeed"));
let coordinator = OperationCoordinator::new(event_pool);
let stream = CudaStream::new().expect("Cuda Stream should succeed");
let op_id = coordinator
.register_operation(
OperationType::Kernel,
EventPriority::High,
&stream,
vec![],
"Test kernel".to_string(),
)
.expect("operation should succeed");
assert!(op_id > 0);
coordinator
.begin_operation(op_id, &stream)
.expect("operation begin should succeed");
coordinator
.complete_operation(op_id)
.expect("operation completion should succeed");
let metrics = coordinator.metrics();
assert_eq!(metrics.total_operations, 1);
assert_eq!(metrics.completed_operations, 1);
}
}
#[test]
fn test_cross_stream_barrier() {
if crate::cuda::is_available() {
let _device =
Arc::new(crate::cuda::device::CudaDevice::new(0).expect("Arc should succeed"));
let stream1 = Arc::new(CudaStream::new().expect("Arc should succeed"));
let stream2 = Arc::new(CudaStream::new().expect("Arc should succeed"));
let streams = vec![stream1, stream2];
let event_pool = Arc::new(EventPool::new(10, 5).expect("Arc should succeed"));
let barrier = CrossStreamBarrier::new(streams, event_pool)
.expect("Cross Stream Barrier should succeed");
let duration = barrier
.synchronize()
.expect("synchronization should succeed");
assert!(duration < Duration::from_secs(1));
}
}
#[test]
#[ignore = "Async event waiter has CUDA context threading issues - worker thread lacks context"]
fn test_async_event_waiter() {
if crate::cuda::is_available() {
let _device =
Arc::new(crate::cuda::device::CudaDevice::new(0).expect("Arc should succeed"));
let waiter = AsyncEventWaiter::new();
let stream = CudaStream::new().expect("Cuda Stream should succeed");
let event = Arc::new(CudaEvent::new().expect("Arc should succeed"));
stream
.record_event(&event)
.expect("event recording should succeed");
stream
.synchronize()
.expect("synchronization should succeed");
let callback_executed = Arc::new(std::sync::atomic::AtomicBool::new(false));
let callback_flag = Arc::clone(&callback_executed);
let wait_id = waiter.wait_async(event, move || {
callback_flag.store(true, std::sync::atomic::Ordering::Relaxed);
});
assert!(wait_id > 0);
thread::sleep(Duration::from_millis(500));
assert!(callback_executed.load(std::sync::atomic::Ordering::Relaxed));
}
}
}