use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub type AllocationId = u64;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct AllocationInfo {
pub id: AllocationId,
pub size: usize,
pub device_id: usize,
#[cfg_attr(feature = "serialize", serde(skip, default = "Instant::now"))]
pub allocated_at: Instant,
pub stack_trace: Option<String>,
pub operation: String,
pub shape: Option<Vec<usize>>,
pub dtype: Option<String>,
pub is_active: bool,
#[cfg_attr(feature = "serialize", serde(skip))]
pub freed_at: Option<Instant>,
#[cfg_attr(feature = "serialize", serde(skip))]
pub lifetime: Option<Duration>,
pub tags: HashMap<String, String>,
}
impl AllocationInfo {
pub fn new(id: AllocationId, size: usize, device_id: usize, operation: String) -> Self {
Self {
id,
size,
device_id,
allocated_at: Instant::now(),
stack_trace: None,
operation,
shape: None,
dtype: None,
is_active: true,
freed_at: None,
lifetime: None,
tags: HashMap::new(),
}
}
pub fn with_shape(mut self, shape: Vec<usize>) -> Self {
self.shape = Some(shape);
self
}
pub fn with_dtype(mut self, dtype: String) -> Self {
self.dtype = Some(dtype);
self
}
pub fn with_tag(mut self, key: String, value: String) -> Self {
self.tags.insert(key, value);
self
}
pub fn mark_freed(&mut self) {
self.is_active = false;
self.freed_at = Some(Instant::now());
self.lifetime = Some(self.allocated_at.elapsed());
}
pub fn age(&self) -> Duration {
if let Some(freed_at) = self.freed_at {
freed_at.duration_since(self.allocated_at)
} else {
self.allocated_at.elapsed()
}
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct MemoryStats {
pub total_allocated: usize,
pub total_allocated_lifetime: usize,
pub total_freed_lifetime: usize,
pub active_allocations: usize,
pub total_allocations_lifetime: u64,
pub total_frees_lifetime: u64,
pub peak_usage: usize,
pub average_allocation_size: usize,
pub largest_allocation: usize,
pub smallest_allocation: Option<usize>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct DeviceMemoryStats {
pub device_id: usize,
pub stats: MemoryStats,
pub allocations_by_operation: HashMap<String, usize>,
pub memory_by_operation: HashMap<String, usize>,
}
#[derive(Debug, Clone)]
pub enum MemoryEvent {
Allocated {
id: AllocationId,
size: usize,
device_id: usize,
operation: String,
},
Freed {
id: AllocationId,
size: usize,
device_id: usize,
lifetime: Duration,
},
OutOfMemory {
device_id: usize,
requested_size: usize,
available_size: usize,
},
}
#[derive(Debug, Clone)]
pub struct MemoryTracingConfig {
pub enabled: bool,
pub capture_stack_traces: bool,
pub max_tracked_allocations: usize,
pub record_history: bool,
pub log_events: bool,
}
impl Default for MemoryTracingConfig {
fn default() -> Self {
Self {
enabled: true,
capture_stack_traces: false,
max_tracked_allocations: 100_000,
record_history: true,
log_events: false,
}
}
}
pub struct GpuMemoryTracker {
config: MemoryTracingConfig,
next_id: AllocationId,
active_allocations: HashMap<AllocationId, AllocationInfo>,
historical_allocations: Vec<AllocationInfo>,
events: Vec<MemoryEvent>,
device_stats: HashMap<usize, DeviceMemoryStats>,
global_stats: MemoryStats,
}
impl GpuMemoryTracker {
pub fn new() -> Self {
Self::with_config(MemoryTracingConfig::default())
}
pub fn with_config(config: MemoryTracingConfig) -> Self {
Self {
config,
next_id: 0,
active_allocations: HashMap::new(),
historical_allocations: Vec::new(),
events: Vec::new(),
device_stats: HashMap::new(),
global_stats: MemoryStats::default(),
}
}
pub fn record_allocation(
&mut self,
size: usize,
device_id: usize,
operation: String,
) -> AllocationId {
if !self.config.enabled {
return 0;
}
let id = self.next_id;
self.next_id += 1;
let info = AllocationInfo::new(id, size, device_id, operation.clone());
self.global_stats.total_allocated += size;
self.global_stats.total_allocated_lifetime += size;
self.global_stats.active_allocations += 1;
self.global_stats.total_allocations_lifetime += 1;
if self.global_stats.total_allocated > self.global_stats.peak_usage {
self.global_stats.peak_usage = self.global_stats.total_allocated;
}
if size > self.global_stats.largest_allocation {
self.global_stats.largest_allocation = size;
}
if let Some(smallest) = self.global_stats.smallest_allocation {
if size < smallest {
self.global_stats.smallest_allocation = Some(size);
}
} else {
self.global_stats.smallest_allocation = Some(size);
}
let device_stats =
self.device_stats
.entry(device_id)
.or_insert_with(|| DeviceMemoryStats {
device_id,
stats: MemoryStats::default(),
allocations_by_operation: HashMap::new(),
memory_by_operation: HashMap::new(),
});
device_stats.stats.total_allocated += size;
device_stats.stats.total_allocated_lifetime += size;
device_stats.stats.active_allocations += 1;
device_stats.stats.total_allocations_lifetime += 1;
*device_stats
.allocations_by_operation
.entry(operation.clone())
.or_insert(0) += 1;
*device_stats
.memory_by_operation
.entry(operation.clone())
.or_insert(0) += size;
if self.config.log_events {
self.events.push(MemoryEvent::Allocated {
id,
size,
device_id,
operation,
});
}
self.active_allocations.insert(id, info);
id
}
pub fn record_deallocation(&mut self, id: AllocationId) {
if !self.config.enabled {
return;
}
if let Some(mut info) = self.active_allocations.remove(&id) {
let size = info.size;
let device_id = info.device_id;
info.mark_freed();
self.global_stats.total_allocated -= size;
self.global_stats.total_freed_lifetime += size;
self.global_stats.active_allocations -= 1;
self.global_stats.total_frees_lifetime += 1;
if let Some(device_stats) = self.device_stats.get_mut(&device_id) {
device_stats.stats.total_allocated -= size;
device_stats.stats.total_freed_lifetime += size;
device_stats.stats.active_allocations -= 1;
device_stats.stats.total_frees_lifetime += 1;
}
if self.config.log_events {
self.events.push(MemoryEvent::Freed {
id,
size,
device_id,
lifetime: info.lifetime.unwrap_or(Duration::ZERO),
});
}
if self.config.record_history {
self.historical_allocations.push(info);
if self.historical_allocations.len() > self.config.max_tracked_allocations {
self.historical_allocations
.drain(0..self.config.max_tracked_allocations / 10);
}
}
}
}
pub fn record_oom(&mut self, device_id: usize, requested_size: usize, available_size: usize) {
if self.config.log_events {
self.events.push(MemoryEvent::OutOfMemory {
device_id,
requested_size,
available_size,
});
}
}
pub fn track_allocation(
&mut self,
size: usize,
device_id: usize,
operation: String,
shape: Option<Vec<usize>>,
dtype: Option<String>,
) -> AllocationId {
let id = self.record_allocation(size, device_id, operation);
if let Some(info) = self.active_allocations.get_mut(&id) {
if let Some(s) = shape {
info.shape = Some(s);
}
if let Some(dt) = dtype {
info.dtype = Some(dt);
}
}
id
}
pub fn track_free(&mut self, id: AllocationId) {
self.record_deallocation(id);
}
pub fn current_usage(&self) -> usize {
self.global_stats.total_allocated
}
pub fn peak_usage(&self) -> usize {
self.global_stats.peak_usage
}
pub fn global_stats(&self) -> &MemoryStats {
&self.global_stats
}
pub fn device_stats(&self, device_id: usize) -> Option<&DeviceMemoryStats> {
self.device_stats.get(&device_id)
}
pub fn all_device_stats(&self) -> &HashMap<usize, DeviceMemoryStats> {
&self.device_stats
}
pub fn active_allocations(&self) -> &HashMap<AllocationId, AllocationInfo> {
&self.active_allocations
}
pub fn events(&self) -> &[MemoryEvent] {
&self.events
}
pub fn find_potential_leaks(&self, age_threshold: Duration) -> Vec<&AllocationInfo> {
self.active_allocations
.values()
.filter(|info| info.age() > age_threshold)
.collect()
}
pub fn usage_by_operation(&self) -> HashMap<String, usize> {
let mut result = HashMap::new();
for info in self.active_allocations.values() {
*result.entry(info.operation.clone()).or_insert(0) += info.size;
}
result
}
pub fn generate_report(&self) -> MemoryReport {
let mut allocations_by_size: Vec<_> = self.active_allocations.values().collect();
allocations_by_size.sort_by_key(|item| std::cmp::Reverse(item.size));
let top_allocations: Vec<_> = allocations_by_size.into_iter().take(10).cloned().collect();
MemoryReport {
global_stats: self.global_stats.clone(),
device_stats: self.device_stats.clone(),
top_allocations,
usage_by_operation: self.usage_by_operation(),
potential_leaks: self
.find_potential_leaks(Duration::from_secs(300))
.into_iter()
.cloned()
.collect(),
}
}
pub fn reset(&mut self) {
self.active_allocations.clear();
self.historical_allocations.clear();
self.events.clear();
self.device_stats.clear();
self.global_stats = MemoryStats::default();
self.next_id = 0;
}
}
#[derive(Debug, Clone)]
pub struct MemoryReport {
pub global_stats: MemoryStats,
pub device_stats: HashMap<usize, DeviceMemoryStats>,
pub top_allocations: Vec<AllocationInfo>,
pub usage_by_operation: HashMap<String, usize>,
pub potential_leaks: Vec<AllocationInfo>,
}
impl MemoryReport {
pub fn print(&self) {
println!("=== GPU Memory Usage Report ===");
println!("\nGlobal Statistics:");
println!(
" Current Allocation: {:.2} MB",
self.global_stats.total_allocated as f64 / 1_048_576.0
);
println!(
" Peak Usage: {:.2} MB",
self.global_stats.peak_usage as f64 / 1_048_576.0
);
println!(
" Active Allocations: {}",
self.global_stats.active_allocations
);
println!(
" Total Allocations: {}",
self.global_stats.total_allocations_lifetime
);
println!(
" Total Frees: {}",
self.global_stats.total_frees_lifetime
);
println!("\nTop 10 Allocations:");
for (i, alloc) in self.top_allocations.iter().enumerate() {
println!(
" {}: {:.2} MB - {} (age: {:.2}s)",
i + 1,
alloc.size as f64 / 1_048_576.0,
alloc.operation,
alloc.age().as_secs_f64()
);
}
println!("\nMemory by Operation:");
let mut ops: Vec<_> = self.usage_by_operation.iter().collect();
ops.sort_by(|a, b| b.1.cmp(a.1));
for (op, size) in ops.iter().take(10) {
println!(" {}: {:.2} MB", op, **size as f64 / 1_048_576.0);
}
if !self.potential_leaks.is_empty() {
println!("\n⚠️ Potential Memory Leaks:");
for leak in &self.potential_leaks {
println!(
" {} - {:.2} MB (age: {:.2}s)",
leak.operation,
leak.size as f64 / 1_048_576.0,
leak.age().as_secs_f64()
);
}
}
println!("\n=============================");
}
}
lazy_static::lazy_static! {
pub static ref GLOBAL_GPU_MEMORY_TRACKER: Arc<Mutex<GpuMemoryTracker>> = {
Arc::new(Mutex::new(GpuMemoryTracker::new()))
};
}
pub fn record_gpu_allocation(size: usize, device_id: usize, operation: String) -> AllocationId {
GLOBAL_GPU_MEMORY_TRACKER
.lock()
.expect("GPU memory tracker mutex poisoned")
.record_allocation(size, device_id, operation)
}
pub fn record_gpu_deallocation(id: AllocationId) {
GLOBAL_GPU_MEMORY_TRACKER
.lock()
.expect("GPU memory tracker mutex poisoned")
.record_deallocation(id);
}
pub fn current_gpu_memory_usage() -> usize {
GLOBAL_GPU_MEMORY_TRACKER
.lock()
.expect("lock should not be poisoned")
.current_usage()
}
pub fn peak_gpu_memory_usage() -> usize {
GLOBAL_GPU_MEMORY_TRACKER
.lock()
.expect("lock should not be poisoned")
.peak_usage()
}
pub fn generate_gpu_memory_report() -> MemoryReport {
GLOBAL_GPU_MEMORY_TRACKER
.lock()
.expect("lock should not be poisoned")
.generate_report()
}
pub fn print_gpu_memory_report() {
generate_gpu_memory_report().print();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_tracking() {
let mut tracker = GpuMemoryTracker::new();
let id1 = tracker.record_allocation(1024, 0, "test_op1".to_string());
assert_eq!(tracker.current_usage(), 1024);
assert_eq!(tracker.global_stats().active_allocations, 1);
let id2 = tracker.record_allocation(2048, 0, "test_op2".to_string());
assert_eq!(tracker.current_usage(), 3072);
assert_eq!(tracker.global_stats().active_allocations, 2);
tracker.record_deallocation(id1);
assert_eq!(tracker.current_usage(), 2048);
assert_eq!(tracker.global_stats().active_allocations, 1);
tracker.record_deallocation(id2);
assert_eq!(tracker.current_usage(), 0);
assert_eq!(tracker.global_stats().active_allocations, 0);
}
#[test]
fn test_peak_tracking() {
let mut tracker = GpuMemoryTracker::new();
let id1 = tracker.record_allocation(1024, 0, "test".to_string());
let id2 = tracker.record_allocation(2048, 0, "test".to_string());
assert_eq!(tracker.peak_usage(), 3072);
tracker.record_deallocation(id1);
tracker.record_deallocation(id2);
assert_eq!(tracker.peak_usage(), 3072); assert_eq!(tracker.current_usage(), 0);
}
#[test]
fn test_usage_by_operation() {
let mut tracker = GpuMemoryTracker::new();
tracker.record_allocation(1024, 0, "op_a".to_string());
tracker.record_allocation(2048, 0, "op_a".to_string());
tracker.record_allocation(512, 0, "op_b".to_string());
let usage = tracker.usage_by_operation();
assert_eq!(usage.get("op_a"), Some(&3072));
assert_eq!(usage.get("op_b"), Some(&512));
}
}