trueno/tuner/evolution/
mod.rs1mod bandit;
6mod online;
7
8pub use bandit::{KernelArm, KernelBandit};
9pub use online::OnlineLearner;
10
11use serde::{Deserialize, Serialize};
12
13#[cfg(feature = "hardware-detect")]
14use crate::hardware::HardwareCapability;
15
16use super::brick_tuner::BrickTuner;
17use super::features::TunerFeatures;
18use super::models::KernelRecommendation;
19use super::pretrained;
20#[cfg(feature = "hardware-detect")]
21use super::types::QuantType;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct CalibrationResult {
30 pub throughput_weights: Vec<f32>,
32 pub local_mape: f32,
34 pub improvement_pct: f32,
36 pub hardware_id: String,
38 pub duration_secs: f32,
40 pub num_benchmarks: usize,
42}
43
44impl BrickTuner {
49 pub fn with_pretrained() -> Self {
58 let mut tuner = Self::new();
59
60 tuner.throughput.weights = pretrained::THROUGHPUT_WEIGHTS.to_vec();
62 tuner.throughput.mape = 0.082; tuner.throughput.sample_count = 10_000;
64
65 tuner.throughput.feature_importance = pretrained::FEATURE_IMPORTANCE
67 .iter()
68 .map(|(_, name, importance)| (name.to_string(), *importance))
69 .collect();
70
71 tuner.version = format!("{}-pretrained", Self::VERSION);
72 tuner
73 }
74
75 #[cfg(feature = "hardware-detect")]
84 #[allow(clippy::cast_precision_loss)] pub fn calibrate(&mut self) -> Result<CalibrationResult, super::error::TunerError> {
86 use std::time::Instant;
87
88 let start = Instant::now();
89 let hw = HardwareCapability::detect();
90 let hardware_id = format!("{:?}", hw.gpu);
91
92 let mut samples = Vec::new();
94 let baseline_tps = self.estimate_baseline_tps(&hw);
95
96 for batch_size in [1, 2, 4, 8] {
98 for model_size in [1.5, 7.0, 13.0] {
99 for quant in [QuantType::Q4K, QuantType::Q8_0] {
100 let features = TunerFeatures::builder()
101 .model_params_b(model_size)
102 .hidden_dim(4096)
103 .num_layers(32)
104 .batch_size(batch_size)
105 .quant_type(quant)
106 .build();
107
108 let estimated_tps = baseline_tps * (batch_size as f32).sqrt()
110 / model_size.sqrt() as f32
111 * quant.bytes_per_param();
112
113 samples.push((features, estimated_tps.max(10.0)));
114 }
115 }
116 }
117
118 let num_benchmarks = samples.len();
119
120 let mut learner = OnlineLearner::new().with_learning_rate(0.01);
122
123 for _ in 0..10 {
125 for (features, target) in &samples {
126 learner.observe(&features.to_vector(), *target);
127 }
128 }
129
130 let pretrained_mape = self.throughput.mape;
132 self.throughput.weights = learner.weights().to_vec();
133
134 let mut total_error = 0.0;
136 for (features, target) in &samples {
137 let predicted = learner.predict(&features.to_vector());
138 total_error += ((predicted - target) / target).abs();
139 }
140 let local_mape = total_error / samples.len().max(1) as f32;
141 self.throughput.mape = local_mape;
142
143 let improvement_pct = ((pretrained_mape - local_mape) / pretrained_mape * 100.0).max(0.0);
144 let duration_secs = start.elapsed().as_secs_f32();
145
146 self.version = format!("{}-calibrated", Self::VERSION);
147
148 Ok(CalibrationResult {
149 throughput_weights: self.throughput.weights.clone(),
150 local_mape,
151 improvement_pct,
152 hardware_id,
153 duration_secs,
154 num_benchmarks,
155 })
156 }
157
158 #[cfg(feature = "hardware-detect")]
160 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
161 fn estimate_baseline_tps(&self, hw: &HardwareCapability) -> f32 {
163 let mem_bw_factor = hw.gpu.as_ref().map(|g| g.memory_bw_gbps / 1000.0).unwrap_or(0.5);
168
169 100.0 * mem_bw_factor as f32
170 }
171
172 pub fn online_learner(&self) -> OnlineLearner {
178 let mut learner = OnlineLearner::new();
179 learner.weights = self.throughput.weights.clone();
180 learner
181 }
182
183 pub fn apply_online_updates(&mut self, learner: &OnlineLearner) {
185 if learner.num_updates() > 0 {
186 self.throughput.weights = learner.weights().to_vec();
187 self.throughput.sample_count += learner.num_updates();
188 self.version = format!("{}-online-{}", Self::VERSION, learner.num_updates());
189 }
190 }
191
192 pub fn kernel_bandit(&self) -> KernelBandit {
198 KernelBandit::new()
199 }
200
201 #[allow(clippy::cast_precision_loss)] pub fn recommend_kernel_with_exploration(
204 &self,
205 features: &TunerFeatures,
206 bandit: &KernelBandit,
207 explore_prob: f32,
208 ) -> KernelRecommendation {
209 let do_explore = {
211 use std::collections::hash_map::DefaultHasher;
212 use std::hash::{Hash, Hasher};
213 let mut hasher = DefaultHasher::new();
214 bandit.total_pulls.hash(&mut hasher);
215 features.batch_size_norm.to_bits().hash(&mut hasher);
216 (hasher.finish() % 1000) as f32 / 1000.0 < explore_prob
217 };
218
219 if do_explore {
220 let kernel = bandit.select();
222 KernelRecommendation {
223 top_kernel: kernel,
224 confidence: 0.5, alternatives: vec![],
226 }
227 } else {
228 self.kernel.predict(features)
230 }
231 }
232}