Skip to main content

cbtop/context_regression/
predictor.rs

1//! Context-aware regression predictor with adaptive thresholds.
2
3use std::collections::HashMap;
4
5use super::{
6    BaselineEntry, RegressionCheck, RegressionThreshold, SystemContext, Trend,
7    DEFAULT_COLD_START_MARGIN, DEFAULT_STALENESS_SEC, MIN_SAMPLES_FOR_CONTEXT,
8};
9
10/// Context-aware regression predictor
11#[derive(Debug)]
12pub struct ContextRegressionPredictor {
13    /// Historical baselines by metric
14    baselines: HashMap<String, Vec<BaselineEntry>>,
15    /// Maximum history size
16    max_history: usize,
17    /// Cold start margin (%)
18    cold_start_margin: f64,
19    /// Minimum learned margin (%)
20    min_margin: f64,
21    /// Context staleness threshold (seconds)
22    staleness_sec: u64,
23    /// Temperature variance factor (% per 10°C)
24    temp_factor: f64,
25    /// Memory variance factor (% per 10% utilization)
26    memory_factor: f64,
27    /// Frequency variance factor (% per 10% reduction)
28    freq_factor: f64,
29    /// Cache cold penalty (%)
30    cache_cold_penalty: f64,
31}
32
33impl Default for ContextRegressionPredictor {
34    fn default() -> Self {
35        Self {
36            baselines: HashMap::new(),
37            max_history: 100,
38            cold_start_margin: DEFAULT_COLD_START_MARGIN,
39            min_margin: 3.0,
40            staleness_sec: DEFAULT_STALENESS_SEC,
41            temp_factor: 2.0,
42            memory_factor: 1.0,
43            freq_factor: 5.0,
44            cache_cold_penalty: 10.0,
45        }
46    }
47}
48
49/// Compute the mean of values extracted from baseline entries via a field accessor.
50fn entries_mean(entries: &[BaselineEntry], field: impl Fn(&BaselineEntry) -> f64) -> f64 {
51    let sum: f64 = entries.iter().map(&field).sum();
52    sum / entries.len() as f64
53}
54
55/// Compute a context adjustment: scale the difference between current and historical
56/// average by a divisor and factor, optionally clamping negative values to zero.
57fn context_adjustment(
58    current: f64,
59    historical_avg: f64,
60    scale: f64,
61    factor: f64,
62    clamp_positive: bool,
63) -> f64 {
64    let diff = current - historical_avg;
65    let scaled = diff * scale;
66    if clamp_positive {
67        scaled.max(0.0) * factor
68    } else {
69        scaled * factor
70    }
71}
72
73/// Simple linear regression result.
74struct LinearFit {
75    slope: f64,
76    r_squared: f64,
77}
78
79/// Compute simple linear regression (slope and R-squared) for (x, y) pairs
80/// extracted from baseline entries. Returns `None` if the data is degenerate.
81fn linear_regression(
82    entries: &[BaselineEntry],
83    x_fn: impl Fn(&BaselineEntry) -> f64,
84    y_fn: impl Fn(&BaselineEntry) -> f64,
85) -> Option<LinearFit> {
86    let n = entries.len() as f64;
87    let mut sum_x = 0.0_f64;
88    let mut sum_y = 0.0_f64;
89    let mut sum_xy = 0.0_f64;
90    let mut sum_xx = 0.0_f64;
91
92    for entry in entries {
93        let x = x_fn(entry);
94        let y = y_fn(entry);
95        sum_x += x;
96        sum_y += y;
97        sum_xy += x * y;
98        sum_xx += x * x;
99    }
100
101    let denom = n * sum_xx - sum_x * sum_x;
102    if denom.abs() < 1e-10 {
103        return None;
104    }
105
106    let slope = (n * sum_xy - sum_x * sum_y) / denom;
107    let intercept = (sum_y - slope * sum_x) / n;
108
109    // Compute R-squared
110    let mean_y = sum_y / n;
111    let mut ss_res = 0.0;
112    let mut ss_tot = 0.0;
113    for entry in entries {
114        let x = x_fn(entry);
115        let y_pred = slope * x + intercept;
116        let y = y_fn(entry);
117        ss_res += (y - y_pred).powi(2);
118        ss_tot += (y - mean_y).powi(2);
119    }
120    let r_squared = if ss_tot > 0.0 {
121        1.0 - ss_res / ss_tot
122    } else {
123        0.0
124    };
125
126    Some(LinearFit { slope, r_squared })
127}
128
129/// Compute the coefficient of variation (%) of a slice of f64 values.
130/// Returns a default CV if the mean is near zero.
131fn coefficient_of_variation(values: &[f64]) -> f64 {
132    let n = values.len();
133    if n < 2 {
134        return 5.0;
135    }
136    let mean = values.iter().sum::<f64>() / n as f64;
137    if mean.abs() <= 1e-10 {
138        return 5.0;
139    }
140    let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (n - 1).max(1) as f64;
141    (variance.sqrt() / mean.abs()) * 100.0
142}
143
144impl ContextRegressionPredictor {
145    /// Create new predictor
146    pub fn new() -> Self {
147        Self::default()
148    }
149
150    /// Set cold start margin
151    pub fn with_cold_start_margin(mut self, margin: f64) -> Self {
152        self.cold_start_margin = margin.max(5.0);
153        self
154    }
155
156    /// Set minimum margin
157    pub fn with_min_margin(mut self, margin: f64) -> Self {
158        self.min_margin = margin.max(1.0);
159        self
160    }
161
162    /// Set temperature factor
163    pub fn with_temp_factor(mut self, factor: f64) -> Self {
164        self.temp_factor = factor.max(0.0);
165        self
166    }
167
168    /// Set staleness threshold
169    pub fn with_staleness(mut self, sec: u64) -> Self {
170        self.staleness_sec = sec;
171        self
172    }
173
174    /// Add baseline entry
175    pub fn add_baseline(&mut self, metric: &str, value: f64, context: SystemContext) {
176        let entry = BaselineEntry::new(metric, value, context);
177
178        self.baselines
179            .entry(metric.to_string())
180            .or_default()
181            .push(entry);
182
183        // Trim old entries
184        if let Some(entries) = self.baselines.get_mut(metric) {
185            while entries.len() > self.max_history {
186                entries.remove(0);
187            }
188        }
189    }
190
191    /// Get baseline count for metric
192    pub fn baseline_count(&self, metric: &str) -> usize {
193        self.baselines.get(metric).map(|e| e.len()).unwrap_or(0)
194    }
195
196    /// Check if sufficient history
197    pub fn has_sufficient_history(&self, metric: &str) -> bool {
198        self.baseline_count(metric) >= MIN_SAMPLES_FOR_CONTEXT
199    }
200
201    /// Compute context-aware threshold
202    pub fn compute_threshold(
203        &self,
204        metric: &str,
205        current_context: &SystemContext,
206    ) -> RegressionThreshold {
207        let sample_count = self.baseline_count(metric);
208
209        // Cold start: use conservative margin
210        if sample_count < MIN_SAMPLES_FOR_CONTEXT {
211            return RegressionThreshold {
212                base_percent: self.cold_start_margin,
213                temp_adjustment: 0.0,
214                memory_adjustment: 0.0,
215                freq_adjustment: 0.0,
216                cache_adjustment: 0.0,
217                final_percent: self.cold_start_margin,
218                confidence: 0.1,
219                sample_count,
220            };
221        }
222
223        let entries = self
224            .baselines
225            .get(metric)
226            .expect("metric should exist in baselines after sufficient history check");
227
228        // Compute base threshold from historical variance (coefficient of variation)
229        let values: Vec<f64> = entries.iter().map(|e| e.value).collect();
230        let cv = coefficient_of_variation(&values);
231
232        // Base threshold: 2*CV + min margin
233        let base_percent = (cv * 2.0).max(self.min_margin);
234
235        // Temperature adjustment: warmer = more variance expected
236        let avg_temp = entries_mean(entries, |e| e.context.cpu_temp_c);
237        let temp_adjustment = context_adjustment(
238            current_context.cpu_temp_c,
239            avg_temp,
240            1.0 / 10.0,
241            self.temp_factor,
242            false,
243        );
244
245        // Memory adjustment: higher pressure = more variance
246        let avg_mem = entries_mean(entries, |e| e.context.memory_percent);
247        let memory_adjustment = context_adjustment(
248            current_context.memory_percent,
249            avg_mem,
250            1.0 / 10.0,
251            self.memory_factor,
252            true,
253        );
254
255        // Frequency adjustment: lower frequency = expect slower
256        // Direction is reversed (historical - current), so we swap current/historical
257        let avg_freq_util = entries_mean(entries, |e| e.context.freq_utilization());
258        let freq_adjustment = context_adjustment(
259            avg_freq_util,
260            current_context.freq_utilization(),
261            10.0,
262            self.freq_factor,
263            true,
264        );
265
266        // Cache adjustment: cold cache = expect slower
267        let cache_adjustment = if !current_context.cache_warm {
268            self.cache_cold_penalty
269        } else {
270            0.0
271        };
272
273        // Final threshold
274        let final_percent = (base_percent
275            + temp_adjustment
276            + memory_adjustment
277            + freq_adjustment
278            + cache_adjustment)
279            .max(self.min_margin);
280
281        // Confidence increases with more samples
282        let confidence = (sample_count as f64 / 50.0).min(1.0);
283
284        RegressionThreshold {
285            base_percent,
286            temp_adjustment,
287            memory_adjustment,
288            freq_adjustment,
289            cache_adjustment,
290            final_percent,
291            confidence,
292            sample_count,
293        }
294    }
295
296    /// Detect trend in baselines
297    pub fn detect_trend(&self, metric: &str) -> Option<Trend> {
298        let entries = self.baselines.get(metric)?;
299        if entries.len() < MIN_SAMPLES_FOR_CONTEXT {
300            return None;
301        }
302
303        let base_time = entries.first()?.context.timestamp;
304        let fit = linear_regression(
305            entries,
306            |e| (e.context.timestamp - base_time) as f64 / 86400.0,
307            |e| e.value,
308        )?;
309
310        let direction = if fit.slope > 0.1 {
311            "increasing"
312        } else if fit.slope < -0.1 {
313            "decreasing"
314        } else {
315            "stable"
316        };
317
318        Some(Trend {
319            slope_per_day: fit.slope,
320            r_squared: fit.r_squared,
321            direction,
322        })
323    }
324
325    /// Check for regression
326    pub fn check_regression(
327        &self,
328        metric: &str,
329        current_value: f64,
330        context: &SystemContext,
331    ) -> RegressionCheck {
332        let threshold = self.compute_threshold(metric, context);
333
334        let entries = self.baselines.get(metric);
335        let baseline_mean = entries
336            .map(|e| entries_mean(e, |x| x.value))
337            .unwrap_or(current_value);
338
339        let percent_change = if baseline_mean.abs() > 1e-10 {
340            ((current_value - baseline_mean) / baseline_mean) * 100.0
341        } else {
342            0.0
343        };
344
345        let is_regression = threshold.is_regression(percent_change);
346        let trend = self.detect_trend(metric);
347
348        RegressionCheck {
349            metric: metric.to_string(),
350            current_value,
351            baseline_mean,
352            percent_change,
353            threshold,
354            is_regression,
355            trend,
356        }
357    }
358
359    /// Clear history for metric
360    pub fn clear(&mut self, metric: &str) {
361        self.baselines.remove(metric);
362    }
363
364    /// Clear all history
365    pub fn clear_all(&mut self) {
366        self.baselines.clear();
367    }
368
369    /// Export baselines to JSON
370    pub fn export_json(&self, metric: &str) -> Option<String> {
371        let entries = self.baselines.get(metric)?;
372        let entries_json: Vec<String> = entries
373            .iter()
374            .map(|e| {
375                format!(
376                    r#"{{"value":{},"context":{}}}"#,
377                    e.value,
378                    e.context.to_json()
379                )
380            })
381            .collect();
382        Some(format!(
383            r#"{{"metric":"{}","entries":[{}]}}"#,
384            metric,
385            entries_json.join(",")
386        ))
387    }
388}