use super::{OptimizationPolicy, PlanCacheKey};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TuningTraceSummary {
pub backend: String,
pub selected_strategy: String,
pub benchmark_count: usize,
pub enable_pointwise_fusion: bool,
pub enable_padic_matmul_valuation_skip: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BenchmarkRecord {
pub plan_key: PlanCacheKey,
pub backend: String,
pub strategy: String,
pub input_description: String,
pub duration_ns: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TuningRecord {
pub plan_key: PlanCacheKey,
pub backend: String,
pub selected_strategy: String,
pub benchmark_records: Vec<BenchmarkRecord>,
pub notes: Vec<String>,
}
impl TuningRecord {
pub fn new(
plan_key: PlanCacheKey,
backend: impl Into<String>,
strategy: impl Into<String>,
) -> Self {
Self {
plan_key,
backend: backend.into(),
selected_strategy: strategy.into(),
benchmark_records: Vec::new(),
notes: Vec::new(),
}
}
pub fn add_benchmark(&mut self, record: BenchmarkRecord) {
self.benchmark_records.push(record);
}
pub fn fastest_strategy(&self) -> Option<&str> {
self.benchmark_records
.iter()
.filter(|record| record.plan_key == self.plan_key && record.backend == self.backend)
.min_by_key(|record| record.duration_ns)
.map(|record| record.strategy.as_str())
}
pub fn optimization_policy(&self) -> OptimizationPolicy {
OptimizationPolicy::from_strategy(
self.fastest_strategy().unwrap_or(&self.selected_strategy),
)
}
pub fn selection_summary(&self) -> String {
match self.fastest_strategy() {
Some(strategy) => format!(
"tuning selected strategy {strategy} for backend {} after {} benchmarks",
self.backend,
self.benchmark_records
.iter()
.filter(
|record| record.plan_key == self.plan_key && record.backend == self.backend
)
.count()
),
None => format!(
"tuning fell back to declared strategy {} for backend {}",
self.selected_strategy, self.backend
),
}
}
pub fn trace_summary(&self) -> TuningTraceSummary {
let policy = self.optimization_policy();
TuningTraceSummary {
backend: self.backend.clone(),
selected_strategy: self
.fastest_strategy()
.unwrap_or(&self.selected_strategy)
.to_string(),
benchmark_count: self
.benchmark_records
.iter()
.filter(|record| record.plan_key == self.plan_key && record.backend == self.backend)
.count(),
enable_pointwise_fusion: policy.enable_pointwise_fusion,
enable_padic_matmul_valuation_skip: policy.enable_padic_matmul_valuation_skip,
}
}
}
impl OptimizationPolicy {
pub fn from_strategy(strategy: &str) -> Self {
match strategy {
"dense-baseline" | "unfused-baseline" => Self {
enable_pointwise_fusion: false,
enable_padic_matmul_valuation_skip: false,
},
"pointwise-fusion" => Self {
enable_pointwise_fusion: true,
enable_padic_matmul_valuation_skip: false,
},
"padic-valuation-skip" => Self {
enable_pointwise_fusion: false,
enable_padic_matmul_valuation_skip: true,
},
"all-optimizations" | "heuristic-baseline" => Self::default(),
_ => Self::default(),
}
}
}