mod bandit;
mod online;
pub use bandit::{KernelArm, KernelBandit};
pub use online::OnlineLearner;
use serde::{Deserialize, Serialize};
#[cfg(feature = "hardware-detect")]
use crate::hardware::HardwareCapability;
use super::brick_tuner::BrickTuner;
use super::features::TunerFeatures;
use super::models::KernelRecommendation;
use super::pretrained;
#[cfg(feature = "hardware-detect")]
use super::types::QuantType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationResult {
pub throughput_weights: Vec<f32>,
pub local_mape: f32,
pub improvement_pct: f32,
pub hardware_id: String,
pub duration_secs: f32,
pub num_benchmarks: usize,
}
impl BrickTuner {
pub fn with_pretrained() -> Self {
let mut tuner = Self::new();
tuner.throughput.weights = pretrained::THROUGHPUT_WEIGHTS.to_vec();
tuner.throughput.mape = 0.082; tuner.throughput.sample_count = 10_000;
tuner.throughput.feature_importance = pretrained::FEATURE_IMPORTANCE
.iter()
.map(|(_, name, importance)| (name.to_string(), *importance))
.collect();
tuner.version = format!("{}-pretrained", Self::VERSION);
tuner
}
#[cfg(feature = "hardware-detect")]
#[allow(clippy::cast_precision_loss)] pub fn calibrate(&mut self) -> Result<CalibrationResult, super::error::TunerError> {
use std::time::Instant;
let start = Instant::now();
let hw = HardwareCapability::detect();
let hardware_id = format!("{:?}", hw.gpu);
let mut samples = Vec::new();
let baseline_tps = self.estimate_baseline_tps(&hw);
for batch_size in [1, 2, 4, 8] {
for model_size in [1.5, 7.0, 13.0] {
for quant in [QuantType::Q4K, QuantType::Q8_0] {
let features = TunerFeatures::builder()
.model_params_b(model_size)
.hidden_dim(4096)
.num_layers(32)
.batch_size(batch_size)
.quant_type(quant)
.build();
let estimated_tps = baseline_tps * (batch_size as f32).sqrt()
/ model_size.sqrt() as f32
* quant.bytes_per_param();
samples.push((features, estimated_tps.max(10.0)));
}
}
}
let num_benchmarks = samples.len();
let mut learner = OnlineLearner::new().with_learning_rate(0.01);
for _ in 0..10 {
for (features, target) in &samples {
learner.observe(&features.to_vector(), *target);
}
}
let pretrained_mape = self.throughput.mape;
self.throughput.weights = learner.weights().to_vec();
let mut total_error = 0.0;
for (features, target) in &samples {
let predicted = learner.predict(&features.to_vector());
total_error += ((predicted - target) / target).abs();
}
let local_mape = total_error / samples.len().max(1) as f32;
self.throughput.mape = local_mape;
let improvement_pct = ((pretrained_mape - local_mape) / pretrained_mape * 100.0).max(0.0);
let duration_secs = start.elapsed().as_secs_f32();
self.version = format!("{}-calibrated", Self::VERSION);
Ok(CalibrationResult {
throughput_weights: self.throughput.weights.clone(),
local_mape,
improvement_pct,
hardware_id,
duration_secs,
num_benchmarks,
})
}
#[cfg(feature = "hardware-detect")]
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
fn estimate_baseline_tps(&self, hw: &HardwareCapability) -> f32 {
let mem_bw_factor = hw.gpu.as_ref().map(|g| g.memory_bw_gbps / 1000.0).unwrap_or(0.5);
100.0 * mem_bw_factor as f32
}
pub fn online_learner(&self) -> OnlineLearner {
let mut learner = OnlineLearner::new();
learner.weights = self.throughput.weights.clone();
learner
}
pub fn apply_online_updates(&mut self, learner: &OnlineLearner) {
if learner.num_updates() > 0 {
self.throughput.weights = learner.weights().to_vec();
self.throughput.sample_count += learner.num_updates();
self.version = format!("{}-online-{}", Self::VERSION, learner.num_updates());
}
}
pub fn kernel_bandit(&self) -> KernelBandit {
KernelBandit::new()
}
#[allow(clippy::cast_precision_loss)] pub fn recommend_kernel_with_exploration(
&self,
features: &TunerFeatures,
bandit: &KernelBandit,
explore_prob: f32,
) -> KernelRecommendation {
let do_explore = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
bandit.total_pulls.hash(&mut hasher);
features.batch_size_norm.to_bits().hash(&mut hasher);
(hasher.finish() % 1000) as f32 / 1000.0 < explore_prob
};
if do_explore {
let kernel = bandit.select();
KernelRecommendation {
top_kernel: kernel,
confidence: 0.5, alternatives: vec![],
}
} else {
self.kernel.predict(features)
}
}
}