use super::memory_tracing::{
AllocationInfo, GpuMemoryTracker, MemoryEvent, MemoryReport, MemoryStats, MemoryTracingConfig,
};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
pub struct FragmentationAnalysis {
pub total_free: usize,
pub largest_free_block: usize,
pub free_block_count: usize,
pub fragmentation_ratio: f64,
pub average_free_block_size: usize,
pub wasted_memory: usize,
}
impl FragmentationAnalysis {
pub fn analyze(free_blocks: &[usize]) -> Self {
if free_blocks.is_empty() {
return Self {
total_free: 0,
largest_free_block: 0,
free_block_count: 0,
fragmentation_ratio: 0.0,
average_free_block_size: 0,
wasted_memory: 0,
};
}
let total_free: usize = free_blocks.iter().sum();
let largest_free_block = *free_blocks.iter().max().unwrap_or(&0);
let free_block_count = free_blocks.len();
let average_free_block_size = total_free.checked_div(free_block_count).unwrap_or(0);
let fragmentation_ratio = if total_free > 0 && free_block_count > 0 {
let scattered_ratio = 1.0 - (largest_free_block as f64 / total_free as f64);
let small_block_count = free_blocks
.iter()
.filter(|&&size| size < largest_free_block / 20)
.count() as f64;
let block_factor = (small_block_count / free_block_count as f64).max(0.0);
(scattered_ratio * 0.3 + block_factor * 0.7).min(1.0)
} else {
0.0
};
let wasted_memory = total_free.saturating_sub(largest_free_block);
Self {
total_free,
largest_free_block,
free_block_count,
fragmentation_ratio,
average_free_block_size,
wasted_memory,
}
}
pub fn is_severe(&self) -> bool {
self.fragmentation_ratio > 0.5
}
pub fn severity_level(&self) -> &'static str {
if self.fragmentation_ratio < 0.3 {
"Low"
} else if self.fragmentation_ratio < 0.5 {
"Moderate"
} else {
"Severe"
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
pub struct LeakDetectionResult {
pub suspected_leaks: Vec<AllocationInfo>,
pub total_leaked_bytes: usize,
pub leak_count: usize,
pub leaks_by_operation: HashMap<String, Vec<AllocationInfo>>,
pub confidence: f64,
}
impl LeakDetectionResult {
pub fn from_allocations(allocations: Vec<AllocationInfo>) -> Self {
let total_leaked_bytes = allocations.iter().map(|a| a.size).sum();
let leak_count = allocations.len();
let mut leaks_by_operation: HashMap<String, Vec<AllocationInfo>> = HashMap::new();
for alloc in &allocations {
leaks_by_operation
.entry(alloc.operation.clone())
.or_insert_with(Vec::new)
.push(alloc.clone());
}
let confidence = if leak_count > 10 { 0.9 } else { 0.5 };
Self {
suspected_leaks: allocations,
total_leaked_bytes,
leak_count,
leaks_by_operation,
confidence,
}
}
pub fn has_leaks(&self) -> bool {
!self.suspected_leaks.is_empty()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
pub struct OperationProfile {
pub operation: String,
pub total_allocated: usize,
pub peak_usage: usize,
pub allocation_count: usize,
pub average_size: usize,
#[cfg_attr(feature = "serialize", serde(skip))]
pub total_time: Option<Duration>,
pub efficiency_score: f64,
}
impl OperationProfile {
pub fn new(operation: String, allocations: &[&AllocationInfo]) -> Self {
let total_allocated: usize = allocations.iter().map(|a| a.size).sum();
let allocation_count = allocations.len();
let average_size = total_allocated.checked_div(allocation_count).unwrap_or(0);
let peak_usage = total_allocated;
let efficiency_score = if allocation_count > 0 {
let size_variance = allocations
.iter()
.map(|a| (a.size as f64 - average_size as f64).abs())
.sum::<f64>()
/ allocation_count as f64;
let normalized_variance = (size_variance / average_size as f64).min(1.0);
1.0 - normalized_variance
} else {
1.0
};
Self {
operation,
total_allocated,
peak_usage,
allocation_count,
average_size,
total_time: None,
efficiency_score,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
pub struct DiagnosticReport {
#[cfg_attr(feature = "serialize", serde(skip))]
pub timestamp: Instant,
pub memory_stats: MemoryStats,
pub fragmentation: Option<FragmentationAnalysis>,
pub leak_detection: LeakDetectionResult,
pub operation_profiles: Vec<OperationProfile>,
pub top_consumers: Vec<(String, usize)>,
pub recommendations: Vec<String>,
}
impl DiagnosticReport {
pub fn print(&self) {
println!("\n╔══════════════════════════════════════════════════════╗");
println!("║ GPU Memory Diagnostic Report ║");
println!("╚══════════════════════════════════════════════════════╝");
println!("\n📊 Memory Statistics:");
println!(
" Current Usage: {:.2} MB",
self.memory_stats.total_allocated as f64 / 1_048_576.0
);
println!(
" Peak Usage: {:.2} MB",
self.memory_stats.peak_usage as f64 / 1_048_576.0
);
println!(
" Active Allocs: {}",
self.memory_stats.active_allocations
);
println!(
" Total Allocs: {}",
self.memory_stats.total_allocations_lifetime
);
println!(
" Total Frees: {}",
self.memory_stats.total_frees_lifetime
);
if let Some(ref frag) = self.fragmentation {
println!("\n🔍 Fragmentation Analysis:");
println!(
" Severity: {} ({:.1}%)",
frag.severity_level(),
frag.fragmentation_ratio * 100.0
);
println!(" Free Blocks: {}", frag.free_block_count);
println!(
" Largest Block: {:.2} MB",
frag.largest_free_block as f64 / 1_048_576.0
);
println!(
" Wasted Memory: {:.2} MB",
frag.wasted_memory as f64 / 1_048_576.0
);
}
if self.leak_detection.has_leaks() {
println!("\n⚠️ Memory Leak Detection:");
println!(" Suspected Leaks: {}", self.leak_detection.leak_count);
println!(
" Leaked Memory: {:.2} MB",
self.leak_detection.total_leaked_bytes as f64 / 1_048_576.0
);
println!(
" Confidence: {:.0}%",
self.leak_detection.confidence * 100.0
);
println!("\n Leaks by Operation:");
for (op, leaks) in &self.leak_detection.leaks_by_operation {
let total: usize = leaks.iter().map(|l| l.size).sum();
println!(
" {} - {:.2} MB ({} allocations)",
op,
total as f64 / 1_048_576.0,
leaks.len()
);
}
}
println!("\n🎯 Top Memory Consumers:");
for (i, (op, size)) in self.top_consumers.iter().enumerate().take(5) {
println!(
" {}. {} - {:.2} MB",
i + 1,
op,
*size as f64 / 1_048_576.0
);
}
if !self.operation_profiles.is_empty() {
println!("\n📈 Operation Profiles:");
for profile in self.operation_profiles.iter().take(5) {
println!(
" {} - {:.2} MB (efficiency: {:.0}%)",
profile.operation,
profile.total_allocated as f64 / 1_048_576.0,
profile.efficiency_score * 100.0
);
}
}
if !self.recommendations.is_empty() {
println!("\n💡 Recommendations:");
for (i, rec) in self.recommendations.iter().enumerate() {
println!(" {}. {}", i + 1, rec);
}
}
println!("\n═══════════════════════════════════════════════════════\n");
}
#[cfg(feature = "serialize")]
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
}
pub struct GpuMemoryDiagnostics {
tracker: Arc<Mutex<GpuMemoryTracker>>,
config: DiagnosticsConfig,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
pub struct DiagnosticsConfig {
#[cfg_attr(feature = "serialize", serde(skip))]
pub leak_detection_threshold: Duration,
pub auto_diagnostics: bool,
#[cfg_attr(feature = "serialize", serde(skip))]
pub diagnostics_interval: Duration,
pub analyze_fragmentation: bool,
pub enable_profiling: bool,
}
impl Default for DiagnosticsConfig {
fn default() -> Self {
Self {
leak_detection_threshold: Duration::from_secs(300), auto_diagnostics: false,
diagnostics_interval: Duration::from_secs(60),
analyze_fragmentation: true,
enable_profiling: true,
}
}
}
impl GpuMemoryDiagnostics {
pub fn new(tracker: Arc<Mutex<GpuMemoryTracker>>) -> Self {
Self::with_config(tracker, DiagnosticsConfig::default())
}
pub fn with_config(tracker: Arc<Mutex<GpuMemoryTracker>>, config: DiagnosticsConfig) -> Self {
Self { tracker, config }
}
pub fn run_diagnostics(&self) -> DiagnosticReport {
let tracker = self.tracker.lock().expect("lock should not be poisoned");
let memory_stats = tracker.global_stats().clone();
let suspected_leaks = tracker
.find_potential_leaks(self.config.leak_detection_threshold)
.into_iter()
.cloned()
.collect();
let leak_detection = LeakDetectionResult::from_allocations(suspected_leaks);
let operation_profiles = if self.config.enable_profiling {
self.profile_operations(&tracker)
} else {
Vec::new()
};
let usage_by_op = tracker.usage_by_operation();
let mut top_consumers: Vec<_> = usage_by_op.into_iter().collect();
top_consumers.sort_by_key(|item| std::cmp::Reverse(item.1));
top_consumers.truncate(10);
let recommendations =
self.generate_recommendations(&memory_stats, &leak_detection, &operation_profiles);
let fragmentation = if self.config.analyze_fragmentation {
let active = tracker.active_allocations();
let sizes: Vec<usize> = active.values().map(|a| a.size).collect();
if !sizes.is_empty() {
Some(FragmentationAnalysis::analyze(&sizes))
} else {
None
}
} else {
None
};
DiagnosticReport {
timestamp: Instant::now(),
memory_stats,
fragmentation,
leak_detection,
operation_profiles,
top_consumers,
recommendations,
}
}
fn profile_operations(&self, tracker: &GpuMemoryTracker) -> Vec<OperationProfile> {
let active = tracker.active_allocations();
let mut by_operation: HashMap<String, Vec<&AllocationInfo>> = HashMap::new();
for alloc in active.values() {
by_operation
.entry(alloc.operation.clone())
.or_insert_with(Vec::new)
.push(alloc);
}
let mut profiles: Vec<_> = by_operation
.into_iter()
.map(|(op, allocs)| OperationProfile::new(op, &allocs))
.collect();
profiles.sort_by_key(|item| std::cmp::Reverse(item.total_allocated));
profiles
}
fn generate_recommendations(
&self,
stats: &MemoryStats,
leak_detection: &LeakDetectionResult,
profiles: &[OperationProfile],
) -> Vec<String> {
let mut recommendations = Vec::new();
if stats.total_allocated > 1_073_741_824 {
recommendations.push(
"High memory usage detected. Consider reducing batch sizes or using gradient checkpointing.".to_string()
);
}
if leak_detection.has_leaks() && leak_detection.confidence > 0.7 {
recommendations.push(format!(
"Potential memory leaks detected in {} operations. Review deallocation logic.",
leak_detection.leaks_by_operation.len()
));
}
for profile in profiles.iter().take(3) {
if profile.efficiency_score < 0.5 {
recommendations.push(format!(
"Operation '{}' has low memory efficiency ({:.0}%). Consider batching or pooling.",
profile.operation,
profile.efficiency_score * 100.0
));
}
}
if stats.active_allocations > 1000 && stats.average_allocation_size < 1024 {
recommendations.push(
"High number of small allocations. Consider using memory pooling or buffer reuse."
.to_string(),
);
}
if recommendations.is_empty() {
recommendations
.push("Memory usage looks healthy. No major issues detected.".to_string());
}
recommendations
}
pub fn check_for_leaks(&self) -> LeakDetectionResult {
let tracker = self.tracker.lock().expect("lock should not be poisoned");
let suspected_leaks = tracker
.find_potential_leaks(self.config.leak_detection_threshold)
.into_iter()
.cloned()
.collect();
LeakDetectionResult::from_allocations(suspected_leaks)
}
pub fn current_usage(&self) -> usize {
self.tracker
.lock()
.expect("lock should not be poisoned")
.current_usage()
}
pub fn peak_usage(&self) -> usize {
self.tracker
.lock()
.expect("lock should not be poisoned")
.peak_usage()
}
pub fn reset(&self) {
self.tracker
.lock()
.expect("lock should not be poisoned")
.reset();
}
}
lazy_static::lazy_static! {
pub static ref GLOBAL_GPU_DIAGNOSTICS: GpuMemoryDiagnostics = {
GpuMemoryDiagnostics::new(
Arc::clone(&super::memory_tracing::GLOBAL_GPU_MEMORY_TRACKER)
)
};
}
pub fn run_gpu_diagnostics() -> DiagnosticReport {
GLOBAL_GPU_DIAGNOSTICS.run_diagnostics()
}
pub fn print_gpu_diagnostics() {
run_gpu_diagnostics().print();
}
pub fn check_gpu_memory_leaks() -> LeakDetectionResult {
GLOBAL_GPU_DIAGNOSTICS.check_for_leaks()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fragmentation_analysis() {
let free_blocks = vec![1000, 500, 250, 125];
let analysis = FragmentationAnalysis::analyze(&free_blocks);
assert_eq!(analysis.total_free, 1875);
assert_eq!(analysis.largest_free_block, 1000);
assert_eq!(analysis.free_block_count, 4);
assert!(analysis.fragmentation_ratio > 0.0);
assert!(analysis.fragmentation_ratio < 1.0);
}
#[test]
fn test_fragmentation_severity() {
let low_frag = vec![1000, 900, 800];
let analysis = FragmentationAnalysis::analyze(&low_frag);
assert!(!analysis.is_severe());
let high_frag = vec![1000, 10, 10, 10, 10];
let analysis = FragmentationAnalysis::analyze(&high_frag);
assert!(analysis.is_severe());
}
#[test]
fn test_leak_detection_result() {
let alloc1 = AllocationInfo::new(1, 1024, 0, "op1".to_string());
let alloc2 = AllocationInfo::new(2, 2048, 0, "op2".to_string());
let result = LeakDetectionResult::from_allocations(vec![alloc1, alloc2]);
assert_eq!(result.leak_count, 2);
assert_eq!(result.total_leaked_bytes, 3072);
assert!(result.has_leaks());
}
#[test]
fn test_diagnostics_engine() {
let tracker = Arc::new(Mutex::new(GpuMemoryTracker::new()));
let diagnostics = GpuMemoryDiagnostics::new(tracker.clone());
{
let mut t = tracker.lock().expect("lock should not be poisoned");
t.record_allocation(1024, 0, "test_op".to_string());
t.record_allocation(2048, 0, "test_op".to_string());
}
let report = diagnostics.run_diagnostics();
assert!(report.memory_stats.active_allocations > 0);
assert!(!report.recommendations.is_empty());
}
}