use crate::Device;
use std::time::Instant;
#[cfg(not(feature = "std"))]
use alloc::{string::String, vec::Vec};
#[derive(Debug, Clone)]
pub struct MemoryAllocation {
pub ptr: usize,
pub size: usize,
pub allocated_at: Instant,
pub source: AllocationSource,
pub memory_type: MemoryType,
pub device: Option<Device>,
pub usage_stats: AllocationUsageStats,
pub lifetime_events: Vec<LifetimeEvent>,
pub performance_hints: Vec<PerformanceHint>,
}
impl MemoryAllocation {
pub fn new(
ptr: usize,
size: usize,
source: AllocationSource,
memory_type: MemoryType,
device: Option<Device>,
) -> Self {
Self {
ptr,
size,
allocated_at: Instant::now(),
source,
memory_type,
device,
usage_stats: AllocationUsageStats::default(),
lifetime_events: Vec::new(),
performance_hints: Vec::new(),
}
}
pub fn record_access(&mut self, read: bool, write: bool, bytes: u64) {
self.usage_stats.access_count += 1;
self.usage_stats.last_accessed = Some(Instant::now());
if read {
self.usage_stats.bytes_read += bytes;
}
if write {
self.usage_stats.bytes_written += bytes;
}
self.lifetime_events.push(LifetimeEvent {
timestamp: Instant::now(),
event_type: LifetimeEventType::Accessed { read, write },
details: format!("Access: {} bytes, read: {}, write: {}", bytes, read, write),
});
self.update_access_frequency();
}
fn update_access_frequency(&mut self) {
let elapsed = self.allocated_at.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.usage_stats.access_frequency = self.usage_stats.access_count as f64 / elapsed;
}
}
pub fn add_performance_hint(&mut self, hint: PerformanceHint) {
if !self.performance_hints.iter().any(|h| std::mem::discriminant(&h.hint_type) == std::mem::discriminant(&hint.hint_type)) {
self.performance_hints.push(hint);
}
}
pub fn record_lifetime_event(&mut self, event_type: LifetimeEventType, details: String) {
self.lifetime_events.push(LifetimeEvent {
timestamp: Instant::now(),
event_type,
details,
});
}
pub fn age(&self) -> std::time::Duration {
self.allocated_at.elapsed()
}
pub fn throughput(&self) -> f64 {
let elapsed = self.allocated_at.elapsed().as_secs_f64();
if elapsed > 0.0 {
(self.usage_stats.bytes_read + self.usage_stats.bytes_written) as f64 / elapsed
} else {
0.0
}
}
pub fn is_likely_unused(&self) -> bool {
if let Some(last_access) = self.usage_stats.last_accessed {
last_access.elapsed().as_secs() > 60 && self.age().as_secs() > 10
} else {
self.age().as_secs() > 10
}
}
}
#[derive(Debug, Clone)]
pub struct AllocationSource {
pub function: String,
pub location: Option<(String, u32)>,
pub stack_depth: usize,
pub thread_id: u64,
pub context: AllocationContext,
}
impl AllocationSource {
pub fn new(
function: String,
location: Option<(String, u32)>,
context: AllocationContext,
) -> Self {
Self {
function,
location,
stack_depth: 0, thread_id: std::thread::current().id().as_u64().get(),
context,
}
}
pub fn description(&self) -> String {
let location_str = if let Some((file, line)) = &self.location {
format!(" ({}:{})", file, line)
} else {
String::new()
};
format!("{}{} - {}", self.function, location_str, self.context.description())
}
}
#[derive(Debug, Clone)]
pub enum AllocationContext {
TensorOperation {
operation_name: String,
tensor_shape: Vec<usize>,
data_type: String,
},
KernelScratch {
kernel_name: String,
scratch_type: String,
},
IntermediateBuffer {
computation_graph_id: String,
buffer_purpose: String,
},
ModelParameters {
model_name: String,
parameter_name: String,
},
UserAllocation { request_id: String },
InternalAllocation { purpose: String },
CacheAllocation {
cache_type: String,
cache_level: usize,
},
PoolAllocation {
pool_name: String,
pool_type: String,
},
}
impl AllocationContext {
pub fn description(&self) -> String {
match self {
AllocationContext::TensorOperation { operation_name, tensor_shape, data_type } => {
format!("Tensor {} {:?} ({})", operation_name, tensor_shape, data_type)
},
AllocationContext::KernelScratch { kernel_name, scratch_type } => {
format!("Kernel {} scratch ({})", kernel_name, scratch_type)
},
AllocationContext::IntermediateBuffer { computation_graph_id, buffer_purpose } => {
format!("Buffer {} ({})", computation_graph_id, buffer_purpose)
},
AllocationContext::ModelParameters { model_name, parameter_name } => {
format!("Model {} parameter {}", model_name, parameter_name)
},
AllocationContext::UserAllocation { request_id } => {
format!("User allocation {}", request_id)
},
AllocationContext::InternalAllocation { purpose } => {
format!("Internal: {}", purpose)
},
AllocationContext::CacheAllocation { cache_type, cache_level } => {
format!("Cache L{} ({})", cache_level, cache_type)
},
AllocationContext::PoolAllocation { pool_name, pool_type } => {
format!("Pool {} ({})", pool_name, pool_type)
},
}
}
pub fn is_temporary(&self) -> bool {
matches!(self,
AllocationContext::KernelScratch { .. } |
AllocationContext::IntermediateBuffer { .. } |
AllocationContext::CacheAllocation { .. }
)
}
pub fn lifetime_category(&self) -> LifetimeCategory {
match self {
AllocationContext::TensorOperation { .. } => LifetimeCategory::Short,
AllocationContext::KernelScratch { .. } => LifetimeCategory::VeryShort,
AllocationContext::IntermediateBuffer { .. } => LifetimeCategory::Short,
AllocationContext::ModelParameters { .. } => LifetimeCategory::Long,
AllocationContext::UserAllocation { .. } => LifetimeCategory::Medium,
AllocationContext::InternalAllocation { .. } => LifetimeCategory::Medium,
AllocationContext::CacheAllocation { .. } => LifetimeCategory::Medium,
AllocationContext::PoolAllocation { .. } => LifetimeCategory::Long,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LifetimeCategory {
VeryShort, Short, Medium, Long, }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryType {
Device,
Host,
Unified,
Pinned,
Texture,
Constant,
Shared,
MemoryMapped,
}
impl MemoryType {
pub fn bandwidth_characteristics(&self) -> BandwidthCharacteristics {
match self {
MemoryType::Device => BandwidthCharacteristics {
peak_bandwidth_gbps: 1000.0, typical_latency_ns: 100.0,
access_granularity: 128, },
MemoryType::Host => BandwidthCharacteristics {
peak_bandwidth_gbps: 100.0, typical_latency_ns: 50.0,
access_granularity: 64,
},
MemoryType::Unified => BandwidthCharacteristics {
peak_bandwidth_gbps: 200.0, typical_latency_ns: 150.0,
access_granularity: 128,
},
MemoryType::Pinned => BandwidthCharacteristics {
peak_bandwidth_gbps: 100.0,
typical_latency_ns: 40.0, access_granularity: 64,
},
MemoryType::Texture => BandwidthCharacteristics {
peak_bandwidth_gbps: 800.0, typical_latency_ns: 200.0,
access_granularity: 16,
},
MemoryType::Constant => BandwidthCharacteristics {
peak_bandwidth_gbps: 200.0, typical_latency_ns: 80.0,
access_granularity: 4,
},
MemoryType::Shared => BandwidthCharacteristics {
peak_bandwidth_gbps: 2000.0, typical_latency_ns: 20.0,
access_granularity: 4,
},
MemoryType::MemoryMapped => BandwidthCharacteristics {
peak_bandwidth_gbps: 10.0, typical_latency_ns: 1000.0,
access_granularity: 4096, },
}
}
pub fn is_device_accessible(&self) -> bool {
matches!(self,
MemoryType::Device |
MemoryType::Unified |
MemoryType::Pinned |
MemoryType::Texture |
MemoryType::Constant |
MemoryType::Shared
)
}
pub fn is_host_accessible(&self) -> bool {
matches!(self,
MemoryType::Host |
MemoryType::Unified |
MemoryType::Pinned |
MemoryType::MemoryMapped
)
}
}
#[derive(Debug, Clone)]
pub struct BandwidthCharacteristics {
pub peak_bandwidth_gbps: f64,
pub typical_latency_ns: f64,
pub access_granularity: usize,
}
#[derive(Debug, Clone, Default)]
pub struct AllocationUsageStats {
pub access_count: u64,
pub bytes_read: u64,
pub bytes_written: u64,
pub last_accessed: Option<Instant>,
pub access_frequency: f64,
pub bandwidth_utilization: f64,
pub cache_stats: CacheStats,
pub access_pattern: AccessPatternType,
}
impl AllocationUsageStats {
pub fn total_throughput(&self) -> u64 {
self.bytes_read + self.bytes_written
}
pub fn read_write_ratio(&self) -> f64 {
if self.bytes_written > 0 {
self.bytes_read as f64 / self.bytes_written as f64
} else if self.bytes_read > 0 {
f64::INFINITY
} else {
0.0
}
}
pub fn is_read_heavy(&self) -> bool {
self.read_write_ratio() > 3.0
}
pub fn is_write_heavy(&self) -> bool {
self.read_write_ratio() < 0.3
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub l1_hits: u64,
pub l1_misses: u64,
pub l2_hits: u64,
pub l2_misses: u64,
pub tlb_hits: u64,
pub tlb_misses: u64,
}
impl CacheStats {
pub fn l1_hit_rate(&self) -> f64 {
let total = self.l1_hits + self.l1_misses;
if total > 0 {
self.l1_hits as f64 / total as f64
} else {
0.0
}
}
pub fn l2_hit_rate(&self) -> f64 {
let total = self.l2_hits + self.l2_misses;
if total > 0 {
self.l2_hits as f64 / total as f64
} else {
0.0
}
}
pub fn tlb_hit_rate(&self) -> f64 {
let total = self.tlb_hits + self.tlb_misses;
if total > 0 {
self.tlb_hits as f64 / total as f64
} else {
0.0
}
}
pub fn efficiency_score(&self) -> f64 {
(self.l1_hit_rate() * 0.5) + (self.l2_hit_rate() * 0.3) + (self.tlb_hit_rate() * 0.2)
}
}
#[derive(Debug, Clone, Default)]
pub enum AccessPatternType {
#[default]
Unknown,
Sequential,
Random,
Strided { stride: usize },
Temporal, Sparse, }
#[derive(Debug, Clone)]
pub struct LifetimeEvent {
pub timestamp: Instant,
pub event_type: LifetimeEventType,
pub details: String,
}
impl LifetimeEvent {
pub fn new(event_type: LifetimeEventType, details: String) -> Self {
Self {
timestamp: Instant::now(),
event_type,
details,
}
}
pub fn age(&self) -> std::time::Duration {
self.timestamp.elapsed()
}
}
#[derive(Debug, Clone)]
pub enum LifetimeEventType {
Allocated,
Accessed { read: bool, write: bool },
Copied { source: bool, destination: bool },
Resized { old_size: usize, new_size: usize },
Deallocated,
MemoryPressure { pressure_level: PressureLevel },
Defragmented,
CacheFlushed,
SwappedOut,
SwappedIn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PressureLevel {
None,
Low,
Medium,
High,
Critical,
}
impl PressureLevel {
pub fn as_f64(&self) -> f64 {
match self {
PressureLevel::None => 0.0,
PressureLevel::Low => 0.25,
PressureLevel::Medium => 0.5,
PressureLevel::High => 0.75,
PressureLevel::Critical => 1.0,
}
}
pub fn from_f64(value: f64) -> Self {
if value >= 0.9 {
PressureLevel::Critical
} else if value >= 0.7 {
PressureLevel::High
} else if value >= 0.4 {
PressureLevel::Medium
} else if value >= 0.1 {
PressureLevel::Low
} else {
PressureLevel::None
}
}
}
#[derive(Debug, Clone)]
pub struct PerformanceHint {
pub hint_type: PerformanceHintType,
pub severity: HintSeverity,
pub description: String,
pub suggested_action: String,
pub impact_estimate: f64,
pub confidence: f64,
}
impl PerformanceHint {
pub fn new(
hint_type: PerformanceHintType,
severity: HintSeverity,
description: String,
suggested_action: String,
impact_estimate: f64,
) -> Self {
Self {
hint_type,
severity,
description,
suggested_action,
impact_estimate: impact_estimate.clamp(0.0, 1.0),
confidence: 0.8, }
}
pub fn priority_score(&self) -> f64 {
let severity_weight = match self.severity {
HintSeverity::Info => 0.3,
HintSeverity::Warning => 0.6,
HintSeverity::Critical => 1.0,
};
severity_weight * self.impact_estimate * self.confidence
}
}
#[derive(Debug, Clone)]
pub enum PerformanceHintType {
SuboptimalAccessPattern,
InefficientSize,
SuboptimalMemoryType,
ExcessiveAllocations,
Fragmentation,
UnusedMemory,
PoorCacheLocality,
BandwidthUnderutilization,
AlignmentIssues,
PrefetchingOpportunity,
CoalescingOpportunity,
PoolOptimization,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum HintSeverity {
Info,
Warning,
Critical,
}
impl HintSeverity {
pub fn as_f64(&self) -> f64 {
match self {
HintSeverity::Info => 0.3,
HintSeverity::Warning => 0.6,
HintSeverity::Critical => 1.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_allocation_creation() {
let source = AllocationSource::new(
"test_function".to_string(),
Some(("test.rs".to_string(), 42)),
AllocationContext::UserAllocation { request_id: "test".to_string() },
);
let allocation = MemoryAllocation::new(
0x1000,
1024,
source,
MemoryType::Host,
None,
);
assert_eq!(allocation.ptr, 0x1000);
assert_eq!(allocation.size, 1024);
assert_eq!(allocation.memory_type, MemoryType::Host);
}
#[test]
fn test_allocation_access_tracking() {
let source = AllocationSource::new(
"test".to_string(),
None,
AllocationContext::InternalAllocation { purpose: "test".to_string() },
);
let mut allocation = MemoryAllocation::new(
0x1000,
1024,
source,
MemoryType::Host,
None,
);
allocation.record_access(true, false, 512);
assert_eq!(allocation.usage_stats.access_count, 1);
assert_eq!(allocation.usage_stats.bytes_read, 512);
assert_eq!(allocation.usage_stats.bytes_written, 0);
}
#[test]
fn test_cache_stats() {
let mut stats = CacheStats::default();
stats.l1_hits = 80;
stats.l1_misses = 20;
assert_eq!(stats.l1_hit_rate(), 0.8);
}
#[test]
fn test_pressure_level_conversion() {
assert_eq!(PressureLevel::from_f64(0.95), PressureLevel::Critical);
assert_eq!(PressureLevel::from_f64(0.75), PressureLevel::High);
assert_eq!(PressureLevel::from_f64(0.5), PressureLevel::Medium);
assert_eq!(PressureLevel::from_f64(0.2), PressureLevel::Low);
assert_eq!(PressureLevel::from_f64(0.05), PressureLevel::None);
}
#[test]
fn test_performance_hint() {
let hint = PerformanceHint::new(
PerformanceHintType::UnusedMemory,
HintSeverity::Warning,
"Memory appears unused".to_string(),
"Consider deallocating".to_string(),
0.6,
);
assert!(hint.priority_score() > 0.0);
}
#[test]
fn test_memory_type_characteristics() {
let device_chars = MemoryType::Device.bandwidth_characteristics();
let host_chars = MemoryType::Host.bandwidth_characteristics();
assert!(device_chars.peak_bandwidth_gbps > host_chars.peak_bandwidth_gbps);
assert!(MemoryType::Device.is_device_accessible());
assert!(MemoryType::Host.is_host_accessible());
}
}