trueno/tuner/models/
throughput.rs1use serde::{Deserialize, Serialize};
4
5#[cfg(feature = "ml-tuner")]
6use aprender::{tree::RandomForestRegressor, Matrix, Vector};
7
8use super::super::error::TunerError;
9use super::super::features::TunerFeatures;
10use super::ThroughputPrediction;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ThroughputRegressor {
18 pub(crate) weights: Vec<f32>,
20 pub(crate) feature_importance: Vec<(String, f32)>,
22 pub(crate) sample_count: usize,
24 pub(crate) mape: f32,
26 #[cfg(feature = "ml-tuner")]
28 #[serde(skip)]
29 rf_model: Option<RandomForestRegressor>,
30}
31
32impl Default for ThroughputRegressor {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl ThroughputRegressor {
39 pub fn new() -> Self {
41 let mut weights = vec![0.0; TunerFeatures::DIM + 1]; weights[0] = 0.4;
47
48 weights[7] = 0.3; weights[9] = 0.1; weights[36] = 0.15; weights[38] = 0.1; weights[1] = -0.15; weights[8] = -0.05; Self {
67 weights,
68 feature_importance: Self::default_feature_importance(),
69 sample_count: 0,
70 mape: 0.15, #[cfg(feature = "ml-tuner")]
72 rf_model: None,
73 }
74 }
75
76 #[cfg(feature = "ml-tuner")]
78 pub fn with_random_forest(n_estimators: usize) -> Self {
79 let mut instance = Self::new();
80 instance.rf_model = Some(RandomForestRegressor::new(n_estimators));
81 instance
82 }
83
84 fn default_feature_importance() -> Vec<(String, f32)> {
85 vec![
86 ("batch_size".into(), 0.25),
87 ("gpu_mem_bw".into(), 0.20),
88 ("model_params".into(), 0.15),
89 ("cuda_graphs".into(), 0.10),
90 ("gpu_sm_count".into(), 0.10),
91 ("hidden_dim".into(), 0.08),
92 ("quant_type".into(), 0.07),
93 ("seq_len".into(), 0.05),
94 ]
95 }
96
97 pub fn train(&mut self, data: &[(TunerFeatures, f32)]) -> Result<(), TunerError> {
99 if data.len() < 10 {
100 return Err(TunerError::InsufficientData(data.len()));
101 }
102
103 let learning_rate = 0.01;
105 let epochs = 100;
106
107 for _ in 0..epochs {
108 let mut gradients = vec![0.0; self.weights.len()];
109
110 for (features, target) in data {
111 let x = features.to_vector();
112 let predicted = self.predict_raw(&x);
113 let error = predicted - target;
114
115 gradients[0] += error;
117
118 for (i, xi) in x.iter().enumerate() {
120 gradients[i + 1] += error * xi;
121 }
122 }
123
124 let n = data.len().max(1) as f32;
126 for (i, g) in gradients.iter().enumerate() {
127 self.weights[i] -= learning_rate * g / n;
128 }
129 }
130
131 let mut total_ape = 0.0;
133 for (features, target) in data {
134 let predicted = self.predict_raw(&features.to_vector());
135 total_ape += ((predicted - target) / target.max(1.0)).abs();
136 }
137 self.mape = total_ape / data.len().max(1) as f32;
138 self.sample_count = data.len();
139
140 Ok(())
141 }
142
143 #[cfg(feature = "ml-tuner")]
148 pub fn train_random_forest(&mut self, data: &[(TunerFeatures, f32)]) -> Result<(), TunerError> {
149 if data.len() < 10 {
150 return Err(TunerError::InsufficientData(data.len()));
151 }
152
153 let n_samples = data.len();
155 let n_features = TunerFeatures::DIM;
156 let mut x_data = Vec::with_capacity(n_samples * n_features);
157 let mut y_data = Vec::with_capacity(n_samples);
158
159 for (features, target) in data {
160 x_data.extend(features.to_vector());
161 y_data.push(*target);
162 }
163
164 let x_matrix = Matrix::from_vec(n_samples, n_features, x_data)
165 .map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
166 let y_vector = Vector::from_vec(y_data);
167
168 let rf = self.rf_model.get_or_insert_with(|| RandomForestRegressor::new(100));
170 rf.fit(&x_matrix, &y_vector).map_err(|e| TunerError::TrainingFailed(e.to_string()))?;
171
172 let predictions = rf.predict(&x_matrix);
174 let mut total_ape = 0.0;
175 for (i, (_, target)) in data.iter().enumerate() {
176 let pred = predictions.as_slice()[i];
177 total_ape += ((pred - target) / target.max(1.0)).abs();
178 }
179 self.mape = total_ape / data.len().max(1) as f32;
180 self.sample_count = data.len();
181
182 Ok(())
183 }
184
185 pub(crate) fn predict_raw(&self, x: &[f32]) -> f32 {
186 let mut result = self.weights[0]; for (i, xi) in x.iter().enumerate() {
188 if i + 1 < self.weights.len() {
189 result += self.weights[i + 1] * xi;
190 }
191 }
192 (result * 1000.0).max(1.0)
194 }
195
196 pub fn predict(&self, features: &TunerFeatures) -> ThroughputPrediction {
201 let x = features.to_vector();
202
203 #[cfg(feature = "ml-tuner")]
205 let raw_predicted_tps = if let Some(ref rf) = self.rf_model {
206 if let Ok(x_matrix) = Matrix::from_vec(1, TunerFeatures::DIM, x.to_vec()) {
208 let predictions = rf.predict(&x_matrix);
209 predictions.as_slice().first().copied().unwrap_or(0.0)
210 } else {
211 self.predict_raw(&x)
212 }
213 } else {
214 self.predict_raw(&x)
215 };
216
217 #[cfg(not(feature = "ml-tuner"))]
218 let raw_predicted_tps = self.predict_raw(&x);
219
220 let theoretical_max_tps = Self::compute_roofline_bound(features);
222 let predicted_tps = raw_predicted_tps.min(theoretical_max_tps);
223
224 let roofline_penalty = if raw_predicted_tps > theoretical_max_tps {
227 0.9 } else {
229 1.0
230 };
231 let confidence = (1.0 - self.mape).max(0.5) * roofline_penalty;
232
233 ThroughputPrediction {
234 predicted_tps,
235 confidence,
236 top_features: self.feature_importance.iter().take(5).cloned().collect(),
237 }
238 }
239
240 pub fn compute_roofline_bound(features: &TunerFeatures) -> f32 {
246 let model_params_b = 10.0_f32.powf(features.model_params_b * 3.0 - 1.0);
250
251 let bytes_per_param = Self::bytes_per_param_from_onehot(&features.quant_type_onehot);
253
254 let gpu_mem_bw_gbs = features.gpu_mem_bw_norm * 3000.0;
256
257 let batch_size = (features.batch_size_norm * 64.0).max(1.0);
259
260 let theoretical_max = (gpu_mem_bw_gbs * batch_size) / (model_params_b * bytes_per_param);
267
268 theoretical_max.clamp(1.0, 10000.0)
270 }
271
272 pub fn bytes_per_param_from_onehot(onehot: &[f32; 8]) -> f32 {
274 let bytes_per_param = [0.5625, 0.5625, 0.5625, 0.6875, 0.8125, 1.0, 2.0, 4.0];
277
278 let idx = onehot
280 .iter()
281 .enumerate()
282 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
283 .map(|(i, _)| i)
284 .unwrap_or(0);
286
287 bytes_per_param[idx]
288 }
289}