use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use anyhow::{Result, anyhow};
use candle_core::{Device, Tensor as CandleTensor};
use ronn_core::{DataType, MemoryType, TensorBuffer};
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct MultiGpuMemoryManager {
device_pools: HashMap<usize, Arc<Mutex<DeviceMemoryPool>>>,
p2p_matrix: Arc<RwLock<P2PConnectivityMatrix>>,
global_stats: Arc<Mutex<GlobalMemoryStats>>,
sync_manager: Arc<SyncManager>,
config: MultiGpuMemoryConfig,
}
#[derive(Debug, Clone)]
pub struct MultiGpuMemoryConfig {
pub enable_p2p: bool,
pub pool_size_per_device: usize,
pub enable_unified_memory: bool,
pub sync_strategy: SyncStrategy,
pub max_p2p_transfer_size: usize,
pub enable_compression: bool,
}
impl Default for MultiGpuMemoryConfig {
fn default() -> Self {
Self {
enable_p2p: true,
pool_size_per_device: 2 * 1024 * 1024 * 1024, enable_unified_memory: true,
sync_strategy: SyncStrategy::Explicit,
max_p2p_transfer_size: 256 * 1024 * 1024, enable_compression: false,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum SyncStrategy {
Explicit,
Automatic,
StreamBased,
}
#[derive(Debug)]
pub struct DeviceMemoryPool {
device_id: usize,
available_blocks: Vec<MemoryBlock>,
allocated_blocks: HashMap<usize, MemoryBlock>,
stats: MemoryPoolStats,
next_alloc_id: usize,
}
#[derive(Debug, Clone)]
pub struct MemoryBlock {
alloc_id: usize,
device_id: usize,
size: usize,
alignment: usize,
virtual_address: usize,
p2p_accessible: bool,
data_type: DataType,
ref_count: usize,
}
#[derive(Debug, Default, Clone)]
pub struct MemoryPoolStats {
pub total_size: usize,
pub allocated_bytes: usize,
pub available_bytes: usize,
pub active_allocations: usize,
pub peak_usage: usize,
pub p2p_transfers_out: u64,
pub p2p_transfers_in: u64,
pub total_p2p_bytes: u64,
}
#[derive(Debug)]
pub struct P2PConnectivityMatrix {
connectivity: HashMap<(usize, usize), P2PCapability>,
bandwidth_matrix: HashMap<(usize, usize), f64>,
latency_matrix: HashMap<(usize, usize), f64>,
}
#[derive(Debug, Clone, Copy)]
pub struct P2PCapability {
pub supported: bool,
pub bandwidth_gbps: f64,
pub latency_us: f64,
pub is_nvlink: bool,
}
#[derive(Debug, Default, Clone)]
pub struct GlobalMemoryStats {
pub total_memory: usize,
pub allocated_memory: usize,
pub fragmentation_percent: f32,
pub cross_device_transfers: u64,
pub total_transfer_bytes: u64,
pub avg_transfer_bandwidth_gbps: f64,
}
#[derive(Debug)]
pub struct SyncManager {
sync_events: Arc<Mutex<HashMap<usize, Vec<SyncEvent>>>>,
stream_deps: Arc<Mutex<HashMap<usize, Vec<usize>>>>,
strategy: SyncStrategy,
}
#[derive(Debug, Clone)]
pub struct SyncEvent {
pub event_id: usize,
pub device_id: usize,
pub event_type: SyncEventType,
pub timestamp: std::time::Instant,
pub completed: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum SyncEventType {
Allocation,
TransferStart,
TransferComplete,
KernelStart,
KernelComplete,
DeviceSync,
}
impl MultiGpuMemoryManager {
pub fn new(device_ids: Vec<usize>, config: MultiGpuMemoryConfig) -> Result<Self> {
let mut device_pools = HashMap::new();
for &device_id in &device_ids {
let pool = DeviceMemoryPool::new(device_id, config.pool_size_per_device);
device_pools.insert(device_id, Arc::new(Mutex::new(pool)));
info!("Initialized memory pool for device {}", device_id);
}
let p2p_matrix = Arc::new(RwLock::new(P2PConnectivityMatrix::discover_connectivity(
&device_ids,
)?));
let sync_manager = Arc::new(SyncManager::new(config.sync_strategy));
let global_stats = Arc::new(Mutex::new(GlobalMemoryStats::default()));
info!(
"Created multi-GPU memory manager for {} devices",
device_ids.len()
);
Ok(Self {
device_pools,
p2p_matrix,
global_stats,
sync_manager,
config,
})
}
pub fn allocate_on_device(
&self,
device_id: usize,
size: usize,
alignment: usize,
data_type: DataType,
) -> Result<MemoryBlock> {
let pool = self
.device_pools
.get(&device_id)
.ok_or_else(|| anyhow!("Device {} not found", device_id))?;
let mut pool = pool.lock().unwrap();
let block = pool.allocate(size, alignment, data_type)?;
self.update_global_stats();
if matches!(
self.config.sync_strategy,
SyncStrategy::Automatic | SyncStrategy::StreamBased
) {
let event = SyncEvent {
event_id: self.generate_event_id(),
device_id,
event_type: SyncEventType::Allocation,
timestamp: std::time::Instant::now(),
completed: true,
};
self.sync_manager.record_event(event);
}
debug!("Allocated {} bytes on device {}", size, device_id);
Ok(block)
}
pub fn deallocate(&self, block: MemoryBlock) -> Result<()> {
let pool = self
.device_pools
.get(&block.device_id)
.ok_or_else(|| anyhow!("Device {} not found", block.device_id))?;
let mut pool = pool.lock().unwrap();
pool.deallocate(block)?;
self.update_global_stats();
Ok(())
}
pub fn transfer_between_devices(
&self,
src_block: &MemoryBlock,
dst_device_id: usize,
size: usize,
) -> Result<MemoryBlock> {
let src_device_id = src_block.device_id;
let p2p_matrix = self.p2p_matrix.read().unwrap();
let p2p_capability = p2p_matrix.get_capability(src_device_id, dst_device_id);
if self.config.enable_p2p && p2p_capability.supported {
self.transfer_p2p(src_block, dst_device_id, size)
} else {
self.transfer_via_host(src_block, dst_device_id, size)
}
}
pub fn synchronize_all(&self) -> Result<()> {
debug!("Synchronizing all devices");
match self.config.sync_strategy {
SyncStrategy::Explicit => {
for &device_id in self.device_pools.keys() {
self.synchronize_device(device_id)?;
}
}
SyncStrategy::Automatic => {
}
SyncStrategy::StreamBased => {
self.sync_manager.synchronize_streams()?;
}
}
info!("All devices synchronized");
Ok(())
}
pub fn get_memory_stats(&self) -> HashMap<usize, MemoryPoolStats> {
let mut stats = HashMap::new();
for (&device_id, pool) in &self.device_pools {
let pool = pool.lock().unwrap();
stats.insert(device_id, pool.stats.clone());
}
stats
}
pub fn get_global_stats(&self) -> GlobalMemoryStats {
let stats = self.global_stats.lock().unwrap();
(*stats).clone()
}
pub fn get_p2p_info(&self) -> HashMap<(usize, usize), P2PCapability> {
let matrix = self.p2p_matrix.read().unwrap();
matrix.connectivity.clone()
}
pub fn optimize_memory_layout(&self, access_pattern: &AccessPattern) -> Result<MemoryLayout> {
let mut layout = MemoryLayout::new();
for allocation in &access_pattern.allocations {
let optimal_device = self.select_optimal_device(allocation)?;
layout.assignments.insert(allocation.id, optimal_device);
}
Ok(layout)
}
fn transfer_p2p(
&self,
src_block: &MemoryBlock,
dst_device_id: usize,
size: usize,
) -> Result<MemoryBlock> {
debug!(
"P2P transfer from device {} to device {}",
src_block.device_id, dst_device_id
);
let start_event = SyncEvent {
event_id: self.generate_event_id(),
device_id: src_block.device_id,
event_type: SyncEventType::TransferStart,
timestamp: std::time::Instant::now(),
completed: false,
};
self.sync_manager.record_event(start_event.clone());
let dst_block = self.allocate_on_device(
dst_device_id,
size,
src_block.alignment,
src_block.data_type,
)?;
let transfer_time = self.simulate_p2p_transfer(src_block.device_id, dst_device_id, size)?;
std::thread::sleep(transfer_time);
let complete_event = SyncEvent {
event_id: start_event.event_id,
device_id: dst_device_id,
event_type: SyncEventType::TransferComplete,
timestamp: std::time::Instant::now(),
completed: true,
};
self.sync_manager.record_event(complete_event);
self.update_p2p_stats(src_block.device_id, dst_device_id, size);
Ok(dst_block)
}
fn transfer_via_host(
&self,
src_block: &MemoryBlock,
dst_device_id: usize,
size: usize,
) -> Result<MemoryBlock> {
debug!(
"Host transfer from device {} to device {}",
src_block.device_id, dst_device_id
);
let transfer_time = std::time::Duration::from_micros(size as u64 / 1000); std::thread::sleep(transfer_time);
let dst_block = self.allocate_on_device(
dst_device_id,
size,
src_block.alignment,
src_block.data_type,
)?;
Ok(dst_block)
}
fn synchronize_device(&self, device_id: usize) -> Result<()> {
debug!("Synchronizing device {}", device_id);
std::thread::sleep(std::time::Duration::from_micros(10));
Ok(())
}
fn simulate_p2p_transfer(
&self,
src_device: usize,
dst_device: usize,
size: usize,
) -> Result<std::time::Duration> {
let p2p_matrix = self.p2p_matrix.read().unwrap();
let capability = p2p_matrix.get_capability(src_device, dst_device);
let bandwidth_bps = capability.bandwidth_gbps * 1_000_000_000.0;
let transfer_time_s = size as f64 / bandwidth_bps;
let latency_s = capability.latency_us / 1_000_000.0;
let total_time_s = transfer_time_s + latency_s;
Ok(std::time::Duration::from_secs_f64(total_time_s))
}
fn update_global_stats(&self) {
let mut global = self.global_stats.lock().unwrap();
global.total_memory = 0;
global.allocated_memory = 0;
for pool in self.device_pools.values() {
let pool_stats = &pool.lock().unwrap().stats;
global.total_memory += pool_stats.total_size;
global.allocated_memory += pool_stats.allocated_bytes;
global.cross_device_transfers += pool_stats.p2p_transfers_out;
global.total_transfer_bytes += pool_stats.total_p2p_bytes;
}
if global.total_memory > 0 {
let used_percent = global.allocated_memory as f32 / global.total_memory as f32;
global.fragmentation_percent = (used_percent * 100.0).min(100.0);
}
}
fn update_p2p_stats(&self, src_device: usize, dst_device: usize, bytes: usize) {
if let Some(src_pool) = self.device_pools.get(&src_device) {
let mut pool = src_pool.lock().unwrap();
pool.stats.p2p_transfers_out += 1;
pool.stats.total_p2p_bytes += bytes as u64;
}
if let Some(dst_pool) = self.device_pools.get(&dst_device) {
let mut pool = dst_pool.lock().unwrap();
pool.stats.p2p_transfers_in += 1;
}
}
fn generate_event_id(&self) -> usize {
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
fn select_optimal_device(&self, allocation: &AllocationRequest) -> Result<usize> {
let mut best_device = allocation.preferred_devices.get(0).copied().unwrap_or(0);
let mut best_available = 0;
for &device_id in &allocation.preferred_devices {
if let Some(pool) = self.device_pools.get(&device_id) {
let pool = pool.lock().unwrap();
if pool.stats.available_bytes > best_available {
best_available = pool.stats.available_bytes;
best_device = device_id;
}
}
}
Ok(best_device)
}
}
impl DeviceMemoryPool {
fn new(device_id: usize, total_size: usize) -> Self {
Self {
device_id,
available_blocks: vec![],
allocated_blocks: HashMap::new(),
stats: MemoryPoolStats {
total_size,
allocated_bytes: 0,
available_bytes: total_size,
active_allocations: 0,
peak_usage: 0,
p2p_transfers_out: 0,
p2p_transfers_in: 0,
total_p2p_bytes: 0,
},
next_alloc_id: 1,
}
}
fn allocate(
&mut self,
size: usize,
alignment: usize,
data_type: DataType,
) -> Result<MemoryBlock> {
if size > self.stats.available_bytes {
return Err(anyhow!(
"Insufficient memory on device {}: requested {}, available {}",
self.device_id,
size,
self.stats.available_bytes
));
}
let alloc_id = self.next_alloc_id;
self.next_alloc_id += 1;
let block = MemoryBlock {
alloc_id,
device_id: self.device_id,
size,
alignment,
virtual_address: alloc_id * 0x1000, p2p_accessible: true, data_type,
ref_count: 1,
};
self.allocated_blocks.insert(alloc_id, block.clone());
self.stats.allocated_bytes += size;
self.stats.available_bytes -= size;
self.stats.active_allocations += 1;
self.stats.peak_usage = self.stats.peak_usage.max(self.stats.allocated_bytes);
Ok(block)
}
fn deallocate(&mut self, block: MemoryBlock) -> Result<()> {
if let Some(stored_block) = self.allocated_blocks.remove(&block.alloc_id) {
self.stats.allocated_bytes -= stored_block.size;
self.stats.available_bytes += stored_block.size;
self.stats.active_allocations -= 1;
Ok(())
} else {
Err(anyhow!("Block not found for deallocation"))
}
}
}
impl P2PConnectivityMatrix {
fn discover_connectivity(device_ids: &[usize]) -> Result<Self> {
let mut connectivity = HashMap::new();
let mut bandwidth_matrix = HashMap::new();
let mut latency_matrix = HashMap::new();
for &src in device_ids {
for &dst in device_ids {
if src != dst {
let is_nvlink = (src.abs_diff(dst)) == 1; let capability = P2PCapability {
supported: true,
bandwidth_gbps: if is_nvlink { 300.0 } else { 16.0 }, latency_us: if is_nvlink { 1.0 } else { 5.0 },
is_nvlink,
};
connectivity.insert((src, dst), capability);
bandwidth_matrix.insert((src, dst), capability.bandwidth_gbps);
latency_matrix.insert((src, dst), capability.latency_us);
}
}
}
Ok(Self {
connectivity,
bandwidth_matrix,
latency_matrix,
})
}
fn get_capability(&self, src: usize, dst: usize) -> P2PCapability {
self.connectivity
.get(&(src, dst))
.copied()
.unwrap_or(P2PCapability {
supported: false,
bandwidth_gbps: 0.0,
latency_us: f64::INFINITY,
is_nvlink: false,
})
}
}
impl SyncManager {
fn new(strategy: SyncStrategy) -> Self {
Self {
sync_events: Arc::new(Mutex::new(HashMap::new())),
stream_deps: Arc::new(Mutex::new(HashMap::new())),
strategy,
}
}
fn record_event(&self, event: SyncEvent) {
let mut events = self.sync_events.lock().unwrap();
events.entry(event.device_id).or_default().push(event);
}
fn synchronize_streams(&self) -> Result<()> {
debug!("Synchronizing streams using events");
let events = self.sync_events.lock().unwrap();
for device_events in events.values() {
for event in device_events {
if !event.completed {
std::thread::sleep(std::time::Duration::from_micros(1));
}
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct AccessPattern {
pub allocations: Vec<AllocationRequest>,
pub transfer_patterns: Vec<TransferPattern>,
}
#[derive(Debug)]
pub struct AllocationRequest {
pub id: usize,
pub size: usize,
pub data_type: DataType,
pub preferred_devices: Vec<usize>,
pub access_frequency: f32,
}
#[derive(Debug)]
pub struct TransferPattern {
pub src_allocation: usize,
pub dst_device: usize,
pub frequency: f32,
pub size: usize,
}
#[derive(Debug)]
pub struct MemoryLayout {
pub assignments: HashMap<usize, usize>, }
impl MemoryLayout {
fn new() -> Self {
Self {
assignments: HashMap::new(),
}
}
}