1#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
42pub enum MemoryLayout {
43 #[default]
45 CContiguous,
46 FContiguous,
48 Strided,
50 Scalar,
52}
53
54impl MemoryLayout {
55 pub fn is_contiguous(&self) -> bool {
57 matches!(self, MemoryLayout::CContiguous | MemoryLayout::FContiguous)
58 }
59
60 pub fn is_c_optimal(&self) -> bool {
62 matches!(self, MemoryLayout::CContiguous | MemoryLayout::Scalar)
63 }
64
65 pub fn is_f_optimal(&self) -> bool {
67 matches!(self, MemoryLayout::FContiguous | MemoryLayout::Scalar)
68 }
69}
70
71pub fn detect_layout(shape: &[usize], strides: &[usize]) -> MemoryLayout {
96 if shape.is_empty() || shape.iter().product::<usize>() <= 1 {
97 return MemoryLayout::Scalar;
98 }
99
100 let mut expected_c_stride = 1;
102 let mut is_c_contiguous = true;
103 for i in (0..shape.len()).rev() {
104 if strides[i] != expected_c_stride {
105 is_c_contiguous = false;
106 break;
107 }
108 expected_c_stride *= shape[i];
109 }
110
111 if is_c_contiguous {
112 return MemoryLayout::CContiguous;
113 }
114
115 let mut expected_f_stride = 1;
117 let mut is_f_contiguous = true;
118 for i in 0..shape.len() {
119 if strides[i] != expected_f_stride {
120 is_f_contiguous = false;
121 break;
122 }
123 expected_f_stride *= shape[i];
124 }
125
126 if is_f_contiguous {
127 return MemoryLayout::FContiguous;
128 }
129
130 MemoryLayout::Strided
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum CacheLevel {
136 L1,
138 L2,
140 L3,
142}
143
144#[derive(Debug, Clone, Copy)]
151pub struct CacheConfig {
152 pub level: CacheLevel,
154 pub size_bytes: usize,
156 pub line_size: usize,
158 pub associativity: usize,
160}
161
162impl CacheConfig {
163 pub fn new(
165 level: CacheLevel,
166 size_bytes: usize,
167 line_size: usize,
168 associativity: usize,
169 ) -> Self {
170 Self {
171 level,
172 size_bytes,
173 line_size,
174 associativity,
175 }
176 }
177
178 pub fn l1_default() -> Self {
180 Self::new(CacheLevel::L1, 32 * 1024, 64, 8)
181 }
182
183 pub fn l2_default() -> Self {
185 Self::new(CacheLevel::L2, 256 * 1024, 64, 8)
186 }
187
188 pub fn l3_default() -> Self {
190 Self::new(CacheLevel::L3, 8 * 1024 * 1024, 64, 16)
191 }
192
193 pub fn elements_per_line<T>(&self) -> usize {
195 let elem_size = std::mem::size_of::<T>();
196 self.line_size.checked_div(elem_size).unwrap_or(0)
197 }
198
199 pub fn elements_per_block<T>(&self) -> usize {
203 let elem_size = std::mem::size_of::<T>();
204 if elem_size == 0 {
205 return 0;
206 }
207
208 let usable_bytes = (self.size_bytes * 3) / 4;
210 usable_bytes / elem_size
211 }
212
213 pub fn tile_size_2d<T>(&self) -> (usize, usize) {
217 let block_elements = self.elements_per_block::<T>();
218 let tile_dim = (block_elements as f64).sqrt() as usize;
219 let tile_dim = tile_dim.max(1);
220
221 let elements_per_line = self.elements_per_line::<T>().max(1);
223 let aligned_dim = tile_dim.div_ceil(elements_per_line) * elements_per_line;
224
225 (aligned_dim, aligned_dim)
226 }
227}
228
229impl Default for CacheConfig {
230 fn default() -> Self {
231 Self::l2_default()
232 }
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq)]
237pub struct Block {
238 pub start: usize,
240 pub end: usize,
242}
243
244impl Block {
245 pub fn new(start: usize, end: usize) -> Self {
247 Self { start, end }
248 }
249
250 pub fn len(&self) -> usize {
252 self.end.saturating_sub(self.start)
253 }
254
255 pub fn is_empty(&self) -> bool {
257 self.len() == 0
258 }
259
260 pub fn iter(&self) -> std::ops::Range<usize> {
262 self.start..self.end
263 }
264}
265
266pub struct BlockedIterator {
283 total: usize,
284 block_size: usize,
285 current: usize,
286}
287
288impl BlockedIterator {
289 pub fn new(total: usize, block_size: usize) -> Self {
296 Self {
297 total,
298 block_size: block_size.max(1),
299 current: 0,
300 }
301 }
302
303 pub fn for_type<T>(total: usize, cache: CacheConfig) -> Self {
305 Self::new(total, cache.elements_per_block::<T>())
306 }
307}
308
309impl Iterator for BlockedIterator {
310 type Item = Block;
311
312 fn next(&mut self) -> Option<Self::Item> {
313 if self.current >= self.total {
314 return None;
315 }
316
317 let start = self.current;
318 let end = (start + self.block_size).min(self.total);
319 self.current = end;
320
321 Some(Block::new(start, end))
322 }
323
324 fn size_hint(&self) -> (usize, Option<usize>) {
325 let remaining = self.total.saturating_sub(self.current);
326 let count = remaining.div_ceil(self.block_size);
327 (count, Some(count))
328 }
329}
330
331impl ExactSizeIterator for BlockedIterator {}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
335pub struct Tile2D {
336 pub row_start: usize,
338 pub row_end: usize,
340 pub col_start: usize,
342 pub col_end: usize,
344}
345
346impl Tile2D {
347 pub fn new(row_start: usize, row_end: usize, col_start: usize, col_end: usize) -> Self {
349 Self {
350 row_start,
351 row_end,
352 col_start,
353 col_end,
354 }
355 }
356
357 pub fn rows(&self) -> usize {
359 self.row_end.saturating_sub(self.row_start)
360 }
361
362 pub fn cols(&self) -> usize {
364 self.col_end.saturating_sub(self.col_start)
365 }
366
367 pub fn len(&self) -> usize {
369 self.rows() * self.cols()
370 }
371
372 pub fn is_empty(&self) -> bool {
374 self.len() == 0
375 }
376}
377
378pub struct TiledIterator2D {
393 total_rows: usize,
394 total_cols: usize,
395 tile_rows: usize,
396 tile_cols: usize,
397 current_row: usize,
398 current_col: usize,
399}
400
401impl TiledIterator2D {
402 pub fn new(total_rows: usize, total_cols: usize, tile_rows: usize, tile_cols: usize) -> Self {
411 Self {
412 total_rows,
413 total_cols,
414 tile_rows: tile_rows.max(1),
415 tile_cols: tile_cols.max(1),
416 current_row: 0,
417 current_col: 0,
418 }
419 }
420
421 pub fn for_type<T>(total_rows: usize, total_cols: usize, cache: CacheConfig) -> Self {
423 let (tile_rows, tile_cols) = cache.tile_size_2d::<T>();
424 Self::new(total_rows, total_cols, tile_rows, tile_cols)
425 }
426}
427
428impl Iterator for TiledIterator2D {
429 type Item = Tile2D;
430
431 fn next(&mut self) -> Option<Self::Item> {
432 if self.current_row >= self.total_rows {
433 return None;
434 }
435
436 let row_start = self.current_row;
437 let row_end = (row_start + self.tile_rows).min(self.total_rows);
438 let col_start = self.current_col;
439 let col_end = (col_start + self.tile_cols).min(self.total_cols);
440
441 self.current_col += self.tile_cols;
443 if self.current_col >= self.total_cols {
444 self.current_col = 0;
445 self.current_row += self.tile_rows;
446 }
447
448 Some(Tile2D::new(row_start, row_end, col_start, col_end))
449 }
450
451 fn size_hint(&self) -> (usize, Option<usize>) {
452 let row_tiles = self.total_rows.div_ceil(self.tile_rows);
453 let col_tiles = self.total_cols.div_ceil(self.tile_cols);
454 let total_tiles = row_tiles * col_tiles;
455
456 let current_row_tile = self.current_row / self.tile_rows;
457 let current_col_tile = self.current_col / self.tile_cols;
458 let current_tile = current_row_tile * col_tiles + current_col_tile;
459
460 let remaining = total_tiles.saturating_sub(current_tile);
461 (remaining, Some(remaining))
462 }
463}
464
465impl ExactSizeIterator for TiledIterator2D {}
466
467#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub enum AccessPattern {
470 Sequential,
472 Reverse,
474 Random,
476 Strided(usize),
478 Blocked,
480}
481
482impl AccessPattern {
483 pub fn prefetch_distance(&self) -> usize {
487 match self {
488 AccessPattern::Sequential => 4,
489 AccessPattern::Reverse => 2,
490 AccessPattern::Random => 0,
491 AccessPattern::Strided(stride) => {
492 if *stride <= 8 {
493 2
494 } else {
495 1
496 }
497 }
498 AccessPattern::Blocked => 2,
499 }
500 }
501
502 pub fn benefits_from_prefetch(&self) -> bool {
504 !matches!(self, AccessPattern::Random)
505 }
506}
507
508#[derive(Debug, Clone)]
510pub struct OptimizationHints {
511 pub layout: MemoryLayout,
513 pub access_pattern: AccessPattern,
515 pub block_size: usize,
517 pub tile_size: Option<(usize, usize)>,
519 pub use_parallel: bool,
521 pub cache_efficiency: f64,
523}
524
525impl OptimizationHints {
526 pub fn analyze<T>(shape: &[usize], strides: &[usize]) -> Self {
533 let layout = detect_layout(shape, strides);
534 let total_elements: usize = shape.iter().product();
535 let total_bytes = total_elements * std::mem::size_of::<T>();
536
537 let cache = if total_bytes <= 32 * 1024 {
539 CacheConfig::l1_default()
540 } else if total_bytes <= 256 * 1024 {
541 CacheConfig::l2_default()
542 } else {
543 CacheConfig::l3_default()
544 };
545
546 let block_size = cache.elements_per_block::<T>();
547
548 let tile_size = if shape.len() >= 2 {
549 Some(cache.tile_size_2d::<T>())
550 } else {
551 None
552 };
553
554 let cache_efficiency = match layout {
556 MemoryLayout::CContiguous | MemoryLayout::FContiguous => 0.95,
557 MemoryLayout::Strided => 0.5,
558 MemoryLayout::Scalar => 1.0,
559 };
560
561 let access_pattern = if layout.is_contiguous() {
563 AccessPattern::Sequential
564 } else if !strides.is_empty() {
565 AccessPattern::Strided(strides.iter().min().copied().unwrap_or(1))
566 } else {
567 AccessPattern::Random
568 };
569
570 let use_parallel = total_elements > 10_000;
572
573 Self {
574 layout,
575 access_pattern,
576 block_size,
577 tile_size,
578 use_parallel,
579 cache_efficiency,
580 }
581 }
582
583 pub fn default_for<T>(total_elements: usize) -> Self {
585 let cache = CacheConfig::l2_default();
586 Self {
587 layout: MemoryLayout::CContiguous,
588 access_pattern: AccessPattern::Sequential,
589 block_size: cache.elements_per_block::<T>(),
590 tile_size: Some(cache.tile_size_2d::<T>()),
591 use_parallel: total_elements > 10_000,
592 cache_efficiency: 0.95,
593 }
594 }
595}
596
597impl Default for OptimizationHints {
598 fn default() -> Self {
599 Self {
600 layout: MemoryLayout::CContiguous,
601 access_pattern: AccessPattern::Sequential,
602 block_size: 4096,
603 tile_size: Some((64, 64)),
604 use_parallel: false,
605 cache_efficiency: 0.95,
606 }
607 }
608}
609
610pub struct StrideOptimizer {
614 strides: Vec<usize>,
616 shape: Vec<usize>,
618 iteration_order: Vec<usize>,
620}
621
622impl StrideOptimizer {
623 pub fn new(shape: &[usize], strides: &[usize]) -> Self {
625 let mut iteration_order: Vec<usize> = (0..shape.len()).collect();
626
627 iteration_order.sort_by_key(|&i| strides.get(i).copied().unwrap_or(0));
630
631 Self {
632 strides: strides.to_vec(),
633 shape: shape.to_vec(),
634 iteration_order,
635 }
636 }
637
638 pub fn optimal_iteration_order(&self) -> &[usize] {
642 &self.iteration_order
643 }
644
645 pub fn should_copy(&self) -> bool {
649 let layout = detect_layout(&self.shape, &self.strides);
650 if layout.is_contiguous() {
651 return false;
652 }
653
654 let min_stride = self.strides.iter().min().copied().unwrap_or(1);
656 min_stride > 4
657 }
658
659 pub fn bandwidth_efficiency(&self) -> f64 {
663 if self.strides.is_empty() {
664 return 1.0;
665 }
666
667 let min_stride = self.strides.iter().min().copied().unwrap_or(1) as f64;
669 (1.0 / min_stride).min(1.0)
670 }
671}
672
673pub fn cache_aware_copy<T: Copy>(src: &[T], dst: &mut [T]) {
687 let len = src.len().min(dst.len());
688 if len == 0 {
689 return;
690 }
691
692 let cache = CacheConfig::l1_default();
694 let block_size = cache.elements_per_block::<T>();
695
696 for block in BlockedIterator::new(len, block_size) {
697 dst[block.start..block.end].copy_from_slice(&src[block.start..block.end]);
698 }
699}
700
701pub fn cache_aware_transform<T, U, F>(src: &[T], dst: &mut [U], f: F)
715where
716 T: Copy,
717 F: Fn(T) -> U,
718{
719 let len = src.len().min(dst.len());
720 if len == 0 {
721 return;
722 }
723
724 let cache = CacheConfig::l1_default();
726 let elem_size = std::mem::size_of::<T>().max(std::mem::size_of::<U>());
727 let block_size = (cache.size_bytes * 3 / 4)
728 .checked_div(elem_size)
729 .unwrap_or(len);
730
731 for block in BlockedIterator::new(len, block_size) {
732 for i in block.start..block.end {
733 dst[i] = f(src[i]);
734 }
735 }
736}
737
738pub fn cache_aware_binary_op<T, U, V, F>(a: &[T], b: &[U], result: &mut [V], f: F)
753where
754 T: Copy,
755 U: Copy,
756 F: Fn(T, U) -> V,
757{
758 let len = a.len().min(b.len()).min(result.len());
759 if len == 0 {
760 return;
761 }
762
763 let cache = CacheConfig::l1_default();
765 let elem_size = std::mem::size_of::<T>()
766 .max(std::mem::size_of::<U>())
767 .max(std::mem::size_of::<V>());
768 let block_size = if elem_size > 0 {
769 (cache.size_bytes * 3 / 4) / (elem_size * 3) } else {
771 len
772 };
773
774 for block in BlockedIterator::new(len, block_size) {
775 for i in block.start..block.end {
776 result[i] = f(a[i], b[i]);
777 }
778 }
779}
780
781#[derive(Debug, Clone, Default)]
783pub struct AccessStats {
784 pub total_accesses: u64,
786 pub sequential_accesses: u64,
788 pub strided_accesses: u64,
790 pub random_accesses: u64,
792 pub estimated_miss_rate: f64,
794}
795
796impl AccessStats {
797 pub fn new() -> Self {
799 Self::default()
800 }
801
802 pub fn record_sequential(&mut self) {
804 self.total_accesses += 1;
805 self.sequential_accesses += 1;
806 }
807
808 pub fn record_strided(&mut self) {
810 self.total_accesses += 1;
811 self.strided_accesses += 1;
812 }
813
814 pub fn record_random(&mut self) {
816 self.total_accesses += 1;
817 self.random_accesses += 1;
818 }
819
820 pub fn cache_efficiency(&self) -> f64 {
822 if self.total_accesses == 0 {
823 return 1.0;
824 }
825
826 let seq_weight = 1.0;
827 let strided_weight = 0.5;
828 let random_weight = 0.1;
829
830 let weighted_sum = (self.sequential_accesses as f64 * seq_weight)
831 + (self.strided_accesses as f64 * strided_weight)
832 + (self.random_accesses as f64 * random_weight);
833
834 weighted_sum / self.total_accesses as f64
835 }
836
837 pub fn update_miss_rate(&mut self) {
839 if self.total_accesses == 0 {
842 self.estimated_miss_rate = 0.0;
843 return;
844 }
845
846 let seq_miss = 0.05;
847 let strided_miss = 0.30;
848 let random_miss = 0.90;
849
850 self.estimated_miss_rate = ((self.sequential_accesses as f64 * seq_miss)
851 + (self.strided_accesses as f64 * strided_miss)
852 + (self.random_accesses as f64 * random_miss))
853 / self.total_accesses as f64;
854 }
855}
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860
861 #[test]
862 fn test_detect_layout_c_contiguous() {
863 let layout = detect_layout(&[3, 4], &[4, 1]);
865 assert_eq!(layout, MemoryLayout::CContiguous);
866 }
867
868 #[test]
869 fn test_detect_layout_f_contiguous() {
870 let layout = detect_layout(&[3, 4], &[1, 3]);
872 assert_eq!(layout, MemoryLayout::FContiguous);
873 }
874
875 #[test]
876 fn test_detect_layout_strided() {
877 let layout = detect_layout(&[3, 4], &[8, 2]);
879 assert_eq!(layout, MemoryLayout::Strided);
880 }
881
882 #[test]
883 fn test_detect_layout_scalar() {
884 let layout = detect_layout(&[], &[]);
885 assert_eq!(layout, MemoryLayout::Scalar);
886
887 let layout = detect_layout(&[1], &[1]);
888 assert_eq!(layout, MemoryLayout::Scalar);
889 }
890
891 #[test]
892 fn test_cache_config_elements() {
893 let cache = CacheConfig::l1_default();
894
895 assert_eq!(cache.elements_per_line::<f64>(), 8);
897
898 assert_eq!(cache.elements_per_line::<f32>(), 16);
900 }
901
902 #[test]
903 fn test_blocked_iterator() {
904 let iter = BlockedIterator::new(100, 30);
905 let blocks: Vec<_> = iter.collect();
906
907 assert_eq!(blocks.len(), 4);
908 assert_eq!(blocks[0], Block::new(0, 30));
909 assert_eq!(blocks[1], Block::new(30, 60));
910 assert_eq!(blocks[2], Block::new(60, 90));
911 assert_eq!(blocks[3], Block::new(90, 100));
912 }
913
914 #[test]
915 fn test_blocked_iterator_exact_division() {
916 let iter = BlockedIterator::new(100, 25);
917 let blocks: Vec<_> = iter.collect();
918
919 assert_eq!(blocks.len(), 4);
920 assert_eq!(blocks[3], Block::new(75, 100));
921 }
922
923 #[test]
924 fn test_tiled_iterator_2d() {
925 let iter = TiledIterator2D::new(10, 10, 4, 4);
926 let tiles: Vec<_> = iter.collect();
927
928 assert_eq!(tiles.len(), 9);
930
931 assert_eq!(tiles[0].row_start, 0);
933 assert_eq!(tiles[0].row_end, 4);
934 assert_eq!(tiles[0].col_start, 0);
935 assert_eq!(tiles[0].col_end, 4);
936
937 let last = tiles
939 .last()
940 .expect("tiles should have at least one element");
941 assert_eq!(last.row_start, 8);
942 assert_eq!(last.row_end, 10);
943 assert_eq!(last.col_start, 8);
944 assert_eq!(last.col_end, 10);
945 }
946
947 #[test]
948 fn test_block_len() {
949 let block = Block::new(10, 25);
950 assert_eq!(block.len(), 15);
951 assert!(!block.is_empty());
952
953 let empty = Block::new(10, 10);
954 assert_eq!(empty.len(), 0);
955 assert!(empty.is_empty());
956 }
957
958 #[test]
959 fn test_tile_2d_len() {
960 let tile = Tile2D::new(0, 4, 0, 5);
961 assert_eq!(tile.rows(), 4);
962 assert_eq!(tile.cols(), 5);
963 assert_eq!(tile.len(), 20);
964 }
965
966 #[test]
967 fn test_optimization_hints() {
968 let hints = OptimizationHints::analyze::<f64>(&[100, 100], &[100, 1]);
970 assert_eq!(hints.layout, MemoryLayout::CContiguous);
971 assert_eq!(hints.access_pattern, AccessPattern::Sequential);
972 assert!(hints.cache_efficiency > 0.9);
973 }
974
975 #[test]
976 fn test_stride_optimizer() {
977 let optimizer = StrideOptimizer::new(&[3, 4], &[4, 1]);
979
980 let order = optimizer.optimal_iteration_order();
982 assert_eq!(order[0], 1); assert_eq!(order[1], 0); assert!(!optimizer.should_copy());
986 assert!(optimizer.bandwidth_efficiency() > 0.9);
987 }
988
989 #[test]
990 fn test_cache_aware_copy() {
991 let src = vec![1.0f64; 1000];
992 let mut dst = vec![0.0f64; 1000];
993
994 cache_aware_copy(&src, &mut dst);
995 assert_eq!(dst, src);
996 }
997
998 #[test]
999 fn test_cache_aware_transform() {
1000 let src = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1001 let mut dst = vec![0.0; 5];
1002
1003 cache_aware_transform(&src, &mut dst, |x| x * x);
1004 assert_eq!(dst, vec![1.0, 4.0, 9.0, 16.0, 25.0]);
1005 }
1006
1007 #[test]
1008 fn test_cache_aware_binary_op() {
1009 let a = vec![1.0, 2.0, 3.0, 4.0];
1010 let b = vec![10.0, 20.0, 30.0, 40.0];
1011 let mut result = vec![0.0; 4];
1012
1013 cache_aware_binary_op(&a, &b, &mut result, |x, y| x + y);
1014 assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0]);
1015 }
1016
1017 #[test]
1018 fn test_access_stats() {
1019 let mut stats = AccessStats::new();
1020
1021 stats.record_sequential();
1022 stats.record_sequential();
1023 stats.record_strided();
1024 stats.record_random();
1025
1026 assert_eq!(stats.total_accesses, 4);
1027 assert_eq!(stats.sequential_accesses, 2);
1028 assert_eq!(stats.strided_accesses, 1);
1029 assert_eq!(stats.random_accesses, 1);
1030
1031 assert!(stats.cache_efficiency() > 0.5);
1033
1034 stats.update_miss_rate();
1035 assert!(stats.estimated_miss_rate > 0.0);
1036 assert!(stats.estimated_miss_rate < 1.0);
1037 }
1038
1039 #[test]
1040 fn test_access_pattern_prefetch() {
1041 assert_eq!(AccessPattern::Sequential.prefetch_distance(), 4);
1042 assert_eq!(AccessPattern::Random.prefetch_distance(), 0);
1043 assert!(AccessPattern::Sequential.benefits_from_prefetch());
1044 assert!(!AccessPattern::Random.benefits_from_prefetch());
1045 }
1046}