#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryLayout {
#[default]
CContiguous,
FContiguous,
Strided,
Scalar,
}
impl MemoryLayout {
pub fn is_contiguous(&self) -> bool {
matches!(self, MemoryLayout::CContiguous | MemoryLayout::FContiguous)
}
pub fn is_c_optimal(&self) -> bool {
matches!(self, MemoryLayout::CContiguous | MemoryLayout::Scalar)
}
pub fn is_f_optimal(&self) -> bool {
matches!(self, MemoryLayout::FContiguous | MemoryLayout::Scalar)
}
}
pub fn detect_layout(shape: &[usize], strides: &[usize]) -> MemoryLayout {
if shape.is_empty() || shape.iter().product::<usize>() <= 1 {
return MemoryLayout::Scalar;
}
let mut expected_c_stride = 1;
let mut is_c_contiguous = true;
for i in (0..shape.len()).rev() {
if strides[i] != expected_c_stride {
is_c_contiguous = false;
break;
}
expected_c_stride *= shape[i];
}
if is_c_contiguous {
return MemoryLayout::CContiguous;
}
let mut expected_f_stride = 1;
let mut is_f_contiguous = true;
for i in 0..shape.len() {
if strides[i] != expected_f_stride {
is_f_contiguous = false;
break;
}
expected_f_stride *= shape[i];
}
if is_f_contiguous {
return MemoryLayout::FContiguous;
}
MemoryLayout::Strided
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheLevel {
L1,
L2,
L3,
}
#[derive(Debug, Clone, Copy)]
pub struct CacheConfig {
pub level: CacheLevel,
pub size_bytes: usize,
pub line_size: usize,
pub associativity: usize,
}
impl CacheConfig {
pub fn new(
level: CacheLevel,
size_bytes: usize,
line_size: usize,
associativity: usize,
) -> Self {
Self {
level,
size_bytes,
line_size,
associativity,
}
}
pub fn l1_default() -> Self {
Self::new(CacheLevel::L1, 32 * 1024, 64, 8)
}
pub fn l2_default() -> Self {
Self::new(CacheLevel::L2, 256 * 1024, 64, 8)
}
pub fn l3_default() -> Self {
Self::new(CacheLevel::L3, 8 * 1024 * 1024, 64, 16)
}
pub fn elements_per_line<T>(&self) -> usize {
let elem_size = std::mem::size_of::<T>();
self.line_size.checked_div(elem_size).unwrap_or(0)
}
pub fn elements_per_block<T>(&self) -> usize {
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return 0;
}
let usable_bytes = (self.size_bytes * 3) / 4;
usable_bytes / elem_size
}
pub fn tile_size_2d<T>(&self) -> (usize, usize) {
let block_elements = self.elements_per_block::<T>();
let tile_dim = (block_elements as f64).sqrt() as usize;
let tile_dim = tile_dim.max(1);
let elements_per_line = self.elements_per_line::<T>().max(1);
let aligned_dim = tile_dim.div_ceil(elements_per_line) * elements_per_line;
(aligned_dim, aligned_dim)
}
}
impl Default for CacheConfig {
fn default() -> Self {
Self::l2_default()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Block {
pub start: usize,
pub end: usize,
}
impl Block {
pub fn new(start: usize, end: usize) -> Self {
Self { start, end }
}
pub fn len(&self) -> usize {
self.end.saturating_sub(self.start)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter(&self) -> std::ops::Range<usize> {
self.start..self.end
}
}
pub struct BlockedIterator {
total: usize,
block_size: usize,
current: usize,
}
impl BlockedIterator {
pub fn new(total: usize, block_size: usize) -> Self {
Self {
total,
block_size: block_size.max(1),
current: 0,
}
}
pub fn for_type<T>(total: usize, cache: CacheConfig) -> Self {
Self::new(total, cache.elements_per_block::<T>())
}
}
impl Iterator for BlockedIterator {
type Item = Block;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.total {
return None;
}
let start = self.current;
let end = (start + self.block_size).min(self.total);
self.current = end;
Some(Block::new(start, end))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.total.saturating_sub(self.current);
let count = remaining.div_ceil(self.block_size);
(count, Some(count))
}
}
impl ExactSizeIterator for BlockedIterator {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Tile2D {
pub row_start: usize,
pub row_end: usize,
pub col_start: usize,
pub col_end: usize,
}
impl Tile2D {
pub fn new(row_start: usize, row_end: usize, col_start: usize, col_end: usize) -> Self {
Self {
row_start,
row_end,
col_start,
col_end,
}
}
pub fn rows(&self) -> usize {
self.row_end.saturating_sub(self.row_start)
}
pub fn cols(&self) -> usize {
self.col_end.saturating_sub(self.col_start)
}
pub fn len(&self) -> usize {
self.rows() * self.cols()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct TiledIterator2D {
total_rows: usize,
total_cols: usize,
tile_rows: usize,
tile_cols: usize,
current_row: usize,
current_col: usize,
}
impl TiledIterator2D {
pub fn new(total_rows: usize, total_cols: usize, tile_rows: usize, tile_cols: usize) -> Self {
Self {
total_rows,
total_cols,
tile_rows: tile_rows.max(1),
tile_cols: tile_cols.max(1),
current_row: 0,
current_col: 0,
}
}
pub fn for_type<T>(total_rows: usize, total_cols: usize, cache: CacheConfig) -> Self {
let (tile_rows, tile_cols) = cache.tile_size_2d::<T>();
Self::new(total_rows, total_cols, tile_rows, tile_cols)
}
}
impl Iterator for TiledIterator2D {
type Item = Tile2D;
fn next(&mut self) -> Option<Self::Item> {
if self.current_row >= self.total_rows {
return None;
}
let row_start = self.current_row;
let row_end = (row_start + self.tile_rows).min(self.total_rows);
let col_start = self.current_col;
let col_end = (col_start + self.tile_cols).min(self.total_cols);
self.current_col += self.tile_cols;
if self.current_col >= self.total_cols {
self.current_col = 0;
self.current_row += self.tile_rows;
}
Some(Tile2D::new(row_start, row_end, col_start, col_end))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let row_tiles = self.total_rows.div_ceil(self.tile_rows);
let col_tiles = self.total_cols.div_ceil(self.tile_cols);
let total_tiles = row_tiles * col_tiles;
let current_row_tile = self.current_row / self.tile_rows;
let current_col_tile = self.current_col / self.tile_cols;
let current_tile = current_row_tile * col_tiles + current_col_tile;
let remaining = total_tiles.saturating_sub(current_tile);
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for TiledIterator2D {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccessPattern {
Sequential,
Reverse,
Random,
Strided(usize),
Blocked,
}
impl AccessPattern {
pub fn prefetch_distance(&self) -> usize {
match self {
AccessPattern::Sequential => 4,
AccessPattern::Reverse => 2,
AccessPattern::Random => 0,
AccessPattern::Strided(stride) => {
if *stride <= 8 {
2
} else {
1
}
}
AccessPattern::Blocked => 2,
}
}
pub fn benefits_from_prefetch(&self) -> bool {
!matches!(self, AccessPattern::Random)
}
}
#[derive(Debug, Clone)]
pub struct OptimizationHints {
pub layout: MemoryLayout,
pub access_pattern: AccessPattern,
pub block_size: usize,
pub tile_size: Option<(usize, usize)>,
pub use_parallel: bool,
pub cache_efficiency: f64,
}
impl OptimizationHints {
pub fn analyze<T>(shape: &[usize], strides: &[usize]) -> Self {
let layout = detect_layout(shape, strides);
let total_elements: usize = shape.iter().product();
let total_bytes = total_elements * std::mem::size_of::<T>();
let cache = if total_bytes <= 32 * 1024 {
CacheConfig::l1_default()
} else if total_bytes <= 256 * 1024 {
CacheConfig::l2_default()
} else {
CacheConfig::l3_default()
};
let block_size = cache.elements_per_block::<T>();
let tile_size = if shape.len() >= 2 {
Some(cache.tile_size_2d::<T>())
} else {
None
};
let cache_efficiency = match layout {
MemoryLayout::CContiguous | MemoryLayout::FContiguous => 0.95,
MemoryLayout::Strided => 0.5,
MemoryLayout::Scalar => 1.0,
};
let access_pattern = if layout.is_contiguous() {
AccessPattern::Sequential
} else if !strides.is_empty() {
AccessPattern::Strided(strides.iter().min().copied().unwrap_or(1))
} else {
AccessPattern::Random
};
let use_parallel = total_elements > 10_000;
Self {
layout,
access_pattern,
block_size,
tile_size,
use_parallel,
cache_efficiency,
}
}
pub fn default_for<T>(total_elements: usize) -> Self {
let cache = CacheConfig::l2_default();
Self {
layout: MemoryLayout::CContiguous,
access_pattern: AccessPattern::Sequential,
block_size: cache.elements_per_block::<T>(),
tile_size: Some(cache.tile_size_2d::<T>()),
use_parallel: total_elements > 10_000,
cache_efficiency: 0.95,
}
}
}
impl Default for OptimizationHints {
fn default() -> Self {
Self {
layout: MemoryLayout::CContiguous,
access_pattern: AccessPattern::Sequential,
block_size: 4096,
tile_size: Some((64, 64)),
use_parallel: false,
cache_efficiency: 0.95,
}
}
}
pub struct StrideOptimizer {
strides: Vec<usize>,
shape: Vec<usize>,
iteration_order: Vec<usize>,
}
impl StrideOptimizer {
pub fn new(shape: &[usize], strides: &[usize]) -> Self {
let mut iteration_order: Vec<usize> = (0..shape.len()).collect();
iteration_order.sort_by_key(|&i| strides.get(i).copied().unwrap_or(0));
Self {
strides: strides.to_vec(),
shape: shape.to_vec(),
iteration_order,
}
}
pub fn optimal_iteration_order(&self) -> &[usize] {
&self.iteration_order
}
pub fn should_copy(&self) -> bool {
let layout = detect_layout(&self.shape, &self.strides);
if layout.is_contiguous() {
return false;
}
let min_stride = self.strides.iter().min().copied().unwrap_or(1);
min_stride > 4
}
pub fn bandwidth_efficiency(&self) -> f64 {
if self.strides.is_empty() {
return 1.0;
}
let min_stride = self.strides.iter().min().copied().unwrap_or(1) as f64;
(1.0 / min_stride).min(1.0)
}
}
pub fn cache_aware_copy<T: Copy>(src: &[T], dst: &mut [T]) {
let len = src.len().min(dst.len());
if len == 0 {
return;
}
let cache = CacheConfig::l1_default();
let block_size = cache.elements_per_block::<T>();
for block in BlockedIterator::new(len, block_size) {
dst[block.start..block.end].copy_from_slice(&src[block.start..block.end]);
}
}
pub fn cache_aware_transform<T, U, F>(src: &[T], dst: &mut [U], f: F)
where
T: Copy,
F: Fn(T) -> U,
{
let len = src.len().min(dst.len());
if len == 0 {
return;
}
let cache = CacheConfig::l1_default();
let elem_size = std::mem::size_of::<T>().max(std::mem::size_of::<U>());
let block_size = (cache.size_bytes * 3 / 4)
.checked_div(elem_size)
.unwrap_or(len);
for block in BlockedIterator::new(len, block_size) {
for i in block.start..block.end {
dst[i] = f(src[i]);
}
}
}
pub fn cache_aware_binary_op<T, U, V, F>(a: &[T], b: &[U], result: &mut [V], f: F)
where
T: Copy,
U: Copy,
F: Fn(T, U) -> V,
{
let len = a.len().min(b.len()).min(result.len());
if len == 0 {
return;
}
let cache = CacheConfig::l1_default();
let elem_size = std::mem::size_of::<T>()
.max(std::mem::size_of::<U>())
.max(std::mem::size_of::<V>());
let block_size = if elem_size > 0 {
(cache.size_bytes * 3 / 4) / (elem_size * 3) } else {
len
};
for block in BlockedIterator::new(len, block_size) {
for i in block.start..block.end {
result[i] = f(a[i], b[i]);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AccessStats {
pub total_accesses: u64,
pub sequential_accesses: u64,
pub strided_accesses: u64,
pub random_accesses: u64,
pub estimated_miss_rate: f64,
}
impl AccessStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_sequential(&mut self) {
self.total_accesses += 1;
self.sequential_accesses += 1;
}
pub fn record_strided(&mut self) {
self.total_accesses += 1;
self.strided_accesses += 1;
}
pub fn record_random(&mut self) {
self.total_accesses += 1;
self.random_accesses += 1;
}
pub fn cache_efficiency(&self) -> f64 {
if self.total_accesses == 0 {
return 1.0;
}
let seq_weight = 1.0;
let strided_weight = 0.5;
let random_weight = 0.1;
let weighted_sum = (self.sequential_accesses as f64 * seq_weight)
+ (self.strided_accesses as f64 * strided_weight)
+ (self.random_accesses as f64 * random_weight);
weighted_sum / self.total_accesses as f64
}
pub fn update_miss_rate(&mut self) {
if self.total_accesses == 0 {
self.estimated_miss_rate = 0.0;
return;
}
let seq_miss = 0.05;
let strided_miss = 0.30;
let random_miss = 0.90;
self.estimated_miss_rate = ((self.sequential_accesses as f64 * seq_miss)
+ (self.strided_accesses as f64 * strided_miss)
+ (self.random_accesses as f64 * random_miss))
/ self.total_accesses as f64;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_layout_c_contiguous() {
let layout = detect_layout(&[3, 4], &[4, 1]);
assert_eq!(layout, MemoryLayout::CContiguous);
}
#[test]
fn test_detect_layout_f_contiguous() {
let layout = detect_layout(&[3, 4], &[1, 3]);
assert_eq!(layout, MemoryLayout::FContiguous);
}
#[test]
fn test_detect_layout_strided() {
let layout = detect_layout(&[3, 4], &[8, 2]);
assert_eq!(layout, MemoryLayout::Strided);
}
#[test]
fn test_detect_layout_scalar() {
let layout = detect_layout(&[], &[]);
assert_eq!(layout, MemoryLayout::Scalar);
let layout = detect_layout(&[1], &[1]);
assert_eq!(layout, MemoryLayout::Scalar);
}
#[test]
fn test_cache_config_elements() {
let cache = CacheConfig::l1_default();
assert_eq!(cache.elements_per_line::<f64>(), 8);
assert_eq!(cache.elements_per_line::<f32>(), 16);
}
#[test]
fn test_blocked_iterator() {
let iter = BlockedIterator::new(100, 30);
let blocks: Vec<_> = iter.collect();
assert_eq!(blocks.len(), 4);
assert_eq!(blocks[0], Block::new(0, 30));
assert_eq!(blocks[1], Block::new(30, 60));
assert_eq!(blocks[2], Block::new(60, 90));
assert_eq!(blocks[3], Block::new(90, 100));
}
#[test]
fn test_blocked_iterator_exact_division() {
let iter = BlockedIterator::new(100, 25);
let blocks: Vec<_> = iter.collect();
assert_eq!(blocks.len(), 4);
assert_eq!(blocks[3], Block::new(75, 100));
}
#[test]
fn test_tiled_iterator_2d() {
let iter = TiledIterator2D::new(10, 10, 4, 4);
let tiles: Vec<_> = iter.collect();
assert_eq!(tiles.len(), 9);
assert_eq!(tiles[0].row_start, 0);
assert_eq!(tiles[0].row_end, 4);
assert_eq!(tiles[0].col_start, 0);
assert_eq!(tiles[0].col_end, 4);
let last = tiles
.last()
.expect("tiles should have at least one element");
assert_eq!(last.row_start, 8);
assert_eq!(last.row_end, 10);
assert_eq!(last.col_start, 8);
assert_eq!(last.col_end, 10);
}
#[test]
fn test_block_len() {
let block = Block::new(10, 25);
assert_eq!(block.len(), 15);
assert!(!block.is_empty());
let empty = Block::new(10, 10);
assert_eq!(empty.len(), 0);
assert!(empty.is_empty());
}
#[test]
fn test_tile_2d_len() {
let tile = Tile2D::new(0, 4, 0, 5);
assert_eq!(tile.rows(), 4);
assert_eq!(tile.cols(), 5);
assert_eq!(tile.len(), 20);
}
#[test]
fn test_optimization_hints() {
let hints = OptimizationHints::analyze::<f64>(&[100, 100], &[100, 1]);
assert_eq!(hints.layout, MemoryLayout::CContiguous);
assert_eq!(hints.access_pattern, AccessPattern::Sequential);
assert!(hints.cache_efficiency > 0.9);
}
#[test]
fn test_stride_optimizer() {
let optimizer = StrideOptimizer::new(&[3, 4], &[4, 1]);
let order = optimizer.optimal_iteration_order();
assert_eq!(order[0], 1); assert_eq!(order[1], 0);
assert!(!optimizer.should_copy());
assert!(optimizer.bandwidth_efficiency() > 0.9);
}
#[test]
fn test_cache_aware_copy() {
let src = vec![1.0f64; 1000];
let mut dst = vec![0.0f64; 1000];
cache_aware_copy(&src, &mut dst);
assert_eq!(dst, src);
}
#[test]
fn test_cache_aware_transform() {
let src = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut dst = vec![0.0; 5];
cache_aware_transform(&src, &mut dst, |x| x * x);
assert_eq!(dst, vec![1.0, 4.0, 9.0, 16.0, 25.0]);
}
#[test]
fn test_cache_aware_binary_op() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![10.0, 20.0, 30.0, 40.0];
let mut result = vec![0.0; 4];
cache_aware_binary_op(&a, &b, &mut result, |x, y| x + y);
assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0]);
}
#[test]
fn test_access_stats() {
let mut stats = AccessStats::new();
stats.record_sequential();
stats.record_sequential();
stats.record_strided();
stats.record_random();
assert_eq!(stats.total_accesses, 4);
assert_eq!(stats.sequential_accesses, 2);
assert_eq!(stats.strided_accesses, 1);
assert_eq!(stats.random_accesses, 1);
assert!(stats.cache_efficiency() > 0.5);
stats.update_miss_rate();
assert!(stats.estimated_miss_rate > 0.0);
assert!(stats.estimated_miss_rate < 1.0);
}
#[test]
fn test_access_pattern_prefetch() {
assert_eq!(AccessPattern::Sequential.prefetch_distance(), 4);
assert_eq!(AccessPattern::Random.prefetch_distance(), 0);
assert!(AccessPattern::Sequential.benefits_from_prefetch());
assert!(!AccessPattern::Random.benefits_from_prefetch());
}
}