use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OptimizationStrategy {
RowWiseCache,
ColumnWiseLocality,
BlockStructured,
DiagonalOptimized,
Hierarchical,
StreamingCompute,
SIMDVectorized,
ParallelWorkStealing,
AdaptiveHybrid,
}
#[derive(Debug)]
pub(crate) struct PatternMemory {
pub matrix_patterns: HashMap<MatrixFingerprint, OptimizationStrategy>,
#[allow(dead_code)]
pub access_patterns: VecDeque<AccessPattern>,
#[allow(dead_code)]
pub performance_cache: HashMap<String, f64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct MatrixFingerprint {
pub rows: usize,
pub cols: usize,
pub nnz: usize,
pub sparsity_pattern_hash: u64,
pub row_distribution_type: DistributionType,
pub column_distribution_type: DistributionType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum DistributionType {
Uniform,
Clustered,
BandDiagonal,
#[allow(dead_code)]
BlockStructured,
Random,
PowerLaw,
}
#[derive(Debug, Clone)]
pub(crate) struct AccessPattern {
#[allow(dead_code)]
pub timestamp: u64,
#[allow(dead_code)]
pub row_sequence: Vec<usize>,
#[allow(dead_code)]
pub column_sequence: Vec<usize>,
#[allow(dead_code)]
pub cache_hits: usize,
#[allow(dead_code)]
pub cache_misses: usize,
}
impl PatternMemory {
pub fn new(capacity: usize) -> Self {
Self {
matrix_patterns: HashMap::new(),
access_patterns: VecDeque::new(),
performance_cache: HashMap::new(),
}
}
pub fn store_pattern(
&mut self,
fingerprint: MatrixFingerprint,
strategy: OptimizationStrategy,
) {
self.matrix_patterns.insert(fingerprint, strategy);
}
pub fn get_strategy(&self, fingerprint: &MatrixFingerprint) -> Option<OptimizationStrategy> {
self.matrix_patterns.get(fingerprint).copied()
}
pub fn find_similar_patterns(
&self,
fingerprint: &MatrixFingerprint,
similarity_threshold: f64,
) -> Vec<(MatrixFingerprint, OptimizationStrategy)> {
let mut similar_patterns = Vec::new();
for (stored_fingerprint, strategy) in &self.matrix_patterns {
let similarity = self.compute_similarity(fingerprint, stored_fingerprint);
if similarity >= similarity_threshold {
similar_patterns.push((stored_fingerprint.clone(), *strategy));
}
}
similar_patterns.sort_by(|a, b| {
let sim_a = self.compute_similarity(fingerprint, &a.0);
let sim_b = self.compute_similarity(fingerprint, &b.0);
sim_b.partial_cmp(&sim_a).expect("Operation failed")
});
similar_patterns
}
fn compute_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
let size_similarity = self.size_similarity(fp1, fp2);
let sparsity_similarity = self.sparsity_similarity(fp1, fp2);
let pattern_similarity = self.pattern_similarity(fp1, fp2);
let distribution_similarity = self.distribution_similarity(fp1, fp2);
0.3 * size_similarity
+ 0.3 * sparsity_similarity
+ 0.2 * pattern_similarity
+ 0.2 * distribution_similarity
}
fn size_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
let row_ratio = (fp1.rows.min(fp2.rows) as f64) / (fp1.rows.max(fp2.rows) as f64);
let col_ratio = (fp1.cols.min(fp2.cols) as f64) / (fp1.cols.max(fp2.cols) as f64);
(row_ratio + col_ratio) / 2.0
}
fn sparsity_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
let sparsity1 = fp1.nnz as f64 / (fp1.rows * fp1.cols) as f64;
let sparsity2 = fp2.nnz as f64 / (fp2.rows * fp2.cols) as f64;
1.0 - (sparsity1 - sparsity2).abs()
}
fn pattern_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
let hash_diff = (fp1.sparsity_pattern_hash ^ fp2.sparsity_pattern_hash).count_ones() as f64;
1.0 - (hash_diff / 64.0) }
fn distribution_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
let row_match = if fp1.row_distribution_type == fp2.row_distribution_type {
1.0
} else {
0.0
};
let col_match = if fp1.column_distribution_type == fp2.column_distribution_type {
1.0
} else {
0.0
};
(row_match + col_match) / 2.0
}
pub fn record_access_pattern(&mut self, pattern: AccessPattern) {
self.access_patterns.push_back(pattern);
if self.access_patterns.len() > 1000 {
self.access_patterns.pop_front();
}
}
pub fn cache_performance(&mut self, key: String, performance: f64) {
self.performance_cache.insert(key, performance);
if self.performance_cache.len() > 10000 {
let keys_to_remove: Vec<String> =
self.performance_cache.keys().take(1000).cloned().collect();
for key in keys_to_remove {
self.performance_cache.remove(&key);
}
}
}
pub fn get_cached_performance(&self, key: &str) -> Option<f64> {
self.performance_cache.get(key).copied()
}
pub fn analyze_access_patterns(&self) -> AccessPatternAnalysis {
if self.access_patterns.is_empty() {
return AccessPatternAnalysis::default();
}
let mut sequential_count = 0;
let mut random_count = 0;
let mut block_count = 0;
for pattern in &self.access_patterns {
let access_type = self.classify_access_pattern(pattern);
match access_type {
AccessType::Sequential => sequential_count += 1,
AccessType::Random => random_count += 1,
AccessType::Block => block_count += 1,
}
}
let total = self.access_patterns.len();
AccessPatternAnalysis {
sequential_ratio: sequential_count as f64 / total as f64,
random_ratio: random_count as f64 / total as f64,
block_ratio: block_count as f64 / total as f64,
cache_hit_rate: self.compute_average_cache_hit_rate(),
}
}
fn classify_access_pattern(&self, pattern: &AccessPattern) -> AccessType {
if pattern.row_sequence.is_empty() {
return AccessType::Random;
}
let mut sequential = true;
for i in 1..pattern.row_sequence.len() {
if pattern.row_sequence[i] != pattern.row_sequence[i - 1] + 1
&& pattern.row_sequence[i] != pattern.row_sequence[i - 1]
{
sequential = false;
break;
}
}
if sequential {
return AccessType::Sequential;
}
let unique_rows: std::collections::HashSet<_> = pattern.row_sequence.iter().collect();
if unique_rows.len() < pattern.row_sequence.len() / 2 {
return AccessType::Block;
}
AccessType::Random
}
fn compute_average_cache_hit_rate(&self) -> f64 {
if self.access_patterns.is_empty() {
return 0.0;
}
let total_hits: usize = self.access_patterns.iter().map(|p| p.cache_hits).sum();
let total_accesses: usize = self
.access_patterns
.iter()
.map(|p| p.cache_hits + p.cache_misses)
.sum();
if total_accesses == 0 {
0.0
} else {
total_hits as f64 / total_accesses as f64
}
}
pub fn suggest_strategy(&self, fingerprint: &MatrixFingerprint) -> OptimizationStrategy {
if let Some(strategy) = self.get_strategy(fingerprint) {
return strategy;
}
let similar = self.find_similar_patterns(fingerprint, 0.7);
if !similar.is_empty() {
return similar[0].1; }
self.heuristic_strategy_suggestion(fingerprint)
}
fn heuristic_strategy_suggestion(
&self,
fingerprint: &MatrixFingerprint,
) -> OptimizationStrategy {
let sparsity = fingerprint.nnz as f64 / (fingerprint.rows * fingerprint.cols) as f64;
let size = fingerprint.rows * fingerprint.cols;
match (
fingerprint.row_distribution_type,
fingerprint.column_distribution_type,
) {
(DistributionType::BandDiagonal, _) | (_, DistributionType::BandDiagonal) => {
OptimizationStrategy::DiagonalOptimized
}
(DistributionType::Clustered, DistributionType::Clustered) => {
OptimizationStrategy::BlockStructured
}
_ => {
if sparsity < 0.01 && size > 10000 {
OptimizationStrategy::StreamingCompute
} else if size > 100000 {
OptimizationStrategy::ParallelWorkStealing
} else if sparsity > 0.1 {
OptimizationStrategy::SIMDVectorized
} else {
OptimizationStrategy::AdaptiveHybrid
}
}
}
}
pub fn get_statistics(&self) -> PatternMemoryStats {
PatternMemoryStats {
stored_patterns: self.matrix_patterns.len(),
access_patterns_recorded: self.access_patterns.len(),
cached_performances: self.performance_cache.len(),
most_common_strategy: self.get_most_common_strategy(),
}
}
fn get_most_common_strategy(&self) -> Option<OptimizationStrategy> {
let mut strategy_counts = HashMap::new();
for strategy in self.matrix_patterns.values() {
*strategy_counts.entry(*strategy).or_insert(0) += 1;
}
strategy_counts
.into_iter()
.max_by(|(_, count1), (_, count2)| count1.cmp(count2))
.map(|(strategy, _)| strategy)
}
}
impl MatrixFingerprint {
pub fn new<T>(rows: Vec<usize>, cols: Vec<usize>, data: &[T], shape: (usize, usize)) -> Self
where
T: std::fmt::Debug + Copy + PartialEq,
{
let nnz = data.len();
let sparsity_pattern_hash = Self::compute_pattern_hash(&rows, &cols);
let row_distribution_type = Self::analyze_distribution(&rows, shape.0);
let column_distribution_type = Self::analyze_distribution(&cols, shape.1);
Self {
rows: shape.0,
cols: shape.1,
nnz,
sparsity_pattern_hash,
row_distribution_type,
column_distribution_type,
}
}
fn compute_pattern_hash(rows: &[usize], cols: &[usize]) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
let step = (rows.len() / 100).max(1);
for i in (0..rows.len()).step_by(step) {
rows[i].hash(&mut hasher);
if i < cols.len() {
cols[i].hash(&mut hasher);
}
}
hasher.finish()
}
fn analyze_distribution(indices: &[usize], max_value: usize) -> DistributionType {
if indices.is_empty() {
return DistributionType::Uniform;
}
if Self::is_band_diagonal(indices) {
return DistributionType::BandDiagonal;
}
if Self::is_clustered(indices, max_value) {
return DistributionType::Clustered;
}
if Self::is_uniform(indices, max_value) {
return DistributionType::Uniform;
}
if Self::is_power_law(indices) {
return DistributionType::PowerLaw;
}
DistributionType::Random
}
fn is_band_diagonal(indices: &[usize]) -> bool {
if indices.len() < 2 {
return false;
}
let mut sorted_indices = indices.to_vec();
sorted_indices.sort_unstable();
let range = sorted_indices[sorted_indices.len() - 1] - sorted_indices[0];
let density = indices.len() as f64 / (range + 1) as f64;
density > 0.5 && range < indices.len() * 3
}
fn is_clustered(indices: &[usize], max_value: usize) -> bool {
if indices.is_empty() {
return false;
}
let mut histogram = vec![0; (max_value / 10).max(10)];
for &idx in indices {
let bucket = (idx * histogram.len()) / (max_value + 1);
if bucket < histogram.len() {
histogram[bucket] += 1;
}
}
histogram.sort_unstable();
let top_buckets = histogram.len() / 3;
let top_count: usize = histogram.iter().rev().take(top_buckets).sum();
let total_count: usize = histogram.iter().sum();
top_count as f64 / total_count as f64 > 0.7
}
fn is_uniform(indices: &[usize], max_value: usize) -> bool {
if indices.is_empty() {
return false;
}
let bucket_count = (max_value / 10).max(10);
let mut histogram = vec![0; bucket_count];
for &idx in indices {
let bucket = (idx * bucket_count) / (max_value + 1);
if bucket < histogram.len() {
histogram[bucket] += 1;
}
}
let mean = indices.len() as f64 / bucket_count as f64;
let variance: f64 = histogram
.iter()
.map(|&count| (count as f64 - mean).powi(2))
.sum::<f64>()
/ bucket_count as f64;
let std_dev = variance.sqrt();
std_dev / mean < 0.5 }
fn is_power_law(indices: &[usize]) -> bool {
if indices.is_empty() {
return false;
}
let mut sorted_indices = indices.to_vec();
sorted_indices.sort_unstable();
sorted_indices.dedup();
let mut frequency_map = HashMap::new();
for &idx in indices {
*frequency_map.entry(idx).or_insert(0) += 1;
}
let mut frequencies: Vec<usize> = frequency_map.values().copied().collect();
frequencies.sort_unstable();
frequencies.reverse();
if frequencies.len() < 3 {
return false;
}
let top_10_percent = (frequencies.len() / 10).max(1);
let top_sum: usize = frequencies.iter().take(top_10_percent).sum();
let total_sum: usize = frequencies.iter().sum();
top_sum as f64 / total_sum as f64 > 0.8
}
}
#[derive(Debug, Clone, Copy)]
enum AccessType {
Sequential,
Random,
Block,
}
#[derive(Debug, Clone)]
pub struct AccessPatternAnalysis {
pub sequential_ratio: f64,
pub random_ratio: f64,
pub block_ratio: f64,
pub cache_hit_rate: f64,
}
impl Default for AccessPatternAnalysis {
fn default() -> Self {
Self {
sequential_ratio: 0.0,
random_ratio: 0.0,
block_ratio: 0.0,
cache_hit_rate: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct PatternMemoryStats {
pub stored_patterns: usize,
pub access_patterns_recorded: usize,
pub cached_performances: usize,
pub most_common_strategy: Option<OptimizationStrategy>,
}