organizational_intelligence_plugin/
training.rs

1//! Training data extraction pipeline for ML defect classification.
2//!
3//! This module implements Phase 2 training data collection from Git history:
4//! - Extract commit messages from repositories
5//! - Filter relevant defect-fix commits
6//! - Auto-label using rule-based classifier
7//! - Create train/test/validation splits
8//! - Export to structured format for ML training
9//!
10//! Implements Section 5.4 Training Data Pipeline from nlp-models-techniques-spec.md
11
12use crate::citl::{SuggestionApplicability, TrainingSource};
13use crate::classifier::{DefectCategory, RuleBasedClassifier};
14use crate::git::CommitInfo;
15use anyhow::{anyhow, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19/// Training example with features and label
20///
21/// NLP-014: Extended with CITL fields for ground-truth training labels
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TrainingExample {
24    /// Commit message text
25    pub message: String,
26    /// Defect category label
27    pub label: DefectCategory,
28    /// Classifier confidence (0.0-1.0)
29    pub confidence: f32,
30    /// Original commit hash
31    pub commit_hash: String,
32    /// Author name
33    pub author: String,
34    /// Unix timestamp
35    pub timestamp: i64,
36    /// Lines added in commit
37    pub lines_added: usize,
38    /// Lines removed in commit
39    pub lines_removed: usize,
40    /// Number of files changed
41    pub files_changed: usize,
42
43    // NLP-014: CITL fields
44    /// Rustc error code (e.g., "E0308")
45    #[serde(default)]
46    pub error_code: Option<String>,
47    /// Clippy lint name (e.g., "clippy::unwrap_used")
48    #[serde(default)]
49    pub clippy_lint: Option<String>,
50    /// Whether a suggestion was provided
51    #[serde(default)]
52    pub has_suggestion: bool,
53    /// Suggestion applicability level
54    #[serde(default)]
55    pub suggestion_applicability: Option<SuggestionApplicability>,
56    /// Source of the training example
57    #[serde(default)]
58    pub source: TrainingSource,
59}
60
61/// Training dataset with train/test/validation splits
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TrainingDataset {
64    /// Training examples
65    pub train: Vec<TrainingExample>,
66    /// Validation examples
67    pub validation: Vec<TrainingExample>,
68    /// Test examples
69    pub test: Vec<TrainingExample>,
70    /// Dataset metadata
71    pub metadata: DatasetMetadata,
72}
73
74/// Metadata about the training dataset
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct DatasetMetadata {
77    /// Total number of examples
78    pub total_examples: usize,
79    /// Number of training examples
80    pub train_size: usize,
81    /// Number of validation examples
82    pub validation_size: usize,
83    /// Number of test examples
84    pub test_size: usize,
85    /// Class distribution (category -> count)
86    pub class_distribution: HashMap<String, usize>,
87    /// Average confidence score
88    pub avg_confidence: f32,
89    /// Minimum confidence threshold used
90    pub min_confidence: f32,
91    /// Repository names included
92    pub repositories: Vec<String>,
93}
94
95/// Training data extractor
96pub struct TrainingDataExtractor {
97    classifier: RuleBasedClassifier,
98    min_confidence: f32,
99}
100
101impl TrainingDataExtractor {
102    /// Create a new training data extractor
103    ///
104    /// # Arguments
105    ///
106    /// * `min_confidence` - Minimum confidence threshold for auto-labeling (0.6-0.9)
107    ///
108    /// # Examples
109    ///
110    /// ```rust
111    /// use organizational_intelligence_plugin::training::TrainingDataExtractor;
112    ///
113    /// let extractor = TrainingDataExtractor::new(0.75);
114    /// ```
115    pub fn new(min_confidence: f32) -> Self {
116        Self {
117            classifier: RuleBasedClassifier::new(),
118            min_confidence,
119        }
120    }
121
122    /// Extract training examples from commit history
123    ///
124    /// Filters commits and auto-labels using rule-based classifier.
125    ///
126    /// # Arguments
127    ///
128    /// * `commits` - Raw commit history
129    /// * `repository_name` - Name of the repository
130    ///
131    /// # Returns
132    ///
133    /// * `Ok(Vec<TrainingExample>)` - Labeled training examples
134    /// * `Err` - If extraction fails
135    ///
136    /// # Examples
137    ///
138    /// ```rust
139    /// use organizational_intelligence_plugin::training::TrainingDataExtractor;
140    /// use organizational_intelligence_plugin::git::CommitInfo;
141    ///
142    /// let extractor = TrainingDataExtractor::new(0.75);
143    /// let commits = vec![
144    ///     CommitInfo {
145    ///         hash: "abc123".to_string(),
146    ///         message: "fix: null pointer dereference".to_string(),
147    ///         author: "dev@example.com".to_string(),
148    ///         timestamp: 1234567890,
149    ///         files_changed: 2,
150    ///         lines_added: 10,
151    ///         lines_removed: 5,
152    ///     },
153    /// ];
154    ///
155    /// let examples = extractor.extract_training_data(&commits, "test-repo").unwrap();
156    /// assert_eq!(examples.len(), 1);
157    /// ```
158    pub fn extract_training_data(
159        &self,
160        commits: &[CommitInfo],
161        _repository_name: &str,
162    ) -> Result<Vec<TrainingExample>> {
163        let mut examples = Vec::new();
164
165        for commit in commits {
166            // Filter: Skip if not a defect-fix commit
167            if !self.is_defect_fix_commit(&commit.message) {
168                continue;
169            }
170
171            // Auto-label using rule-based classifier
172            if let Some(classification) = self.classifier.classify_from_message(&commit.message) {
173                // Only include if confidence meets threshold
174                if classification.confidence >= self.min_confidence {
175                    examples.push(TrainingExample {
176                        message: commit.message.clone(),
177                        label: classification.category,
178                        confidence: classification.confidence,
179                        commit_hash: commit.hash.clone(),
180                        author: commit.author.clone(),
181                        timestamp: commit.timestamp,
182                        lines_added: commit.lines_added,
183                        lines_removed: commit.lines_removed,
184                        files_changed: commit.files_changed,
185                        // NLP-014: Default CITL fields for commit message source
186                        error_code: None,
187                        clippy_lint: None,
188                        has_suggestion: false,
189                        suggestion_applicability: None,
190                        source: TrainingSource::CommitMessage,
191                    });
192                }
193            }
194        }
195
196        Ok(examples)
197    }
198
199    /// Check if a commit message is a defect fix
200    ///
201    /// Uses heuristics to identify defect-fix commits:
202    /// - Starts with "fix:", "bug:", "patch:"
203    /// - Contains keywords: "fix", "bug", "error", "crash", "issue"
204    /// - Excludes: merge commits, reverts, docs, tests (unless fixing a bug)
205    fn is_defect_fix_commit(&self, message: &str) -> bool {
206        let lower = message.to_lowercase();
207
208        // Skip obvious non-defect commits
209        if lower.starts_with("merge")
210            || lower.starts_with("revert")
211            || lower.contains("wip")
212            || lower.contains("work in progress")
213        {
214            return false;
215        }
216
217        // Check for defect-fix indicators
218        lower.starts_with("fix:")
219            || lower.starts_with("bug:")
220            || lower.starts_with("patch:")
221            || lower.contains("fix ")
222            || lower.contains("bug ")
223            || lower.contains("error")
224            || lower.contains("crash")
225            || lower.contains("issue")
226    }
227
228    /// Create train/test/validation splits
229    ///
230    /// Uses 70/15/15 split (train/validation/test) as recommended by the spec.
231    ///
232    /// # Arguments
233    ///
234    /// * `examples` - Labeled training examples
235    /// * `repositories` - List of repository names
236    ///
237    /// # Returns
238    ///
239    /// * `Ok(TrainingDataset)` - Dataset with splits
240    /// * `Err` - If split fails
241    ///
242    /// # Examples
243    ///
244    /// ```rust
245    /// use organizational_intelligence_plugin::training::TrainingDataExtractor;
246    /// use organizational_intelligence_plugin::training::TrainingExample;
247    /// use organizational_intelligence_plugin::classifier::DefectCategory;
248    ///
249    /// let extractor = TrainingDataExtractor::new(0.75);
250    /// let examples = vec![
251    ///     TrainingExample {
252    ///         message: "fix: bug".to_string(),
253    ///         label: DefectCategory::MemorySafety,
254    ///         confidence: 0.85,
255    ///         commit_hash: "abc".to_string(),
256    ///         author: "dev".to_string(),
257    ///         timestamp: 123,
258    ///         lines_added: 5,
259    ///         lines_removed: 2,
260    ///         files_changed: 1,
261    ///         error_code: None,
262    ///         clippy_lint: None,
263    ///         has_suggestion: false,
264    ///         suggestion_applicability: None,
265    ///         source: organizational_intelligence_plugin::citl::TrainingSource::CommitMessage,
266    ///     },
267    /// ];
268    ///
269    /// let dataset = extractor.create_splits(&examples, &["repo1".to_string()]).unwrap();
270    /// assert!(dataset.train.len() + dataset.validation.len() + dataset.test.len() == 1);
271    /// ```
272    pub fn create_splits(
273        &self,
274        examples: &[TrainingExample],
275        repositories: &[String],
276    ) -> Result<TrainingDataset> {
277        if examples.is_empty() {
278            return Err(anyhow!("Cannot create splits from empty dataset"));
279        }
280
281        let total = examples.len();
282
283        // Calculate split sizes (70/15/15)
284        let train_size = (total as f32 * 0.70) as usize;
285        let validation_size = (total as f32 * 0.15) as usize;
286        let test_size = total - train_size - validation_size;
287
288        // Split the data
289        let train = examples[0..train_size].to_vec();
290        let validation = examples[train_size..train_size + validation_size].to_vec();
291        let test = examples[train_size + validation_size..].to_vec();
292
293        // Calculate class distribution
294        let mut class_distribution = HashMap::new();
295        for example in examples {
296            let category_name = format!("{}", example.label);
297            *class_distribution.entry(category_name).or_insert(0) += 1;
298        }
299
300        // Calculate average confidence
301        let avg_confidence =
302            examples.iter().map(|e| e.confidence).sum::<f32>() / examples.len() as f32;
303
304        let metadata = DatasetMetadata {
305            total_examples: total,
306            train_size,
307            validation_size,
308            test_size,
309            class_distribution,
310            avg_confidence,
311            min_confidence: self.min_confidence,
312            repositories: repositories.to_vec(),
313        };
314
315        Ok(TrainingDataset {
316            train,
317            validation,
318            test,
319            metadata,
320        })
321    }
322
323    /// Get statistics about extracted training data
324    ///
325    /// # Arguments
326    ///
327    /// * `examples` - Training examples
328    ///
329    /// # Returns
330    ///
331    /// * Formatted statistics string
332    pub fn get_statistics(&self, examples: &[TrainingExample]) -> String {
333        if examples.is_empty() {
334            return "No examples extracted".to_string();
335        }
336
337        let mut category_counts: HashMap<DefectCategory, usize> = HashMap::new();
338        let mut confidence_sum = 0.0_f32;
339
340        for example in examples {
341            *category_counts.entry(example.label).or_insert(0) += 1;
342            confidence_sum += example.confidence;
343        }
344
345        let avg_confidence = confidence_sum / examples.len() as f32;
346
347        let mut stats = "Training Data Statistics:\n".to_string();
348        stats.push_str(&format!("  Total examples: {}\n", examples.len()));
349        stats.push_str(&format!("  Avg confidence: {:.2}\n", avg_confidence));
350        stats.push_str(&format!(
351            "  Min confidence threshold: {:.2}\n",
352            self.min_confidence
353        ));
354        stats.push_str("\nClass Distribution:\n");
355
356        let mut sorted_categories: Vec<_> = category_counts.iter().collect();
357        sorted_categories.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
358
359        for (category, count) in sorted_categories {
360            let percentage = (*count as f32 / examples.len() as f32) * 100.0;
361            stats.push_str(&format!(
362                "  {:?}: {} ({:.1}%)\n",
363                category, count, percentage
364            ));
365        }
366
367        stats
368    }
369}
370
371impl Default for TrainingDataExtractor {
372    fn default() -> Self {
373        Self::new(0.75) // Default 75% confidence threshold
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_extractor_creation() {
383        let extractor = TrainingDataExtractor::new(0.80);
384        assert_eq!(extractor.min_confidence, 0.80);
385    }
386
387    #[test]
388    fn test_is_defect_fix_commit() {
389        let extractor = TrainingDataExtractor::new(0.75);
390
391        // Should be defect fixes
392        assert!(extractor.is_defect_fix_commit("fix: null pointer"));
393        assert!(extractor.is_defect_fix_commit("bug: race condition"));
394        assert!(extractor.is_defect_fix_commit("patch: memory leak"));
395        assert!(extractor.is_defect_fix_commit("fix memory leak in parser"));
396
397        // Should not be defect fixes
398        assert!(!extractor.is_defect_fix_commit("Merge branch 'main'"));
399        assert!(!extractor.is_defect_fix_commit("Revert commit abc123"));
400        assert!(!extractor.is_defect_fix_commit("feat: add new feature"));
401        assert!(!extractor.is_defect_fix_commit("docs: update README"));
402        assert!(!extractor.is_defect_fix_commit("WIP: working on feature"));
403    }
404
405    #[test]
406    fn test_extract_training_data() {
407        let extractor = TrainingDataExtractor::new(0.70);
408
409        let commits = vec![
410            CommitInfo {
411                hash: "abc123".to_string(),
412                message: "fix: null pointer dereference in parser".to_string(),
413                author: "dev@example.com".to_string(),
414                timestamp: 1234567890,
415                files_changed: 2,
416                lines_added: 10,
417                lines_removed: 5,
418            },
419            CommitInfo {
420                hash: "def456".to_string(),
421                message: "feat: add new feature".to_string(), // Not a defect fix
422                author: "dev@example.com".to_string(),
423                timestamp: 1234567891,
424                files_changed: 5,
425                lines_added: 100,
426                lines_removed: 0,
427            },
428            CommitInfo {
429                hash: "ghi789".to_string(),
430                message: "fix: race condition in mutex lock".to_string(),
431                author: "dev@example.com".to_string(),
432                timestamp: 1234567892,
433                files_changed: 1,
434                lines_added: 5,
435                lines_removed: 3,
436            },
437        ];
438
439        let examples = extractor
440            .extract_training_data(&commits, "test-repo")
441            .unwrap();
442
443        // Should extract 2 defect-fix commits
444        assert_eq!(examples.len(), 2);
445        assert_eq!(
446            examples[0].message,
447            "fix: null pointer dereference in parser"
448        );
449        assert_eq!(examples[1].message, "fix: race condition in mutex lock");
450    }
451
452    #[test]
453    fn test_create_splits() {
454        let extractor = TrainingDataExtractor::new(0.75);
455
456        // Create 100 examples for clean split
457        let mut examples = Vec::new();
458        for i in 0..100 {
459            examples.push(TrainingExample {
460                message: format!("fix: bug {}", i),
461                label: DefectCategory::MemorySafety,
462                confidence: 0.85,
463                commit_hash: format!("hash{}", i),
464                author: "dev".to_string(),
465                timestamp: 123 + i as i64,
466                lines_added: 5,
467                lines_removed: 2,
468                files_changed: 1,
469                error_code: None,
470                clippy_lint: None,
471                has_suggestion: false,
472                suggestion_applicability: None,
473                source: TrainingSource::CommitMessage,
474            });
475        }
476
477        let dataset = extractor
478            .create_splits(&examples, &["repo1".to_string()])
479            .unwrap();
480
481        // Check split sizes (70/15/15)
482        assert_eq!(dataset.train.len(), 70);
483        assert_eq!(dataset.validation.len(), 15);
484        assert_eq!(dataset.test.len(), 15);
485        assert_eq!(dataset.metadata.total_examples, 100);
486        assert_eq!(dataset.metadata.train_size, 70);
487    }
488
489    #[test]
490    fn test_empty_dataset_error() {
491        let extractor = TrainingDataExtractor::new(0.75);
492        let examples: Vec<TrainingExample> = vec![];
493
494        let result = extractor.create_splits(&examples, &[]);
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_get_statistics() {
500        let extractor = TrainingDataExtractor::new(0.75);
501
502        let examples = vec![
503            TrainingExample {
504                message: "fix: bug 1".to_string(),
505                label: DefectCategory::MemorySafety,
506                confidence: 0.85,
507                commit_hash: "a".to_string(),
508                author: "dev".to_string(),
509                timestamp: 123,
510                lines_added: 5,
511                lines_removed: 2,
512                files_changed: 1,
513                error_code: None,
514                clippy_lint: None,
515                has_suggestion: false,
516                suggestion_applicability: None,
517                source: TrainingSource::CommitMessage,
518            },
519            TrainingExample {
520                message: "fix: bug 2".to_string(),
521                label: DefectCategory::ConcurrencyBugs,
522                confidence: 0.90,
523                commit_hash: "b".to_string(),
524                author: "dev".to_string(),
525                timestamp: 124,
526                lines_added: 3,
527                lines_removed: 1,
528                files_changed: 1,
529                error_code: None,
530                clippy_lint: None,
531                has_suggestion: false,
532                suggestion_applicability: None,
533                source: TrainingSource::CommitMessage,
534            },
535        ];
536
537        let stats = extractor.get_statistics(&examples);
538        assert!(stats.contains("Total examples: 2"));
539        assert!(stats.contains("Avg confidence:"));
540        assert!(stats.contains("Class Distribution:"));
541    }
542
543    #[test]
544    fn test_confidence_threshold_filtering() {
545        let extractor = TrainingDataExtractor::new(0.90); // High threshold
546
547        let commits = vec![CommitInfo {
548            hash: "abc".to_string(),
549            message: "fix: memory leak".to_string(), // Will have ~0.85 confidence
550            author: "dev".to_string(),
551            timestamp: 123,
552            files_changed: 1,
553            lines_added: 5,
554            lines_removed: 2,
555        }];
556
557        let examples = extractor
558            .extract_training_data(&commits, "test-repo")
559            .unwrap();
560
561        // With 0.90 threshold, low-confidence examples should be filtered
562        // (actual result depends on classifier confidence)
563        assert!(examples.len() <= 1);
564    }
565
566    #[test]
567    fn test_default_extractor() {
568        let extractor = TrainingDataExtractor::default();
569        assert_eq!(extractor.min_confidence, 0.75);
570    }
571}