use super::pools::{MemoryPool, MemoryPoolStats};
use crate::{Result, TensorError};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub struct MultiStreamMemoryManager {
pools: Vec<MemoryPool>,
stream_assignment: Arc<Mutex<HashMap<usize, usize>>>, current_stream: Arc<Mutex<usize>>,
}
impl MultiStreamMemoryManager {
#[cfg(feature = "gpu")]
pub fn new(device_id: usize, num_streams: usize, pool_size_per_stream: usize) -> Result<Self> {
let mut pools = Vec::new();
for _ in 0..num_streams {
pools.push(MemoryPool::new(device_id, pool_size_per_stream)?);
}
Ok(Self {
pools,
stream_assignment: Arc::new(Mutex::new(HashMap::new())),
current_stream: Arc::new(Mutex::new(0)),
})
}
pub fn get_pool(&self, operation_id: usize) -> Result<&MemoryPool> {
let stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
let stream_id = if let Some(&stream_id) = stream_assignment.get(&operation_id) {
stream_id
} else {
let mut current_stream = self
.current_stream
.lock()
.expect("lock should not be poisoned");
let stream_id = *current_stream;
*current_stream = (*current_stream + 1) % self.pools.len();
stream_id
};
self.pools
.get(stream_id)
.ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
}
pub fn assign_operation_to_stream(&self, operation_id: usize, stream_id: usize) -> Result<()> {
if stream_id >= self.pools.len() {
return Err(TensorError::invalid_argument(format!(
"Stream ID {} out of range. Available streams: {}",
stream_id,
self.pools.len()
)));
}
let mut stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
stream_assignment.insert(operation_id, stream_id);
Ok(())
}
pub fn unassign_operation(&self, operation_id: usize) {
let mut stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
stream_assignment.remove(&operation_id);
}
pub fn get_operation_stream(&self, operation_id: usize) -> Option<usize> {
let stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
stream_assignment.get(&operation_id).copied()
}
pub fn num_streams(&self) -> usize {
self.pools.len()
}
pub fn get_pool_by_stream(&self, stream_id: usize) -> Result<&MemoryPool> {
self.pools
.get(stream_id)
.ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
}
pub fn stats(&self) -> Vec<MemoryPoolStats> {
self.pools.iter().map(|pool| pool.stats()).collect()
}
pub fn stream_stats(&self, stream_id: usize) -> Result<MemoryPoolStats> {
self.pools
.get(stream_id)
.map(|pool| pool.stats())
.ok_or_else(|| TensorError::invalid_argument(format!("Invalid stream ID: {stream_id}")))
}
pub fn total_memory_usage(&self) -> (usize, usize) {
let mut total_allocated = 0;
let mut total_free = 0;
for pool in &self.pools {
let stats = pool.stats();
total_allocated += stats.total_allocated;
total_free += stats.total_free;
}
(total_allocated, total_free)
}
pub fn get_least_loaded_stream(&self) -> usize {
let mut min_load = usize::MAX;
let mut best_stream = 0;
for (i, pool) in self.pools.iter().enumerate() {
let stats = pool.stats();
if stats.total_allocated < min_load {
min_load = stats.total_allocated;
best_stream = i;
}
}
best_stream
}
pub fn get_stream_with_most_free_memory(&self) -> usize {
let mut max_free = 0;
let mut best_stream = 0;
for (i, pool) in self.pools.iter().enumerate() {
let stats = pool.stats();
if stats.total_free > max_free {
max_free = stats.total_free;
best_stream = i;
}
}
best_stream
}
pub fn balance_streams(&self) -> Result<usize> {
let mut reassignments = 0;
let target_load = {
let (total_allocated, _) = self.total_memory_usage();
total_allocated / self.pools.len()
};
let mut stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
let mut overloaded_streams = Vec::new();
let mut underloaded_streams = Vec::new();
for (i, pool) in self.pools.iter().enumerate() {
let stats = pool.stats();
if stats.total_allocated > target_load * 11 / 10 {
overloaded_streams.push(i);
} else if stats.total_allocated < target_load * 9 / 10 {
underloaded_streams.push(i);
}
}
let operations_to_reassign: Vec<_> = stream_assignment
.iter()
.filter(|(_, &stream_id)| overloaded_streams.contains(&stream_id))
.map(|(&op_id, &stream_id)| (op_id, stream_id))
.collect();
for (op_id, _old_stream) in operations_to_reassign {
if let Some(&new_stream) = underloaded_streams.first() {
stream_assignment.insert(op_id, new_stream);
reassignments += 1;
underloaded_streams.rotate_left(1);
}
}
Ok(reassignments)
}
pub fn generate_streams_report(&self) -> String {
let mut report = String::new();
report.push_str("=== Multi-Stream Memory Manager Report ===\n\n");
let (total_allocated, total_free) = self.total_memory_usage();
report.push_str(&format!(
"Total Memory - Allocated: {} bytes, Free: {} bytes\n",
total_allocated, total_free
));
report.push_str(&format!("Number of Streams: {}\n\n", self.pools.len()));
for (i, pool) in self.pools.iter().enumerate() {
let stats = pool.stats();
report.push_str(&format!("Stream {}:\n", i));
report.push_str(&format!(" Allocated: {} bytes\n", stats.total_allocated));
report.push_str(&format!(" Free: {} bytes\n", stats.total_free));
report.push_str(&format!(" Blocks Allocated: {}\n", stats.blocks_allocated));
report.push_str(&format!(" Blocks Free: {}\n", stats.blocks_free));
report.push_str(&format!(
" Fragmentation Ratio: {:.2}\n",
stats.fragmentation_ratio
));
report.push_str(&format!(
" Memory Pressure: {:.2}%\n",
stats.memory_pressure * 100.0
));
report.push('\n');
}
let stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
if !stream_assignment.is_empty() {
report.push_str("Operation Assignments:\n");
for (op_id, stream_id) in stream_assignment.iter() {
report.push_str(&format!(" Operation {}: Stream {}\n", op_id, stream_id));
}
}
report
}
pub fn clear_assignments(&self) {
let mut stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
stream_assignment.clear();
}
pub fn get_operation_counts(&self) -> Vec<usize> {
let stream_assignment = self
.stream_assignment
.lock()
.expect("lock should not be poisoned");
let mut counts = vec![0; self.pools.len()];
for &stream_id in stream_assignment.values() {
if stream_id < counts.len() {
counts[stream_id] += 1;
}
}
counts
}
pub fn are_streams_balanced(&self, tolerance_percent: f32) -> bool {
let (total_allocated, _) = self.total_memory_usage();
if total_allocated == 0 {
return true; }
let target_load = total_allocated / self.pools.len();
let tolerance = (target_load as f32 * tolerance_percent / 100.0) as usize;
for pool in &self.pools {
let stats = pool.stats();
let deviation = stats.total_allocated.abs_diff(target_load);
if deviation > tolerance {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_assignment() {
let _assignments: HashMap<usize, usize> = HashMap::new();
let current_stream = 0;
let num_streams = 3;
let mut stream_id = current_stream;
for i in 0..6 {
let assigned_stream = stream_id;
stream_id = (stream_id + 1) % num_streams;
assert_eq!(assigned_stream, i % num_streams);
}
}
#[test]
fn test_load_balancing_logic() {
let target_load = 1000;
let tolerance = target_load / 10;
let overloaded = 1200;
assert!(overloaded > target_load + tolerance);
let underloaded = 800;
assert!(underloaded < target_load - tolerance);
let balanced = 950;
assert!(balanced >= target_load - tolerance && balanced <= target_load + tolerance);
}
#[test]
fn test_stream_balancing_calculation() {
let total_allocated = 3000;
let num_streams = 3;
let target_load = total_allocated / num_streams;
assert_eq!(target_load, 1000);
let tolerance_percent = 10.0;
let tolerance = (target_load as f32 * tolerance_percent / 100.0) as usize;
assert_eq!(tolerance, 100);
let stream_load: usize = 1150;
let deviation = stream_load.abs_diff(target_load);
assert_eq!(deviation, 150);
assert!(deviation > tolerance); }
#[test]
fn test_operation_count_tracking() {
let mut counts = vec![0; 3]; let assignments = vec![(1, 0), (2, 1), (3, 0), (4, 2), (5, 1)];
for (_, stream_id) in assignments {
if stream_id < counts.len() {
counts[stream_id] += 1;
}
}
assert_eq!(counts, vec![2, 2, 1]); }
#[test]
fn test_memory_usage_aggregation() {
let stream_stats = vec![
(500, 1500), (800, 1200),
(300, 1700),
];
let mut total_allocated = 0;
let mut total_free = 0;
for (allocated, free) in stream_stats {
total_allocated += allocated;
total_free += free;
}
assert_eq!(total_allocated, 1600);
assert_eq!(total_free, 4400);
}
#[test]
fn test_least_loaded_stream_selection() {
let stream_loads = [1200, 800, 1000];
let mut min_load = usize::MAX;
let mut best_stream = 0;
for (i, &load) in stream_loads.iter().enumerate() {
if load < min_load {
min_load = load;
best_stream = i;
}
}
assert_eq!(best_stream, 1); assert_eq!(min_load, 800);
}
#[test]
fn test_most_free_memory_selection() {
let stream_free_memory = [500, 1200, 800];
let mut max_free = 0;
let mut best_stream = 0;
for (i, &free) in stream_free_memory.iter().enumerate() {
if free > max_free {
max_free = free;
best_stream = i;
}
}
assert_eq!(best_stream, 1); assert_eq!(max_free, 1200);
}
}