tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Tuning record: per-plan benchmark-backed optimization policy.
//!
//! `TuningRecord` is the per-plan optimization record. It carries
//! a list of `BenchmarkRecord`s (one per strategy), a chosen
//! `fastest_strategy`, and the `optimization_policy()` derived
//! from the recorded timings. The planner uses the policy to
//! bias the heuristic cost model.
//!
use super::{OptimizationPolicy, PlanCacheKey};

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TuningTraceSummary {
    pub backend: String,
    pub selected_strategy: String,
    pub benchmark_count: usize,
    pub enable_pointwise_fusion: bool,
    pub enable_padic_matmul_valuation_skip: bool,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BenchmarkRecord {
    pub plan_key: PlanCacheKey,
    pub backend: String,
    pub strategy: String,
    pub input_description: String,
    pub duration_ns: u64,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TuningRecord {
    pub plan_key: PlanCacheKey,
    pub backend: String,
    pub selected_strategy: String,
    pub benchmark_records: Vec<BenchmarkRecord>,
    pub notes: Vec<String>,
}

impl TuningRecord {
    pub fn new(
        plan_key: PlanCacheKey,
        backend: impl Into<String>,
        strategy: impl Into<String>,
    ) -> Self {
        Self {
            plan_key,
            backend: backend.into(),
            selected_strategy: strategy.into(),
            benchmark_records: Vec::new(),
            notes: Vec::new(),
        }
    }

    pub fn add_benchmark(&mut self, record: BenchmarkRecord) {
        self.benchmark_records.push(record);
    }

    pub fn fastest_strategy(&self) -> Option<&str> {
        self.benchmark_records
            .iter()
            .filter(|record| record.plan_key == self.plan_key && record.backend == self.backend)
            .min_by_key(|record| record.duration_ns)
            .map(|record| record.strategy.as_str())
    }

    pub fn optimization_policy(&self) -> OptimizationPolicy {
        OptimizationPolicy::from_strategy(
            self.fastest_strategy().unwrap_or(&self.selected_strategy),
        )
    }

    pub fn selection_summary(&self) -> String {
        match self.fastest_strategy() {
            Some(strategy) => format!(
                "tuning selected strategy {strategy} for backend {} after {} benchmarks",
                self.backend,
                self.benchmark_records
                    .iter()
                    .filter(
                        |record| record.plan_key == self.plan_key && record.backend == self.backend
                    )
                    .count()
            ),
            None => format!(
                "tuning fell back to declared strategy {} for backend {}",
                self.selected_strategy, self.backend
            ),
        }
    }

    pub fn trace_summary(&self) -> TuningTraceSummary {
        let policy = self.optimization_policy();
        TuningTraceSummary {
            backend: self.backend.clone(),
            selected_strategy: self
                .fastest_strategy()
                .unwrap_or(&self.selected_strategy)
                .to_string(),
            benchmark_count: self
                .benchmark_records
                .iter()
                .filter(|record| record.plan_key == self.plan_key && record.backend == self.backend)
                .count(),
            enable_pointwise_fusion: policy.enable_pointwise_fusion,
            enable_padic_matmul_valuation_skip: policy.enable_padic_matmul_valuation_skip,
        }
    }
}

impl OptimizationPolicy {
    pub fn from_strategy(strategy: &str) -> Self {
        match strategy {
            "dense-baseline" | "unfused-baseline" => Self {
                enable_pointwise_fusion: false,
                enable_padic_matmul_valuation_skip: false,
            },
            "pointwise-fusion" => Self {
                enable_pointwise_fusion: true,
                enable_padic_matmul_valuation_skip: false,
            },
            "padic-valuation-skip" => Self {
                enable_pointwise_fusion: false,
                enable_padic_matmul_valuation_skip: true,
            },
            "all-optimizations" | "heuristic-baseline" => Self::default(),
            _ => Self::default(),
        }
    }
}