use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::sync::{Arc, Barrier, Condvar, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct GpuEvent {
pub id: u64,
pub device_id: usize,
pub created_at: Instant,
pub completed: Arc<Mutex<bool>>,
pub completion_notifier: Arc<(Mutex<bool>, Condvar)>,
}
#[derive(Debug)]
pub struct GpuStream {
pub id: u64,
pub device_id: usize,
pub priority: StreamPriority,
pub events: Vec<GpuEvent>,
pub operation_queue: Arc<Mutex<Vec<StreamOperation>>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum StreamPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub enum StreamOperation {
MemoryCopy {
src_device: usize,
dst_device: usize,
size: usize,
},
KernelExecution {
kernel_name: String,
device_id: usize,
grid_size: (u32, u32, u32),
block_size: (u32, u32, u32),
},
Barrier { group_name: String },
EventRecord { event_id: u64 },
EventWait { event_id: u64 },
}
pub struct MultiGpuBarrier {
num_gpus: usize,
gpu_barriers: HashMap<usize, Arc<Barrier>>,
completion_counter: Arc<Mutex<usize>>,
completion_cv: Arc<Condvar>,
timeout: Duration,
}
impl MultiGpuBarrier {
pub fn new(gpu_ids: Vec<usize>, timeout: Duration) -> Self {
let num_gpus = gpu_ids.len();
let mut gpu_barriers = HashMap::new();
for gpu_id in gpu_ids {
gpu_barriers.insert(gpu_id, Arc::new(Barrier::new(1)));
}
Self {
num_gpus,
gpu_barriers,
completion_counter: Arc::new(Mutex::new(0)),
completion_cv: Arc::new(Condvar::new()),
timeout,
}
}
pub fn wait(&self, gpu_id: usize) -> RusTorchResult<()> {
let start_time = Instant::now();
if let Some(barrier) = self.gpu_barriers.get(&gpu_id) {
barrier.wait();
}
{
let mut counter = self.completion_counter.lock().unwrap();
*counter += 1;
if *counter >= self.num_gpus {
self.completion_cv.notify_all();
*counter = 0; return Ok(());
}
}
let cv = &*self.completion_cv;
let mut completed = self.completion_counter.lock().unwrap();
loop {
let elapsed = start_time.elapsed();
if elapsed >= self.timeout {
return Err(RusTorchError::gpu(format!(
"Multi-GPU barrier timeout after {:?}",
elapsed
)));
}
let remaining = self.timeout - elapsed;
let (_guard, timeout_result) = cv.wait_timeout(completed, remaining).unwrap();
completed = self.completion_counter.lock().unwrap();
if timeout_result.timed_out() {
return Err(RusTorchError::gpu("Multi-GPU barrier wait timeout"));
}
if *completed >= self.num_gpus {
break;
}
}
Ok(())
}
pub fn reset(&self) {
let mut counter = self.completion_counter.lock().unwrap();
*counter = 0;
}
}
pub struct StreamManager {
streams: HashMap<usize, Vec<GpuStream>>,
events: HashMap<u64, GpuEvent>,
next_stream_id: Arc<Mutex<u64>>,
next_event_id: Arc<Mutex<u64>>,
}
impl StreamManager {
pub fn new() -> Self {
Self {
streams: HashMap::new(),
events: HashMap::new(),
next_stream_id: Arc::new(Mutex::new(0)),
next_event_id: Arc::new(Mutex::new(0)),
}
}
pub fn create_stream(
&mut self,
device_id: usize,
priority: StreamPriority,
) -> RusTorchResult<u64> {
let mut stream_id_guard = self.next_stream_id.lock().unwrap();
let stream_id = *stream_id_guard;
*stream_id_guard += 1;
drop(stream_id_guard);
let stream = GpuStream {
id: stream_id,
device_id,
priority,
events: Vec::new(),
operation_queue: Arc::new(Mutex::new(Vec::new())),
};
self.streams
.entry(device_id)
.or_insert_with(Vec::new)
.push(stream);
Ok(stream_id)
}
pub fn create_event(&mut self, device_id: usize) -> RusTorchResult<u64> {
let mut event_id_guard = self.next_event_id.lock().unwrap();
let event_id = *event_id_guard;
*event_id_guard += 1;
drop(event_id_guard);
let event = GpuEvent {
id: event_id,
device_id,
created_at: Instant::now(),
completed: Arc::new(Mutex::new(false)),
completion_notifier: Arc::new((Mutex::new(false), Condvar::new())),
};
self.events.insert(event_id, event);
Ok(event_id)
}
pub fn record_event(&mut self, stream_id: u64, event_id: u64) -> RusTorchResult<()> {
if let Some(event) = self.events.get(&event_id) {
let mut completed = event.completed.lock().unwrap();
*completed = true;
let (lock, cv) = &*event.completion_notifier;
let mut notified = lock.lock().unwrap();
*notified = true;
cv.notify_all();
}
Ok(())
}
pub fn wait_event(&self, event_id: u64, timeout: Option<Duration>) -> RusTorchResult<()> {
if let Some(event) = self.events.get(&event_id) {
let (lock, cv) = &*event.completion_notifier;
let notified = lock.lock().unwrap();
if let Some(timeout_duration) = timeout {
let (_notified, timeout_result) =
cv.wait_timeout(notified, timeout_duration).unwrap();
if timeout_result.timed_out() {
return Err(RusTorchError::gpu(format!(
"Event {} wait timeout",
event_id
)));
}
} else {
let _notified = cv.wait(notified).unwrap();
}
}
Ok(())
}
pub fn synchronize_device(&self, device_id: usize) -> RusTorchResult<()> {
if let Some(streams) = self.streams.get(&device_id) {
for stream in streams {
let queue = stream.operation_queue.lock().unwrap();
for operation in queue.iter() {
match operation {
StreamOperation::EventWait { event_id } => {
self.wait_event(*event_id, Some(Duration::from_secs(30)))?;
}
_ => {
}
}
}
}
}
Ok(())
}
pub fn query_event(&self, event_id: u64) -> bool {
if let Some(event) = self.events.get(&event_id) {
let completed = event.completed.lock().unwrap();
*completed
} else {
false
}
}
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}