use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use crate::gpu_backend::{ComputeBackend, DeviceStream, GpuError};
#[derive(Clone)]
pub struct GpuStream {
pub(crate) inner: DeviceStream,
pub(crate) label: Arc<String>,
}
impl GpuStream {
pub fn new(backend: &Arc<dyn ComputeBackend>, label: impl Into<String>) -> Result<Self, GpuError> {
let inner = backend.create_stream()?;
Ok(Self {
inner,
label: Arc::new(label.into()),
})
}
pub fn id(&self) -> u64 { self.inner.id() }
pub fn synchronize(&self) -> Result<(), GpuError> {
self.inner.synchronize()
}
pub fn label(&self) -> &str { &self.label }
}
impl fmt::Debug for GpuStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "GpuStream(id={}, label={})", self.id(), self.label)
}
}
pub struct StreamPool {
streams: Vec<Arc<StreamSlot>>,
strategy: StreamAssignment,
next_robin: AtomicUsize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamAssignment {
RoundRobin,
LeastPending,
}
struct StreamSlot {
stream: GpuStream,
pending_ops: AtomicU64,
}
impl fmt::Debug for StreamPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "StreamPool({} streams, {:?})", self.streams.len(), self.strategy)
}
}
impl StreamPool {
pub fn new(
backend: &Arc<dyn ComputeBackend>,
count: usize,
strategy: StreamAssignment,
) -> Result<Self, GpuError> {
if count == 0 {
return Err(GpuError::new(backend.kind(), "StreamPool count must be > 0"));
}
let mut streams = Vec::with_capacity(count);
for i in 0..count {
let label = format!("{}-stream-{}", backend.kind(), i);
let stream = GpuStream::new(backend, label)?;
streams.push(Arc::new(StreamSlot {
stream,
pending_ops: AtomicU64::new(0),
}));
}
Ok(Self {
streams,
strategy,
next_robin: AtomicUsize::new(0),
})
}
pub fn len(&self) -> usize { self.streams.len() }
pub fn is_empty(&self) -> bool { self.streams.is_empty() }
pub fn acquire(&self) -> Result<StreamGuard<'_>, GpuError> {
let slot = match self.strategy {
StreamAssignment::RoundRobin => {
let idx = self.next_robin.fetch_add(1, Ordering::Relaxed) % self.streams.len();
Arc::clone(&self.streams[idx])
}
StreamAssignment::LeastPending => {
self.streams
.iter()
.min_by_key(|s| s.pending_ops.load(Ordering::Relaxed))
.map(Arc::clone)
.unwrap() }
};
slot.pending_ops.fetch_add(1, Ordering::Relaxed);
Ok(StreamGuard { slot, _marker: std::marker::PhantomData })
}
pub fn synchronize_all(&self) -> Result<(), GpuError> {
for slot in &self.streams {
slot.stream.synchronize()?;
}
Ok(())
}
pub fn iter(&self) -> impl Iterator<Item = &GpuStream> {
self.streams.iter().map(|s| &s.stream)
}
}
pub struct StreamGuard<'pool> {
slot: Arc<StreamSlot>,
_marker: std::marker::PhantomData<&'pool ()>,
}
impl<'pool> StreamGuard<'pool> {
pub fn stream(&self) -> &GpuStream { &self.slot.stream }
pub fn record_op(&self) { self.slot.pending_ops.fetch_add(1, Ordering::Relaxed); }
pub fn pending_ops(&self) -> u64 { self.slot.pending_ops.load(Ordering::Relaxed) }
}
impl Drop for StreamGuard<'_> {
fn drop(&mut self) {
let prev = self.slot.pending_ops.load(Ordering::Relaxed);
if prev > 0 {
self.slot.pending_ops.fetch_sub(1, Ordering::Relaxed);
}
}
}
impl fmt::Debug for StreamGuard<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "StreamGuard({:?})", self.slot.stream)
}
}
#[derive(Debug, Clone)]
pub struct StreamSet {
streams: Vec<GpuStream>,
depth: usize,
}
impl StreamSet {
pub fn new(
backend: &Arc<dyn ComputeBackend>,
depth: usize,
prefix: &str,
) -> Result<Self, GpuError> {
let mut streams = Vec::with_capacity(depth);
for i in 0..depth {
streams.push(GpuStream::new(backend, format!("{}-{}", prefix, i))?);
}
Ok(Self { streams, depth })
}
pub fn get(&self, index: usize) -> &GpuStream {
&self.streams[index % self.depth]
}
pub fn depth(&self) -> usize { self.depth }
pub fn synchronize_all(&self) -> Result<(), GpuError> {
for s in &self.streams {
s.synchronize()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu_backend::stub::StubBackend;
fn stub_backend() -> Arc<dyn ComputeBackend> {
Arc::new(StubBackend::new(0))
}
#[test]
fn test_stream_pool_round_robin() {
let backend = stub_backend();
let pool = StreamPool::new(&backend, 4, StreamAssignment::RoundRobin).unwrap();
assert_eq!(pool.len(), 4);
let g0 = pool.acquire().unwrap();
let g1 = pool.acquire().unwrap();
assert_ne!(g0.stream().id(), g1.stream().id());
}
#[test]
fn test_stream_pool_least_pending() {
let backend = stub_backend();
let pool = StreamPool::new(&backend, 3, StreamAssignment::LeastPending).unwrap();
let g = pool.acquire().unwrap();
g.record_op();
g.record_op();
let g2 = pool.acquire().unwrap();
assert!(g2.pending_ops() < g.pending_ops());
}
#[test]
fn test_stream_guard_decrements_on_drop() {
let backend = stub_backend();
let pool = StreamPool::new(&backend, 2, StreamAssignment::LeastPending).unwrap();
{
let g = pool.acquire().unwrap();
assert_eq!(g.pending_ops(), 1); }
let g2 = pool.acquire().unwrap();
assert_eq!(g2.pending_ops(), 1);
}
#[test]
fn test_stream_set() {
let backend = stub_backend();
let set = StreamSet::new(&backend, 3, "compute").unwrap();
assert_eq!(set.depth(), 3);
assert_eq!(set.get(0).id(), set.get(3).id());
assert_ne!(set.get(0).id(), set.get(1).id());
set.synchronize_all().unwrap();
}
#[test]
fn test_pool_synchronize_all() {
let backend = stub_backend();
let pool = StreamPool::new(&backend, 4, StreamAssignment::RoundRobin).unwrap();
pool.synchronize_all().unwrap();
}
}