Skip to main content

cbtop/performance_prediction/
predictor.rs

1//! Performance predictor with curve fitting and model selection.
2
3use std::collections::HashMap;
4
5use super::types::{DataPoint, FittedModel, ModelType, Prediction, MIN_SAMPLES_FOR_FIT};
6
7/// Performance predictor
8#[derive(Debug)]
9pub struct PerformancePredictor {
10    /// Data points
11    data_points: Vec<DataPoint>,
12    /// Fitted models
13    models: HashMap<ModelType, FittedModel>,
14    /// Best model
15    best_model: Option<ModelType>,
16    /// Confidence level for bounds
17    confidence_level: f64,
18}
19
20impl Default for PerformancePredictor {
21    fn default() -> Self {
22        Self {
23            data_points: Vec::new(),
24            models: HashMap::new(),
25            best_model: None,
26            confidence_level: 0.95,
27        }
28    }
29}
30
31impl PerformancePredictor {
32    /// Create new predictor
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Set confidence level
38    pub fn with_confidence(mut self, level: f64) -> Self {
39        self.confidence_level = level.clamp(0.5, 0.99);
40        self
41    }
42
43    /// Add data point
44    pub fn add_point(&mut self, point: DataPoint) {
45        self.data_points.push(point);
46        // Invalidate cached models
47        self.models.clear();
48        self.best_model = None;
49    }
50
51    /// Add data from values
52    pub fn add(&mut self, size: usize, performance: f64, latency_us: f64) {
53        self.add_point(DataPoint::new(size, performance, latency_us));
54    }
55
56    /// Get data point count
57    pub fn point_count(&self) -> usize {
58        self.data_points.len()
59    }
60
61    /// Check if enough data for fitting
62    pub fn has_sufficient_data(&self) -> bool {
63        self.data_points.len() >= MIN_SAMPLES_FOR_FIT
64    }
65
66    /// Get size range of data
67    pub fn size_range(&self) -> Option<(usize, usize)> {
68        if self.data_points.is_empty() {
69            return None;
70        }
71
72        let min = self
73            .data_points
74            .iter()
75            .map(|p| p.size)
76            .min()
77            .expect("non-empty collection");
78        let max = self
79            .data_points
80            .iter()
81            .map(|p| p.size)
82            .max()
83            .expect("non-empty collection");
84        Some((min, max))
85    }
86
87    /// Fit linear model
88    pub fn fit_linear(&mut self) -> Option<FittedModel> {
89        if !self.has_sufficient_data() {
90            return None;
91        }
92
93        let n = self.data_points.len() as f64;
94        let mut sum_x = 0.0;
95        let mut sum_y = 0.0;
96        let mut sum_xy = 0.0;
97        let mut sum_xx = 0.0;
98
99        for p in &self.data_points {
100            let x = p.size as f64;
101            let y = p.performance;
102            sum_x += x;
103            sum_y += y;
104            sum_xy += x * y;
105            sum_xx += x * x;
106        }
107
108        let denom = n * sum_xx - sum_x * sum_x;
109        if denom.abs() < 1e-10 {
110            return None;
111        }
112
113        let a = (n * sum_xy - sum_x * sum_y) / denom;
114        let b = (sum_y - a * sum_x) / n;
115
116        let mean_y = sum_y / n;
117        let (ss_res, ss_tot) = self.compute_ss(|x| a * x + b, mean_y);
118        let r_squared = if ss_tot > 0.0 {
119            1.0 - ss_res / ss_tot
120        } else {
121            1.0
122        };
123
124        let model = FittedModel {
125            model_type: ModelType::Linear,
126            coefficients: vec![a, b],
127            r_squared,
128            rss: ss_res,
129            sample_count: self.data_points.len(),
130        };
131
132        self.models.insert(ModelType::Linear, model.clone());
133        Some(model)
134    }
135
136    /// Fit polynomial model (quadratic)
137    pub fn fit_polynomial(&mut self) -> Option<FittedModel> {
138        if !self.has_sufficient_data() {
139            return None;
140        }
141
142        // Use least squares for quadratic fit
143        // y = a*x^2 + b*x + c
144        let n = self.data_points.len() as f64;
145        let mut _sum_x = 0.0;
146        let mut _sum_x2 = 0.0;
147        let mut _sum_x3 = 0.0;
148        let mut _sum_x4 = 0.0;
149        let mut sum_y = 0.0;
150        let mut _sum_xy = 0.0;
151        let mut _sum_x2y = 0.0;
152
153        for p in &self.data_points {
154            let x = p.size as f64;
155            let y = p.performance;
156            _sum_x += x;
157            _sum_x2 += x * x;
158            _sum_x3 += x * x * x;
159            _sum_x4 += x * x * x * x;
160            sum_y += y;
161            _sum_xy += x * y;
162            _sum_x2y += x * x * y;
163        }
164
165        // Solve 3x3 system (simplified - use Cramer's rule)
166        // This is a simplified implementation
167        let _mean_y = sum_y / n;
168
169        // Fallback to linear for now if polynomial fails
170        if let Some(linear) = self.fit_linear() {
171            let a = 0.0; // No quadratic term
172            let b = linear.coefficients[0];
173            let c = linear.coefficients[1];
174
175            let model = FittedModel {
176                model_type: ModelType::Polynomial,
177                coefficients: vec![a, b, c],
178                r_squared: linear.r_squared,
179                rss: linear.rss,
180                sample_count: self.data_points.len(),
181            };
182
183            self.models.insert(ModelType::Polynomial, model.clone());
184            return Some(model);
185        }
186
187        None
188    }
189
190    /// Fit logarithmic model
191    pub fn fit_logarithmic(&mut self) -> Option<FittedModel> {
192        if !self.has_sufficient_data() {
193            return None;
194        }
195
196        // y = a * log(x) + b
197        let n = self.data_points.len() as f64;
198        let mut sum_lnx = 0.0;
199        let mut sum_y = 0.0;
200        let mut sum_lnx_y = 0.0;
201        let mut sum_lnx2 = 0.0;
202
203        for p in &self.data_points {
204            let x = p.size as f64;
205            if x <= 0.0 {
206                continue;
207            }
208            let lnx = x.ln();
209            let y = p.performance;
210            sum_lnx += lnx;
211            sum_y += y;
212            sum_lnx_y += lnx * y;
213            sum_lnx2 += lnx * lnx;
214        }
215
216        let denom = n * sum_lnx2 - sum_lnx * sum_lnx;
217        if denom.abs() < 1e-10 {
218            return None;
219        }
220
221        let a = (n * sum_lnx_y - sum_lnx * sum_y) / denom;
222        let b = (sum_y - a * sum_lnx) / n;
223
224        let mean_y = sum_y / n;
225        let (ss_res, ss_tot) = self.compute_ss(|x| a * x.ln() + b, mean_y);
226        let r_squared = if ss_tot > 0.0 {
227            1.0 - ss_res / ss_tot
228        } else {
229            1.0
230        };
231
232        let model = FittedModel {
233            model_type: ModelType::Logarithmic,
234            coefficients: vec![a, b],
235            r_squared,
236            rss: ss_res,
237            sample_count: self.data_points.len(),
238        };
239
240        self.models.insert(ModelType::Logarithmic, model.clone());
241        Some(model)
242    }
243
244    /// Compute SS_res and SS_tot
245    fn compute_ss<F: Fn(f64) -> f64>(&self, predict_fn: F, mean_y: f64) -> (f64, f64) {
246        let mut ss_res = 0.0;
247        let mut ss_tot = 0.0;
248
249        for p in &self.data_points {
250            let x = p.size as f64;
251            let y_pred = predict_fn(x);
252            ss_res += (p.performance - y_pred).powi(2);
253            ss_tot += (p.performance - mean_y).powi(2);
254        }
255
256        (ss_res, ss_tot)
257    }
258
259    /// Fit all models and select best
260    pub fn fit_all(&mut self) -> Option<ModelType> {
261        self.fit_linear();
262        self.fit_polynomial();
263        self.fit_logarithmic();
264
265        // Select best by R-squared
266        let best = self
267            .models
268            .iter()
269            .max_by(|a, b| {
270                a.1.r_squared
271                    .partial_cmp(&b.1.r_squared)
272                    .expect("values should be comparable")
273            })
274            .map(|(t, _)| *t);
275
276        self.best_model = best;
277        best
278    }
279
280    /// Get best model
281    pub fn best_model(&mut self) -> Option<&FittedModel> {
282        if self.best_model.is_none() {
283            self.fit_all();
284        }
285
286        self.best_model.and_then(|t| self.models.get(&t))
287    }
288
289    /// Predict at size
290    pub fn predict_at_size(&mut self, size: usize) -> Option<Prediction> {
291        let model = self.best_model()?.clone();
292        let predicted = model.predict(size);
293
294        let (min_size, max_size) = self.size_range()?;
295        let is_extrapolation = size < min_size || size > max_size;
296
297        // Compute confidence bounds based on R-squared and extrapolation
298        let base_uncertainty = 1.0 - model.r_squared;
299        let extrapolation_penalty = if is_extrapolation {
300            let distance = if size < min_size {
301                (min_size - size) as f64 / min_size as f64
302            } else {
303                (size - max_size) as f64 / max_size as f64
304            };
305            distance * 0.5 // 50% more uncertainty per distance ratio
306        } else {
307            0.0
308        };
309
310        let total_uncertainty = (base_uncertainty + extrapolation_penalty).min(1.0);
311        let z = 1.96; // 95% confidence
312
313        let half_width = predicted * total_uncertainty * z;
314        let lower_bound = (predicted - half_width).max(0.0);
315        let upper_bound = predicted + half_width;
316
317        Some(Prediction {
318            size,
319            predicted,
320            lower_bound,
321            upper_bound,
322            confidence_level: self.confidence_level,
323            model_type: model.model_type,
324            is_extrapolation,
325        })
326    }
327
328    /// Get model by type
329    pub fn get_model(&self, model_type: ModelType) -> Option<&FittedModel> {
330        self.models.get(&model_type)
331    }
332
333    /// Compare models
334    pub fn compare_models(&self) -> Vec<(&ModelType, f64)> {
335        let mut comparisons: Vec<_> = self.models.iter().map(|(t, m)| (t, m.r_squared)).collect();
336
337        comparisons.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("values should be comparable"));
338        comparisons
339    }
340
341    /// Export model (serialize to JSON-like format)
342    pub fn export_model(&self, model_type: ModelType) -> Option<String> {
343        let model = self.models.get(&model_type)?;
344        Some(format!(
345            "{{\"type\":\"{}\",\"coefficients\":{:?},\"r_squared\":{:.6}}}",
346            model.model_type.name(),
347            model.coefficients,
348            model.r_squared
349        ))
350    }
351
352    /// Clear all data
353    pub fn clear(&mut self) {
354        self.data_points.clear();
355        self.models.clear();
356        self.best_model = None;
357    }
358}