#[cfg(not(feature = "std"))]
use alloc::{string::String, vec, vec::Vec};
#[cfg(feature = "std")]
use std::{collections::HashMap, sync::Arc};
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap as HashMap;
#[cfg(not(feature = "std"))]
use alloc::sync::Arc;
use crate::{
error::{Result, TorshError},
shape::Shape,
MemoryFormat,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AccessPattern {
Sequential,
Strided { stride: usize },
Random,
RowMajor,
ColumnMajor,
BlockWise { block_size: usize },
Diagonal,
Broadcast,
}
#[derive(Debug, Clone)]
pub struct AccessStatistics {
pub total_accesses: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub average_stride: f64,
pub stride_variance: f64,
pub dominant_pattern: AccessPattern,
pub pattern_distribution: HashMap<AccessPattern, u64>,
}
#[derive(Debug, Clone)]
pub struct AccessTracker {
shape: Shape,
memory_format: MemoryFormat,
recent_accesses: Vec<usize>,
max_history: usize,
stats: AccessStatistics,
cache_line_size: usize,
}
impl AccessTracker {
pub fn new(shape: Shape, memory_format: MemoryFormat) -> Self {
Self {
shape,
memory_format,
recent_accesses: Vec::with_capacity(1000),
max_history: 1000,
stats: AccessStatistics {
total_accesses: 0,
cache_hits: 0,
cache_misses: 0,
average_stride: 0.0,
stride_variance: 0.0,
dominant_pattern: AccessPattern::Random,
pattern_distribution: HashMap::new(),
},
cache_line_size: 64, }
}
pub fn with_cache_line_size(mut self, cache_line_size: usize) -> Self {
self.cache_line_size = cache_line_size;
self
}
pub fn record_access(&mut self, linear_index: usize) {
if self.recent_accesses.len() >= self.max_history {
self.recent_accesses.remove(0);
}
self.recent_accesses.push(linear_index);
self.stats.total_accesses += 1;
if self.recent_accesses.len() >= 2 {
let prev_index = self.recent_accesses[self.recent_accesses.len() - 2];
let stride = if linear_index > prev_index {
linear_index - prev_index
} else {
prev_index - linear_index
};
if stride * core::mem::size_of::<f32>() <= self.cache_line_size {
self.stats.cache_hits += 1;
} else {
self.stats.cache_misses += 1;
}
}
if self.stats.total_accesses % 100 == 0 {
self.analyze_pattern();
}
}
fn analyze_pattern(&mut self) {
if self.recent_accesses.len() < 10 {
return;
}
let mut strides = Vec::new();
for i in 1..self.recent_accesses.len() {
let stride = if self.recent_accesses[i] > self.recent_accesses[i - 1] {
self.recent_accesses[i] - self.recent_accesses[i - 1]
} else {
self.recent_accesses[i - 1] - self.recent_accesses[i]
};
strides.push(stride as f64);
}
let sum: f64 = strides.iter().sum();
let avg = sum / strides.len() as f64;
self.stats.average_stride = avg;
let variance_sum: f64 = strides.iter().map(|&s| (s - avg).powi(2)).sum();
self.stats.stride_variance = variance_sum / strides.len() as f64;
let pattern = self.detect_pattern(&strides);
*self.stats.pattern_distribution.entry(pattern).or_insert(0) += 1;
if let Some((&dominant, _)) = self
.stats
.pattern_distribution
.iter()
.max_by_key(|(_, &count)| count)
{
self.stats.dominant_pattern = dominant;
}
}
fn detect_pattern(&self, strides: &[f64]) -> AccessPattern {
if strides.is_empty() {
return AccessPattern::Random;
}
let avg = self.stats.average_stride;
let variance = self.stats.stride_variance;
if (avg - 1.0).abs() < 0.1 && variance < 0.5 {
return AccessPattern::Sequential;
}
if variance < avg * 0.2 && avg > 1.5 {
return AccessPattern::Strided {
stride: avg.round() as usize,
};
}
if let Some(row_len) = self.shape.dims().last() {
if (avg - *row_len as f64).abs() < 0.5 {
return AccessPattern::RowMajor;
}
}
if let Some(&first_dim) = self.shape.dims().first() {
if (avg - first_dim as f64).abs() < 0.5 {
return AccessPattern::ColumnMajor;
}
}
if variance < 1.0 && avg < 2.0 {
return AccessPattern::Broadcast;
}
AccessPattern::Random
}
pub fn statistics(&self) -> &AccessStatistics {
&self.stats
}
pub fn cache_hit_rate(&self) -> f64 {
if self.stats.total_accesses == 0 {
return 0.0;
}
self.stats.cache_hits as f64 / self.stats.total_accesses as f64
}
}
#[derive(Debug, Clone)]
pub struct LayoutRecommendation {
pub current_format: MemoryFormat,
pub recommended_format: MemoryFormat,
pub expected_improvement: f64,
pub reason: String,
pub transformation_cost: TransformationCost,
}
#[derive(Debug, Clone)]
pub struct TransformationCost {
pub memory_copies: usize,
pub estimated_time_us: f64,
pub memory_overhead_bytes: usize,
}
#[derive(Debug)]
pub struct LayoutOptimizer {
trackers: HashMap<usize, Arc<AccessTracker>>,
optimization_threshold: f64,
aggressive: bool,
}
impl Default for LayoutOptimizer {
fn default() -> Self {
Self::new()
}
}
impl LayoutOptimizer {
pub fn new() -> Self {
Self {
trackers: HashMap::new(),
optimization_threshold: 0.1, aggressive: false,
}
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.optimization_threshold = threshold;
self
}
pub fn aggressive(mut self, enabled: bool) -> Self {
self.aggressive = enabled;
self
}
pub fn register_tensor(&mut self, tensor_id: usize, shape: Shape, format: MemoryFormat) {
let tracker = AccessTracker::new(shape, format);
self.trackers.insert(tensor_id, Arc::new(tracker));
}
pub fn record_access(&mut self, tensor_id: usize, linear_index: usize) -> Result<()> {
if let Some(tracker) = self.trackers.get_mut(&tensor_id) {
let mut tracker_mut = (**tracker).clone();
tracker_mut.record_access(linear_index);
*tracker = Arc::new(tracker_mut);
Ok(())
} else {
Err(TorshError::InvalidArgument(format!(
"Tensor {} not registered for tracking",
tensor_id
)))
}
}
pub fn recommend_layout(&self, tensor_id: usize) -> Result<Option<LayoutRecommendation>> {
let tracker = self.trackers.get(&tensor_id).ok_or_else(|| {
TorshError::InvalidArgument(format!("Tensor {} not registered", tensor_id))
})?;
let stats = tracker.statistics();
if stats.total_accesses < 100 {
return Ok(None);
}
let recommendation = self.analyze_and_recommend(tracker)?;
if recommendation.expected_improvement >= self.optimization_threshold {
Ok(Some(recommendation))
} else {
Ok(None)
}
}
fn analyze_and_recommend(&self, tracker: &AccessTracker) -> Result<LayoutRecommendation> {
let stats = tracker.statistics();
let current_format = tracker.memory_format;
let cache_hit_rate = tracker.cache_hit_rate();
match stats.dominant_pattern {
AccessPattern::Sequential | AccessPattern::RowMajor => {
if current_format != MemoryFormat::Contiguous {
Ok(LayoutRecommendation {
current_format,
recommended_format: MemoryFormat::Contiguous,
expected_improvement: 0.3, reason: "Sequential/row-major access pattern detected. Contiguous layout will improve cache locality.".to_string(),
transformation_cost: self.estimate_cost(&tracker.shape),
})
} else {
Ok(LayoutRecommendation {
current_format,
recommended_format: current_format,
expected_improvement: 0.0,
reason: "Already using optimal layout".to_string(),
transformation_cost: TransformationCost {
memory_copies: 0,
estimated_time_us: 0.0,
memory_overhead_bytes: 0,
},
})
}
}
AccessPattern::ColumnMajor => {
if current_format != MemoryFormat::ChannelsLast {
Ok(LayoutRecommendation {
current_format,
recommended_format: MemoryFormat::ChannelsLast,
expected_improvement: 0.25,
reason: "Column-major access detected. ChannelsLast layout will improve stride patterns.".to_string(),
transformation_cost: self.estimate_cost(&tracker.shape),
})
} else {
Ok(LayoutRecommendation {
current_format,
recommended_format: current_format,
expected_improvement: 0.0,
reason: "Already using optimal layout".to_string(),
transformation_cost: TransformationCost {
memory_copies: 0,
estimated_time_us: 0.0,
memory_overhead_bytes: 0,
},
})
}
}
AccessPattern::Strided { stride } => {
let improvement = if cache_hit_rate < 0.5 { 0.4 } else { 0.15 };
Ok(LayoutRecommendation {
current_format,
recommended_format: MemoryFormat::Contiguous,
expected_improvement: improvement,
reason: format!(
"Strided access (stride={}) with low cache hit rate ({}%). Contiguous layout recommended.",
stride,
(cache_hit_rate * 100.0) as u32
),
transformation_cost: self.estimate_cost(&tracker.shape),
})
}
AccessPattern::BlockWise { block_size } => {
if self.aggressive {
Ok(LayoutRecommendation {
current_format,
recommended_format: MemoryFormat::Contiguous,
expected_improvement: 0.2,
reason: format!(
"Block-wise access (block_size={}) detected. Consider cache-friendly blocking.",
block_size
),
transformation_cost: self.estimate_cost(&tracker.shape),
})
} else {
Ok(LayoutRecommendation {
current_format,
recommended_format: current_format,
expected_improvement: 0.0,
reason: "Block-wise access requires specialized optimization".to_string(),
transformation_cost: TransformationCost {
memory_copies: 0,
estimated_time_us: 0.0,
memory_overhead_bytes: 0,
},
})
}
}
AccessPattern::Random => {
Ok(LayoutRecommendation {
current_format,
recommended_format: current_format,
expected_improvement: 0.0,
reason: "Random access pattern - layout optimization unlikely to help"
.to_string(),
transformation_cost: TransformationCost {
memory_copies: 0,
estimated_time_us: 0.0,
memory_overhead_bytes: 0,
},
})
}
AccessPattern::Broadcast => Ok(LayoutRecommendation {
current_format,
recommended_format: current_format,
expected_improvement: 0.0,
reason: "Broadcast-like access - current layout is fine".to_string(),
transformation_cost: TransformationCost {
memory_copies: 0,
estimated_time_us: 0.0,
memory_overhead_bytes: 0,
},
}),
AccessPattern::Diagonal => Ok(LayoutRecommendation {
current_format,
recommended_format: current_format,
expected_improvement: 0.0,
reason: "Diagonal access - specialized algorithm recommended".to_string(),
transformation_cost: TransformationCost {
memory_copies: 0,
estimated_time_us: 0.0,
memory_overhead_bytes: 0,
},
}),
}
}
fn estimate_cost(&self, shape: &Shape) -> TransformationCost {
let numel = shape.numel();
let element_size = 4; let total_bytes = numel * element_size;
let copy_time_us = (total_bytes as f64 / 10_000.0) * 1_000_000.0;
TransformationCost {
memory_copies: 1,
estimated_time_us: copy_time_us,
memory_overhead_bytes: total_bytes,
}
}
pub fn tracked_tensors(&self) -> Vec<usize> {
self.trackers.keys().copied().collect()
}
pub fn get_statistics(&self, tensor_id: usize) -> Result<AccessStatistics> {
let tracker = self.trackers.get(&tensor_id).ok_or_else(|| {
TorshError::InvalidArgument(format!("Tensor {} not registered", tensor_id))
})?;
Ok(tracker.statistics().clone())
}
pub fn clear_tensor(&mut self, tensor_id: usize) {
self.trackers.remove(&tensor_id);
}
pub fn clear_all(&mut self) {
self.trackers.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_access_tracker_creation() {
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
let tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
assert_eq!(tracker.statistics().total_accesses, 0);
}
#[test]
fn test_sequential_access_detection() {
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
for i in 0..1000 {
tracker.record_access(i);
}
let stats = tracker.statistics();
assert!(stats.total_accesses == 1000);
assert!(stats.cache_hits > stats.cache_misses); }
#[test]
fn test_strided_access_detection() {
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
for i in 0..100 {
tracker.record_access(i * 10);
}
let stats = tracker.statistics();
assert!(stats.total_accesses == 100);
assert!(stats.average_stride > 8.0);
}
#[test]
fn test_random_access_detection() {
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
let indices = [42, 1000, 5, 9999, 50, 7500, 200];
for &idx in &indices {
tracker.record_access(idx);
}
let stats = tracker.statistics();
assert!(stats.total_accesses == indices.len() as u64);
}
#[test]
fn test_cache_hit_rate() {
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
for i in 0..100 {
tracker.record_access(i);
}
let hit_rate = tracker.cache_hit_rate();
assert!(hit_rate > 0.5); }
#[test]
fn test_layout_optimizer_creation() {
let optimizer = LayoutOptimizer::new();
assert!(optimizer.tracked_tensors().is_empty());
}
#[test]
fn test_register_and_track_tensor() {
let mut optimizer = LayoutOptimizer::new();
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
assert_eq!(optimizer.tracked_tensors().len(), 1);
assert!(optimizer.tracked_tensors().contains(&1));
}
#[test]
fn test_record_access() {
let mut optimizer = LayoutOptimizer::new();
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
for i in 0..50 {
optimizer
.record_access(1, i)
.expect("record_access should succeed");
}
let stats = optimizer
.get_statistics(1)
.expect("get_statistics should succeed");
assert_eq!(stats.total_accesses, 50);
}
#[test]
fn test_optimization_recommendation() {
let mut optimizer = LayoutOptimizer::new().with_threshold(0.05);
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
optimizer.register_tensor(1, shape, MemoryFormat::Strided);
for i in 0..200 {
optimizer
.record_access(1, i)
.expect("record_access should succeed");
}
let recommendation = optimizer
.recommend_layout(1)
.expect("recommend_layout should succeed");
assert!(recommendation.is_some());
if let Some(rec) = recommendation {
assert_eq!(rec.recommended_format, MemoryFormat::Contiguous);
assert!(rec.expected_improvement > 0.0);
}
}
#[test]
fn test_insufficient_data_no_recommendation() {
let mut optimizer = LayoutOptimizer::new();
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
for i in 0..10 {
optimizer
.record_access(1, i)
.expect("record_access should succeed");
}
let recommendation = optimizer
.recommend_layout(1)
.expect("recommend_layout should succeed");
assert!(recommendation.is_none()); }
#[test]
fn test_clear_tensor() {
let mut optimizer = LayoutOptimizer::new();
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
assert_eq!(optimizer.tracked_tensors().len(), 1);
optimizer.clear_tensor(1);
assert!(optimizer.tracked_tensors().is_empty());
}
#[test]
fn test_aggressive_optimization() {
let optimizer = LayoutOptimizer::new().aggressive(true);
assert!(optimizer.aggressive);
}
#[test]
fn test_transformation_cost_estimation() {
let optimizer = LayoutOptimizer::new();
let shape = Shape::from_array([1000, 1000]).expect("shape creation should succeed");
let cost = optimizer.estimate_cost(&shape);
assert!(cost.memory_copies > 0);
assert!(cost.estimated_time_us > 0.0);
assert!(cost.memory_overhead_bytes > 0);
}
#[test]
fn test_custom_cache_line_size() {
let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
let tracker = AccessTracker::new(shape, MemoryFormat::Contiguous).with_cache_line_size(128);
assert_eq!(tracker.cache_line_size, 128);
}
}