#![allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AccessPattern {
Sequential,
Strided {
stride: usize,
},
Random,
Gather,
}
impl AccessPattern {
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn cache_efficiency(&self) -> f32 {
match self {
Self::Sequential => 1.0,
Self::Strided { stride } => {
let s = *stride as f32;
(1.0 / s.sqrt()).clamp(0.1, 1.0)
}
Self::Random => 0.1,
Self::Gather => 0.2,
}
}
#[must_use]
pub fn label(&self) -> &'static str {
match self {
Self::Sequential => "sequential",
Self::Strided { .. } => "strided",
Self::Random => "random",
Self::Gather => "gather",
}
}
}
#[derive(Debug, Clone)]
pub struct RegionStat {
pub region: String,
pub pattern: AccessPattern,
pub bytes_accessed: u64,
}
impl RegionStat {
#[must_use]
pub fn new(region: impl Into<String>, pattern: AccessPattern, bytes_accessed: u64) -> Self {
Self {
region: region.into(),
pattern,
bytes_accessed,
}
}
}
#[derive(Debug, Default)]
pub struct MemoryAccessProfile {
regions: Vec<RegionStat>,
}
impl MemoryAccessProfile {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn record(&mut self, stat: RegionStat) {
self.regions.push(stat);
}
#[must_use]
pub fn dominant_pattern(&self) -> AccessPattern {
if self.regions.is_empty() {
return AccessPattern::Sequential;
}
let mut seq: u64 = 0;
let mut strided: u64 = 0;
let mut random: u64 = 0;
let mut gather: u64 = 0;
for r in &self.regions {
match &r.pattern {
AccessPattern::Sequential => seq += r.bytes_accessed,
AccessPattern::Strided { .. } => strided += r.bytes_accessed,
AccessPattern::Random => random += r.bytes_accessed,
AccessPattern::Gather => gather += r.bytes_accessed,
}
}
let max = seq.max(strided).max(random).max(gather);
if max == seq {
AccessPattern::Sequential
} else if max == strided {
AccessPattern::Strided { stride: 1 }
} else if max == random {
AccessPattern::Random
} else {
AccessPattern::Gather
}
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn weighted_cache_efficiency(&self) -> f32 {
let total_bytes: u64 = self.regions.iter().map(|r| r.bytes_accessed).sum();
if total_bytes == 0 {
return 1.0;
}
let weighted: f64 = self
.regions
.iter()
.map(|r| r.pattern.cache_efficiency() as f64 * r.bytes_accessed as f64)
.sum();
(weighted / total_bytes as f64) as f32
}
#[must_use]
pub fn region_count(&self) -> usize {
self.regions.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LayoutSuggestion {
NoChange,
Interleave,
TiledLayout {
tile_width: u32,
tile_height: u32,
},
Prefetch {
distance: u32,
},
}
#[derive(Debug, Default)]
pub struct MemoryOptimizer;
impl MemoryOptimizer {
#[must_use]
pub fn new() -> Self {
Self
}
#[must_use]
pub fn suggest_layout(&self, profile: &MemoryAccessProfile) -> LayoutSuggestion {
match profile.dominant_pattern() {
AccessPattern::Sequential => {
LayoutSuggestion::Prefetch { distance: 8 }
}
AccessPattern::Strided { stride } if stride <= 4 => {
LayoutSuggestion::Interleave
}
AccessPattern::Strided { .. } => {
LayoutSuggestion::TiledLayout {
tile_width: 32,
tile_height: 32,
}
}
AccessPattern::Random | AccessPattern::Gather => {
LayoutSuggestion::TiledLayout {
tile_width: 16,
tile_height: 16,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sequential_cache_efficiency() {
assert!((AccessPattern::Sequential.cache_efficiency() - 1.0).abs() < 1e-6);
}
#[test]
fn test_random_cache_efficiency() {
assert!((AccessPattern::Random.cache_efficiency() - 0.1).abs() < 1e-6);
}
#[test]
fn test_gather_cache_efficiency() {
assert!((AccessPattern::Gather.cache_efficiency() - 0.2).abs() < 1e-6);
}
#[test]
fn test_strided_efficiency_decreases_with_stride() {
let s1 = AccessPattern::Strided { stride: 1 }.cache_efficiency();
let s16 = AccessPattern::Strided { stride: 16 }.cache_efficiency();
assert!(s1 > s16);
}
#[test]
fn test_strided_efficiency_floor() {
let eff = AccessPattern::Strided { stride: 10_000 }.cache_efficiency();
assert!(eff >= 0.1);
}
#[test]
fn test_access_pattern_labels() {
assert_eq!(AccessPattern::Sequential.label(), "sequential");
assert_eq!(AccessPattern::Random.label(), "random");
assert_eq!(AccessPattern::Gather.label(), "gather");
assert_eq!(AccessPattern::Strided { stride: 4 }.label(), "strided");
}
#[test]
fn test_profile_empty_dominant_pattern() {
let profile = MemoryAccessProfile::new();
assert_eq!(profile.dominant_pattern(), AccessPattern::Sequential);
}
#[test]
fn test_profile_dominant_random_when_heaviest() {
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("lut", AccessPattern::Random, 1_000_000));
profile.record(RegionStat::new("fb", AccessPattern::Sequential, 100));
assert_eq!(profile.dominant_pattern(), AccessPattern::Random);
}
#[test]
fn test_profile_weighted_efficiency_all_sequential() {
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("fb", AccessPattern::Sequential, 4096));
assert!((profile.weighted_cache_efficiency() - 1.0).abs() < 1e-5);
}
#[test]
fn test_profile_region_count() {
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("a", AccessPattern::Sequential, 100));
profile.record(RegionStat::new("b", AccessPattern::Random, 200));
assert_eq!(profile.region_count(), 2);
}
#[test]
fn test_optimizer_sequential_suggests_prefetch() {
let optimizer = MemoryOptimizer::new();
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("fb", AccessPattern::Sequential, 1024));
assert!(matches!(
optimizer.suggest_layout(&profile),
LayoutSuggestion::Prefetch { .. }
));
}
#[test]
fn test_optimizer_random_suggests_tiled() {
let optimizer = MemoryOptimizer::new();
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("lut", AccessPattern::Random, 1024));
assert!(matches!(
optimizer.suggest_layout(&profile),
LayoutSuggestion::TiledLayout { .. }
));
}
#[test]
fn test_optimizer_small_stride_suggests_interleave() {
let optimizer = MemoryOptimizer::new();
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("ch", AccessPattern::Strided { stride: 2 }, 1024));
assert_eq!(optimizer.suggest_layout(&profile), LayoutSuggestion::Interleave);
}
#[test]
fn test_optimizer_large_stride_suggests_tiled() {
let optimizer = MemoryOptimizer::new();
let mut profile = MemoryAccessProfile::new();
profile.record(RegionStat::new("row", AccessPattern::Strided { stride: 1920 }, 1024));
assert!(matches!(
optimizer.suggest_layout(&profile),
LayoutSuggestion::TiledLayout { .. }
));
}
}