use super::config::PerformanceProfile;
use crate::{QuantConfig, TorshResult};
use std::collections::{HashMap, VecDeque};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct WorkloadPatternAnalyzer {
patterns: HashMap<String, WorkloadPattern>,
current_pattern: Option<String>,
pattern_history: VecDeque<PatternTransition>,
classifier: PatternClassifier,
}
#[derive(Debug, Clone)]
pub struct WorkloadPattern {
pub name: String,
pub features: Vec<f32>,
pub optimal_config: QuantConfig,
pub performance_profile: PerformanceProfile,
pub frequency: usize,
}
#[derive(Debug, Clone)]
pub struct PatternTransition {
pub from_pattern: Option<String>,
pub to_pattern: String,
pub timestamp: Instant,
pub transition_cost: f32,
}
#[derive(Debug, Clone)]
pub struct PatternClassifier {
cluster_centers: Vec<Vec<f32>>,
#[allow(dead_code)]
confidence_threshold: f32,
feature_weights: Vec<f32>,
}
impl WorkloadPatternAnalyzer {
pub fn new() -> Self {
let mut patterns = HashMap::new();
patterns.insert(
"compute_intensive".to_string(),
WorkloadPattern {
name: "compute_intensive".to_string(),
features: vec![
0.8, 0.2, 0.1, 0.9, 0.7, 0.3, 0.5, 0.8, 0.6, 0.4, 0.3, 0.7, 0.5, 0.2, 0.8, 0.6,
],
optimal_config: QuantConfig::default(),
performance_profile: PerformanceProfile {
avg_execution_time: 5.0,
memory_usage: 200.0,
energy_consumption: 25.0,
cache_efficiency: 0.6,
},
frequency: 0,
},
);
patterns.insert(
"memory_bound".to_string(),
WorkloadPattern {
name: "memory_bound".to_string(),
features: vec![
0.3, 0.8, 0.9, 0.2, 0.1, 0.7, 0.8, 0.3, 0.4, 0.9, 0.7, 0.2, 0.5, 0.8, 0.3, 0.6,
],
optimal_config: QuantConfig::default(),
performance_profile: PerformanceProfile {
avg_execution_time: 2.0,
memory_usage: 500.0,
energy_consumption: 15.0,
cache_efficiency: 0.9,
},
frequency: 0,
},
);
patterns.insert(
"balanced".to_string(),
WorkloadPattern {
name: "balanced".to_string(),
features: vec![0.5; 16],
optimal_config: QuantConfig::default(),
performance_profile: PerformanceProfile::default(),
frequency: 0,
},
);
Self {
patterns,
current_pattern: None,
pattern_history: VecDeque::new(),
classifier: PatternClassifier::new(),
}
}
pub fn analyze_pattern(&mut self, features: &[f32]) -> TorshResult<Option<String>> {
let classified_pattern = self.classifier.classify_pattern(features, &self.patterns)?;
if let Some(ref pattern_name) = classified_pattern {
if let Some(pattern) = self.patterns.get_mut(pattern_name) {
pattern.frequency += 1;
}
if self.current_pattern.as_ref() != Some(pattern_name) {
let transition = PatternTransition {
from_pattern: self.current_pattern.clone(),
to_pattern: pattern_name.clone(),
timestamp: Instant::now(),
transition_cost: self
.calculate_transition_cost(&self.current_pattern, pattern_name),
};
self.pattern_history.push_back(transition);
if self.pattern_history.len() > 100 {
self.pattern_history.pop_front();
}
self.current_pattern = Some(pattern_name.clone());
}
}
Ok(classified_pattern)
}
fn calculate_transition_cost(&self, from: &Option<String>, to: &String) -> f32 {
match from {
Some(from_pattern) if from_pattern == to => 0.0, Some(from_pattern) => {
match (from_pattern.as_str(), to.as_str()) {
("compute_intensive", "memory_bound") => 0.3,
("memory_bound", "compute_intensive") => 0.4,
("balanced", _) => 0.1,
(_, "balanced") => 0.1,
_ => 0.2,
}
}
None => 0.0, }
}
pub fn get_current_pattern(&self) -> &Option<String> {
&self.current_pattern
}
pub fn get_pattern(&self, name: &str) -> Option<&WorkloadPattern> {
self.patterns.get(name)
}
pub fn get_all_patterns(&self) -> &HashMap<String, WorkloadPattern> {
&self.patterns
}
pub fn add_pattern(&mut self, pattern: WorkloadPattern) {
self.patterns.insert(pattern.name.clone(), pattern);
}
pub fn get_pattern_history(&self) -> &VecDeque<PatternTransition> {
&self.pattern_history
}
pub fn get_pattern_statistics(&self) -> PatternStatistics {
let total_frequency: usize = self.patterns.values().map(|p| p.frequency).sum();
let most_common_pattern = self
.patterns
.iter()
.max_by_key(|(_, pattern)| pattern.frequency)
.map(|(name, _)| name.clone());
let transition_count = self.pattern_history.len();
let avg_transition_cost = if transition_count > 0 {
self.pattern_history
.iter()
.map(|t| t.transition_cost)
.sum::<f32>()
/ transition_count as f32
} else {
0.0
};
PatternStatistics {
total_patterns: self.patterns.len(),
total_frequency,
most_common_pattern,
current_pattern: self.current_pattern.clone(),
transition_count,
avg_transition_cost,
}
}
pub fn learn_pattern(
&mut self,
name: String,
features: Vec<f32>,
performance: PerformanceProfile,
) {
let pattern = WorkloadPattern {
name: name.clone(),
features,
optimal_config: QuantConfig::default(), performance_profile: performance,
frequency: 1,
};
self.patterns.insert(name, pattern);
self.classifier.update_clusters(&self.patterns);
}
}
impl Default for WorkloadPatternAnalyzer {
fn default() -> Self {
Self::new()
}
}
impl PatternClassifier {
fn new() -> Self {
Self {
cluster_centers: Vec::new(),
confidence_threshold: 0.8,
feature_weights: vec![1.0; 16], }
}
fn classify_pattern(
&self,
features: &[f32],
patterns: &HashMap<String, WorkloadPattern>,
) -> TorshResult<Option<String>> {
if patterns.is_empty() {
return Ok(None);
}
let mut best_match = None;
let mut best_distance = f32::INFINITY;
for (name, pattern) in patterns {
let distance = self.calculate_distance(features, &pattern.features);
if distance < best_distance {
best_distance = distance;
best_match = Some(name.clone());
}
}
if best_distance < 2.0 {
Ok(best_match)
} else {
Ok(Some("unknown".to_string()))
}
}
fn calculate_distance(&self, features1: &[f32], features2: &[f32]) -> f32 {
let min_len = features1.len().min(features2.len());
let mut distance = 0.0;
for i in 0..min_len {
let weight = if i < self.feature_weights.len() {
self.feature_weights[i]
} else {
1.0
};
let diff = features1[i] - features2[i];
distance += weight * diff * diff;
}
distance.sqrt()
}
fn update_clusters(&mut self, patterns: &HashMap<String, WorkloadPattern>) {
self.cluster_centers.clear();
for pattern in patterns.values() {
self.cluster_centers.push(pattern.features.clone());
}
}
pub fn update_feature_weights(&mut self, _patterns: &HashMap<String, WorkloadPattern>) {
for weight in &mut self.feature_weights {
*weight = (*weight * 0.9 + 0.1).max(0.1).min(2.0);
}
}
}
#[derive(Debug, Clone)]
pub struct PatternStatistics {
pub total_patterns: usize,
pub total_frequency: usize,
pub most_common_pattern: Option<String>,
pub current_pattern: Option<String>,
pub transition_count: usize,
pub avg_transition_cost: f32,
}
impl Default for PatternStatistics {
fn default() -> Self {
Self {
total_patterns: 0,
total_frequency: 0,
most_common_pattern: None,
current_pattern: None,
transition_count: 0,
avg_transition_cost: 0.0,
}
}
}