#[cfg(feature = "simd")]
use crate::storage::SimdStorage;
use crate::{Tensor, TensorStorage};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use torsh_core::{
dtype::TensorElement,
error::{Result, TorshError},
shape::Shape,
};
#[cfg(feature = "simd")]
use scirs2_core::simd_aligned::AlignedVec;
#[derive(Debug, Clone)]
pub struct CacheAnalysisReport {
pub cache_efficiency: f64,
pub estimated_cache_misses: usize,
pub spatial_locality_score: f64,
pub temporal_locality_score: f64,
pub memory_layout_optimal: bool,
pub recommended_optimizations: Vec<String>,
}
impl<T: TensorElement + Copy> Tensor<T> {
pub fn optimize_cache_layout(&mut self) -> Result<()> {
if self.numel() < 1024 {
return Ok(()); }
let current_strides = self.compute_strides();
let optimal_order = self.determine_optimal_dimension_order(¤t_strides);
if optimal_order.iter().enumerate().all(|(i, &dim)| dim == i) {
return Ok(());
}
self.reorder_dimensions(&optimal_order)?;
self.add_cache_padding()?;
Ok(())
}
fn determine_optimal_dimension_order(&self, strides: &[usize]) -> Vec<usize> {
let shape_binding = self.shape();
let dims = shape_binding.dims();
let mut dim_priorities: Vec<(usize, f64)> = (0..dims.len())
.map(|i| {
let size_factor = dims[i] as f64;
let stride_factor = 1.0 / (strides[i] as f64 + 1.0);
let cache_friendliness = size_factor * stride_factor;
(i, cache_friendliness)
})
.collect();
dim_priorities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
dim_priorities.into_iter().map(|(dim, _)| dim).collect()
}
fn reorder_dimensions(&mut self, optimal_order: &[usize]) -> Result<()> {
if optimal_order.len() != self.ndim() {
return Err(TorshError::InvalidOperation(
"Dimension order length mismatch".to_string(),
));
}
let data = self.to_vec()?;
let old_dims = self.shape().dims().to_vec();
let old_strides = self.compute_strides();
let new_dims: Vec<usize> = optimal_order.iter().map(|&i| old_dims[i]).collect();
let new_numel = new_dims.iter().product::<usize>();
let mut new_data = vec![data[0]; new_numel];
#[allow(clippy::needless_range_loop)]
for i in 0..new_numel {
let mut old_indices = vec![0; self.ndim()];
let mut remaining = i;
for (j, &dim_size) in new_dims.iter().enumerate().rev() {
old_indices[optimal_order[j]] = remaining % dim_size;
remaining /= dim_size;
}
let old_flat_index: usize = old_indices
.iter()
.zip(old_strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum();
new_data[i] = data[old_flat_index];
}
self.storage = TensorStorage::create_optimal(new_data)?;
self.shape = Shape::new(new_dims);
Ok(())
}
fn add_cache_padding(&mut self) -> Result<()> {
const CACHE_LINE_SIZE: usize = 64; let element_size = std::mem::size_of::<T>();
let elements_per_cache_line = CACHE_LINE_SIZE / element_size;
let shape_binding = self.shape();
let dims = shape_binding.dims();
if dims.is_empty() || dims[dims.len() - 1] % elements_per_cache_line == 0 {
return Ok(()); }
let last_dim = dims[dims.len() - 1];
let padded_last_dim = last_dim.div_ceil(elements_per_cache_line) * elements_per_cache_line;
let padding_needed = padded_last_dim - last_dim;
if (padding_needed as f64 / last_dim as f64) > 0.25 {
return Ok(());
}
let data = self.to_vec()?;
let mut new_dims = dims.to_vec();
let last_idx = new_dims.len() - 1;
new_dims[last_idx] = padded_last_dim;
let new_numel = new_dims.iter().product::<usize>();
let mut padded_data = Vec::with_capacity(new_numel);
let outer_size = new_numel / padded_last_dim;
for i in 0..outer_size {
let start_idx = i * last_dim;
let end_idx = (i + 1) * last_dim;
padded_data.extend_from_slice(&data[start_idx..end_idx]);
for _ in 0..padding_needed {
padded_data.push(data[0]); }
}
self.storage = TensorStorage::create_optimal(padded_data)?;
self.shape = Shape::new(new_dims);
Ok(())
}
pub fn analyze_cache_performance(&self) -> CacheAnalysisReport {
let shape_binding = self.shape();
let dims = shape_binding.dims();
let strides = self.compute_strides();
let numel = self.numel();
let mut cache_misses_estimate = 0f64;
for (i, &stride) in strides.iter().enumerate() {
let dimension_accesses = dims[i] as f64;
let stride_penalty = if stride > 64 {
stride as f64 / 64.0
} else {
1.0
};
cache_misses_estimate += dimension_accesses * stride_penalty;
}
let spatial_locality_score = if strides.last().copied().unwrap_or(1) == 1usize {
1.0
} else {
1.0 / strides.last().copied().unwrap_or(1) as f64
};
let temporal_locality_score = 1.0 / (numel as f64).log2().max(1.0);
CacheAnalysisReport {
cache_efficiency: (spatial_locality_score + temporal_locality_score) / 2.0,
estimated_cache_misses: cache_misses_estimate as usize,
spatial_locality_score,
temporal_locality_score,
memory_layout_optimal: strides.last().copied().unwrap_or(1) == 1usize,
recommended_optimizations: self.generate_optimization_recommendations(&strides),
}
}
fn generate_optimization_recommendations(&self, strides: &[usize]) -> Vec<String> {
let mut recommendations = Vec::new();
let shape_binding = self.shape();
let dims = shape_binding.dims();
if strides.last().copied().unwrap_or(1) != 1 {
recommendations
.push("Consider using .contiguous() to ensure row-major layout".to_string());
}
if self.numel() < 1024 {
recommendations.push("Tensor too small to benefit from cache optimization".to_string());
}
if dims.len() > 2 {
let largest_dim = dims.iter().enumerate().max_by_key(|(_, &size)| size);
if let Some((largest_idx, _)) = largest_dim {
if largest_idx != dims.len() - 1 {
recommendations.push(format!(
"Consider moving dimension {largest_idx} to the end for better cache locality"
));
}
}
}
const CACHE_LINE_SIZE: usize = 64;
let element_size = std::mem::size_of::<T>();
let elements_per_cache_line = CACHE_LINE_SIZE / element_size;
if !dims.is_empty() {
let last_dim = dims[dims.len() - 1];
if last_dim % elements_per_cache_line != 0 {
recommendations
.push("Consider adding cache-line padding for better alignment".to_string());
}
}
recommendations
}
pub fn to_cache_optimized(&self) -> Result<Self> {
let mut optimized = self.clone();
optimized.optimize_cache_layout()?;
Ok(optimized)
}
pub fn memory_stats(&self) -> MemoryStats {
let element_size = std::mem::size_of::<T>();
let total_elements = self.numel();
let total_bytes = total_elements * element_size;
let overhead_bytes = match &self.storage {
TensorStorage::InMemory(_) => {
std::mem::size_of::<std::sync::Arc<std::sync::RwLock<Vec<T>>>>()
}
TensorStorage::MemoryMapped(_) => {
1024 }
#[cfg(feature = "simd")]
TensorStorage::Aligned(_) => {
std::mem::size_of::<std::sync::Arc<std::sync::RwLock<AlignedVec<T>>>>()
}
#[cfg(feature = "simd")]
TensorStorage::SimdOptimized(_) => {
std::mem::size_of::<std::sync::Arc<SimdStorage<T>>>()
}
};
MemoryStats {
total_bytes,
element_size,
total_elements,
overhead_bytes,
is_memory_mapped: matches!(&self.storage, TensorStorage::MemoryMapped(_)),
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_bytes: usize,
pub element_size: usize,
pub total_elements: usize,
pub overhead_bytes: usize,
pub is_memory_mapped: bool,
}
impl MemoryStats {
pub fn effective_bytes(&self) -> usize {
self.total_bytes + self.overhead_bytes
}
pub fn efficiency(&self) -> f64 {
self.total_bytes as f64 / self.effective_bytes() as f64
}
}
pub struct TensorMemoryPool {
pool: Arc<Mutex<HashMap<usize, Vec<Vec<u8>>>>>,
stats: Arc<Mutex<PoolStatistics>>,
max_pool_size: usize,
current_pool_size: Arc<Mutex<usize>>,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStatistics {
pub allocations: usize,
pub deallocations: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub peak_memory_usage: usize,
pub total_memory_saved: usize,
}
impl TensorMemoryPool {
pub fn new(max_size_mb: usize) -> Self {
Self {
pool: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(PoolStatistics::default())),
max_pool_size: max_size_mb * 1024 * 1024,
current_pool_size: Arc::new(Mutex::new(0)),
}
}
pub fn allocate(&self, size_bytes: usize) -> Vec<u8> {
let mut pool = self.pool.lock().expect("lock should not be poisoned");
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.allocations += 1;
let rounded_size = size_bytes.next_power_of_two();
if let Some(pool_vec) = pool.get_mut(&rounded_size) {
if let Some(memory) = pool_vec.pop() {
stats.cache_hits += 1;
let mut current_size = self
.current_pool_size
.lock()
.expect("lock should not be poisoned");
*current_size -= rounded_size;
return memory;
}
}
stats.cache_misses += 1;
vec![0u8; rounded_size]
}
pub fn deallocate(&self, mut memory: Vec<u8>) {
let size = memory.len();
let mut pool = self.pool.lock().expect("lock should not be poisoned");
let mut stats = self.stats.lock().expect("lock should not be poisoned");
let mut current_size = self
.current_pool_size
.lock()
.expect("lock should not be poisoned");
stats.deallocations += 1;
if *current_size + size <= self.max_pool_size {
memory.fill(0);
pool.entry(size).or_default().push(memory);
*current_size += size;
stats.total_memory_saved += size;
}
stats.peak_memory_usage = stats.peak_memory_usage.max(*current_size);
}
pub fn get_statistics(&self) -> PoolStatistics {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn clear(&self) {
let mut pool = self.pool.lock().expect("lock should not be poisoned");
let mut current_size = self
.current_pool_size
.lock()
.expect("lock should not be poisoned");
pool.clear();
*current_size = 0;
}
}
pub struct MemoryPressureMonitor {
samples: Arc<Mutex<Vec<(Instant, usize)>>>,
pressure_level: Arc<Mutex<f64>>,
high_pressure_threshold: usize,
}
impl MemoryPressureMonitor {
pub fn new(memory_limit_mb: usize) -> Self {
Self {
samples: Arc::new(Mutex::new(Vec::new())),
pressure_level: Arc::new(Mutex::new(0.0)),
high_pressure_threshold: memory_limit_mb * 1024 * 1024,
}
}
pub fn record_usage(&self, bytes_used: usize) {
let mut samples = self.samples.lock().expect("lock should not be poisoned");
let mut pressure = self
.pressure_level
.lock()
.expect("lock should not be poisoned");
let now = Instant::now();
samples.push((now, bytes_used));
samples.retain(|(time, _)| now.duration_since(*time) < Duration::from_secs(60));
let avg_usage = if samples.is_empty() {
0.0
} else {
samples.iter().map(|(_, usage)| *usage as f64).sum::<f64>() / samples.len() as f64
};
*pressure = (avg_usage / self.high_pressure_threshold as f64).min(1.0);
}
pub fn get_pressure_level(&self) -> f64 {
*self
.pressure_level
.lock()
.expect("lock should not be poisoned")
}
pub fn is_high_pressure(&self) -> bool {
self.get_pressure_level() > 0.8
}
}
#[derive(Debug, Clone, Copy)]
pub enum NumaNode {
Local,
Node(u32),
Interleaved,
}
#[derive(Debug, Clone)]
pub struct NumaAllocationHint {
pub preferred_node: NumaNode,
pub allow_fallback: bool,
pub bind_threads: bool,
}
impl<T: TensorElement + Copy + Default> Tensor<T> {
pub fn optimize_memory_layout(&mut self, numa_hint: Option<NumaAllocationHint>) -> Result<()> {
self.optimize_cache_layout()?;
if let Some(hint) = numa_hint {
self.apply_numa_optimization(hint)?;
}
self.optimize_access_patterns()?;
Ok(())
}
fn apply_numa_optimization(&mut self, _hint: NumaAllocationHint) -> Result<()> {
if self.numel() > 1_000_000 {
if !self.is_contiguous() {
let contiguous_tensor = self.contiguous()?;
*self = contiguous_tensor;
}
}
Ok(())
}
fn optimize_access_patterns(&mut self) -> Result<()> {
let shape_binding = self.shape();
let dims = shape_binding.dims();
if dims.len() == 2 && dims[0] > 64 && dims[1] > 64 {
let row_size = dims[1] * std::mem::size_of::<T>();
let cache_line_size = 64;
if row_size % cache_line_size != 0 && row_size < cache_line_size * 4 {
self.add_cache_padding()?;
}
}
if dims.len() >= 3 {
let innermost_size = dims[dims.len() - 1] * std::mem::size_of::<T>();
if !(32..=256).contains(&innermost_size) {
self.add_cache_padding()?;
}
}
Ok(())
}
pub fn create_memory_mapped_optimized(
data: Vec<T>,
shape: Vec<usize>,
numa_hint: Option<NumaAllocationHint>,
) -> Result<Self> {
let mut tensor = Self::from_data(data, shape, torsh_core::device::DeviceType::Cpu)?;
tensor.optimize_memory_layout(numa_hint)?;
Ok(tensor)
}
pub fn prefetch_data(&self) -> Result<()> {
if self.numel() > 10_000 {
let data = self.to_vec()?;
let stride = data.len() / 100;
let mut _sum = T::default();
for i in (0..data.len()).step_by(stride.max(1)) {
_sum = data[i]; }
}
Ok(())
}
}
static GLOBAL_MEMORY_POOL: std::sync::OnceLock<TensorMemoryPool> = std::sync::OnceLock::new();
static MEMORY_PRESSURE_MONITOR: std::sync::OnceLock<MemoryPressureMonitor> =
std::sync::OnceLock::new();
pub fn get_memory_pool() -> &'static TensorMemoryPool {
GLOBAL_MEMORY_POOL.get_or_init(|| TensorMemoryPool::new(1024)) }
pub fn get_memory_pressure_monitor() -> &'static MemoryPressureMonitor {
MEMORY_PRESSURE_MONITOR.get_or_init(|| MemoryPressureMonitor::new(8192)) }
#[cfg(test)]
mod tests {
use crate::creation::*;
#[test]
fn test_cache_optimization() {
let mut tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
assert!(tensor.optimize_cache_layout().is_ok());
}
#[test]
fn test_cache_analysis() {
let tensor = ones::<f32>(&[64, 64]).expect("ones creation should succeed");
let report = tensor.analyze_cache_performance();
assert!(report.cache_efficiency >= 0.0 && report.cache_efficiency <= 1.0);
}
#[test]
fn test_contiguous_layout() {
let tensor = ones::<f32>(&[10, 10]).expect("ones creation should succeed");
assert!(tensor.is_contiguous());
let contiguous = tensor
.contiguous()
.expect("contiguous conversion should succeed");
assert!(contiguous.is_contiguous());
}
#[test]
fn test_memory_stats() {
let tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
let stats = tensor.memory_stats();
assert_eq!(stats.total_elements, 10000);
assert_eq!(stats.element_size, 4); assert_eq!(stats.total_bytes, 40000);
}
#[test]
fn test_memory_pool() {
use super::*;
let pool = TensorMemoryPool::new(10);
let memory1 = pool.allocate(1024);
assert_eq!(memory1.len(), 1024);
let memory2 = pool.allocate(2048);
assert_eq!(memory2.len(), 2048);
pool.deallocate(memory1);
let memory3 = pool.allocate(1024);
assert_eq!(memory3.len(), 1024);
let stats = pool.get_statistics();
assert!(stats.allocations > 0);
assert!(stats.deallocations > 0);
pool.deallocate(memory2);
pool.deallocate(memory3);
}
#[test]
fn test_memory_pressure_monitor() {
use super::*;
let monitor = MemoryPressureMonitor::new(100);
monitor.record_usage(50 * 1024 * 1024); assert!(monitor.get_pressure_level() < 0.6);
monitor.record_usage(90 * 1024 * 1024); assert!(monitor.get_pressure_level() > 0.6);
assert!(monitor.get_pressure_level() < 0.8);
assert!(!monitor.is_high_pressure());
monitor.record_usage(95 * 1024 * 1024); monitor.record_usage(100 * 1024 * 1024); assert!(monitor.is_high_pressure());
}
#[test]
fn test_advanced_memory_optimization() {
let mut tensor = ones::<f32>(&[64, 64]).expect("ones creation should succeed");
let numa_hint = super::NumaAllocationHint {
preferred_node: super::NumaNode::Local,
allow_fallback: true,
bind_threads: false,
};
assert!(tensor.optimize_memory_layout(Some(numa_hint)).is_ok());
assert!(tensor.is_contiguous());
}
#[test]
fn test_cache_optimized_creation() {
let data: Vec<f32> = (0..10000).map(|i| i as f32).collect();
let shape = vec![100, 100];
let numa_hint = super::NumaAllocationHint {
preferred_node: super::NumaNode::Interleaved,
allow_fallback: true,
bind_threads: false,
};
let tensor = super::Tensor::create_memory_mapped_optimized(data, shape, Some(numa_hint));
assert!(tensor.is_ok());
let tensor = tensor.expect("operation should succeed");
let shape = tensor.shape();
let dims = shape.dims();
assert_eq!(dims[0], 100); assert!(dims[1] >= 100); }
#[test]
fn test_memory_prefetch() {
let tensor = ones::<f32>(&[200, 200]).expect("ones creation should succeed");
assert!(tensor.prefetch_data().is_ok());
}
#[test]
fn test_global_memory_pool_access() {
use super::*;
let pool = get_memory_pool();
let memory = pool.allocate(1024);
assert_eq!(memory.len(), 1024);
pool.deallocate(memory);
let monitor = get_memory_pressure_monitor();
monitor.record_usage(1024 * 1024); assert!(monitor.get_pressure_level() >= 0.0);
}
#[test]
fn test_pool_statistics() {
use super::*;
let pool = TensorMemoryPool::new(5);
let mut memories = Vec::new();
for i in 0..10 {
let size = (i + 1) * 512;
memories.push(pool.allocate(size));
}
for memory in memories {
pool.deallocate(memory);
}
let stats = pool.get_statistics();
assert_eq!(stats.allocations, 10);
assert_eq!(stats.deallocations, 10);
assert!(stats.cache_hits + stats.cache_misses == 10);
pool.clear();
}
#[test]
fn test_memory_efficiency_calculation() {
let tensor = ones::<f32>(&[50, 50]).expect("ones creation should succeed");
let stats = tensor.memory_stats();
let efficiency = stats.efficiency();
assert!(efficiency > 0.0 && efficiency <= 1.0);
let effective = stats.effective_bytes();
assert!(effective >= stats.total_bytes);
}
}