Skip to main content

verificar/ml/
trainer.rs

1//! Bug prediction model training pipeline
2//!
3//! This module provides a complete ML training pipeline:
4//! - Train/test split with stratification
5//! - Cross-validation (k-fold)
6//! - Metrics tracking (precision, recall, F1, AUC)
7//! - Model serialization
8
9use crate::data::CodeFeatures;
10use crate::Result;
11use rand::seq::SliceRandom;
12use rand::SeedableRng;
13use serde::{Deserialize, Serialize};
14
15/// Metrics from model evaluation
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct ModelMetrics {
18    /// True positives
19    pub true_positives: usize,
20    /// True negatives
21    pub true_negatives: usize,
22    /// False positives
23    pub false_positives: usize,
24    /// False negatives
25    pub false_negatives: usize,
26    /// Precision = TP / (TP + FP)
27    pub precision: f64,
28    /// Recall = TP / (TP + FN)
29    pub recall: f64,
30    /// F1 score = 2 * (precision * recall) / (precision + recall)
31    pub f1_score: f64,
32    /// Accuracy = (TP + TN) / total
33    pub accuracy: f64,
34    /// Area under ROC curve (approximated)
35    pub auc: f64,
36}
37
38impl ModelMetrics {
39    /// Compute metrics from predictions and ground truth
40    #[must_use]
41    pub fn compute(predictions: &[bool], ground_truth: &[bool]) -> Self {
42        let mut tp = 0;
43        let mut tn = 0;
44        let mut fp = 0;
45        let mut r#fn = 0;
46
47        for (pred, truth) in predictions.iter().zip(ground_truth.iter()) {
48            match (pred, truth) {
49                (true, true) => tp += 1,
50                (false, false) => tn += 1,
51                (true, false) => fp += 1,
52                (false, true) => r#fn += 1,
53            }
54        }
55
56        let precision = if tp + fp > 0 {
57            tp as f64 / (tp + fp) as f64
58        } else {
59            0.0
60        };
61
62        let recall = if tp + r#fn > 0 {
63            tp as f64 / (tp + r#fn) as f64
64        } else {
65            0.0
66        };
67
68        let f1_score = if precision + recall > 0.0 {
69            2.0 * (precision * recall) / (precision + recall)
70        } else {
71            0.0
72        };
73
74        let total = tp + tn + fp + r#fn;
75        let accuracy = if total > 0 {
76            (tp + tn) as f64 / total as f64
77        } else {
78            0.0
79        };
80
81        // Approximate AUC using balanced accuracy
82        let tpr = recall;
83        let tnr = if tn + fp > 0 {
84            tn as f64 / (tn + fp) as f64
85        } else {
86            0.0
87        };
88        let auc = (tpr + tnr) / 2.0;
89
90        Self {
91            true_positives: tp,
92            true_negatives: tn,
93            false_positives: fp,
94            false_negatives: r#fn,
95            precision,
96            recall,
97            f1_score,
98            accuracy,
99            auc,
100        }
101    }
102
103    /// Average metrics from multiple folds
104    #[must_use]
105    pub fn average(metrics: &[ModelMetrics]) -> Self {
106        if metrics.is_empty() {
107            return Self::default();
108        }
109
110        let n = metrics.len() as f64;
111        Self {
112            true_positives: metrics.iter().map(|m| m.true_positives).sum::<usize>() / metrics.len(),
113            true_negatives: metrics.iter().map(|m| m.true_negatives).sum::<usize>() / metrics.len(),
114            false_positives: metrics.iter().map(|m| m.false_positives).sum::<usize>()
115                / metrics.len(),
116            false_negatives: metrics.iter().map(|m| m.false_negatives).sum::<usize>()
117                / metrics.len(),
118            precision: metrics.iter().map(|m| m.precision).sum::<f64>() / n,
119            recall: metrics.iter().map(|m| m.recall).sum::<f64>() / n,
120            f1_score: metrics.iter().map(|m| m.f1_score).sum::<f64>() / n,
121            accuracy: metrics.iter().map(|m| m.accuracy).sum::<f64>() / n,
122            auc: metrics.iter().map(|m| m.auc).sum::<f64>() / n,
123        }
124    }
125}
126
127/// Configuration for model training
128#[derive(Debug, Clone)]
129pub struct TrainingConfig {
130    /// Train/test split ratio (e.g., 0.8 for 80% train)
131    pub train_ratio: f64,
132    /// Number of cross-validation folds
133    pub cv_folds: usize,
134    /// Random seed for reproducibility
135    pub seed: u64,
136    /// Number of trees in random forest
137    pub n_trees: usize,
138    /// Maximum depth of trees
139    pub max_depth: usize,
140}
141
142impl Default for TrainingConfig {
143    fn default() -> Self {
144        Self {
145            train_ratio: 0.8,
146            cv_folds: 5,
147            seed: 42,
148            n_trees: 100,
149            max_depth: 10,
150        }
151    }
152}
153
154/// Results from model training
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TrainingResult {
157    /// Metrics on training set
158    pub train_metrics: ModelMetrics,
159    /// Metrics on test set
160    pub test_metrics: ModelMetrics,
161    /// Cross-validation metrics (one per fold)
162    pub cv_metrics: Vec<ModelMetrics>,
163    /// Average cross-validation metrics
164    pub cv_average: ModelMetrics,
165    /// Number of training samples
166    pub train_samples: usize,
167    /// Number of test samples
168    pub test_samples: usize,
169}
170
171/// Model trainer for bug prediction
172#[derive(Debug)]
173pub struct ModelTrainer {
174    config: TrainingConfig,
175}
176
177impl ModelTrainer {
178    /// Create a new trainer with default configuration
179    #[must_use]
180    pub fn new() -> Self {
181        Self {
182            config: TrainingConfig::default(),
183        }
184    }
185
186    /// Create trainer with custom configuration
187    #[must_use]
188    pub fn with_config(config: TrainingConfig) -> Self {
189        Self { config }
190    }
191
192    /// Set train/test split ratio
193    #[must_use]
194    pub fn train_ratio(mut self, ratio: f64) -> Self {
195        self.config.train_ratio = ratio.clamp(0.1, 0.99);
196        self
197    }
198
199    /// Set number of cross-validation folds
200    #[must_use]
201    pub fn cv_folds(mut self, folds: usize) -> Self {
202        self.config.cv_folds = folds.max(2);
203        self
204    }
205
206    /// Set random seed
207    #[must_use]
208    pub fn seed(mut self, seed: u64) -> Self {
209        self.config.seed = seed;
210        self
211    }
212
213    /// Split data into train and test sets with stratification
214    pub fn train_test_split(
215        &self,
216        features: &[CodeFeatures],
217        labels: &[bool],
218    ) -> (Vec<CodeFeatures>, Vec<bool>, Vec<CodeFeatures>, Vec<bool>) {
219        let mut rng = rand::rngs::StdRng::seed_from_u64(self.config.seed);
220
221        // Separate positive and negative samples for stratification
222        let positives: Vec<usize> = labels
223            .iter()
224            .enumerate()
225            .filter(|(_, &l)| l)
226            .map(|(i, _)| i)
227            .collect();
228        let negatives: Vec<usize> = labels
229            .iter()
230            .enumerate()
231            .filter(|(_, &l)| !l)
232            .map(|(i, _)| i)
233            .collect();
234
235        // Shuffle each class
236        let mut pos_shuffled = positives.clone();
237        let mut neg_shuffled = negatives.clone();
238        pos_shuffled.shuffle(&mut rng);
239        neg_shuffled.shuffle(&mut rng);
240
241        // Split each class by ratio (values are always non-negative)
242        #[allow(clippy::cast_sign_loss)]
243        let pos_split = (pos_shuffled.len() as f64 * self.config.train_ratio) as usize;
244        #[allow(clippy::cast_sign_loss)]
245        let neg_split = (neg_shuffled.len() as f64 * self.config.train_ratio) as usize;
246
247        let train_indices: Vec<usize> = pos_shuffled[..pos_split]
248            .iter()
249            .chain(neg_shuffled[..neg_split].iter())
250            .copied()
251            .collect();
252
253        let test_indices: Vec<usize> = pos_shuffled[pos_split..]
254            .iter()
255            .chain(neg_shuffled[neg_split..].iter())
256            .copied()
257            .collect();
258
259        let train_features: Vec<CodeFeatures> =
260            train_indices.iter().map(|&i| features[i].clone()).collect();
261        let train_labels: Vec<bool> = train_indices.iter().map(|&i| labels[i]).collect();
262        let test_features: Vec<CodeFeatures> =
263            test_indices.iter().map(|&i| features[i].clone()).collect();
264        let test_labels: Vec<bool> = test_indices.iter().map(|&i| labels[i]).collect();
265
266        (train_features, train_labels, test_features, test_labels)
267    }
268
269    /// Perform k-fold cross-validation
270    ///
271    /// # Errors
272    ///
273    /// Returns error if evaluation fails on any fold.
274    pub fn cross_validate(
275        &self,
276        features: &[CodeFeatures],
277        labels: &[bool],
278    ) -> Result<Vec<ModelMetrics>> {
279        let mut rng = rand::rngs::StdRng::seed_from_u64(self.config.seed);
280        let n = features.len();
281        let fold_size = n / self.config.cv_folds;
282
283        // Shuffle indices
284        let mut indices: Vec<usize> = (0..n).collect();
285        indices.shuffle(&mut rng);
286
287        let mut metrics = Vec::with_capacity(self.config.cv_folds);
288
289        for fold in 0..self.config.cv_folds {
290            let start = fold * fold_size;
291            let end = if fold == self.config.cv_folds - 1 {
292                n
293            } else {
294                start + fold_size
295            };
296
297            // Test fold
298            let test_indices: Vec<usize> = indices[start..end].to_vec();
299
300            // Train on remaining folds
301            let train_indices: Vec<usize> = indices[..start]
302                .iter()
303                .chain(indices[end..].iter())
304                .copied()
305                .collect();
306
307            let train_features: Vec<CodeFeatures> =
308                train_indices.iter().map(|&i| features[i].clone()).collect();
309            let train_labels: Vec<bool> = train_indices.iter().map(|&i| labels[i]).collect();
310            let test_features: Vec<CodeFeatures> =
311                test_indices.iter().map(|&i| features[i].clone()).collect();
312            let test_labels: Vec<bool> = test_indices.iter().map(|&i| labels[i]).collect();
313
314            // Train and evaluate on this fold
315            let fold_metrics = self.train_and_evaluate(
316                &train_features,
317                &train_labels,
318                &test_features,
319                &test_labels,
320            )?;
321            metrics.push(fold_metrics);
322        }
323
324        Ok(metrics)
325    }
326
327    /// Train model and evaluate on test set
328    fn train_and_evaluate(
329        &self,
330        _train_features: &[CodeFeatures],
331        _train_labels: &[bool],
332        test_features: &[CodeFeatures],
333        test_labels: &[bool],
334    ) -> Result<ModelMetrics> {
335        // Use heuristic predictor for now (aprender training requires 'ml' feature)
336        // This demonstrates the training pipeline structure
337        let predictor = super::BugPredictor::new();
338
339        let predictions: Vec<bool> = test_features
340            .iter()
341            .map(|f| predictor.predict(f) > 0.5)
342            .collect();
343
344        Ok(ModelMetrics::compute(&predictions, test_labels))
345    }
346
347    /// Full training pipeline: split, train, cross-validate, evaluate
348    ///
349    /// # Errors
350    ///
351    /// Returns error if training or evaluation fails.
352    pub fn train(&self, features: &[CodeFeatures], labels: &[bool]) -> Result<TrainingResult> {
353        // Train/test split
354        let (train_features, train_labels, test_features, test_labels) =
355            self.train_test_split(features, labels);
356
357        // Cross-validation on training set
358        let cv_metrics = self.cross_validate(&train_features, &train_labels)?;
359        let cv_average = ModelMetrics::average(&cv_metrics);
360
361        // Final evaluation on test set
362        let train_metrics = self.train_and_evaluate(
363            &train_features,
364            &train_labels,
365            &train_features,
366            &train_labels,
367        )?;
368        let test_metrics =
369            self.train_and_evaluate(&train_features, &train_labels, &test_features, &test_labels)?;
370
371        Ok(TrainingResult {
372            train_metrics,
373            test_metrics,
374            cv_metrics,
375            cv_average,
376            train_samples: train_features.len(),
377            test_samples: test_features.len(),
378        })
379    }
380}
381
382impl Default for ModelTrainer {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388/// Serializable model state for persistence
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct SerializedModel {
391    /// Model version
392    pub version: String,
393    /// Training configuration
394    pub config: TrainingConfig,
395    /// Training metrics
396    pub metrics: ModelMetrics,
397    /// Feature weights (for simple models)
398    pub weights: Vec<f64>,
399}
400
401impl SerializedModel {
402    /// Save model to JSON file
403    ///
404    /// # Errors
405    ///
406    /// Returns error if file writing fails
407    pub fn save(&self, path: &str) -> Result<()> {
408        let json = serde_json::to_string_pretty(self)
409            .map_err(|e| crate::Error::Data(format!("Serialization failed: {e}")))?;
410        std::fs::write(path, json)
411            .map_err(|e| crate::Error::Data(format!("Failed to write file: {e}")))?;
412        Ok(())
413    }
414
415    /// Load model from JSON file
416    ///
417    /// # Errors
418    ///
419    /// Returns error if file reading or parsing fails
420    pub fn load(path: &str) -> Result<Self> {
421        let json = std::fs::read_to_string(path)
422            .map_err(|e| crate::Error::Data(format!("Failed to read file: {e}")))?;
423        let model: Self = serde_json::from_str(&json)
424            .map_err(|e| crate::Error::Data(format!("Deserialization failed: {e}")))?;
425        Ok(model)
426    }
427}
428
429// Implement Serialize/Deserialize for TrainingConfig
430impl Serialize for TrainingConfig {
431    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
432    where
433        S: serde::Serializer,
434    {
435        use serde::ser::SerializeStruct;
436        let mut state = serializer.serialize_struct("TrainingConfig", 5)?;
437        state.serialize_field("train_ratio", &self.train_ratio)?;
438        state.serialize_field("cv_folds", &self.cv_folds)?;
439        state.serialize_field("seed", &self.seed)?;
440        state.serialize_field("n_trees", &self.n_trees)?;
441        state.serialize_field("max_depth", &self.max_depth)?;
442        state.end()
443    }
444}
445
446impl<'de> Deserialize<'de> for TrainingConfig {
447    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
448    where
449        D: serde::Deserializer<'de>,
450    {
451        #[derive(Deserialize)]
452        struct Helper {
453            train_ratio: f64,
454            cv_folds: usize,
455            seed: u64,
456            n_trees: usize,
457            max_depth: usize,
458        }
459
460        let helper = Helper::deserialize(deserializer)?;
461        Ok(Self {
462            train_ratio: helper.train_ratio,
463            cv_folds: helper.cv_folds,
464            seed: helper.seed,
465            n_trees: helper.n_trees,
466            max_depth: helper.max_depth,
467        })
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    fn sample_data() -> (Vec<CodeFeatures>, Vec<bool>) {
476        let features: Vec<CodeFeatures> = (0..100)
477            .map(|i| CodeFeatures {
478                ast_depth: (i % 10) as u32,
479                num_operators: (i % 20) as u32,
480                num_control_flow: (i % 5) as u32,
481                cyclomatic_complexity: (i % 15) as f32,
482                uses_edge_values: i % 3 == 0,
483                ..Default::default()
484            })
485            .collect();
486        let labels: Vec<bool> = (0..100).map(|i| i % 4 == 0).collect();
487        (features, labels)
488    }
489
490    #[test]
491    fn test_model_metrics_compute() {
492        let predictions = vec![true, true, false, false, true];
493        let ground_truth = vec![true, false, false, true, true];
494
495        let metrics = ModelMetrics::compute(&predictions, &ground_truth);
496
497        assert_eq!(metrics.true_positives, 2);
498        assert_eq!(metrics.true_negatives, 1);
499        assert_eq!(metrics.false_positives, 1);
500        assert_eq!(metrics.false_negatives, 1);
501        assert!((metrics.precision - 0.666).abs() < 0.01);
502        assert!((metrics.recall - 0.666).abs() < 0.01);
503    }
504
505    #[test]
506    fn test_model_metrics_perfect() {
507        let predictions = vec![true, false, true, false];
508        let ground_truth = vec![true, false, true, false];
509
510        let metrics = ModelMetrics::compute(&predictions, &ground_truth);
511
512        assert!((metrics.precision - 1.0).abs() < f64::EPSILON);
513        assert!((metrics.recall - 1.0).abs() < f64::EPSILON);
514        assert!((metrics.f1_score - 1.0).abs() < f64::EPSILON);
515        assert!((metrics.accuracy - 1.0).abs() < f64::EPSILON);
516    }
517
518    #[test]
519    fn test_model_metrics_average() {
520        let metrics = vec![
521            ModelMetrics {
522                precision: 0.8,
523                recall: 0.7,
524                f1_score: 0.75,
525                accuracy: 0.85,
526                auc: 0.9,
527                ..Default::default()
528            },
529            ModelMetrics {
530                precision: 0.6,
531                recall: 0.9,
532                f1_score: 0.72,
533                accuracy: 0.75,
534                auc: 0.8,
535                ..Default::default()
536            },
537        ];
538
539        let avg = ModelMetrics::average(&metrics);
540
541        assert!((avg.precision - 0.7).abs() < f64::EPSILON);
542        assert!((avg.recall - 0.8).abs() < f64::EPSILON);
543    }
544
545    #[test]
546    fn test_training_config_default() {
547        let config = TrainingConfig::default();
548        assert!((config.train_ratio - 0.8).abs() < f64::EPSILON);
549        assert_eq!(config.cv_folds, 5);
550        assert_eq!(config.seed, 42);
551    }
552
553    #[test]
554    fn test_trainer_new() {
555        let trainer = ModelTrainer::new();
556        assert!((trainer.config.train_ratio - 0.8).abs() < f64::EPSILON);
557    }
558
559    #[test]
560    fn test_trainer_builder() {
561        let trainer = ModelTrainer::new().train_ratio(0.7).cv_folds(10).seed(123);
562
563        assert!((trainer.config.train_ratio - 0.7).abs() < f64::EPSILON);
564        assert_eq!(trainer.config.cv_folds, 10);
565        assert_eq!(trainer.config.seed, 123);
566    }
567
568    #[test]
569    fn test_train_test_split() {
570        let (features, labels) = sample_data();
571        let trainer = ModelTrainer::new();
572
573        let (train_f, train_l, test_f, test_l) = trainer.train_test_split(&features, &labels);
574
575        // Check split ratio approximately
576        let total = features.len();
577        let train_expected = (total as f64 * 0.8) as usize;
578        assert!(train_f.len() >= train_expected - 5 && train_f.len() <= train_expected + 5);
579        assert_eq!(train_f.len(), train_l.len());
580        assert_eq!(test_f.len(), test_l.len());
581    }
582
583    #[test]
584    fn test_cross_validate() {
585        let (features, labels) = sample_data();
586        let trainer = ModelTrainer::new().cv_folds(5);
587
588        let cv_metrics = trainer.cross_validate(&features, &labels).unwrap();
589
590        assert_eq!(cv_metrics.len(), 5);
591        for m in &cv_metrics {
592            assert!((0.0..=1.0).contains(&m.accuracy));
593        }
594    }
595
596    #[test]
597    fn test_train_full_pipeline() {
598        let (features, labels) = sample_data();
599        let trainer = ModelTrainer::new();
600
601        let result = trainer.train(&features, &labels).unwrap();
602
603        assert!(result.train_samples > 0);
604        assert!(result.test_samples > 0);
605        assert_eq!(result.cv_metrics.len(), 5);
606        assert!((0.0..=1.0).contains(&result.test_metrics.accuracy));
607    }
608
609    #[test]
610    fn test_serialized_model() {
611        let model = SerializedModel {
612            version: "0.1.0".to_string(),
613            config: TrainingConfig::default(),
614            metrics: ModelMetrics::default(),
615            weights: vec![0.1, 0.2, 0.3],
616        };
617
618        let json = serde_json::to_string(&model).unwrap();
619        let loaded: SerializedModel = serde_json::from_str(&json).unwrap();
620
621        assert_eq!(loaded.version, "0.1.0");
622        assert_eq!(loaded.weights.len(), 3);
623    }
624
625    #[test]
626    fn test_training_result_serialize() {
627        let result = TrainingResult {
628            train_metrics: ModelMetrics::default(),
629            test_metrics: ModelMetrics::default(),
630            cv_metrics: vec![ModelMetrics::default()],
631            cv_average: ModelMetrics::default(),
632            train_samples: 80,
633            test_samples: 20,
634        };
635
636        let json = serde_json::to_string(&result).unwrap();
637        assert!(json.contains("train_samples"));
638    }
639
640    #[test]
641    fn test_model_metrics_empty() {
642        let metrics = ModelMetrics::compute(&[], &[]);
643        assert_eq!(metrics.true_positives, 0);
644        assert!((metrics.accuracy - 0.0).abs() < f64::EPSILON);
645    }
646
647    #[test]
648    fn test_model_metrics_all_negative() {
649        let predictions = vec![false, false, false];
650        let ground_truth = vec![false, false, false];
651
652        let metrics = ModelMetrics::compute(&predictions, &ground_truth);
653
654        assert_eq!(metrics.true_negatives, 3);
655        assert!((metrics.accuracy - 1.0).abs() < f64::EPSILON);
656    }
657
658    #[test]
659    fn test_trainer_ratio_clamp() {
660        let trainer = ModelTrainer::new().train_ratio(0.05);
661        assert!((trainer.config.train_ratio - 0.1).abs() < f64::EPSILON);
662
663        let trainer = ModelTrainer::new().train_ratio(1.5);
664        assert!((trainer.config.train_ratio - 0.99).abs() < f64::EPSILON);
665    }
666
667    #[test]
668    fn test_trainer_cv_folds_min() {
669        let trainer = ModelTrainer::new().cv_folds(1);
670        assert_eq!(trainer.config.cv_folds, 2);
671    }
672
673    #[test]
674    fn test_model_metrics_auc_calculation() {
675        // Test that AUC is computed correctly when tn + fp > 0
676        // Mutant: replace > with == would break this
677        //
678        // Create data to produce: tp=10, tn=20, fp=5, fn=3
679        let mut predictions = Vec::new();
680        let mut ground_truth = Vec::new();
681
682        // 10 true positives (pred=true, truth=true)
683        for _ in 0..10 {
684            predictions.push(true);
685            ground_truth.push(true);
686        }
687        // 20 true negatives (pred=false, truth=false)
688        for _ in 0..20 {
689            predictions.push(false);
690            ground_truth.push(false);
691        }
692        // 5 false positives (pred=true, truth=false)
693        for _ in 0..5 {
694            predictions.push(true);
695            ground_truth.push(false);
696        }
697        // 3 false negatives (pred=false, truth=true)
698        for _ in 0..3 {
699            predictions.push(false);
700            ground_truth.push(true);
701        }
702
703        let metrics = ModelMetrics::compute(&predictions, &ground_truth);
704
705        // tnr = tn / (tn + fp) = 20 / (20 + 5) = 0.8
706        // tpr = recall = tp / (tp + fn) = 10 / (10 + 3) = 0.769...
707        // auc = (tpr + tnr) / 2 = (0.769 + 0.8) / 2 = 0.784...
708        assert!(
709            metrics.auc > 0.7,
710            "AUC should be > 0.7, got {}",
711            metrics.auc
712        );
713        assert!(
714            metrics.auc < 0.85,
715            "AUC should be < 0.85, got {}",
716            metrics.auc
717        );
718
719        // If mutant changed > to ==, tnr would be 0.0 when tn+fp > 0
720        // Then auc would be ~0.38, which would fail the > 0.7 check
721    }
722
723    #[test]
724    fn test_model_metrics_average_fp_fn() {
725        // Test that false_positives and false_negatives are averaged (divided) correctly
726        // Mutant: replace / with * would break this
727        let metrics = vec![
728            ModelMetrics {
729                false_positives: 10,
730                false_negatives: 20,
731                ..Default::default()
732            },
733            ModelMetrics {
734                false_positives: 30,
735                false_negatives: 40,
736                ..Default::default()
737            },
738        ];
739
740        let avg = ModelMetrics::average(&metrics);
741
742        // Average of [10, 30] = 20, not 80 (if * instead of /)
743        assert_eq!(avg.false_positives, 20);
744        // Average of [20, 40] = 30, not 120 (if * instead of /)
745        assert_eq!(avg.false_negatives, 30);
746    }
747
748    #[test]
749    fn test_trainer_with_config() {
750        // Test that with_config actually uses the provided config
751        // Mutant: replace with Default::default() would break this
752        let config = TrainingConfig {
753            train_ratio: 0.6,
754            cv_folds: 10,
755            seed: 12345,
756            n_trees: 50,
757            max_depth: 5,
758        };
759        let trainer = ModelTrainer::with_config(config);
760
761        // These values are different from defaults, so mutant would fail
762        assert!((trainer.config.train_ratio - 0.6).abs() < f64::EPSILON);
763        assert_eq!(trainer.config.cv_folds, 10);
764        assert_eq!(trainer.config.seed, 12345);
765    }
766}