1use crate::classifier::DefectCategory;
13use crate::nlp::TfidfFeatureExtractor;
14use crate::training::{TrainingDataset, TrainingExample};
15use anyhow::{anyhow, Result};
16use aprender::primitives::Matrix;
17use aprender::tree::RandomForestClassifier;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::fs;
21use std::path::Path;
22
23#[derive(Serialize, Deserialize)]
25pub struct TrainedModel {
26 #[serde(skip)]
28 pub classifier: Option<RandomForestClassifier>,
29 #[serde(skip)]
31 pub tfidf_extractor: Option<TfidfFeatureExtractor>,
32 pub category_to_label: HashMap<String, usize>,
34 pub label_to_category: HashMap<usize, String>,
36 pub metadata: TrainingMetadata,
38 pub tfidf_vocabulary: Vec<String>,
40 pub max_features: usize,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TrainingMetadata {
47 pub n_train: usize,
49 pub n_validation: usize,
51 pub n_test: usize,
53 pub n_estimators: usize,
55 pub max_depth: Option<usize>,
57 pub n_features: usize,
59 pub n_classes: usize,
61 pub train_accuracy: f32,
63 pub validation_accuracy: f32,
65 pub test_accuracy: Option<f32>,
67}
68
69pub struct MLTrainer {
71 n_estimators: usize,
72 max_depth: Option<usize>,
73 max_features: usize,
74 random_state: u64,
75}
76
77impl MLTrainer {
78 pub fn new(n_estimators: usize, max_depth: Option<usize>, max_features: usize) -> Self {
94 Self {
95 n_estimators,
96 max_depth,
97 max_features,
98 random_state: 42,
99 }
100 }
101
102 pub fn load_dataset<P: AsRef<Path>>(path: P) -> Result<TrainingDataset> {
113 let content = fs::read_to_string(path.as_ref())
114 .map_err(|e| anyhow!("Failed to read training data: {}", e))?;
115
116 serde_json::from_str(&content)
117 .map_err(|e| anyhow!("Failed to parse training data JSON: {}", e))
118 }
119
120 pub fn train(&self, dataset: &TrainingDataset) -> Result<TrainedModel> {
145 if dataset.train.is_empty() {
146 return Err(anyhow!("Training dataset is empty"));
147 }
148
149 let train_messages: Vec<String> =
151 dataset.train.iter().map(|ex| ex.message.clone()).collect();
152
153 let validation_messages: Vec<String> = dataset
154 .validation
155 .iter()
156 .map(|ex| ex.message.clone())
157 .collect();
158
159 let mut unique_categories: Vec<String> = dataset
161 .train
162 .iter()
163 .map(|ex| format!("{}", ex.label))
164 .collect();
165 unique_categories.sort();
166 unique_categories.dedup();
167
168 let category_to_label: HashMap<String, usize> = unique_categories
169 .iter()
170 .enumerate()
171 .map(|(i, cat)| (cat.clone(), i))
172 .collect();
173
174 let label_to_category: HashMap<usize, String> = unique_categories
175 .iter()
176 .enumerate()
177 .map(|(i, cat)| (i, cat.clone()))
178 .collect();
179
180 let train_labels: Vec<usize> = dataset
182 .train
183 .iter()
184 .map(|ex| {
185 *category_to_label
186 .get(&format!("{}", ex.label))
187 .unwrap_or(&0)
188 })
189 .collect();
190
191 let validation_labels: Vec<usize> = dataset
192 .validation
193 .iter()
194 .map(|ex| {
195 *category_to_label
196 .get(&format!("{}", ex.label))
197 .unwrap_or(&0)
198 })
199 .collect();
200
201 let mut tfidf_extractor = TfidfFeatureExtractor::new(self.max_features);
203 let train_features = tfidf_extractor.fit_transform(&train_messages)?;
204 let validation_features = tfidf_extractor.transform(&validation_messages)?;
205
206 let train_features_f32 = Self::convert_f64_to_f32(&train_features)?;
208 let validation_features_f32 = Self::convert_f64_to_f32(&validation_features)?;
209
210 let mut classifier = RandomForestClassifier::new(self.n_estimators);
212 if let Some(depth) = self.max_depth {
213 classifier = classifier.with_max_depth(depth);
214 }
215 classifier = classifier.with_random_state(self.random_state);
216
217 classifier
218 .fit(&train_features_f32, &train_labels)
219 .map_err(|e| anyhow!("Random Forest training failed: {}", e))?;
220
221 let train_predictions = classifier.predict(&train_features_f32);
223 let train_accuracy = Self::calculate_accuracy(&train_predictions, &train_labels);
224
225 let validation_predictions = classifier.predict(&validation_features_f32);
227 let validation_accuracy =
228 Self::calculate_accuracy(&validation_predictions, &validation_labels);
229
230 let metadata = TrainingMetadata {
231 n_train: dataset.train.len(),
232 n_validation: dataset.validation.len(),
233 n_test: dataset.test.len(),
234 n_estimators: self.n_estimators,
235 max_depth: self.max_depth,
236 n_features: tfidf_extractor.vocabulary_size(),
237 n_classes: unique_categories.len(),
238 train_accuracy,
239 validation_accuracy,
240 test_accuracy: None,
241 };
242
243 let tfidf_vocabulary: Vec<String> = vec![];
246 let max_features = self.max_features;
247
248 Ok(TrainedModel {
249 classifier: Some(classifier),
250 tfidf_extractor: Some(tfidf_extractor),
251 category_to_label,
252 label_to_category,
253 metadata,
254 tfidf_vocabulary,
255 max_features,
256 })
257 }
258
259 fn convert_f64_to_f32(matrix: &Matrix<f64>) -> Result<Matrix<f32>> {
261 let (n_rows, n_cols) = (matrix.n_rows(), matrix.n_cols());
262 let data_f32: Vec<f32> = (0..n_rows * n_cols)
263 .map(|i| {
264 let row = i / n_cols;
265 let col = i % n_cols;
266 matrix.get(row, col) as f32
267 })
268 .collect();
269
270 Matrix::from_vec(n_rows, n_cols, data_f32)
271 .map_err(|e| anyhow!("Failed to convert matrix: {}", e))
272 }
273
274 fn calculate_accuracy(predictions: &[usize], labels: &[usize]) -> f32 {
276 if predictions.is_empty() || predictions.len() != labels.len() {
277 return 0.0;
278 }
279
280 let correct = predictions
281 .iter()
282 .zip(labels.iter())
283 .filter(|(pred, label)| pred == label)
284 .count();
285
286 correct as f32 / predictions.len() as f32
287 }
288
289 pub fn evaluate(model: &TrainedModel, test_examples: &[TrainingExample]) -> Result<f32> {
300 if test_examples.is_empty() {
301 return Ok(0.0);
302 }
303
304 let classifier = model
305 .classifier
306 .as_ref()
307 .ok_or_else(|| anyhow!("Model has no classifier"))?;
308
309 let tfidf_extractor = model
310 .tfidf_extractor
311 .as_ref()
312 .ok_or_else(|| anyhow!("Model has no TF-IDF extractor"))?;
313
314 let test_messages: Vec<String> =
315 test_examples.iter().map(|ex| ex.message.clone()).collect();
316
317 let test_labels: Vec<usize> = test_examples
318 .iter()
319 .map(|ex| {
320 *model
321 .category_to_label
322 .get(&format!("{}", ex.label))
323 .unwrap_or(&0)
324 })
325 .collect();
326
327 let test_features = tfidf_extractor.transform(&test_messages)?;
329 let test_features_f32 = Self::convert_f64_to_f32(&test_features)?;
330
331 let test_predictions = classifier.predict(&test_features_f32);
333 let test_accuracy = Self::calculate_accuracy(&test_predictions, &test_labels);
334
335 Ok(test_accuracy)
336 }
337
338 pub fn save_model<P: AsRef<Path>>(model: &TrainedModel, path: P) -> Result<()> {
350 let json = serde_json::to_string_pretty(model)
351 .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
352
353 fs::write(path.as_ref(), json).map_err(|e| anyhow!("Failed to write model file: {}", e))?;
354
355 Ok(())
356 }
357
358 pub fn load_model<P: AsRef<Path>>(path: P) -> Result<TrainedModel> {
369 let content = fs::read_to_string(path.as_ref())
370 .map_err(|e| anyhow!("Failed to read model file: {}", e))?;
371
372 serde_json::from_str(&content).map_err(|e| anyhow!("Failed to parse model JSON: {}", e))
373 }
374}
375
376impl Default for MLTrainer {
377 fn default() -> Self {
378 Self::new(100, Some(20), 1500)
379 }
380}
381
382impl TrainedModel {
383 pub fn predict(&self, message: &str) -> Result<Option<(DefectCategory, f32)>> {
404 let tfidf = self
406 .tfidf_extractor
407 .as_ref()
408 .ok_or_else(|| anyhow!("TF-IDF extractor not available"))?;
409 let classifier = self
410 .classifier
411 .as_ref()
412 .ok_or_else(|| anyhow!("Classifier not available"))?;
413
414 let features = tfidf.transform(&[message.to_string()])?;
416
417 let (n_rows, n_cols) = (features.n_rows(), features.n_cols());
419 let data_f32: Vec<f32> = (0..n_rows * n_cols)
420 .map(|i| {
421 let row = i / n_cols;
422 let col = i % n_cols;
423 features.get(row, col) as f32
424 })
425 .collect();
426
427 let features_f32 = Matrix::from_vec(n_rows, n_cols, data_f32)
428 .map_err(|e| anyhow!("Failed to create feature matrix: {}", e))?;
429
430 let predictions = classifier.predict(&features_f32);
432
433 if predictions.is_empty() {
434 return Ok(None);
435 }
436
437 let label_idx = predictions[0];
439
440 let category_name = self
442 .label_to_category
443 .get(&label_idx)
444 .ok_or_else(|| anyhow!("Unknown label index: {}", label_idx))?;
445
446 let category = Self::parse_category(category_name)?;
448
449 let confidence = 0.75f32;
452
453 Ok(Some((category, confidence)))
454 }
455
456 pub fn predict_top_n(
477 &self,
478 message: &str,
479 _top_n: usize,
480 ) -> Result<Vec<(DefectCategory, f32)>> {
481 if let Some((category, confidence)) = self.predict(message)? {
483 Ok(vec![(category, confidence)])
484 } else {
485 Ok(vec![])
486 }
487 }
488
489 fn parse_category(name: &str) -> Result<DefectCategory> {
491 match name {
492 "MemorySafety" => Ok(DefectCategory::MemorySafety),
493 "ConcurrencyBugs" => Ok(DefectCategory::ConcurrencyBugs),
494 "LogicErrors" => Ok(DefectCategory::LogicErrors),
495 "ApiMisuse" => Ok(DefectCategory::ApiMisuse),
496 "ResourceLeaks" => Ok(DefectCategory::ResourceLeaks),
497 "TypeErrors" => Ok(DefectCategory::TypeErrors),
498 "ConfigurationErrors" => Ok(DefectCategory::ConfigurationErrors),
499 "SecurityVulnerabilities" => Ok(DefectCategory::SecurityVulnerabilities),
500 "PerformanceIssues" => Ok(DefectCategory::PerformanceIssues),
501 "IntegrationFailures" => Ok(DefectCategory::IntegrationFailures),
502 "OperatorPrecedence" => Ok(DefectCategory::OperatorPrecedence),
503 "TypeAnnotationGaps" => Ok(DefectCategory::TypeAnnotationGaps),
504 "StdlibMapping" => Ok(DefectCategory::StdlibMapping),
505 "ASTTransform" => Ok(DefectCategory::ASTTransform),
506 "ComprehensionBugs" => Ok(DefectCategory::ComprehensionBugs),
507 "IteratorChain" => Ok(DefectCategory::IteratorChain),
508 "OwnershipBorrow" => Ok(DefectCategory::OwnershipBorrow),
509 "TraitBounds" => Ok(DefectCategory::TraitBounds),
510 _ => Err(anyhow!("Unknown category: {}", name)),
511 }
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::git::CommitInfo;
519 use crate::training::TrainingDataExtractor;
520
521 fn create_test_commits() -> Vec<CommitInfo> {
522 vec![
523 CommitInfo {
524 hash: "abc1".to_string(),
525 message: "fix: null pointer dereference in parser".to_string(),
526 author: "dev@example.com".to_string(),
527 timestamp: 1234567890,
528 files_changed: 2,
529 lines_added: 10,
530 lines_removed: 5,
531 },
532 CommitInfo {
533 hash: "abc2".to_string(),
534 message: "fix: race condition in mutex lock".to_string(),
535 author: "dev@example.com".to_string(),
536 timestamp: 1234567891,
537 files_changed: 1,
538 lines_added: 5,
539 lines_removed: 3,
540 },
541 CommitInfo {
542 hash: "abc3".to_string(),
543 message: "fix: memory leak in allocator".to_string(),
544 author: "dev@example.com".to_string(),
545 timestamp: 1234567892,
546 files_changed: 1,
547 lines_added: 8,
548 lines_removed: 2,
549 },
550 CommitInfo {
551 hash: "abc4".to_string(),
552 message: "fix: configuration error in yaml parser".to_string(),
553 author: "dev@example.com".to_string(),
554 timestamp: 1234567893,
555 files_changed: 1,
556 lines_added: 3,
557 lines_removed: 1,
558 },
559 CommitInfo {
560 hash: "abc5".to_string(),
561 message: "fix: type error in generic bounds".to_string(),
562 author: "dev@example.com".to_string(),
563 timestamp: 1234567894,
564 files_changed: 2,
565 lines_added: 15,
566 lines_removed: 8,
567 },
568 ]
569 }
570
571 #[test]
572 fn test_ml_trainer_creation() {
573 let trainer = MLTrainer::new(100, Some(20), 1500);
574 assert_eq!(trainer.n_estimators, 100);
575 assert_eq!(trainer.max_depth, Some(20));
576 assert_eq!(trainer.max_features, 1500);
577 }
578
579 #[test]
580 fn test_ml_trainer_default() {
581 let trainer = MLTrainer::default();
582 assert_eq!(trainer.n_estimators, 100);
583 assert_eq!(trainer.max_depth, Some(20));
584 assert_eq!(trainer.max_features, 1500);
585 }
586
587 #[test]
588 fn test_calculate_accuracy() {
589 let predictions = vec![0, 1, 2, 0, 1];
590 let labels = vec![0, 1, 2, 1, 1];
591 let accuracy = MLTrainer::calculate_accuracy(&predictions, &labels);
592 assert_eq!(accuracy, 0.8); }
594
595 #[test]
596 fn test_calculate_accuracy_perfect() {
597 let predictions = vec![0, 1, 2];
598 let labels = vec![0, 1, 2];
599 let accuracy = MLTrainer::calculate_accuracy(&predictions, &labels);
600 assert_eq!(accuracy, 1.0);
601 }
602
603 #[test]
604 fn test_calculate_accuracy_empty() {
605 let predictions: Vec<usize> = vec![];
606 let labels: Vec<usize> = vec![];
607 let accuracy = MLTrainer::calculate_accuracy(&predictions, &labels);
608 assert_eq!(accuracy, 0.0);
609 }
610
611 #[test]
612 fn test_train_with_small_dataset() {
613 let trainer = MLTrainer::new(10, Some(5), 100);
614
615 let extractor = TrainingDataExtractor::new(0.70);
617 let commits = create_test_commits();
618 let examples = extractor
619 .extract_training_data(&commits, "test-repo")
620 .unwrap();
621
622 if examples.len() < 10 {
623 return;
625 }
626
627 let dataset = extractor
628 .create_splits(&examples, &["test-repo".to_string()])
629 .unwrap();
630
631 if dataset.train.is_empty() || dataset.validation.is_empty() {
633 return;
634 }
635
636 let result = trainer.train(&dataset);
638 if let Err(e) = &result {
639 eprintln!("Training error: {}", e);
640 }
641 assert!(result.is_ok());
642
643 let model = result.unwrap();
644 assert!(model.classifier.is_some());
645 assert!(model.metadata.train_accuracy > 0.0);
646 assert!(model.metadata.n_classes > 0);
647 }
648
649 #[test]
650 fn test_train_empty_dataset_error() {
651 let trainer = MLTrainer::new(10, Some(5), 100);
652
653 let dataset = TrainingDataset {
655 train: vec![],
656 validation: vec![],
657 test: vec![],
658 metadata: crate::training::DatasetMetadata {
659 total_examples: 0,
660 train_size: 0,
661 validation_size: 0,
662 test_size: 0,
663 class_distribution: HashMap::new(),
664 avg_confidence: 0.0,
665 min_confidence: 0.75,
666 repositories: vec![],
667 },
668 };
669
670 let result = trainer.train(&dataset);
671 assert!(result.is_err());
672 }
673
674 #[test]
675 fn test_convert_f64_to_f32() {
676 let matrix_f64 = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
677
678 let result = MLTrainer::convert_f64_to_f32(&matrix_f64);
679 assert!(result.is_ok());
680
681 let matrix_f32 = result.unwrap();
682 assert_eq!(matrix_f32.n_rows(), 2);
683 assert_eq!(matrix_f32.n_cols(), 3);
684 assert_eq!(matrix_f32.get(0, 0), 1.0f32);
685 assert_eq!(matrix_f32.get(1, 2), 6.0f32);
686 }
687}