use std::collections::HashMap;
use std::sync::Mutex;
use super::GpuError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamId(pub u64);
impl StreamId {
pub fn default_stream() -> Self {
StreamId(0)
}
pub fn new(id: u64) -> Self {
StreamId(id)
}
}
impl std::fmt::Display for StreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Stream({})", self.0)
}
}
struct StreamArena {
stream_id: StreamId,
allocated_bytes: usize,
max_bytes: usize,
allocations: Vec<Vec<u8>>,
}
impl StreamArena {
fn new(stream_id: StreamId, max_bytes: usize) -> Self {
Self {
stream_id,
allocated_bytes: 0,
max_bytes,
allocations: Vec::new(),
}
}
fn allocate(&mut self, bytes: usize) -> Result<*mut u8, GpuError> {
if bytes == 0 {
return Err(GpuError::InvalidParameter(
"allocation size must be > 0".to_string(),
));
}
if self.allocated_bytes + bytes > self.max_bytes {
return Err(GpuError::OutOfMemory(format!(
"{}: {} bytes requested, {} bytes available",
self.stream_id,
bytes,
self.max_bytes.saturating_sub(self.allocated_bytes),
)));
}
let mut buf = vec![0u8; bytes];
let ptr = buf.as_mut_ptr();
self.allocations.push(buf);
self.allocated_bytes += bytes;
Ok(ptr)
}
fn reset(&mut self) {
self.allocations.clear();
self.allocated_bytes = 0;
}
fn allocated_bytes(&self) -> usize {
self.allocated_bytes
}
}
pub struct StreamAllocator {
arenas: Mutex<HashMap<StreamId, StreamArena>>,
per_stream_max_bytes: usize,
global_max_bytes: usize,
}
impl StreamAllocator {
pub fn new(per_stream_max_bytes: usize, global_max_bytes: usize) -> Self {
Self {
arenas: Mutex::new(HashMap::new()),
per_stream_max_bytes,
global_max_bytes,
}
}
pub fn register_stream(&self, stream_id: StreamId) -> Result<(), GpuError> {
let mut arenas = self.arenas.lock().map_err(|_| {
GpuError::Other("StreamAllocator mutex poisoned during register_stream".to_string())
})?;
if arenas.contains_key(&stream_id) {
return Err(GpuError::InvalidParameter(format!(
"stream {stream_id} is already registered",
)));
}
arenas.insert(
stream_id,
StreamArena::new(stream_id, self.per_stream_max_bytes),
);
Ok(())
}
pub fn unregister_stream(&self, stream_id: StreamId) {
if let Ok(mut arenas) = self.arenas.lock() {
arenas.remove(&stream_id);
}
}
pub fn allocate(&self, stream_id: StreamId, bytes: usize) -> Result<*mut u8, GpuError> {
let mut arenas = self.arenas.lock().map_err(|_| {
GpuError::Other("StreamAllocator mutex poisoned during allocate".to_string())
})?;
let total: usize = arenas.values().map(|a| a.allocated_bytes()).sum();
if total + bytes > self.global_max_bytes {
return Err(GpuError::OutOfMemory(format!(
"global limit: {} bytes requested, {} bytes available",
bytes,
self.global_max_bytes.saturating_sub(total),
)));
}
let arena = arenas.get_mut(&stream_id).ok_or_else(|| {
GpuError::InvalidParameter(format!(
"stream {stream_id} is not registered; call register_stream first",
))
})?;
arena.allocate(bytes)
}
pub fn reset_stream(&self, stream_id: StreamId) {
if let Ok(mut arenas) = self.arenas.lock() {
if let Some(arena) = arenas.get_mut(&stream_id) {
arena.reset();
}
}
}
pub fn total_allocated_bytes(&self) -> usize {
self.arenas
.lock()
.map(|arenas| arenas.values().map(|a| a.allocated_bytes()).sum())
.unwrap_or(0)
}
pub fn registered_streams(&self) -> Vec<StreamId> {
self.arenas
.lock()
.map(|arenas| arenas.keys().copied().collect())
.unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_allocator_register() {
let alloc = StreamAllocator::new(1024, 8192);
let s0 = StreamId::new(0);
let s1 = StreamId::new(1);
alloc.register_stream(s0).expect("register s0 failed");
alloc.register_stream(s1).expect("register s1 failed");
let streams = alloc.registered_streams();
assert!(streams.contains(&s0), "s0 should be registered");
assert!(streams.contains(&s1), "s1 should be registered");
assert_eq!(streams.len(), 2);
}
#[test]
fn test_stream_allocator_register_duplicate() {
let alloc = StreamAllocator::new(1024, 8192);
let s0 = StreamId::default_stream();
alloc.register_stream(s0).expect("first register failed");
let result = alloc.register_stream(s0);
assert!(result.is_err(), "duplicate registration should return Err");
}
#[test]
fn test_stream_allocator_allocate() {
let alloc = StreamAllocator::new(1024 * 1024, 8 * 1024 * 1024);
let s0 = StreamId::new(10);
let s1 = StreamId::new(20);
alloc.register_stream(s0).expect("register s0");
alloc.register_stream(s1).expect("register s1");
let p0 = alloc.allocate(s0, 512).expect("allocate on s0");
let p1 = alloc.allocate(s1, 512).expect("allocate on s1");
assert!(!p0.is_null(), "s0 pointer should not be null");
assert!(!p1.is_null(), "s1 pointer should not be null");
assert_ne!(p0, p1, "pointers from different streams should differ");
assert_eq!(alloc.total_allocated_bytes(), 1024);
}
#[test]
fn test_stream_allocator_overflow() {
let alloc = StreamAllocator::new(256, 8192);
let s = StreamId::new(5);
alloc.register_stream(s).expect("register");
alloc
.allocate(s, 200)
.expect("first allocation should succeed");
let result = alloc.allocate(s, 200);
assert!(
matches!(result, Err(GpuError::OutOfMemory(_))),
"expected OutOfMemory, got {result:?}"
);
}
#[test]
fn test_stream_allocator_global_overflow() {
let alloc = StreamAllocator::new(200, 300);
let s0 = StreamId::new(0);
let s1 = StreamId::new(1);
alloc.register_stream(s0).expect("register s0");
alloc.register_stream(s1).expect("register s1");
alloc.allocate(s0, 200).expect("first allocation");
let result = alloc.allocate(s1, 200);
assert!(
matches!(result, Err(GpuError::OutOfMemory(_))),
"expected global OutOfMemory"
);
}
#[test]
fn test_stream_allocator_reset() {
let alloc = StreamAllocator::new(1024, 8192);
let s = StreamId::new(99);
alloc.register_stream(s).expect("register");
alloc.allocate(s, 512).expect("allocate");
assert_eq!(alloc.total_allocated_bytes(), 512);
alloc.reset_stream(s);
assert_eq!(alloc.total_allocated_bytes(), 0, "reset should clear bytes");
}
#[test]
fn test_stream_allocator_unregister() {
let alloc = StreamAllocator::new(1024, 8192);
let s = StreamId::new(7);
alloc.register_stream(s).expect("register");
assert_eq!(alloc.registered_streams().len(), 1);
alloc.unregister_stream(s);
assert!(
alloc.registered_streams().is_empty(),
"stream should be gone after unregister"
);
let result = alloc.allocate(s, 64);
assert!(
result.is_err(),
"allocate on unregistered stream should fail"
);
}
#[test]
fn test_stream_id_default() {
assert_eq!(StreamId::default_stream(), StreamId(0));
}
#[test]
fn test_stream_allocator_zero_size_rejected() {
let alloc = StreamAllocator::new(1024, 8192);
let s = StreamId::new(3);
alloc.register_stream(s).expect("register");
let result = alloc.allocate(s, 0);
assert!(result.is_err(), "zero-size allocation should fail");
}
}