#[cfg(feature = "advanced_math")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "advanced_math")]
use crate::error::Result;
#[cfg(feature = "advanced_math")]
use super::context::CudaEvent;
#[cfg(feature = "advanced_math")]
pub type CudaStreamHandle = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamPriority {
Low,
Normal,
High,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamFlags {
Default,
NonBlocking,
}
#[cfg(feature = "advanced_math")]
pub struct CudaStream {
id: usize,
handle: Arc<Mutex<Option<CudaStreamHandle>>>,
priority: StreamPriority,
flags: StreamFlags,
}
#[cfg(feature = "advanced_math")]
impl CudaStream {
pub fn new() -> Result<Self> {
Self::with_priority_and_flags(StreamPriority::Normal, StreamFlags::Default)
}
pub fn with_priority_and_flags(priority: StreamPriority, flags: StreamFlags) -> Result<Self> {
let id = Self::allocate_stream_id();
let handle = Arc::new(Mutex::new(Self::create_cuda_stream(priority, flags)?));
Ok(Self {
id,
handle,
priority,
flags,
})
}
pub fn get_id(&self) -> usize {
self.id
}
pub fn get_priority(&self) -> StreamPriority {
self.priority
}
pub fn get_flags(&self) -> StreamFlags {
self.flags
}
pub fn synchronize(&self) -> Result<()> {
let handle = self.handle.lock().map_err(|e| {
crate::error::SimulatorError::InvalidOperation(format!("Lock poisoned: {e}"))
})?;
if let Some(cuda_handle) = *handle {
Self::cuda_stream_synchronize(cuda_handle)?;
}
Ok(())
}
pub fn query(&self) -> Result<bool> {
let handle = self.handle.lock().map_err(|e| {
crate::error::SimulatorError::InvalidOperation(format!("Lock poisoned: {e}"))
})?;
if let Some(cuda_handle) = *handle {
Self::cuda_stream_query(cuda_handle)
} else {
Ok(true) }
}
pub fn record_event(&self, event: &mut CudaEvent) -> Result<()> {
let handle = self.handle.lock().map_err(|e| {
crate::error::SimulatorError::InvalidOperation(format!("Lock poisoned: {e}"))
})?;
if let Some(cuda_handle) = *handle {
Self::cuda_event_record(*event, cuda_handle)?;
}
Ok(())
}
pub fn wait_event(&self, event: CudaEvent) -> Result<()> {
let handle = self.handle.lock().map_err(|e| {
crate::error::SimulatorError::InvalidOperation(format!("Lock poisoned: {e}"))
})?;
if let Some(cuda_handle) = *handle {
Self::cuda_stream_wait_event(cuda_handle, event)?;
}
Ok(())
}
pub fn is_ready(&self) -> Result<bool> {
self.query()
}
fn allocate_stream_id() -> usize {
use std::sync::atomic::{AtomicUsize, Ordering};
static STREAM_COUNTER: AtomicUsize = AtomicUsize::new(0);
STREAM_COUNTER.fetch_add(1, Ordering::SeqCst)
}
fn create_cuda_stream(
priority: StreamPriority,
flags: StreamFlags,
) -> Result<Option<CudaStreamHandle>> {
let _priority_val = match priority {
StreamPriority::Low => -1,
StreamPriority::Normal => 0,
StreamPriority::High => 1,
};
let _flags_val = match flags {
StreamFlags::Default => 0,
StreamFlags::NonBlocking => 1,
};
Ok(Some(Self::allocate_stream_id()))
}
fn cuda_stream_synchronize(_handle: CudaStreamHandle) -> Result<()> {
std::thread::sleep(std::time::Duration::from_micros(10));
Ok(())
}
fn cuda_stream_query(_handle: CudaStreamHandle) -> Result<bool> {
Ok(true)
}
fn cuda_event_record(_event: CudaEvent, _stream: CudaStreamHandle) -> Result<()> {
Ok(())
}
fn cuda_stream_wait_event(_stream: CudaStreamHandle, _event: CudaEvent) -> Result<()> {
Ok(())
}
pub fn get_handle(&self) -> Arc<Mutex<Option<CudaStreamHandle>>> {
Arc::clone(&self.handle)
}
pub fn get_handle_value(&self) -> Option<CudaStreamHandle> {
let handle = self.handle.lock().ok()?;
*handle
}
}
#[cfg(feature = "advanced_math")]
impl Drop for CudaStream {
fn drop(&mut self) {
if let Ok(handle) = self.handle.lock() {
if let Some(cuda_handle) = *handle {
let _ = Self::cuda_stream_destroy(cuda_handle);
}
}
}
}
#[cfg(feature = "advanced_math")]
impl CudaStream {
fn cuda_stream_destroy(_handle: CudaStreamHandle) -> Result<()> {
Ok(())
}
}
#[cfg(feature = "advanced_math")]
pub struct CudaStreamPool {
streams: Vec<CudaStream>,
current_index: std::sync::atomic::AtomicUsize,
}
#[cfg(feature = "advanced_math")]
impl CudaStreamPool {
pub fn new(num_streams: usize) -> Result<Self> {
let mut streams = Vec::new();
for _ in 0..num_streams {
streams.push(CudaStream::new()?);
}
Ok(Self {
streams,
current_index: std::sync::atomic::AtomicUsize::new(0),
})
}
pub fn get_next_stream(&self) -> &CudaStream {
let index = self
.current_index
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
&self.streams[index % self.streams.len()]
}
pub fn get_stream(&self, index: usize) -> Option<&CudaStream> {
self.streams.get(index)
}
pub fn synchronize_all(&self) -> Result<()> {
for stream in &self.streams {
stream.synchronize()?;
}
Ok(())
}
pub fn len(&self) -> usize {
self.streams.len()
}
pub fn is_empty(&self) -> bool {
self.streams.is_empty()
}
}