organizational_intelligence_plugin/
ml.rs

1//! ML Model Integration for Defect Prediction
2//!
3//! Implements PHASE2-005: aprender ML model integration
4//! Provides Random Forest for classification and K-means for clustering
5//!
6//! References:
7//! - Breiman (2001): Random Forests
8//! - MacQueen (1967): K-means Clustering
9
10use crate::features::CommitFeatures;
11use anyhow::Result;
12
13/// Defect prediction model using Random Forest
14///
15/// Predicts defect category (0-9) from commit features
16pub struct DefectPredictor {
17    n_trees: usize,
18    max_depth: usize,
19    trained: bool,
20    // Store training data for simple prediction (Phase 1)
21    // Full aprender integration in Phase 2
22    training_data: Vec<(Vec<f32>, u8)>,
23}
24
25impl DefectPredictor {
26    /// Create new predictor with default parameters
27    pub fn new() -> Self {
28        Self {
29            n_trees: 100,
30            max_depth: 10,
31            trained: false,
32            training_data: Vec::new(),
33        }
34    }
35
36    /// Create predictor with custom parameters
37    pub fn with_params(n_trees: usize, max_depth: usize) -> Self {
38        Self {
39            n_trees,
40            max_depth,
41            trained: false,
42            training_data: Vec::new(),
43        }
44    }
45
46    /// Train model on labeled features
47    ///
48    /// # Arguments
49    /// * `features` - Training features with defect_category labels
50    pub fn train(&mut self, features: &[CommitFeatures]) -> Result<()> {
51        if features.is_empty() {
52            anyhow::bail!("Cannot train on empty dataset");
53        }
54
55        // Store training data for k-NN based prediction
56        self.training_data = features
57            .iter()
58            .map(|f| (f.to_vector(), f.defect_category))
59            .collect();
60
61        self.trained = true;
62        Ok(())
63    }
64
65    /// Predict defect category for new features
66    ///
67    /// Uses k-NN approximation (k=5) for Phase 1
68    /// Full Random Forest in Phase 2 with aprender
69    pub fn predict(&self, features: &CommitFeatures) -> Result<u8> {
70        if !self.trained {
71            anyhow::bail!("Model not trained");
72        }
73
74        let query = features.to_vector();
75
76        // k-NN prediction (k=5)
77        let k = 5.min(self.training_data.len());
78        let mut distances: Vec<(f32, u8)> = self
79            .training_data
80            .iter()
81            .map(|(v, label)| (Self::euclidean_distance(&query, v), *label))
82            .collect();
83
84        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
85
86        // Vote among k nearest neighbors
87        let mut votes = [0u32; 10];
88        for (_, label) in distances.iter().take(k) {
89            let idx = (*label as usize).min(9);
90            votes[idx] += 1;
91        }
92
93        // Return most common category
94        let prediction = votes
95            .iter()
96            .enumerate()
97            .max_by_key(|(_, &count)| count)
98            .map(|(idx, _)| idx as u8)
99            .unwrap_or(0);
100
101        Ok(prediction)
102    }
103
104    /// Predict probabilities for all categories
105    pub fn predict_proba(&self, features: &CommitFeatures) -> Result<Vec<f32>> {
106        if !self.trained {
107            anyhow::bail!("Model not trained");
108        }
109
110        let query = features.to_vector();
111        let k = 10.min(self.training_data.len());
112
113        let mut distances: Vec<(f32, u8)> = self
114            .training_data
115            .iter()
116            .map(|(v, label)| (Self::euclidean_distance(&query, v), *label))
117            .collect();
118
119        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
120
121        // Compute probability as fraction of k neighbors
122        let mut counts = [0u32; 10];
123        for (_, label) in distances.iter().take(k) {
124            let idx = (*label as usize).min(9);
125            counts[idx] += 1;
126        }
127
128        let probs: Vec<f32> = counts.iter().map(|&c| c as f32 / k as f32).collect();
129
130        Ok(probs)
131    }
132
133    /// Get model parameters
134    pub fn params(&self) -> (usize, usize) {
135        (self.n_trees, self.max_depth)
136    }
137
138    /// Check if model is trained
139    pub fn is_trained(&self) -> bool {
140        self.trained
141    }
142
143    fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
144        a.iter()
145            .zip(b.iter())
146            .map(|(x, y)| (x - y).powi(2))
147            .sum::<f32>()
148            .sqrt()
149    }
150}
151
152impl Default for DefectPredictor {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Pattern clustering using K-means
159///
160/// Groups similar commits into clusters for pattern discovery
161pub struct PatternClusterer {
162    k: usize,
163    max_iterations: usize,
164    centroids: Vec<Vec<f32>>,
165    trained: bool,
166}
167
168impl PatternClusterer {
169    /// Create clusterer with default k=5 clusters
170    pub fn new() -> Self {
171        Self {
172            k: 5,
173            max_iterations: 100,
174            centroids: Vec::new(),
175            trained: false,
176        }
177    }
178
179    /// Create clusterer with custom k
180    pub fn with_k(k: usize) -> Self {
181        Self {
182            k,
183            max_iterations: 100,
184            centroids: Vec::new(),
185            trained: false,
186        }
187    }
188
189    /// Fit clusters to data
190    pub fn fit(&mut self, features: &[CommitFeatures]) -> Result<()> {
191        if features.is_empty() {
192            anyhow::bail!("Cannot cluster empty dataset");
193        }
194
195        if features.len() < self.k {
196            anyhow::bail!("Need at least {} samples for {} clusters", self.k, self.k);
197        }
198
199        let vectors: Vec<Vec<f32>> = features.iter().map(|f| f.to_vector()).collect();
200        let n_dims = CommitFeatures::DIMENSION;
201
202        // Initialize centroids (first k points)
203        self.centroids = vectors.iter().take(self.k).cloned().collect();
204
205        // K-means iteration
206        for _ in 0..self.max_iterations {
207            // Assign points to clusters
208            let assignments: Vec<usize> =
209                vectors.iter().map(|v| self.nearest_centroid(v)).collect();
210
211            // Update centroids
212            let mut new_centroids = vec![vec![0.0; n_dims]; self.k];
213            let mut counts = vec![0usize; self.k];
214
215            for (vec, &cluster) in vectors.iter().zip(assignments.iter()) {
216                for (dim, &val) in vec.iter().enumerate() {
217                    new_centroids[cluster][dim] += val;
218                }
219                counts[cluster] += 1;
220            }
221
222            // Normalize centroids
223            for (centroid, &count) in new_centroids.iter_mut().zip(counts.iter()) {
224                if count > 0 {
225                    for val in centroid.iter_mut() {
226                        *val /= count as f32;
227                    }
228                }
229            }
230
231            // Check convergence
232            let converged = self
233                .centroids
234                .iter()
235                .zip(new_centroids.iter())
236                .all(|(old, new)| Self::euclidean_distance(old, new) < 1e-6);
237
238            self.centroids = new_centroids;
239
240            if converged {
241                break;
242            }
243        }
244
245        self.trained = true;
246        Ok(())
247    }
248
249    /// Predict cluster for new features
250    pub fn predict(&self, features: &CommitFeatures) -> Result<usize> {
251        if !self.trained {
252            anyhow::bail!("Clusterer not fitted");
253        }
254
255        let vec = features.to_vector();
256        Ok(self.nearest_centroid(&vec))
257    }
258
259    /// Predict clusters for multiple features
260    pub fn predict_batch(&self, features: &[CommitFeatures]) -> Result<Vec<usize>> {
261        if !self.trained {
262            anyhow::bail!("Clusterer not fitted");
263        }
264
265        Ok(features
266            .iter()
267            .map(|f| self.nearest_centroid(&f.to_vector()))
268            .collect())
269    }
270
271    /// Get cluster centroids
272    pub fn centroids(&self) -> &[Vec<f32>] {
273        &self.centroids
274    }
275
276    /// Compute inertia (sum of squared distances to centroids)
277    pub fn inertia(&self, features: &[CommitFeatures]) -> Result<f32> {
278        if !self.trained {
279            anyhow::bail!("Clusterer not fitted");
280        }
281
282        let total: f32 = features
283            .iter()
284            .map(|f| {
285                let vec = f.to_vector();
286                let cluster = self.nearest_centroid(&vec);
287                Self::euclidean_distance(&vec, &self.centroids[cluster]).powi(2)
288            })
289            .sum();
290
291        Ok(total)
292    }
293
294    fn nearest_centroid(&self, point: &[f32]) -> usize {
295        self.centroids
296            .iter()
297            .enumerate()
298            .map(|(i, c)| (i, Self::euclidean_distance(point, c)))
299            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
300            .map(|(i, _)| i)
301            .unwrap_or(0)
302    }
303
304    fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
305        a.iter()
306            .zip(b.iter())
307            .map(|(x, y)| (x - y).powi(2))
308            .sum::<f32>()
309            .sqrt()
310    }
311}
312
313impl Default for PatternClusterer {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319/// Model evaluation metrics
320pub struct ModelMetrics;
321
322impl ModelMetrics {
323    /// Compute accuracy
324    pub fn accuracy(predictions: &[u8], labels: &[u8]) -> f32 {
325        if predictions.len() != labels.len() || predictions.is_empty() {
326            return 0.0;
327        }
328
329        let correct = predictions
330            .iter()
331            .zip(labels.iter())
332            .filter(|(p, l)| p == l)
333            .count();
334
335        correct as f32 / predictions.len() as f32
336    }
337
338    /// Compute per-class precision
339    pub fn precision(predictions: &[u8], labels: &[u8], class: u8) -> f32 {
340        let true_positives = predictions
341            .iter()
342            .zip(labels.iter())
343            .filter(|(&p, &l)| p == class && l == class)
344            .count() as f32;
345
346        let predicted_positives = predictions.iter().filter(|&&p| p == class).count() as f32;
347
348        if predicted_positives > 0.0 {
349            true_positives / predicted_positives
350        } else {
351            0.0
352        }
353    }
354
355    /// Compute per-class recall
356    pub fn recall(predictions: &[u8], labels: &[u8], class: u8) -> f32 {
357        let true_positives = predictions
358            .iter()
359            .zip(labels.iter())
360            .filter(|(&p, &l)| p == class && l == class)
361            .count() as f32;
362
363        let actual_positives = labels.iter().filter(|&&l| l == class).count() as f32;
364
365        if actual_positives > 0.0 {
366            true_positives / actual_positives
367        } else {
368            0.0
369        }
370    }
371
372    /// Compute F1 score
373    pub fn f1_score(predictions: &[u8], labels: &[u8], class: u8) -> f32 {
374        let p = Self::precision(predictions, labels, class);
375        let r = Self::recall(predictions, labels, class);
376
377        if p + r > 0.0 {
378            2.0 * p * r / (p + r)
379        } else {
380            0.0
381        }
382    }
383
384    /// Compute silhouette score for clustering
385    pub fn silhouette_score(features: &[CommitFeatures], assignments: &[usize], k: usize) -> f32 {
386        if features.len() != assignments.len() || features.is_empty() {
387            return 0.0;
388        }
389
390        let vectors: Vec<Vec<f32>> = features.iter().map(|f| f.to_vector()).collect();
391
392        let mut total_score = 0.0;
393        let n = vectors.len();
394
395        for i in 0..n {
396            let cluster_i = assignments[i];
397
398            // a(i) = mean distance to same cluster
399            let same_cluster: Vec<_> = (0..n)
400                .filter(|&j| j != i && assignments[j] == cluster_i)
401                .collect();
402
403            let a = if same_cluster.is_empty() {
404                0.0
405            } else {
406                same_cluster
407                    .iter()
408                    .map(|&j| Self::euclidean_distance(&vectors[i], &vectors[j]))
409                    .sum::<f32>()
410                    / same_cluster.len() as f32
411            };
412
413            // b(i) = min mean distance to other clusters
414            let mut b = f32::INFINITY;
415            for c in 0..k {
416                if c == cluster_i {
417                    continue;
418                }
419                let other_cluster: Vec<_> = (0..n).filter(|&j| assignments[j] == c).collect();
420
421                if !other_cluster.is_empty() {
422                    let mean_dist = other_cluster
423                        .iter()
424                        .map(|&j| Self::euclidean_distance(&vectors[i], &vectors[j]))
425                        .sum::<f32>()
426                        / other_cluster.len() as f32;
427                    b = b.min(mean_dist);
428                }
429            }
430
431            if b.is_infinite() {
432                b = 0.0;
433            }
434
435            // s(i) = (b - a) / max(a, b)
436            let s = if a.max(b) > 0.0 {
437                (b - a) / a.max(b)
438            } else {
439                0.0
440            };
441
442            total_score += s;
443        }
444
445        total_score / n as f32
446    }
447
448    fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
449        a.iter()
450            .zip(b.iter())
451            .map(|(x, y)| (x - y).powi(2))
452            .sum::<f32>()
453            .sqrt()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    fn make_feature(category: u8, files: u32) -> CommitFeatures {
462        CommitFeatures {
463            defect_category: category,
464            files_changed: files as f32,
465            lines_added: (files * 10) as f32,
466            lines_deleted: (files * 5) as f32,
467            complexity_delta: files as f32 * 0.1,
468            timestamp: 1700000000.0 + files as f64,
469            hour_of_day: 10,
470            day_of_week: 1,
471            ..Default::default()
472        }
473    }
474
475    #[test]
476    fn test_predictor_creation() {
477        let predictor = DefectPredictor::new();
478        assert_eq!(predictor.params(), (100, 10));
479        assert!(!predictor.is_trained());
480    }
481
482    #[test]
483    fn test_predictor_train() {
484        let mut predictor = DefectPredictor::new();
485        let features = vec![
486            make_feature(0, 1),
487            make_feature(0, 2),
488            make_feature(1, 10),
489            make_feature(1, 11),
490        ];
491
492        predictor.train(&features).unwrap();
493        assert!(predictor.is_trained());
494    }
495
496    #[test]
497    fn test_predictor_predict() {
498        let mut predictor = DefectPredictor::new();
499        let features = vec![
500            make_feature(0, 1),
501            make_feature(0, 2),
502            make_feature(0, 3),
503            make_feature(1, 100),
504            make_feature(1, 101),
505            make_feature(1, 102),
506        ];
507
508        predictor.train(&features).unwrap();
509
510        // Similar to category 0 samples
511        let test_cat0 = make_feature(0, 2);
512        let pred0 = predictor.predict(&test_cat0).unwrap();
513        assert_eq!(pred0, 0);
514
515        // Similar to category 1 samples
516        let test_cat1 = make_feature(1, 101);
517        let pred1 = predictor.predict(&test_cat1).unwrap();
518        assert_eq!(pred1, 1);
519    }
520
521    #[test]
522    fn test_predictor_proba() {
523        let mut predictor = DefectPredictor::new();
524        let features = vec![make_feature(0, 1), make_feature(0, 2), make_feature(1, 100)];
525
526        predictor.train(&features).unwrap();
527
528        let probs = predictor.predict_proba(&make_feature(0, 1)).unwrap();
529        assert_eq!(probs.len(), 10);
530        assert!(probs[0] > probs[1]); // Category 0 should have higher probability
531    }
532
533    #[test]
534    fn test_clusterer_creation() {
535        let clusterer = PatternClusterer::new();
536        assert_eq!(clusterer.k, 5);
537    }
538
539    #[test]
540    fn test_clusterer_fit() {
541        let mut clusterer = PatternClusterer::with_k(2);
542        let features = vec![
543            make_feature(0, 1),
544            make_feature(0, 2),
545            make_feature(0, 3),
546            make_feature(1, 100),
547            make_feature(1, 101),
548            make_feature(1, 102),
549        ];
550
551        clusterer.fit(&features).unwrap();
552        assert!(clusterer.trained);
553        assert_eq!(clusterer.centroids().len(), 2);
554    }
555
556    #[test]
557    fn test_clusterer_predict() {
558        let mut clusterer = PatternClusterer::with_k(2);
559        let features = vec![
560            make_feature(0, 1),
561            make_feature(0, 2),
562            make_feature(1, 100),
563            make_feature(1, 101),
564        ];
565
566        clusterer.fit(&features).unwrap();
567
568        let assignments = clusterer.predict_batch(&features).unwrap();
569
570        // Similar features should be in same cluster
571        assert_eq!(assignments[0], assignments[1]);
572        assert_eq!(assignments[2], assignments[3]);
573        // Different features should be in different clusters
574        assert_ne!(assignments[0], assignments[2]);
575    }
576
577    #[test]
578    fn test_clusterer_inertia() {
579        let mut clusterer = PatternClusterer::with_k(2);
580        let features = vec![
581            make_feature(0, 1),
582            make_feature(0, 2),
583            make_feature(1, 100),
584            make_feature(1, 101),
585        ];
586
587        clusterer.fit(&features).unwrap();
588        let inertia = clusterer.inertia(&features).unwrap();
589        assert!(inertia >= 0.0);
590    }
591
592    #[test]
593    fn test_metrics_accuracy() {
594        let predictions = vec![0, 0, 1, 1];
595        let labels = vec![0, 0, 1, 0];
596
597        let acc = ModelMetrics::accuracy(&predictions, &labels);
598        assert!((acc - 0.75).abs() < 0.01);
599    }
600
601    #[test]
602    fn test_metrics_precision_recall() {
603        let predictions = vec![1, 1, 0, 0];
604        let labels = vec![1, 0, 0, 0];
605
606        let precision = ModelMetrics::precision(&predictions, &labels, 1);
607        assert!((precision - 0.5).abs() < 0.01); // 1 TP, 1 FP
608
609        let recall = ModelMetrics::recall(&predictions, &labels, 1);
610        assert!((recall - 1.0).abs() < 0.01); // 1 TP, 0 FN
611    }
612
613    #[test]
614    fn test_metrics_f1() {
615        let predictions = vec![1, 1, 0, 0];
616        let labels = vec![1, 0, 0, 0];
617
618        let f1 = ModelMetrics::f1_score(&predictions, &labels, 1);
619        // F1 = 2 * 0.5 * 1.0 / (0.5 + 1.0) = 0.667
620        assert!((f1 - 0.667).abs() < 0.01);
621    }
622
623    #[test]
624    fn test_silhouette_score() {
625        let features = vec![
626            make_feature(0, 1),
627            make_feature(0, 2),
628            make_feature(1, 100),
629            make_feature(1, 101),
630        ];
631        let assignments = vec![0, 0, 1, 1];
632
633        let score = ModelMetrics::silhouette_score(&features, &assignments, 2);
634        // Well-separated clusters should have positive silhouette
635        assert!(score > 0.0);
636    }
637
638    #[test]
639    fn test_predictor_with_params() {
640        let predictor = DefectPredictor::with_params(50, 5);
641        assert_eq!(predictor.params(), (50, 5));
642        assert!(!predictor.is_trained());
643    }
644
645    #[test]
646    fn test_predictor_train_empty_error() {
647        let mut predictor = DefectPredictor::new();
648        let result = predictor.train(&[]);
649        assert!(result.is_err());
650        assert!(result
651            .unwrap_err()
652            .to_string()
653            .contains("Cannot train on empty dataset"));
654    }
655
656    #[test]
657    fn test_predictor_predict_not_trained_error() {
658        let predictor = DefectPredictor::new();
659        let result = predictor.predict(&make_feature(0, 1));
660        assert!(result.is_err());
661        assert!(result
662            .unwrap_err()
663            .to_string()
664            .contains("Model not trained"));
665    }
666
667    #[test]
668    fn test_predictor_proba_not_trained_error() {
669        let predictor = DefectPredictor::new();
670        let result = predictor.predict_proba(&make_feature(0, 1));
671        assert!(result.is_err());
672        assert!(result
673            .unwrap_err()
674            .to_string()
675            .contains("Model not trained"));
676    }
677
678    #[test]
679    fn test_predictor_default() {
680        let predictor = DefectPredictor::default();
681        assert_eq!(predictor.params(), (100, 10));
682        assert!(!predictor.is_trained());
683    }
684
685    #[test]
686    fn test_clusterer_fit_empty_error() {
687        let mut clusterer = PatternClusterer::new();
688        let result = clusterer.fit(&[]);
689        assert!(result.is_err());
690        assert!(result
691            .unwrap_err()
692            .to_string()
693            .contains("Cannot cluster empty dataset"));
694    }
695
696    #[test]
697    fn test_clusterer_fit_too_few_samples_error() {
698        let mut clusterer = PatternClusterer::with_k(5);
699        let features = vec![make_feature(0, 1), make_feature(0, 2)];
700        let result = clusterer.fit(&features);
701        assert!(result.is_err());
702        assert!(result
703            .unwrap_err()
704            .to_string()
705            .contains("Need at least 5 samples"));
706    }
707
708    #[test]
709    fn test_clusterer_predict_not_fitted_error() {
710        let clusterer = PatternClusterer::new();
711        let result = clusterer.predict(&make_feature(0, 1));
712        assert!(result.is_err());
713        assert!(result
714            .unwrap_err()
715            .to_string()
716            .contains("Clusterer not fitted"));
717    }
718
719    #[test]
720    fn test_clusterer_predict_batch_not_fitted_error() {
721        let clusterer = PatternClusterer::new();
722        let result = clusterer.predict_batch(&[make_feature(0, 1)]);
723        assert!(result.is_err());
724        assert!(result
725            .unwrap_err()
726            .to_string()
727            .contains("Clusterer not fitted"));
728    }
729
730    #[test]
731    fn test_clusterer_inertia_not_fitted_error() {
732        let clusterer = PatternClusterer::new();
733        let result = clusterer.inertia(&[make_feature(0, 1)]);
734        assert!(result.is_err());
735        assert!(result
736            .unwrap_err()
737            .to_string()
738            .contains("Clusterer not fitted"));
739    }
740
741    #[test]
742    fn test_clusterer_default() {
743        let clusterer = PatternClusterer::default();
744        assert_eq!(clusterer.k, 5);
745        assert!(!clusterer.trained);
746    }
747
748    #[test]
749    fn test_accuracy_empty_arrays() {
750        let acc = ModelMetrics::accuracy(&[], &[]);
751        assert_eq!(acc, 0.0);
752    }
753
754    #[test]
755    fn test_accuracy_mismatched_lengths() {
756        let predictions = vec![0, 0, 1];
757        let labels = vec![0, 0];
758        let acc = ModelMetrics::accuracy(&predictions, &labels);
759        assert_eq!(acc, 0.0);
760    }
761
762    #[test]
763    fn test_precision_no_predicted_positives() {
764        let predictions = vec![0, 0, 0, 0];
765        let labels = vec![1, 1, 0, 0];
766        let precision = ModelMetrics::precision(&predictions, &labels, 1);
767        assert_eq!(precision, 0.0);
768    }
769
770    #[test]
771    fn test_recall_no_actual_positives() {
772        let predictions = vec![1, 1, 0, 0];
773        let labels = vec![0, 0, 0, 0];
774        let recall = ModelMetrics::recall(&predictions, &labels, 1);
775        assert_eq!(recall, 0.0);
776    }
777
778    #[test]
779    fn test_f1_zero_precision_and_recall() {
780        let predictions = vec![0, 0, 0, 0];
781        let labels = vec![1, 1, 0, 0];
782        let f1 = ModelMetrics::f1_score(&predictions, &labels, 1);
783        assert_eq!(f1, 0.0);
784    }
785
786    #[test]
787    fn test_silhouette_empty_features() {
788        let score = ModelMetrics::silhouette_score(&[], &[], 2);
789        assert_eq!(score, 0.0);
790    }
791
792    #[test]
793    fn test_silhouette_mismatched_lengths() {
794        let features = vec![make_feature(0, 1), make_feature(0, 2)];
795        let assignments = vec![0];
796        let score = ModelMetrics::silhouette_score(&features, &assignments, 2);
797        assert_eq!(score, 0.0);
798    }
799
800    #[test]
801    fn test_clusterer_predict_single_item() {
802        let mut clusterer = PatternClusterer::with_k(2);
803        let features = vec![
804            make_feature(0, 1),
805            make_feature(0, 2),
806            make_feature(1, 100),
807            make_feature(1, 101),
808        ];
809
810        clusterer.fit(&features).unwrap();
811
812        let test_feature = make_feature(0, 1);
813        let cluster = clusterer.predict(&test_feature).unwrap();
814        assert!(cluster < 2);
815    }
816
817    #[test]
818    fn test_predictor_with_small_training_set() {
819        let mut predictor = DefectPredictor::new();
820        let features = vec![make_feature(0, 1), make_feature(1, 10)];
821
822        predictor.train(&features).unwrap();
823
824        // k is clamped to min(5, training_data.len()) = 2
825        let pred = predictor.predict(&make_feature(0, 2)).unwrap();
826        assert!(pred < 10);
827    }
828
829    #[test]
830    fn test_predictor_proba_sums_to_one() {
831        let mut predictor = DefectPredictor::new();
832        let features = vec![make_feature(0, 1), make_feature(1, 10), make_feature(2, 20)];
833
834        predictor.train(&features).unwrap();
835
836        let probs = predictor.predict_proba(&make_feature(0, 1)).unwrap();
837        let sum: f32 = probs.iter().sum();
838        assert!((sum - 1.0).abs() < 0.01);
839    }
840}