organizational_intelligence_plugin/
ml_trainer.rs

1//! ML model training module for defect classification.
2//!
3//! This module implements Phase 2 ML classifier training:
4//! - Load training data from JSON
5//! - Extract TF-IDF features from commit messages
6//! - Train RandomForestClassifier
7//! - Evaluate on validation and test sets
8//! - Save trained model to disk
9//!
10//! Implements Section 3 ML Classification from nlp-models-techniques-spec.md
11
12use crate::classifier::DefectCategory;
13use crate::nlp::TfidfFeatureExtractor;
14use crate::training::{TrainingDataset, TrainingExample};
15use anyhow::{anyhow, Result};
16use aprender::primitives::Matrix;
17use aprender::tree::RandomForestClassifier;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::fs;
21use std::path::Path;
22
23/// Trained ML classifier model with metadata
24#[derive(Serialize, Deserialize)]
25pub struct TrainedModel {
26    /// Random Forest classifier
27    #[serde(skip)]
28    pub classifier: Option<RandomForestClassifier>,
29    /// TF-IDF feature extractor
30    #[serde(skip)]
31    pub tfidf_extractor: Option<TfidfFeatureExtractor>,
32    /// Mapping from category to label index
33    pub category_to_label: HashMap<String, usize>,
34    /// Mapping from label index to category
35    pub label_to_category: HashMap<usize, String>,
36    /// Training metadata
37    pub metadata: TrainingMetadata,
38    /// TF-IDF vocabulary (for reconstruction)
39    pub tfidf_vocabulary: Vec<String>,
40    /// Max features for TF-IDF
41    pub max_features: usize,
42}
43
44/// Metadata about model training
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TrainingMetadata {
47    /// Number of training examples
48    pub n_train: usize,
49    /// Number of validation examples
50    pub n_validation: usize,
51    /// Number of test examples
52    pub n_test: usize,
53    /// Number of trees in Random Forest
54    pub n_estimators: usize,
55    /// Maximum tree depth
56    pub max_depth: Option<usize>,
57    /// Number of TF-IDF features
58    pub n_features: usize,
59    /// Number of classes
60    pub n_classes: usize,
61    /// Training accuracy
62    pub train_accuracy: f32,
63    /// Validation accuracy
64    pub validation_accuracy: f32,
65    /// Test accuracy (optional - set after evaluation)
66    pub test_accuracy: Option<f32>,
67}
68
69/// ML model trainer
70pub struct MLTrainer {
71    n_estimators: usize,
72    max_depth: Option<usize>,
73    max_features: usize,
74    random_state: u64,
75}
76
77impl MLTrainer {
78    /// Create a new ML trainer
79    ///
80    /// # Arguments
81    ///
82    /// * `n_estimators` - Number of trees in Random Forest
83    /// * `max_depth` - Maximum depth of each tree (None for unlimited)
84    /// * `max_features` - Maximum TF-IDF features
85    ///
86    /// # Examples
87    ///
88    /// ```rust
89    /// use organizational_intelligence_plugin::ml_trainer::MLTrainer;
90    ///
91    /// let trainer = MLTrainer::new(100, Some(20), 1500);
92    /// ```
93    pub fn new(n_estimators: usize, max_depth: Option<usize>, max_features: usize) -> Self {
94        Self {
95            n_estimators,
96            max_depth,
97            max_features,
98            random_state: 42,
99        }
100    }
101
102    /// Load training dataset from JSON file
103    ///
104    /// # Arguments
105    ///
106    /// * `path` - Path to training data JSON file
107    ///
108    /// # Returns
109    ///
110    /// * `Ok(TrainingDataset)` if successful
111    /// * `Err` if file not found or invalid format
112    pub fn load_dataset<P: AsRef<Path>>(path: P) -> Result<TrainingDataset> {
113        let content = fs::read_to_string(path.as_ref())
114            .map_err(|e| anyhow!("Failed to read training data: {}", e))?;
115
116        serde_json::from_str(&content)
117            .map_err(|e| anyhow!("Failed to parse training data JSON: {}", e))
118    }
119
120    /// Train ML classifier on training dataset
121    ///
122    /// # Arguments
123    ///
124    /// * `dataset` - Training dataset with splits
125    ///
126    /// # Returns
127    ///
128    /// * `Ok(TrainedModel)` with trained classifier
129    /// * `Err` if training fails
130    ///
131    /// # Examples
132    ///
133    /// ```no_run
134    /// use organizational_intelligence_plugin::ml_trainer::MLTrainer;
135    /// use std::path::PathBuf;
136    ///
137    /// # async fn example() -> Result<(), anyhow::Error> {
138    /// let trainer = MLTrainer::new(100, Some(20), 1500);
139    /// let dataset = MLTrainer::load_dataset("training-data.json")?;
140    /// let model = trainer.train(&dataset)?;
141    /// # Ok(())
142    /// # }
143    /// ```
144    pub fn train(&self, dataset: &TrainingDataset) -> Result<TrainedModel> {
145        if dataset.train.is_empty() {
146            return Err(anyhow!("Training dataset is empty"));
147        }
148
149        // Extract messages and labels
150        let train_messages: Vec<String> =
151            dataset.train.iter().map(|ex| ex.message.clone()).collect();
152
153        let validation_messages: Vec<String> = dataset
154            .validation
155            .iter()
156            .map(|ex| ex.message.clone())
157            .collect();
158
159        // Build category-to-label mapping
160        let mut unique_categories: Vec<String> = dataset
161            .train
162            .iter()
163            .map(|ex| format!("{}", ex.label))
164            .collect();
165        unique_categories.sort();
166        unique_categories.dedup();
167
168        let category_to_label: HashMap<String, usize> = unique_categories
169            .iter()
170            .enumerate()
171            .map(|(i, cat)| (cat.clone(), i))
172            .collect();
173
174        let label_to_category: HashMap<usize, String> = unique_categories
175            .iter()
176            .enumerate()
177            .map(|(i, cat)| (i, cat.clone()))
178            .collect();
179
180        // Convert labels to indices
181        let train_labels: Vec<usize> = dataset
182            .train
183            .iter()
184            .map(|ex| {
185                *category_to_label
186                    .get(&format!("{}", ex.label))
187                    .unwrap_or(&0)
188            })
189            .collect();
190
191        let validation_labels: Vec<usize> = dataset
192            .validation
193            .iter()
194            .map(|ex| {
195                *category_to_label
196                    .get(&format!("{}", ex.label))
197                    .unwrap_or(&0)
198            })
199            .collect();
200
201        // Extract TF-IDF features
202        let mut tfidf_extractor = TfidfFeatureExtractor::new(self.max_features);
203        let train_features = tfidf_extractor.fit_transform(&train_messages)?;
204        let validation_features = tfidf_extractor.transform(&validation_messages)?;
205
206        // Convert Matrix<f64> to Matrix<f32> for RandomForestClassifier
207        let train_features_f32 = Self::convert_f64_to_f32(&train_features)?;
208        let validation_features_f32 = Self::convert_f64_to_f32(&validation_features)?;
209
210        // Train Random Forest
211        let mut classifier = RandomForestClassifier::new(self.n_estimators);
212        if let Some(depth) = self.max_depth {
213            classifier = classifier.with_max_depth(depth);
214        }
215        classifier = classifier.with_random_state(self.random_state);
216
217        classifier
218            .fit(&train_features_f32, &train_labels)
219            .map_err(|e| anyhow!("Random Forest training failed: {}", e))?;
220
221        // Evaluate on training set
222        let train_predictions = classifier.predict(&train_features_f32);
223        let train_accuracy = Self::calculate_accuracy(&train_predictions, &train_labels);
224
225        // Evaluate on validation set
226        let validation_predictions = classifier.predict(&validation_features_f32);
227        let validation_accuracy =
228            Self::calculate_accuracy(&validation_predictions, &validation_labels);
229
230        let metadata = TrainingMetadata {
231            n_train: dataset.train.len(),
232            n_validation: dataset.validation.len(),
233            n_test: dataset.test.len(),
234            n_estimators: self.n_estimators,
235            max_depth: self.max_depth,
236            n_features: tfidf_extractor.vocabulary_size(),
237            n_classes: unique_categories.len(),
238            train_accuracy,
239            validation_accuracy,
240            test_accuracy: None,
241        };
242
243        // Extract vocabulary for serialization (simplified - just store metadata)
244        // NOTE: Vocabulary extraction deferred - TfidfVectorizer stores features internally
245        let tfidf_vocabulary: Vec<String> = vec![];
246        let max_features = self.max_features;
247
248        Ok(TrainedModel {
249            classifier: Some(classifier),
250            tfidf_extractor: Some(tfidf_extractor),
251            category_to_label,
252            label_to_category,
253            metadata,
254            tfidf_vocabulary,
255            max_features,
256        })
257    }
258
259    /// Convert Matrix<f64> to Matrix<f32>
260    fn convert_f64_to_f32(matrix: &Matrix<f64>) -> Result<Matrix<f32>> {
261        let (n_rows, n_cols) = (matrix.n_rows(), matrix.n_cols());
262        let data_f32: Vec<f32> = (0..n_rows * n_cols)
263            .map(|i| {
264                let row = i / n_cols;
265                let col = i % n_cols;
266                matrix.get(row, col) as f32
267            })
268            .collect();
269
270        Matrix::from_vec(n_rows, n_cols, data_f32)
271            .map_err(|e| anyhow!("Failed to convert matrix: {}", e))
272    }
273
274    /// Calculate classification accuracy
275    fn calculate_accuracy(predictions: &[usize], labels: &[usize]) -> f32 {
276        if predictions.is_empty() || predictions.len() != labels.len() {
277            return 0.0;
278        }
279
280        let correct = predictions
281            .iter()
282            .zip(labels.iter())
283            .filter(|(pred, label)| pred == label)
284            .count();
285
286        correct as f32 / predictions.len() as f32
287    }
288
289    /// Evaluate model on test set
290    ///
291    /// # Arguments
292    ///
293    /// * `model` - Trained model
294    /// * `test_examples` - Test examples
295    ///
296    /// # Returns
297    ///
298    /// * Test accuracy (0.0-1.0)
299    pub fn evaluate(model: &TrainedModel, test_examples: &[TrainingExample]) -> Result<f32> {
300        if test_examples.is_empty() {
301            return Ok(0.0);
302        }
303
304        let classifier = model
305            .classifier
306            .as_ref()
307            .ok_or_else(|| anyhow!("Model has no classifier"))?;
308
309        let tfidf_extractor = model
310            .tfidf_extractor
311            .as_ref()
312            .ok_or_else(|| anyhow!("Model has no TF-IDF extractor"))?;
313
314        let test_messages: Vec<String> =
315            test_examples.iter().map(|ex| ex.message.clone()).collect();
316
317        let test_labels: Vec<usize> = test_examples
318            .iter()
319            .map(|ex| {
320                *model
321                    .category_to_label
322                    .get(&format!("{}", ex.label))
323                    .unwrap_or(&0)
324            })
325            .collect();
326
327        // Extract features and convert to f32
328        let test_features = tfidf_extractor.transform(&test_messages)?;
329        let test_features_f32 = Self::convert_f64_to_f32(&test_features)?;
330
331        // Predict and calculate accuracy
332        let test_predictions = classifier.predict(&test_features_f32);
333        let test_accuracy = Self::calculate_accuracy(&test_predictions, &test_labels);
334
335        Ok(test_accuracy)
336    }
337
338    /// Save trained model to disk
339    ///
340    /// # Arguments
341    ///
342    /// * `model` - Trained model
343    /// * `path` - Path to save model JSON
344    ///
345    /// # Returns
346    ///
347    /// * `Ok(())` if successful
348    /// * `Err` if save fails
349    pub fn save_model<P: AsRef<Path>>(model: &TrainedModel, path: P) -> Result<()> {
350        let json = serde_json::to_string_pretty(model)
351            .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
352
353        fs::write(path.as_ref(), json).map_err(|e| anyhow!("Failed to write model file: {}", e))?;
354
355        Ok(())
356    }
357
358    /// Load trained model from disk
359    ///
360    /// # Arguments
361    ///
362    /// * `path` - Path to model JSON file
363    ///
364    /// # Returns
365    ///
366    /// * `Ok(TrainedModel)` if successful
367    /// * `Err` if load fails
368    pub fn load_model<P: AsRef<Path>>(path: P) -> Result<TrainedModel> {
369        let content = fs::read_to_string(path.as_ref())
370            .map_err(|e| anyhow!("Failed to read model file: {}", e))?;
371
372        serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse model JSON: {}", e))
373    }
374}
375
376impl Default for MLTrainer {
377    fn default() -> Self {
378        Self::new(100, Some(20), 1500)
379    }
380}
381
382impl TrainedModel {
383    /// Predict defect category for a single commit message
384    ///
385    /// # Arguments
386    /// * `message` - Commit message to classify
387    ///
388    /// # Returns
389    /// * `Ok(Some((DefectCategory, f32)))` - Predicted category and confidence
390    /// * `Ok(None)` - Model components not available (deserialized model)
391    /// * `Err` - Prediction error
392    ///
393    /// # Examples
394    /// ```no_run
395    /// # use organizational_intelligence_plugin::ml_trainer::TrainedModel;
396    /// # fn example(model: &TrainedModel) -> anyhow::Result<()> {
397    /// if let Some((category, confidence)) = model.predict("fix: null pointer in parser")? {
398    ///     println!("Predicted: {:?} ({:.2})", category, confidence);
399    /// }
400    /// # Ok(())
401    /// # }
402    /// ```
403    pub fn predict(&self, message: &str) -> Result<Option<(DefectCategory, f32)>> {
404        // Check if model components are available
405        let tfidf = self
406            .tfidf_extractor
407            .as_ref()
408            .ok_or_else(|| anyhow!("TF-IDF extractor not available"))?;
409        let classifier = self
410            .classifier
411            .as_ref()
412            .ok_or_else(|| anyhow!("Classifier not available"))?;
413
414        // Extract TF-IDF features
415        let features = tfidf.transform(&[message.to_string()])?;
416
417        // Convert to f32 for Random Forest
418        let (n_rows, n_cols) = (features.n_rows(), features.n_cols());
419        let data_f32: Vec<f32> = (0..n_rows * n_cols)
420            .map(|i| {
421                let row = i / n_cols;
422                let col = i % n_cols;
423                features.get(row, col) as f32
424            })
425            .collect();
426
427        let features_f32 = Matrix::from_vec(n_rows, n_cols, data_f32)
428            .map_err(|e| anyhow!("Failed to create feature matrix: {}", e))?;
429
430        // Predict
431        let predictions = classifier.predict(&features_f32);
432
433        if predictions.is_empty() {
434            return Ok(None);
435        }
436
437        // Get predicted label index
438        let label_idx = predictions[0];
439
440        // Map label index back to DefectCategory
441        let category_name = self
442            .label_to_category
443            .get(&label_idx)
444            .ok_or_else(|| anyhow!("Unknown label index: {}", label_idx))?;
445
446        // Parse category name to DefectCategory enum
447        let category = Self::parse_category(category_name)?;
448
449        // NOTE: aprender RandomForest returns class labels, not probabilities.
450        // Using default confidence; upgrade when aprender adds predict_proba().
451        let confidence = 0.75f32;
452
453        Ok(Some((category, confidence)))
454    }
455
456    /// Predict top-N defect categories for a commit message
457    ///
458    /// # Arguments
459    /// * `message` - Commit message to classify
460    /// * `top_n` - Number of top categories to return
461    ///
462    /// # Returns
463    /// * `Ok(Vec<(DefectCategory, f32)>)` - Top-N categories with confidences
464    ///
465    /// # Examples
466    /// ```no_run
467    /// # use organizational_intelligence_plugin::ml_trainer::TrainedModel;
468    /// # fn example(model: &TrainedModel) -> anyhow::Result<()> {
469    /// let predictions = model.predict_top_n("fix: null pointer in parser", 3)?;
470    /// for (category, confidence) in predictions {
471    ///     println!("{:?}: {:.2}", category, confidence);
472    /// }
473    /// # Ok(())
474    /// # }
475    /// ```
476    pub fn predict_top_n(
477        &self,
478        message: &str,
479        _top_n: usize,
480    ) -> Result<Vec<(DefectCategory, f32)>> {
481        // NOTE: Returns single prediction; multi-label requires predict_proba() in aprender
482        if let Some((category, confidence)) = self.predict(message)? {
483            Ok(vec![(category, confidence)])
484        } else {
485            Ok(vec![])
486        }
487    }
488
489    /// Parse category name string to DefectCategory enum
490    fn parse_category(name: &str) -> Result<DefectCategory> {
491        match name {
492            "MemorySafety" => Ok(DefectCategory::MemorySafety),
493            "ConcurrencyBugs" => Ok(DefectCategory::ConcurrencyBugs),
494            "LogicErrors" => Ok(DefectCategory::LogicErrors),
495            "ApiMisuse" => Ok(DefectCategory::ApiMisuse),
496            "ResourceLeaks" => Ok(DefectCategory::ResourceLeaks),
497            "TypeErrors" => Ok(DefectCategory::TypeErrors),
498            "ConfigurationErrors" => Ok(DefectCategory::ConfigurationErrors),
499            "SecurityVulnerabilities" => Ok(DefectCategory::SecurityVulnerabilities),
500            "PerformanceIssues" => Ok(DefectCategory::PerformanceIssues),
501            "IntegrationFailures" => Ok(DefectCategory::IntegrationFailures),
502            "OperatorPrecedence" => Ok(DefectCategory::OperatorPrecedence),
503            "TypeAnnotationGaps" => Ok(DefectCategory::TypeAnnotationGaps),
504            "StdlibMapping" => Ok(DefectCategory::StdlibMapping),
505            "ASTTransform" => Ok(DefectCategory::ASTTransform),
506            "ComprehensionBugs" => Ok(DefectCategory::ComprehensionBugs),
507            "IteratorChain" => Ok(DefectCategory::IteratorChain),
508            "OwnershipBorrow" => Ok(DefectCategory::OwnershipBorrow),
509            "TraitBounds" => Ok(DefectCategory::TraitBounds),
510            _ => Err(anyhow!("Unknown category: {}", name)),
511        }
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::git::CommitInfo;
519    use crate::training::TrainingDataExtractor;
520
521    fn create_test_commits() -> Vec<CommitInfo> {
522        vec![
523            CommitInfo {
524                hash: "abc1".to_string(),
525                message: "fix: null pointer dereference in parser".to_string(),
526                author: "dev@example.com".to_string(),
527                timestamp: 1234567890,
528                files_changed: 2,
529                lines_added: 10,
530                lines_removed: 5,
531            },
532            CommitInfo {
533                hash: "abc2".to_string(),
534                message: "fix: race condition in mutex lock".to_string(),
535                author: "dev@example.com".to_string(),
536                timestamp: 1234567891,
537                files_changed: 1,
538                lines_added: 5,
539                lines_removed: 3,
540            },
541            CommitInfo {
542                hash: "abc3".to_string(),
543                message: "fix: memory leak in allocator".to_string(),
544                author: "dev@example.com".to_string(),
545                timestamp: 1234567892,
546                files_changed: 1,
547                lines_added: 8,
548                lines_removed: 2,
549            },
550            CommitInfo {
551                hash: "abc4".to_string(),
552                message: "fix: configuration error in yaml parser".to_string(),
553                author: "dev@example.com".to_string(),
554                timestamp: 1234567893,
555                files_changed: 1,
556                lines_added: 3,
557                lines_removed: 1,
558            },
559            CommitInfo {
560                hash: "abc5".to_string(),
561                message: "fix: type error in generic bounds".to_string(),
562                author: "dev@example.com".to_string(),
563                timestamp: 1234567894,
564                files_changed: 2,
565                lines_added: 15,
566                lines_removed: 8,
567            },
568        ]
569    }
570
571    #[test]
572    fn test_ml_trainer_creation() {
573        let trainer = MLTrainer::new(100, Some(20), 1500);
574        assert_eq!(trainer.n_estimators, 100);
575        assert_eq!(trainer.max_depth, Some(20));
576        assert_eq!(trainer.max_features, 1500);
577    }
578
579    #[test]
580    fn test_ml_trainer_default() {
581        let trainer = MLTrainer::default();
582        assert_eq!(trainer.n_estimators, 100);
583        assert_eq!(trainer.max_depth, Some(20));
584        assert_eq!(trainer.max_features, 1500);
585    }
586
587    #[test]
588    fn test_calculate_accuracy() {
589        let predictions = vec![0, 1, 2, 0, 1];
590        let labels = vec![0, 1, 2, 1, 1];
591        let accuracy = MLTrainer::calculate_accuracy(&predictions, &labels);
592        assert_eq!(accuracy, 0.8); // 4 out of 5 correct
593    }
594
595    #[test]
596    fn test_calculate_accuracy_perfect() {
597        let predictions = vec![0, 1, 2];
598        let labels = vec![0, 1, 2];
599        let accuracy = MLTrainer::calculate_accuracy(&predictions, &labels);
600        assert_eq!(accuracy, 1.0);
601    }
602
603    #[test]
604    fn test_calculate_accuracy_empty() {
605        let predictions: Vec<usize> = vec![];
606        let labels: Vec<usize> = vec![];
607        let accuracy = MLTrainer::calculate_accuracy(&predictions, &labels);
608        assert_eq!(accuracy, 0.0);
609    }
610
611    #[test]
612    fn test_train_with_small_dataset() {
613        let trainer = MLTrainer::new(10, Some(5), 100);
614
615        // Create small training dataset
616        let extractor = TrainingDataExtractor::new(0.70);
617        let commits = create_test_commits();
618        let examples = extractor
619            .extract_training_data(&commits, "test-repo")
620            .unwrap();
621
622        if examples.len() < 10 {
623            // Not enough data for meaningful test (need at least 10 for proper splits)
624            return;
625        }
626
627        let dataset = extractor
628            .create_splits(&examples, &["test-repo".to_string()])
629            .unwrap();
630
631        // Ensure splits are non-empty
632        if dataset.train.is_empty() || dataset.validation.is_empty() {
633            return;
634        }
635
636        // Train model
637        let result = trainer.train(&dataset);
638        if let Err(e) = &result {
639            eprintln!("Training error: {}", e);
640        }
641        assert!(result.is_ok());
642
643        let model = result.unwrap();
644        assert!(model.classifier.is_some());
645        assert!(model.metadata.train_accuracy > 0.0);
646        assert!(model.metadata.n_classes > 0);
647    }
648
649    #[test]
650    fn test_train_empty_dataset_error() {
651        let trainer = MLTrainer::new(10, Some(5), 100);
652
653        // Create empty dataset
654        let dataset = TrainingDataset {
655            train: vec![],
656            validation: vec![],
657            test: vec![],
658            metadata: crate::training::DatasetMetadata {
659                total_examples: 0,
660                train_size: 0,
661                validation_size: 0,
662                test_size: 0,
663                class_distribution: HashMap::new(),
664                avg_confidence: 0.0,
665                min_confidence: 0.75,
666                repositories: vec![],
667            },
668        };
669
670        let result = trainer.train(&dataset);
671        assert!(result.is_err());
672    }
673
674    #[test]
675    fn test_convert_f64_to_f32() {
676        let matrix_f64 = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
677
678        let result = MLTrainer::convert_f64_to_f32(&matrix_f64);
679        assert!(result.is_ok());
680
681        let matrix_f32 = result.unwrap();
682        assert_eq!(matrix_f32.n_rows(), 2);
683        assert_eq!(matrix_f32.n_cols(), 3);
684        assert_eq!(matrix_f32.get(0, 0), 1.0f32);
685        assert_eq!(matrix_f32.get(1, 2), 6.0f32);
686    }
687}