use crate::{Result, Shape};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct OperationMetrics {
pub op_name: String,
pub input_shapes: Vec<Shape>,
pub duration_ns: u64,
pub memory_bandwidth: f64,
pub cpu_utilization: f32,
pub cache_hit_rate: f32,
pub hardware_features: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ExecutionStrategy {
Sequential,
Parallel { num_threads: usize },
Simd { instruction_set: String },
Gpu { device_id: u32 },
Hybrid { cpu_ratio_percent: u8 },
Custom { algorithm: String },
}
type PerformanceMap = HashMap<(String, Vec<Shape>, ExecutionStrategy), f64>;
#[derive(Debug, Clone)]
pub struct PerformancePredictor {
metrics_history: Arc<RwLock<Vec<OperationMetrics>>>,
strategy_performance: Arc<RwLock<PerformanceMap>>,
learning_rate: f64,
}
impl Default for PerformancePredictor {
fn default() -> Self {
Self::new()
}
}
impl PerformancePredictor {
pub fn new() -> Self {
Self {
metrics_history: Arc::new(RwLock::new(Vec::new())),
strategy_performance: Arc::new(RwLock::new(HashMap::new())),
learning_rate: 0.1,
}
}
pub fn predict_best_strategy(&self, op_name: &str, shapes: &[Shape]) -> ExecutionStrategy {
let performance_map = self
.strategy_performance
.read()
.expect("read lock should not be poisoned");
let mut best_strategy = ExecutionStrategy::Sequential;
let mut best_performance = f64::INFINITY;
for ((stored_op, stored_shapes, strategy), &performance) in performance_map.iter() {
if stored_op == op_name
&& self.shapes_match(stored_shapes, shapes)
&& performance < best_performance
{
best_performance = performance;
best_strategy = strategy.clone();
}
}
if best_performance == f64::INFINITY {
self.heuristic_strategy_selection(shapes)
} else {
best_strategy
}
}
pub fn update_performance(&self, metrics: &OperationMetrics, strategy: ExecutionStrategy) {
let mut history = self
.metrics_history
.write()
.expect("write lock should not be poisoned");
history.push(metrics.clone());
let mut performance_map = self
.strategy_performance
.write()
.expect("write lock should not be poisoned");
let key = (
metrics.op_name.clone(),
metrics.input_shapes.clone(),
strategy,
);
let new_performance = metrics.duration_ns as f64;
let entry = performance_map.entry(key).or_insert(new_performance);
*entry = (1.0 - self.learning_rate) * *entry + self.learning_rate * new_performance;
if history.len() > 10000 {
history.drain(..1000);
}
}
fn shapes_match(&self, historical: &[Shape], current: &[Shape]) -> bool {
if historical.len() != current.len() {
return false;
}
for (hist_shape, curr_shape) in historical.iter().zip(current.iter()) {
if hist_shape.dims() != curr_shape.dims() {
return false;
}
let hist_size: usize = hist_shape.size();
let curr_size: usize = curr_shape.size();
let size_ratio = (hist_size.max(curr_size) as f64) / (hist_size.min(curr_size) as f64);
if size_ratio > 1.2 {
return false;
}
}
true
}
fn heuristic_strategy_selection(&self, shapes: &[Shape]) -> ExecutionStrategy {
let total_elements: usize = shapes.iter().map(|s| s.size()).sum();
match total_elements {
0..=1000 => ExecutionStrategy::Sequential,
1001..=100000 => {
if self.has_avx2() {
ExecutionStrategy::Simd {
instruction_set: "avx2".to_string(),
}
} else if self.has_neon() {
ExecutionStrategy::Simd {
instruction_set: "neon".to_string(),
}
} else {
ExecutionStrategy::Parallel { num_threads: 4 }
}
}
100001..=10000000 => ExecutionStrategy::Parallel {
num_threads: num_cpus::get().min(16),
},
_ => ExecutionStrategy::Hybrid {
cpu_ratio_percent: 30,
},
}
}
fn has_avx2(&self) -> bool {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
false
}
fn has_neon(&self) -> bool {
#[cfg(target_arch = "aarch64")]
{
std::arch::is_aarch64_feature_detected!("neon")
}
#[cfg(not(target_arch = "aarch64"))]
false
}
}
pub struct AdaptiveTuner {
predictor: PerformancePredictor,
active_strategies: Arc<Mutex<HashMap<String, ExecutionStrategy>>>,
profiling_enabled: bool,
}
impl AdaptiveTuner {
pub fn new() -> Self {
Self {
predictor: PerformancePredictor::new(),
active_strategies: Arc::new(Mutex::new(HashMap::new())),
profiling_enabled: true,
}
}
pub fn execute_with_tuning<F, T>(
&self,
op_name: &str,
shapes: &[Shape],
operation: F,
) -> Result<T>
where
F: Fn(ExecutionStrategy) -> Result<T>,
{
let cache_key = self.create_cache_key(op_name, shapes);
let strategy = {
let cache = self
.active_strategies
.lock()
.expect("lock should not be poisoned");
cache.get(&cache_key).cloned()
}
.unwrap_or_else(|| {
self.predictor.predict_best_strategy(op_name, shapes)
});
let start_time = Instant::now();
let result = operation(strategy.clone())?;
let duration = start_time.elapsed();
if self.profiling_enabled {
let metrics = OperationMetrics {
op_name: op_name.to_string(),
input_shapes: shapes.to_vec(),
duration_ns: duration.as_nanos() as u64,
memory_bandwidth: self.estimate_memory_bandwidth(shapes, duration),
cpu_utilization: self.get_cpu_utilization(),
cache_hit_rate: 0.95, hardware_features: self.get_active_features(&strategy),
};
self.predictor
.update_performance(&metrics, strategy.clone());
let mut cache = self
.active_strategies
.lock()
.expect("lock should not be poisoned");
cache.insert(cache_key, strategy);
}
Ok(result)
}
pub fn set_profiling_enabled(&mut self, enabled: bool) {
self.profiling_enabled = enabled;
}
fn create_cache_key(&self, op_name: &str, shapes: &[Shape]) -> String {
let shapes_str = shapes
.iter()
.map(|shape| format!("{shape:?}"))
.collect::<Vec<_>>()
.join(",");
format!("{op_name}:{shapes_str}")
}
pub fn clear_strategy_cache(&self) {
let mut cache = self
.active_strategies
.lock()
.expect("lock should not be poisoned");
cache.clear();
}
pub fn get_performance_stats(&self) -> Result<String> {
let history = self
.predictor
.metrics_history
.read()
.expect("read lock should not be poisoned");
if history.is_empty() {
return Ok("No performance data collected yet.".to_string());
}
let mut stats = String::new();
stats.push_str("Adaptive Tuning Performance Statistics\n");
stats.push_str("======================================\n");
stats.push_str(&format!("Total operations profiled: {}\n", history.len()));
let mut op_stats: HashMap<String, Vec<&OperationMetrics>> = HashMap::new();
for metrics in history.iter() {
op_stats
.entry(metrics.op_name.clone())
.or_default()
.push(metrics);
}
for (op_name, metrics) in op_stats {
let avg_duration =
metrics.iter().map(|m| m.duration_ns).sum::<u64>() / metrics.len() as u64;
let avg_bandwidth =
metrics.iter().map(|m| m.memory_bandwidth).sum::<f64>() / metrics.len() as f64;
stats.push_str(&format!(
"\n{}: {} executions, avg {:.2}ms, {:.2} GB/s\n",
op_name,
metrics.len(),
avg_duration as f64 / 1_000_000.0,
avg_bandwidth / 1e9
));
}
Ok(stats)
}
fn estimate_memory_bandwidth(&self, shapes: &[Shape], duration: Duration) -> f64 {
let total_elements: usize = shapes.iter().map(|s| s.size()).sum();
let estimated_bytes = total_elements * 8;
if duration.as_nanos() == 0 {
0.0
} else {
(estimated_bytes as f64) / (duration.as_secs_f64())
}
}
fn get_cpu_utilization(&self) -> f32 {
0.8
}
fn get_active_features(&self, strategy: &ExecutionStrategy) -> Vec<String> {
match strategy {
ExecutionStrategy::Simd { instruction_set } => vec![instruction_set.clone()],
ExecutionStrategy::Gpu { .. } => vec!["gpu".to_string()],
ExecutionStrategy::Parallel { .. } => vec!["multi-thread".to_string()],
_ => vec![],
}
}
}
impl Default for AdaptiveTuner {
fn default() -> Self {
Self::new()
}
}
lazy_static::lazy_static! {
pub static ref GLOBAL_TUNER: Arc<Mutex<AdaptiveTuner>> =
Arc::new(Mutex::new(AdaptiveTuner::new()));
}
pub fn execute_with_adaptive_tuning<F, T>(
op_name: &str,
shapes: &[Shape],
operation: F,
) -> Result<T>
where
F: Fn(ExecutionStrategy) -> Result<T>,
{
let tuner = GLOBAL_TUNER.lock().expect("lock should not be poisoned");
tuner.execute_with_tuning(op_name, shapes, operation)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_performance_predictor_creation() {
let predictor = PerformancePredictor::new();
let strategy = predictor.predict_best_strategy("test_op", &[Shape::from_slice(&[10, 10])]);
assert!(matches!(
strategy,
ExecutionStrategy::Sequential | ExecutionStrategy::Simd { .. }
));
}
#[test]
fn test_adaptive_tuner_execution() {
let tuner = AdaptiveTuner::new();
let result =
tuner.execute_with_tuning("test_add", &[Shape::from_slice(&[100])], |strategy| {
Ok(format!("Executed with {:?}", strategy))
});
assert!(result.is_ok());
assert!(result
.expect("test: operation should succeed")
.contains("Executed with"));
}
#[test]
fn test_heuristic_strategy_selection() {
let predictor = PerformancePredictor::new();
let small_strategy = predictor.heuristic_strategy_selection(&[Shape::from_slice(&[10])]);
assert!(matches!(small_strategy, ExecutionStrategy::Sequential));
let large_strategy = predictor.heuristic_strategy_selection(&[Shape::from_slice(&[10000])]);
assert!(matches!(
large_strategy,
ExecutionStrategy::Parallel { .. } | ExecutionStrategy::Simd { .. }
));
}
#[test]
fn test_performance_metrics_update() {
let predictor = PerformancePredictor::new();
let metrics = OperationMetrics {
op_name: "test_op".to_string(),
input_shapes: vec![Shape::from_slice(&[100, 100])],
duration_ns: 1000000,
memory_bandwidth: 1e9,
cpu_utilization: 0.8,
cache_hit_rate: 0.95,
hardware_features: vec!["avx2".to_string()],
};
predictor.update_performance(
&metrics,
ExecutionStrategy::Simd {
instruction_set: "avx2".to_string(),
},
);
let predicted =
predictor.predict_best_strategy("test_op", &[Shape::from_slice(&[100, 100])]);
assert!(matches!(predicted, ExecutionStrategy::Simd { .. }));
}
}