Skip to main content

cbtop/adaptive_ml/
types.rs

1//! Core types for adaptive ML thresholds.
2
3/// Result type for ML threshold operations
4pub type MlThresholdResult<T> = Result<T, MlThresholdError>;
5
6/// Errors in ML threshold operations
7#[derive(Debug, Clone, PartialEq)]
8pub enum MlThresholdError {
9    /// Insufficient training data
10    InsufficientData { have: usize, need: usize },
11    /// Workload not recognized
12    UnknownWorkload { name: String },
13    /// Model not trained
14    ModelNotTrained,
15    /// Feature extraction failed
16    FeatureExtractionFailed { reason: String },
17    /// Confidence too low
18    LowConfidence { confidence: f64, threshold: f64 },
19    /// Drift detected, re-calibration needed
20    DriftDetected { metric: String, drift_score: f64 },
21}
22
23impl std::fmt::Display for MlThresholdError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::InsufficientData { have, need } => {
27                write!(f, "Insufficient data: have {}, need {}", have, need)
28            }
29            Self::UnknownWorkload { name } => write!(f, "Unknown workload: {}", name),
30            Self::ModelNotTrained => write!(f, "Model not trained"),
31            Self::FeatureExtractionFailed { reason } => {
32                write!(f, "Feature extraction failed: {}", reason)
33            }
34            Self::LowConfidence {
35                confidence,
36                threshold,
37            } => {
38                write!(f, "Low confidence {} < {}", confidence, threshold)
39            }
40            Self::DriftDetected {
41                metric,
42                drift_score,
43            } => {
44                write!(f, "Drift detected in {}: score {:.2}", metric, drift_score)
45            }
46        }
47    }
48}
49
50impl std::error::Error for MlThresholdError {}
51
52/// Workload type for classification
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum WorkloadClass {
55    /// FFN/MLP operations
56    Ffn,
57    /// Matrix multiplication
58    Matmul,
59    /// Attention operations
60    Attention,
61    /// Quantization/dequantization
62    Quantize,
63    /// Memory-bound operations
64    MemoryBound,
65    /// Compute-bound operations
66    ComputeBound,
67    /// Unknown workload
68    Unknown,
69}
70
71impl WorkloadClass {
72    /// Get default CV threshold for this workload class
73    pub fn default_cv_threshold(&self) -> f64 {
74        match self {
75            Self::Ffn => 18.0,         // FFN naturally has higher variance
76            Self::Matmul => 10.0,      // Matmul is very consistent
77            Self::Attention => 15.0,   // Attention has moderate variance
78            Self::Quantize => 12.0,    // Quantize is fairly consistent
79            Self::MemoryBound => 20.0, // Memory-bound is highly variable
80            Self::ComputeBound => 8.0, // Compute-bound is very consistent
81            Self::Unknown => 15.0,     // Conservative default
82        }
83    }
84
85    /// Get name as string
86    pub fn name(&self) -> &'static str {
87        match self {
88            Self::Ffn => "FFN",
89            Self::Matmul => "Matmul",
90            Self::Attention => "Attention",
91            Self::Quantize => "Quantize",
92            Self::MemoryBound => "MemoryBound",
93            Self::ComputeBound => "ComputeBound",
94            Self::Unknown => "Unknown",
95        }
96    }
97
98    /// Parse from name string (used by import_state)
99    pub(super) fn from_name(name: &str) -> Option<Self> {
100        match name {
101            "FFN" => Some(WorkloadClass::Ffn),
102            "Matmul" => Some(WorkloadClass::Matmul),
103            "Attention" => Some(WorkloadClass::Attention),
104            "Quantize" => Some(WorkloadClass::Quantize),
105            "MemoryBound" => Some(WorkloadClass::MemoryBound),
106            "ComputeBound" => Some(WorkloadClass::ComputeBound),
107            _ => None,
108        }
109    }
110}
111
112/// Features extracted from a time series
113#[derive(Debug, Clone)]
114pub struct TimeSeriesFeatures {
115    /// Mean value
116    pub mean: f64,
117    /// Standard deviation
118    pub std_dev: f64,
119    /// Coefficient of variation (CV)
120    pub cv: f64,
121    /// Skewness
122    pub skewness: f64,
123    /// Kurtosis
124    pub kurtosis: f64,
125    /// Autocorrelation at lag 1
126    pub autocorr_lag1: f64,
127    /// Trend slope (linear fit)
128    pub trend_slope: f64,
129    /// Number of samples
130    pub sample_count: usize,
131}
132
133impl TimeSeriesFeatures {
134    /// Extract features from sample values
135    pub fn extract(values: &[f64]) -> Option<Self> {
136        if values.len() < 10 {
137            return None;
138        }
139
140        let n = values.len() as f64;
141        let mean = values.iter().sum::<f64>() / n;
142
143        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
144        let std_dev = variance.sqrt();
145
146        let cv = if mean.abs() > 1e-10 {
147            (std_dev / mean) * 100.0
148        } else {
149            0.0
150        };
151
152        // Skewness
153        let skewness = if std_dev > 1e-10 {
154            let m3 = values
155                .iter()
156                .map(|x| ((x - mean) / std_dev).powi(3))
157                .sum::<f64>()
158                / n;
159            m3
160        } else {
161            0.0
162        };
163
164        // Kurtosis
165        let kurtosis = if std_dev > 1e-10 {
166            let m4 = values
167                .iter()
168                .map(|x| ((x - mean) / std_dev).powi(4))
169                .sum::<f64>()
170                / n;
171            m4 - 3.0 // Excess kurtosis
172        } else {
173            0.0
174        };
175
176        // Autocorrelation at lag 1
177        let autocorr_lag1 = if values.len() > 1 && std_dev > 1e-10 {
178            let mut sum = 0.0;
179            for i in 0..values.len() - 1 {
180                sum += (values[i] - mean) * (values[i + 1] - mean);
181            }
182            sum / ((values.len() - 1) as f64 * variance)
183        } else {
184            0.0
185        };
186
187        // Trend slope (simple linear regression)
188        let trend_slope = {
189            let x_mean = (values.len() as f64 - 1.0) / 2.0;
190            let mut num = 0.0;
191            let mut den = 0.0;
192            for (i, &y) in values.iter().enumerate() {
193                let x = i as f64;
194                num += (x - x_mean) * (y - mean);
195                den += (x - x_mean).powi(2);
196            }
197            if den > 1e-10 {
198                num / den
199            } else {
200                0.0
201            }
202        };
203
204        Some(Self {
205            mean,
206            std_dev,
207            cv,
208            skewness,
209            kurtosis,
210            autocorr_lag1,
211            trend_slope,
212            sample_count: values.len(),
213        })
214    }
215
216    /// Convert features to vector for model input
217    pub fn to_vec(&self) -> Vec<f64> {
218        vec![
219            self.cv,
220            self.skewness,
221            self.kurtosis,
222            self.autocorr_lag1,
223            self.trend_slope,
224        ]
225    }
226}
227
228/// Anomaly detection result
229#[derive(Debug, Clone)]
230pub struct AnomalyResult {
231    /// Whether this is an anomaly
232    pub is_anomaly: bool,
233    /// Anomaly score (higher = more anomalous)
234    pub score: f64,
235    /// Threshold used
236    pub threshold: f64,
237    /// Confidence in the result
238    pub confidence: f64,
239    /// Workload class
240    pub workload_class: WorkloadClass,
241    /// Reason for classification
242    pub reason: String,
243}
244
245/// Classification precision/recall metrics
246#[derive(Debug, Clone, Default)]
247pub struct ClassificationMetrics {
248    /// True positives
249    pub true_positives: usize,
250    /// False positives
251    pub false_positives: usize,
252    /// True negatives
253    pub true_negatives: usize,
254    /// False negatives
255    pub false_negatives: usize,
256}
257
258impl ClassificationMetrics {
259    /// Calculate precision
260    pub fn precision(&self) -> f64 {
261        let total = self.true_positives + self.false_positives;
262        if total == 0 {
263            0.0
264        } else {
265            self.true_positives as f64 / total as f64
266        }
267    }
268
269    /// Calculate recall
270    pub fn recall(&self) -> f64 {
271        let total = self.true_positives + self.false_negatives;
272        if total == 0 {
273            0.0
274        } else {
275            self.true_positives as f64 / total as f64
276        }
277    }
278
279    /// Calculate F1 score
280    pub fn f1(&self) -> f64 {
281        let p = self.precision();
282        let r = self.recall();
283        if p + r == 0.0 {
284            0.0
285        } else {
286            2.0 * p * r / (p + r)
287        }
288    }
289
290    /// Calculate false positive rate
291    pub fn false_positive_rate(&self) -> f64 {
292        let total = self.false_positives + self.true_negatives;
293        if total == 0 {
294            0.0
295        } else {
296            self.false_positives as f64 / total as f64
297        }
298    }
299}