use crate::error::{RusTorchError, RusTorchResult};
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant, SystemTime};
#[derive(Debug, Clone)]
pub struct AllocationRecord {
pub id: u64,
pub size: usize,
pub allocated_at: SystemTime,
pub source_location: String,
pub call_stack: Vec<String>,
pub deallocated_at: Option<SystemTime>,
pub lifetime: Option<Duration>,
pub pattern: AllocationPattern,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AllocationPattern {
ShortLived,
MediumLived,
LongLived,
Leaked,
Cyclic,
}
impl AllocationPattern {
pub fn classify(lifetime: Option<Duration>, allocated_at: SystemTime) -> Self {
match lifetime {
Some(duration) => {
if duration < Duration::from_secs(1) {
AllocationPattern::ShortLived
} else if duration < Duration::from_secs(60) {
AllocationPattern::MediumLived
} else {
AllocationPattern::LongLived
}
}
None => {
let age = SystemTime::now()
.duration_since(allocated_at)
.unwrap_or(Duration::from_secs(0));
if age > Duration::from_secs(300) {
AllocationPattern::Leaked
} else {
AllocationPattern::MediumLived
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryHotspot {
pub location: String,
pub total_allocations: usize,
pub total_memory: usize,
pub avg_size: f64,
pub peak_concurrent: usize,
pub leak_count: usize,
pub frequency: f64,
}
#[derive(Debug, Clone)]
pub struct MemoryReport {
pub generated_at: SystemTime,
pub total_allocations: usize,
pub total_deallocations: usize,
pub active_allocations: usize,
pub total_allocated_bytes: usize,
pub current_memory_usage: usize,
pub peak_memory_usage: usize,
pub avg_allocation_size: f64,
pub leak_stats: LeakStats,
pub pattern_distribution: HashMap<AllocationPattern, usize>,
pub hotspots: Vec<MemoryHotspot>,
pub fragmentation_analysis: FragmentationAnalysis,
}
#[derive(Debug, Clone)]
pub struct LeakStats {
pub potential_leaks: usize,
pub leaked_bytes: usize,
pub oldest_leak_age: Option<Duration>,
pub leak_rate: f64,
}
#[derive(Debug, Clone)]
pub struct FragmentationAnalysis {
pub fragmentation_ratio: f64,
pub pool_count: usize,
pub avg_pool_utilization: f64,
pub wasted_memory: usize,
}
#[derive(Clone, Debug)]
pub struct AnalyticsConfig {
pub max_records: usize,
pub enable_stack_trace: bool,
pub stack_trace_depth: usize,
pub leak_threshold: Duration,
pub report_interval: Duration,
pub enable_hotspot_analysis: bool,
pub hotspot_threshold: usize,
}
impl Default for AnalyticsConfig {
fn default() -> Self {
Self {
max_records: 100_000,
enable_stack_trace: true,
stack_trace_depth: 10,
leak_threshold: Duration::from_secs(300), report_interval: Duration::from_secs(60), enable_hotspot_analysis: true,
hotspot_threshold: 10,
}
}
}
pub struct MemoryAnalytics {
config: AnalyticsConfig,
records: RwLock<HashMap<u64, AllocationRecord>>,
next_id: Mutex<u64>,
current_usage: RwLock<usize>,
peak_usage: RwLock<usize>,
stats: RwLock<AnalyticsStats>,
analysis_thread: Mutex<Option<thread::JoinHandle<()>>>,
running: Arc<RwLock<bool>>,
}
#[derive(Debug, Clone)]
pub struct AnalyticsStats {
pub total_allocations: usize,
pub total_deallocations: usize,
pub reports_generated: usize,
pub last_report_time: Option<SystemTime>,
pub analysis_overhead: Duration,
}
impl Default for AnalyticsStats {
fn default() -> Self {
Self {
total_allocations: 0,
total_deallocations: 0,
reports_generated: 0,
last_report_time: None,
analysis_overhead: Duration::from_millis(0),
}
}
}
impl MemoryAnalytics {
pub fn new(config: AnalyticsConfig) -> Self {
Self {
config,
records: RwLock::new(HashMap::new()),
next_id: Mutex::new(1),
current_usage: RwLock::new(0),
peak_usage: RwLock::new(0),
stats: RwLock::new(AnalyticsStats::default()),
analysis_thread: Mutex::new(None),
running: Arc::new(RwLock::new(false)),
}
}
pub fn record_allocation(&self, size: usize, source_location: String) -> RusTorchResult<u64> {
let start_time = Instant::now();
let id = {
let mut next_id = self
.next_id
.lock()
.map_err(|_| RusTorchError::MemoryError("Failed to acquire ID lock".to_string()))?;
let id = *next_id;
*next_id += 1;
id
};
let mut call_stack = Vec::new();
if self.config.enable_stack_trace {
call_stack.push("tensor::core::Tensor::new".to_string());
call_stack.push("main".to_string());
}
let record = AllocationRecord {
id,
size,
allocated_at: SystemTime::now(),
source_location,
call_stack,
deallocated_at: None,
lifetime: None,
pattern: AllocationPattern::MediumLived, };
{
let mut current = self.current_usage.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire usage write lock".to_string())
})?;
*current += size;
let mut peak = self.peak_usage.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire peak write lock".to_string())
})?;
if *current > *peak {
*peak = *current;
}
}
{
let mut records = self.records.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire records write lock".to_string())
})?;
records.insert(id, record);
if records.len() > self.config.max_records {
Self::cleanup_old_records(&mut records, self.config.max_records / 2);
}
}
{
let mut stats = self.stats.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats write lock".to_string())
})?;
stats.total_allocations += 1;
stats.analysis_overhead += start_time.elapsed();
}
Ok(id)
}
pub fn record_deallocation(&self, id: u64) -> RusTorchResult<()> {
let start_time = Instant::now();
let dealloc_time = SystemTime::now();
let size = {
let mut records = self.records.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire records write lock".to_string())
})?;
if let Some(record) = records.get_mut(&id) {
record.deallocated_at = Some(dealloc_time);
record.lifetime = dealloc_time.duration_since(record.allocated_at).ok();
record.pattern = AllocationPattern::classify(record.lifetime, record.allocated_at);
record.size
} else {
return Err(RusTorchError::MemoryError(format!(
"Allocation ID {} not found",
id
)));
}
};
{
let mut current = self.current_usage.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire usage write lock".to_string())
})?;
*current = current.saturating_sub(size);
}
{
let mut stats = self.stats.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats write lock".to_string())
})?;
stats.total_deallocations += 1;
stats.analysis_overhead += start_time.elapsed();
}
Ok(())
}
pub fn generate_report(&self) -> RusTorchResult<MemoryReport> {
let start_time = Instant::now();
let records = self.records.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire records read lock".to_string())
})?;
let current_usage = *self.current_usage.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire usage read lock".to_string())
})?;
let peak_usage = *self.peak_usage.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire peak read lock".to_string())
})?;
let stats = self.stats.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats read lock".to_string())
})?;
let mut pattern_distribution = HashMap::new();
let mut location_stats = HashMap::new();
let mut total_allocated = 0;
let mut active_count = 0;
let mut leaked_count = 0;
let mut leaked_bytes = 0;
let mut oldest_leak_age = None;
for record in records.values() {
*pattern_distribution
.entry(record.pattern.clone())
.or_insert(0) += 1;
let location_stat = location_stats
.entry(record.source_location.clone())
.or_insert((0, 0, 0)); location_stat.0 += 1;
location_stat.1 += record.size;
total_allocated += record.size;
if record.deallocated_at.is_none() {
active_count += 1;
let age = SystemTime::now()
.duration_since(record.allocated_at)
.unwrap_or(Duration::from_secs(0));
if age > self.config.leak_threshold {
leaked_count += 1;
leaked_bytes += record.size;
location_stat.2 += 1;
if oldest_leak_age.is_none() || age > oldest_leak_age.unwrap() {
oldest_leak_age = Some(age);
}
}
}
}
let mut hotspots = Vec::new();
if self.config.enable_hotspot_analysis {
for (location, (count, total_size, leak_count)) in location_stats {
if count >= self.config.hotspot_threshold {
hotspots.push(MemoryHotspot {
location,
total_allocations: count,
total_memory: total_size,
avg_size: total_size as f64 / count as f64,
peak_concurrent: count, leak_count,
frequency: count as f64 / 3600.0, });
}
}
hotspots.sort_by(|a, b| b.total_memory.cmp(&a.total_memory));
}
let leak_stats = LeakStats {
potential_leaks: leaked_count,
leaked_bytes,
oldest_leak_age,
leak_rate: leaked_count as f64 / 3600.0, };
let fragmentation_analysis = FragmentationAnalysis {
fragmentation_ratio: 0.1, pool_count: 8, avg_pool_utilization: 0.75, wasted_memory: total_allocated / 20, };
let avg_allocation_size = if stats.total_allocations > 0 {
total_allocated as f64 / stats.total_allocations as f64
} else {
0.0
};
let report = MemoryReport {
generated_at: SystemTime::now(),
total_allocations: stats.total_allocations,
total_deallocations: stats.total_deallocations,
active_allocations: active_count,
total_allocated_bytes: total_allocated,
current_memory_usage: current_usage,
peak_memory_usage: peak_usage,
avg_allocation_size,
leak_stats,
pattern_distribution,
hotspots,
fragmentation_analysis,
};
drop(stats);
{
let mut stats = self.stats.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats write lock".to_string())
})?;
stats.reports_generated += 1;
stats.last_report_time = Some(SystemTime::now());
stats.analysis_overhead += start_time.elapsed();
}
Ok(report)
}
pub fn start_analysis(&self) -> RusTorchResult<()> {
let mut running = self.running.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire running write lock".to_string())
})?;
if *running {
return Err(RusTorchError::MemoryError(
"Analysis already running".to_string(),
));
}
*running = true;
Ok(())
}
pub fn stop_analysis(&self) -> RusTorchResult<()> {
let mut running = self.running.write().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire running write lock".to_string())
})?;
*running = false;
Ok(())
}
pub fn get_stats(&self) -> RusTorchResult<AnalyticsStats> {
let stats = self.stats.read().map_err(|_| {
RusTorchError::MemoryError("Failed to acquire stats read lock".to_string())
})?;
Ok(stats.clone())
}
fn cleanup_old_records(records: &mut HashMap<u64, AllocationRecord>, target_size: usize) {
if records.len() <= target_size {
return;
}
let mut to_remove = Vec::new();
let mut deallocated_records: Vec<_> = records
.iter()
.filter(|(_, record)| record.deallocated_at.is_some())
.collect();
deallocated_records.sort_by_key(|(_, record)| record.allocated_at);
let remove_count = records.len() - target_size;
for (id, _) in deallocated_records.into_iter().take(remove_count) {
to_remove.push(*id);
}
for id in to_remove {
records.remove(&id);
}
}
}
impl std::fmt::Display for MemoryReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Memory Analytics Report")?;
writeln!(f, "======================")?;
writeln!(f, "Generated: {:?}", self.generated_at)?;
writeln!(f, "")?;
writeln!(f, "Allocation Summary:")?;
writeln!(f, " Total Allocations: {}", self.total_allocations)?;
writeln!(f, " Total Deallocations: {}", self.total_deallocations)?;
writeln!(f, " Active Allocations: {}", self.active_allocations)?;
writeln!(f, " Current Usage: {} bytes", self.current_memory_usage)?;
writeln!(f, " Peak Usage: {} bytes", self.peak_memory_usage)?;
writeln!(f, " Average Size: {:.2} bytes", self.avg_allocation_size)?;
writeln!(f, "")?;
writeln!(f, "Leak Detection:")?;
writeln!(f, " Potential Leaks: {}", self.leak_stats.potential_leaks)?;
writeln!(f, " Leaked Memory: {} bytes", self.leak_stats.leaked_bytes)?;
writeln!(f, " Leak Rate: {:.2}/hour", self.leak_stats.leak_rate)?;
writeln!(f, "")?;
writeln!(f, "Memory Hotspots:")?;
for (i, hotspot) in self.hotspots.iter().take(5).enumerate() {
writeln!(
f,
" {}. {} ({} allocs, {} bytes)",
i + 1,
hotspot.location,
hotspot.total_allocations,
hotspot.total_memory
)?;
}
writeln!(f, "")?;
writeln!(f, "Fragmentation:")?;
writeln!(
f,
" Fragmentation Ratio: {:.2}%",
self.fragmentation_analysis.fragmentation_ratio * 100.0
)?;
writeln!(
f,
" Pool Utilization: {:.2}%",
self.fragmentation_analysis.avg_pool_utilization * 100.0
)?;
writeln!(
f,
" Wasted Memory: {} bytes",
self.fragmentation_analysis.wasted_memory
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_pattern_classification() {
let now = SystemTime::now();
assert_eq!(
AllocationPattern::classify(Some(Duration::from_millis(500)), now),
AllocationPattern::ShortLived
);
assert_eq!(
AllocationPattern::classify(Some(Duration::from_secs(30)), now),
AllocationPattern::MediumLived
);
assert_eq!(
AllocationPattern::classify(Some(Duration::from_secs(120)), now),
AllocationPattern::LongLived
);
}
#[test]
fn test_analytics_creation() {
let config = AnalyticsConfig::default();
let analytics = MemoryAnalytics::new(config);
let stats = analytics.get_stats().unwrap();
assert_eq!(stats.total_allocations, 0);
}
#[test]
fn test_allocation_tracking() {
let config = AnalyticsConfig::default();
let analytics = MemoryAnalytics::new(config);
let id = analytics
.record_allocation(1024, "test.rs:10".to_string())
.unwrap();
assert!(id > 0);
analytics.record_deallocation(id).unwrap();
let stats = analytics.get_stats().unwrap();
assert_eq!(stats.total_allocations, 1);
assert_eq!(stats.total_deallocations, 1);
}
#[test]
fn test_report_generation() {
let config = AnalyticsConfig::default();
let analytics = MemoryAnalytics::new(config);
let id1 = analytics
.record_allocation(1024, "test.rs:10".to_string())
.unwrap();
let _id2 = analytics
.record_allocation(2048, "test.rs:20".to_string())
.unwrap();
analytics.record_deallocation(id1).unwrap();
let report = analytics.generate_report().unwrap();
assert_eq!(report.total_allocations, 2);
assert_eq!(report.total_deallocations, 1);
assert_eq!(report.active_allocations, 1);
assert!(report.current_memory_usage > 0);
}
}