organizational_intelligence_plugin/
classifier.rs

1// Rule-based defect classifier
2// Phase 1: Heuristic-based classification with confidence scores and explanations
3// Phase 2: ML classifier integration with hybrid approach
4// Toyota Way: Start simple, collect data for Phase 2 ML
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8use tracing::debug;
9
10/// Defect categories based on research literature
11/// See specification Section 2.2.3 and Section 5.2 (Expanded Taxonomy)
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum DefectCategory {
14    // General defect categories (10)
15    MemorySafety,
16    ConcurrencyBugs,
17    LogicErrors,
18    ApiMisuse,
19    ResourceLeaks,
20    TypeErrors,
21    ConfigurationErrors,
22    SecurityVulnerabilities,
23    PerformanceIssues,
24    IntegrationFailures,
25    // Transpiler-specific categories (8)
26    OperatorPrecedence,
27    TypeAnnotationGaps,
28    StdlibMapping,
29    ASTTransform,
30    ComprehensionBugs,
31    IteratorChain,
32    OwnershipBorrow,
33    TraitBounds,
34}
35
36impl DefectCategory {
37    /// Get human-readable name for the category
38    pub fn as_str(&self) -> &'static str {
39        match self {
40            Self::MemorySafety => "Memory Safety",
41            Self::ConcurrencyBugs => "Concurrency Bugs",
42            Self::LogicErrors => "Logic Errors",
43            Self::ApiMisuse => "API Misuse",
44            Self::ResourceLeaks => "Resource Leaks",
45            Self::TypeErrors => "Type Errors",
46            Self::ConfigurationErrors => "Configuration Errors",
47            Self::SecurityVulnerabilities => "Security Vulnerabilities",
48            Self::PerformanceIssues => "Performance Issues",
49            Self::IntegrationFailures => "Integration Failures",
50            Self::OperatorPrecedence => "Operator Precedence",
51            Self::TypeAnnotationGaps => "Type Annotation Gaps",
52            Self::StdlibMapping => "Stdlib Mapping",
53            Self::ASTTransform => "AST Transform",
54            Self::ComprehensionBugs => "Comprehension Bugs",
55            Self::IteratorChain => "Iterator Chain",
56            Self::OwnershipBorrow => "Ownership/Borrow",
57            Self::TraitBounds => "Trait Bounds",
58        }
59    }
60}
61
62impl fmt::Display for DefectCategory {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        // Use the enum variant name for serialization compatibility
65        match self {
66            Self::MemorySafety => write!(f, "MemorySafety"),
67            Self::ConcurrencyBugs => write!(f, "ConcurrencyBugs"),
68            Self::LogicErrors => write!(f, "LogicErrors"),
69            Self::ApiMisuse => write!(f, "ApiMisuse"),
70            Self::ResourceLeaks => write!(f, "ResourceLeaks"),
71            Self::TypeErrors => write!(f, "TypeErrors"),
72            Self::ConfigurationErrors => write!(f, "ConfigurationErrors"),
73            Self::SecurityVulnerabilities => write!(f, "SecurityVulnerabilities"),
74            Self::PerformanceIssues => write!(f, "PerformanceIssues"),
75            Self::IntegrationFailures => write!(f, "IntegrationFailures"),
76            Self::OperatorPrecedence => write!(f, "OperatorPrecedence"),
77            Self::TypeAnnotationGaps => write!(f, "TypeAnnotationGaps"),
78            Self::StdlibMapping => write!(f, "StdlibMapping"),
79            Self::ASTTransform => write!(f, "ASTTransform"),
80            Self::ComprehensionBugs => write!(f, "ComprehensionBugs"),
81            Self::IteratorChain => write!(f, "IteratorChain"),
82            Self::OwnershipBorrow => write!(f, "OwnershipBorrow"),
83            Self::TraitBounds => write!(f, "TraitBounds"),
84        }
85    }
86}
87
88/// Classification result with confidence and explanation
89/// Following Toyota Way: Respect for People - provide explanations for learning
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Classification {
92    pub category: DefectCategory,
93    pub confidence: f32, // 0.0 to 1.0
94    pub explanation: String,
95    pub matched_patterns: Vec<String>,
96}
97
98/// Multi-label classification result with top-N categories
99/// Implements Section 5.3 of nlp-models-techniques-spec.md
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct MultiLabelClassification {
102    pub categories: Vec<(DefectCategory, f32)>, // (category, confidence) sorted by confidence
103    pub primary_category: DefectCategory,
104    pub primary_confidence: f32,
105    pub matched_patterns: Vec<String>,
106}
107
108/// Pattern matching rule
109#[derive(Debug, Clone)]
110struct Rule {
111    category: DefectCategory,
112    patterns: Vec<&'static str>,
113    confidence: f32,
114}
115
116/// Rule-based classifier
117/// Phase 1: Simple pattern matching on commit messages
118/// Phase 2: Will evolve to ML-based with user feedback
119pub struct RuleBasedClassifier {
120    rules: Vec<Rule>,
121}
122
123impl RuleBasedClassifier {
124    /// Create a new rule-based classifier with predefined patterns
125    ///
126    /// # Examples
127    /// ```
128    /// use organizational_intelligence_plugin::classifier::RuleBasedClassifier;
129    ///
130    /// let classifier = RuleBasedClassifier::new();
131    /// ```
132    pub fn new() -> Self {
133        let rules = vec![
134            // Memory Safety patterns
135            Rule {
136                category: DefectCategory::MemorySafety,
137                patterns: vec![
138                    "use after free",
139                    "use-after-free",
140                    "null pointer",
141                    "nullptr",
142                    "buffer overflow",
143                    "memory leak",
144                    "dangling pointer",
145                    "double free",
146                    "heap corruption",
147                ],
148                confidence: 0.85,
149            },
150            // Concurrency patterns
151            Rule {
152                category: DefectCategory::ConcurrencyBugs,
153                patterns: vec![
154                    "race condition",
155                    "data race",
156                    "deadlock",
157                    "atomicity",
158                    "thread safety",
159                    "concurrent",
160                    "synchronization",
161                    "mutex",
162                    "lock contention",
163                ],
164                confidence: 0.80,
165            },
166            // Security patterns
167            Rule {
168                category: DefectCategory::SecurityVulnerabilities,
169                patterns: vec![
170                    "sql injection",
171                    "xss",
172                    "cross-site scripting",
173                    "authentication",
174                    "authorization",
175                    "security",
176                    "vulnerability",
177                    "exploit",
178                    "cve-",
179                ],
180                confidence: 0.90,
181            },
182            // Logic Error patterns
183            Rule {
184                category: DefectCategory::LogicErrors,
185                patterns: vec![
186                    "off by one",
187                    "off-by-one",
188                    "boundary",
189                    "incorrect logic",
190                    "wrong condition",
191                    "infinite loop",
192                ],
193                confidence: 0.70,
194            },
195            // API Misuse patterns
196            Rule {
197                category: DefectCategory::ApiMisuse,
198                patterns: vec![
199                    "api misuse",
200                    "wrong parameter",
201                    "incorrect usage",
202                    "missing error handling",
203                    "unchecked error",
204                ],
205                confidence: 0.75,
206            },
207            // Resource Leak patterns
208            Rule {
209                category: DefectCategory::ResourceLeaks,
210                patterns: vec![
211                    "resource leak",
212                    "file handle leak",
213                    "connection leak",
214                    "not closed",
215                    "forgot to close",
216                ],
217                confidence: 0.80,
218            },
219            // Type Error patterns
220            Rule {
221                category: DefectCategory::TypeErrors,
222                patterns: vec![
223                    "type error",
224                    "type mismatch",
225                    "casting error",
226                    "serialization",
227                    "deserialization",
228                ],
229                confidence: 0.75,
230            },
231            // Configuration patterns
232            Rule {
233                category: DefectCategory::ConfigurationErrors,
234                patterns: vec![
235                    "configuration",
236                    "config",
237                    "environment variable",
238                    "missing env",
239                    "settings",
240                ],
241                confidence: 0.70,
242            },
243            // Performance patterns
244            Rule {
245                category: DefectCategory::PerformanceIssues,
246                patterns: vec![
247                    "performance",
248                    "slow",
249                    "inefficient",
250                    "n+1 query",
251                    "optimization",
252                ],
253                confidence: 0.65,
254            },
255            // Integration patterns
256            Rule {
257                category: DefectCategory::IntegrationFailures,
258                patterns: vec![
259                    "integration",
260                    "compatibility",
261                    "version mismatch",
262                    "breaking change",
263                    "api change",
264                ],
265                confidence: 0.70,
266            },
267            // Transpiler-specific patterns
268            // Operator Precedence patterns
269            Rule {
270                category: DefectCategory::OperatorPrecedence,
271                patterns: vec![
272                    "operator precedence",
273                    "parentheses",
274                    "parse expression",
275                    "order of operations",
276                    "precedence",
277                    "expression parsing",
278                    "operator order",
279                ],
280                confidence: 0.80,
281            },
282            // Type Annotation Gaps patterns
283            Rule {
284                category: DefectCategory::TypeAnnotationGaps,
285                patterns: vec![
286                    "type annotation",
287                    "type hint",
288                    "unsupported type",
289                    "generic type",
290                    "type parameter",
291                    "annotation",
292                    "typing",
293                ],
294                confidence: 0.75,
295            },
296            // Stdlib Mapping patterns
297            Rule {
298                category: DefectCategory::StdlibMapping,
299                patterns: vec![
300                    "stdlib",
301                    "standard library",
302                    "python to rust",
303                    "library mapping",
304                    "std::",
305                    "builtin",
306                    "library conversion",
307                ],
308                confidence: 0.80,
309            },
310            // AST Transform patterns
311            Rule {
312                category: DefectCategory::ASTTransform,
313                patterns: vec![
314                    "ast",
315                    "hir",
316                    "codegen",
317                    "transform",
318                    "syntax tree",
319                    "ast node",
320                    "tree traversal",
321                ],
322                confidence: 0.85,
323            },
324            // Comprehension Bugs patterns
325            Rule {
326                category: DefectCategory::ComprehensionBugs,
327                patterns: vec![
328                    "comprehension",
329                    "list comprehension",
330                    "dict comprehension",
331                    "set comprehension",
332                    "generator",
333                    "generator expression",
334                ],
335                confidence: 0.80,
336            },
337            // Iterator Chain patterns
338            Rule {
339                category: DefectCategory::IteratorChain,
340                patterns: vec![
341                    "iterator",
342                    "into_iter",
343                    ".map(",
344                    ".filter(",
345                    ".chain(",
346                    "iterator chain",
347                    "iter method",
348                ],
349                confidence: 0.80,
350            },
351            // Ownership/Borrow patterns
352            Rule {
353                category: DefectCategory::OwnershipBorrow,
354                patterns: vec![
355                    "ownership",
356                    "borrow",
357                    "lifetime",
358                    "borrow checker",
359                    "move",
360                    "borrowed value",
361                    "lifetime parameter",
362                ],
363                confidence: 0.85,
364            },
365            // Trait Bounds patterns
366            Rule {
367                category: DefectCategory::TraitBounds,
368                patterns: vec![
369                    "trait bound",
370                    "generic constraint",
371                    "where clause",
372                    "impl trait",
373                    "trait constraint",
374                    "bound",
375                ],
376                confidence: 0.80,
377            },
378        ];
379
380        Self { rules }
381    }
382
383    /// Classify a defect based on commit message
384    ///
385    /// # Arguments
386    /// * `message` - Commit message text
387    ///
388    /// # Returns
389    /// * `Some(Classification)` if patterns match
390    /// * `None` if no patterns match (not a defect fix)
391    ///
392    /// # Examples
393    /// ```
394    /// use organizational_intelligence_plugin::classifier::RuleBasedClassifier;
395    ///
396    /// let classifier = RuleBasedClassifier::new();
397    /// let result = classifier.classify_from_message("fix: null pointer dereference");
398    ///
399    /// assert!(result.is_some());
400    /// ```
401    pub fn classify_from_message(&self, message: &str) -> Option<Classification> {
402        let message_lower = message.to_lowercase();
403
404        debug!("Classifying message: {}", message);
405
406        let mut matches: Vec<(DefectCategory, f32, Vec<String>)> = Vec::new();
407
408        // Check each rule
409        for rule in &self.rules {
410            let mut matched_patterns = Vec::new();
411
412            for pattern in &rule.patterns {
413                if message_lower.contains(pattern) {
414                    matched_patterns.push(pattern.to_string());
415                }
416            }
417
418            if !matched_patterns.is_empty() {
419                // Boost confidence if multiple patterns match
420                let confidence_boost = (matched_patterns.len() - 1) as f32 * 0.05;
421                let adjusted_confidence = (rule.confidence + confidence_boost).min(0.95);
422
423                matches.push((rule.category, adjusted_confidence, matched_patterns));
424            }
425        }
426
427        if matches.is_empty() {
428            debug!("No patterns matched for message");
429            return None;
430        }
431
432        // Sort by confidence (highest first)
433        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
434
435        // Take the highest confidence match
436        let (category, confidence, matched_patterns) = matches.into_iter().next().unwrap();
437
438        let explanation = format!(
439            "Classified as '{}' based on patterns: {}. Confidence: {:.0}%",
440            category.as_str(),
441            matched_patterns.join(", "),
442            confidence * 100.0
443        );
444
445        debug!(
446            "Classification: {:?} with confidence {}",
447            category, confidence
448        );
449
450        Some(Classification {
451            category,
452            confidence,
453            explanation,
454            matched_patterns,
455        })
456    }
457
458    /// Classify a defect with multi-label support (top-N categories)
459    ///
460    /// Returns top-N categories that match patterns above the confidence threshold.
461    /// Implements Section 5.3 Multi-Label Classification from nlp-models-techniques-spec.md
462    ///
463    /// # Arguments
464    /// * `message` - Commit message text
465    /// * `top_n` - Maximum number of categories to return (default 3)
466    /// * `min_confidence` - Minimum confidence threshold (default 0.60)
467    ///
468    /// # Returns
469    /// * `Some(MultiLabelClassification)` if patterns match
470    /// * `None` if no patterns match
471    ///
472    /// # Examples
473    /// ```
474    /// use organizational_intelligence_plugin::classifier::RuleBasedClassifier;
475    ///
476    /// let classifier = RuleBasedClassifier::new();
477    /// let result = classifier.classify_multi_label(
478    ///     "fix: null pointer in ast transform",
479    ///     3,
480    ///     0.60
481    /// );
482    ///
483    /// assert!(result.is_some());
484    /// let classification = result.unwrap();
485    /// assert!(classification.categories.len() >= 1);
486    /// assert!(classification.categories.len() <= 3);
487    /// ```
488    pub fn classify_multi_label(
489        &self,
490        message: &str,
491        top_n: usize,
492        min_confidence: f32,
493    ) -> Option<MultiLabelClassification> {
494        let message_lower = message.to_lowercase();
495
496        debug!(
497            "Multi-label classifying message: {} (top_n={}, min_confidence={})",
498            message, top_n, min_confidence
499        );
500
501        let mut matches: Vec<(DefectCategory, f32, Vec<String>)> = Vec::new();
502
503        // Check each rule
504        for rule in &self.rules {
505            let mut matched_patterns = Vec::new();
506
507            for pattern in &rule.patterns {
508                if message_lower.contains(pattern) {
509                    matched_patterns.push(pattern.to_string());
510                }
511            }
512
513            if !matched_patterns.is_empty() {
514                // Boost confidence if multiple patterns match
515                let confidence_boost = (matched_patterns.len() - 1) as f32 * 0.05;
516                let adjusted_confidence = (rule.confidence + confidence_boost).min(0.95);
517
518                matches.push((rule.category, adjusted_confidence, matched_patterns));
519            }
520        }
521
522        if matches.is_empty() {
523            debug!("No patterns matched for multi-label classification");
524            return None;
525        }
526
527        // Sort by confidence (highest first)
528        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
529
530        // Filter by minimum confidence and take top-N
531        let filtered_matches: Vec<(DefectCategory, f32, Vec<String>)> = matches
532            .into_iter()
533            .filter(|(_, confidence, _)| *confidence >= min_confidence)
534            .take(top_n)
535            .collect();
536
537        if filtered_matches.is_empty() {
538            debug!("No matches above confidence threshold {}", min_confidence);
539            return None;
540        }
541
542        // Extract categories and confidence scores
543        let categories: Vec<(DefectCategory, f32)> = filtered_matches
544            .iter()
545            .map(|(cat, conf, _)| (*cat, *conf))
546            .collect();
547
548        // Primary category is the highest confidence
549        let (primary_category, primary_confidence) = categories[0];
550
551        // Collect all unique matched patterns
552        let mut all_matched_patterns: Vec<String> = Vec::new();
553        for (_, _, patterns) in &filtered_matches {
554            for pattern in patterns {
555                if !all_matched_patterns.contains(pattern) {
556                    all_matched_patterns.push(pattern.clone());
557                }
558            }
559        }
560
561        debug!(
562            "Multi-label classification: {} categories, primary: {:?} ({})",
563            categories.len(),
564            primary_category,
565            primary_confidence
566        );
567
568        Some(MultiLabelClassification {
569            categories,
570            primary_category,
571            primary_confidence,
572            matched_patterns: all_matched_patterns,
573        })
574    }
575}
576
577impl Default for RuleBasedClassifier {
578    fn default() -> Self {
579        Self::new()
580    }
581}
582
583/// Hybrid classifier combining rule-based and ML approaches
584///
585/// NLP-010: Integrates trained ML models with fallback to rule-based classification
586/// Implements three-tier architecture from nlp-models-techniques-spec.md:
587/// - Tier 1: Rule-based (fast, <10ms)
588/// - Tier 2: TF-IDF + Random Forest (medium, <100ms)
589/// - Tier 3: Transformer models (future work)
590pub enum HybridClassifier {
591    /// Rule-based only (Tier 1)
592    RuleBased(RuleBasedClassifier),
593    /// ML model with rule-based fallback (Tier 2 + Tier 1)
594    Hybrid {
595        ml_model: Box<crate::ml_trainer::TrainedModel>,
596        fallback: RuleBasedClassifier,
597        confidence_threshold: f32,
598    },
599}
600
601impl HybridClassifier {
602    /// Create a new rule-based classifier (Tier 1 only)
603    ///
604    /// # Examples
605    /// ```
606    /// use organizational_intelligence_plugin::classifier::HybridClassifier;
607    ///
608    /// let classifier = HybridClassifier::new_rule_based();
609    /// ```
610    pub fn new_rule_based() -> Self {
611        Self::RuleBased(RuleBasedClassifier::new())
612    }
613
614    /// Load a trained ML model with rule-based fallback (Tier 2 + Tier 1)
615    ///
616    /// # Arguments
617    /// * `model` - Trained ML model
618    /// * `confidence_threshold` - Minimum confidence to use ML prediction (default: 0.60)
619    ///
620    /// # Examples
621    /// ```no_run
622    /// # use organizational_intelligence_plugin::classifier::HybridClassifier;
623    /// # use organizational_intelligence_plugin::ml_trainer::MLTrainer;
624    /// # fn example(model: organizational_intelligence_plugin::ml_trainer::TrainedModel) {
625    /// let classifier = HybridClassifier::new_hybrid(model, 0.65);
626    /// # }
627    /// ```
628    pub fn new_hybrid(
629        ml_model: crate::ml_trainer::TrainedModel,
630        confidence_threshold: f32,
631    ) -> Self {
632        Self::Hybrid {
633            ml_model: Box::new(ml_model),
634            fallback: RuleBasedClassifier::new(),
635            confidence_threshold,
636        }
637    }
638
639    /// Classify a commit message
640    ///
641    /// Uses ML model if available and confident, otherwise falls back to rule-based.
642    ///
643    /// # Arguments
644    /// * `message` - Commit message to classify
645    ///
646    /// # Returns
647    /// * `Some(Classification)` if a category is detected
648    /// * `None` if no patterns match
649    ///
650    /// # Examples
651    /// ```
652    /// use organizational_intelligence_plugin::classifier::HybridClassifier;
653    ///
654    /// let classifier = HybridClassifier::new_rule_based();
655    /// if let Some(result) = classifier.classify_from_message("fix: null pointer") {
656    ///     println!("Category: {:?}, Confidence: {:.2}", result.category, result.confidence);
657    /// }
658    /// ```
659    pub fn classify_from_message(&self, message: &str) -> Option<Classification> {
660        match self {
661            Self::RuleBased(classifier) => classifier.classify_from_message(message),
662            Self::Hybrid {
663                ml_model,
664                fallback,
665                confidence_threshold,
666            } => {
667                // Try ML model first
668                if let Ok(Some((category, confidence))) = ml_model.predict(message) {
669                    if confidence >= *confidence_threshold {
670                        return Some(Classification {
671                            category,
672                            confidence,
673                            explanation: format!("ML prediction (confidence: {:.2})", confidence),
674                            matched_patterns: vec!["ML-based classification".to_string()],
675                        });
676                    }
677                }
678
679                // Fall back to rule-based
680                fallback.classify_from_message(message)
681            }
682        }
683    }
684
685    /// Classify with multiple labels
686    ///
687    /// # Arguments
688    /// * `message` - Commit message to classify
689    /// * `top_n` - Maximum number of categories to return
690    /// * `min_confidence` - Minimum confidence threshold (0.0 to 1.0)
691    ///
692    /// # Returns
693    /// * `Ok(MultiLabelClassification)` with top-N categories
694    ///
695    /// # Examples
696    /// ```
697    /// use organizational_intelligence_plugin::classifier::HybridClassifier;
698    ///
699    /// let classifier = HybridClassifier::new_rule_based();
700    /// let result = classifier.classify_multi_label("fix: null pointer in parser", 3, 0.60).unwrap();
701    /// println!("Primary: {:?} ({:.2})", result.primary_category, result.primary_confidence);
702    /// ```
703    pub fn classify_multi_label(
704        &self,
705        message: &str,
706        top_n: usize,
707        min_confidence: f32,
708    ) -> anyhow::Result<MultiLabelClassification> {
709        match self {
710            Self::RuleBased(classifier) => classifier
711                .classify_multi_label(message, top_n, min_confidence)
712                .ok_or_else(|| anyhow::anyhow!("No classification found")),
713            Self::Hybrid {
714                ml_model, fallback, ..
715            } => {
716                // Try ML model for multi-label
717                if let Ok(predictions) = ml_model.predict_top_n(message, top_n) {
718                    if !predictions.is_empty() {
719                        let filtered: Vec<(DefectCategory, f32)> = predictions
720                            .into_iter()
721                            .filter(|(_, conf)| *conf >= min_confidence)
722                            .collect();
723
724                        if !filtered.is_empty() {
725                            return Ok(MultiLabelClassification {
726                                primary_category: filtered[0].0,
727                                primary_confidence: filtered[0].1,
728                                categories: filtered.clone(),
729                                matched_patterns: vec!["ML-based classification".to_string()],
730                            });
731                        }
732                    }
733                }
734
735                // Fall back to rule-based
736                fallback
737                    .classify_multi_label(message, top_n, min_confidence)
738                    .ok_or_else(|| anyhow::anyhow!("No classification found"))
739            }
740        }
741    }
742}
743
744impl Default for HybridClassifier {
745    fn default() -> Self {
746        Self::new_rule_based()
747    }
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    #[test]
755    fn test_classifier_creation() {
756        let _classifier = RuleBasedClassifier::new();
757    }
758
759    #[test]
760    fn test_all_categories_covered() {
761        let classifier = RuleBasedClassifier::new();
762
763        // Verify we have rules for all 18 categories (10 general + 8 transpiler)
764        let mut categories_covered = std::collections::HashSet::new();
765        for rule in &classifier.rules {
766            categories_covered.insert(rule.category);
767        }
768
769        assert_eq!(
770            categories_covered.len(),
771            18,
772            "Should have rules for all 18 categories (10 general + 8 transpiler)"
773        );
774    }
775
776    #[test]
777    fn test_pattern_matching() {
778        let classifier = RuleBasedClassifier::new();
779
780        let test_cases = vec![
781            ("fix: use-after-free bug", DefectCategory::MemorySafety),
782            ("fix: race condition", DefectCategory::ConcurrencyBugs),
783            (
784                "security: prevent SQL injection",
785                DefectCategory::SecurityVulnerabilities,
786            ),
787        ];
788
789        for (message, expected_category) in test_cases {
790            let result = classifier.classify_from_message(message);
791            assert!(result.is_some(), "Should classify: {}", message);
792            assert_eq!(result.unwrap().category, expected_category);
793        }
794    }
795
796    #[test]
797    fn test_non_defect_returns_none() {
798        let classifier = RuleBasedClassifier::new();
799
800        let non_defect_messages = vec![
801            "docs: update README",
802            "chore: bump version",
803            "feat: add new feature",
804            "refactor: simplify code",
805        ];
806
807        for message in non_defect_messages {
808            let result = classifier.classify_from_message(message);
809            assert!(
810                result.is_none(),
811                "Should not classify as defect: {}",
812                message
813            );
814        }
815    }
816
817    #[test]
818    fn test_defect_category_as_str() {
819        // General categories
820        assert_eq!(DefectCategory::MemorySafety.as_str(), "Memory Safety");
821        assert_eq!(DefectCategory::ConcurrencyBugs.as_str(), "Concurrency Bugs");
822        assert_eq!(DefectCategory::LogicErrors.as_str(), "Logic Errors");
823        assert_eq!(DefectCategory::ApiMisuse.as_str(), "API Misuse");
824        assert_eq!(DefectCategory::ResourceLeaks.as_str(), "Resource Leaks");
825        assert_eq!(DefectCategory::TypeErrors.as_str(), "Type Errors");
826        assert_eq!(
827            DefectCategory::ConfigurationErrors.as_str(),
828            "Configuration Errors"
829        );
830        assert_eq!(
831            DefectCategory::SecurityVulnerabilities.as_str(),
832            "Security Vulnerabilities"
833        );
834        assert_eq!(
835            DefectCategory::PerformanceIssues.as_str(),
836            "Performance Issues"
837        );
838        assert_eq!(
839            DefectCategory::IntegrationFailures.as_str(),
840            "Integration Failures"
841        );
842        // Transpiler categories
843        assert_eq!(
844            DefectCategory::OperatorPrecedence.as_str(),
845            "Operator Precedence"
846        );
847        assert_eq!(
848            DefectCategory::TypeAnnotationGaps.as_str(),
849            "Type Annotation Gaps"
850        );
851        assert_eq!(DefectCategory::StdlibMapping.as_str(), "Stdlib Mapping");
852        assert_eq!(DefectCategory::ASTTransform.as_str(), "AST Transform");
853        assert_eq!(
854            DefectCategory::ComprehensionBugs.as_str(),
855            "Comprehension Bugs"
856        );
857        assert_eq!(DefectCategory::IteratorChain.as_str(), "Iterator Chain");
858        assert_eq!(DefectCategory::OwnershipBorrow.as_str(), "Ownership/Borrow");
859        assert_eq!(DefectCategory::TraitBounds.as_str(), "Trait Bounds");
860    }
861
862    #[test]
863    fn test_defect_category_display() {
864        // General categories
865        assert_eq!(format!("{}", DefectCategory::MemorySafety), "MemorySafety");
866        assert_eq!(
867            format!("{}", DefectCategory::ConcurrencyBugs),
868            "ConcurrencyBugs"
869        );
870        assert_eq!(format!("{}", DefectCategory::LogicErrors), "LogicErrors");
871        assert_eq!(format!("{}", DefectCategory::ApiMisuse), "ApiMisuse");
872        assert_eq!(
873            format!("{}", DefectCategory::ResourceLeaks),
874            "ResourceLeaks"
875        );
876        assert_eq!(format!("{}", DefectCategory::TypeErrors), "TypeErrors");
877        assert_eq!(
878            format!("{}", DefectCategory::ConfigurationErrors),
879            "ConfigurationErrors"
880        );
881        assert_eq!(
882            format!("{}", DefectCategory::SecurityVulnerabilities),
883            "SecurityVulnerabilities"
884        );
885        assert_eq!(
886            format!("{}", DefectCategory::PerformanceIssues),
887            "PerformanceIssues"
888        );
889        assert_eq!(
890            format!("{}", DefectCategory::IntegrationFailures),
891            "IntegrationFailures"
892        );
893        // Transpiler categories
894        assert_eq!(
895            format!("{}", DefectCategory::OperatorPrecedence),
896            "OperatorPrecedence"
897        );
898        assert_eq!(
899            format!("{}", DefectCategory::TypeAnnotationGaps),
900            "TypeAnnotationGaps"
901        );
902        assert_eq!(
903            format!("{}", DefectCategory::StdlibMapping),
904            "StdlibMapping"
905        );
906        assert_eq!(format!("{}", DefectCategory::ASTTransform), "ASTTransform");
907        assert_eq!(
908            format!("{}", DefectCategory::ComprehensionBugs),
909            "ComprehensionBugs"
910        );
911        assert_eq!(
912            format!("{}", DefectCategory::IteratorChain),
913            "IteratorChain"
914        );
915        assert_eq!(
916            format!("{}", DefectCategory::OwnershipBorrow),
917            "OwnershipBorrow"
918        );
919        assert_eq!(format!("{}", DefectCategory::TraitBounds), "TraitBounds");
920    }
921
922    #[test]
923    fn test_default_constructor() {
924        let classifier = RuleBasedClassifier::default();
925        assert_eq!(classifier.rules.len(), 18);
926    }
927
928    #[test]
929    fn test_empty_message() {
930        let classifier = RuleBasedClassifier::new();
931        let result = classifier.classify_from_message("");
932        assert!(result.is_none());
933    }
934
935    #[test]
936    fn test_case_insensitive_matching() {
937        let classifier = RuleBasedClassifier::new();
938
939        let result = classifier.classify_from_message("Fix: NULL POINTER dereference");
940        assert!(result.is_some());
941        assert_eq!(result.unwrap().category, DefectCategory::MemorySafety);
942    }
943
944    #[test]
945    fn test_multiple_patterns_boost_confidence() {
946        let classifier = RuleBasedClassifier::new();
947
948        // Message with multiple memory safety patterns
949        let result = classifier
950            .classify_from_message("fix: null pointer and buffer overflow")
951            .unwrap();
952
953        assert_eq!(result.category, DefectCategory::MemorySafety);
954        // Base confidence 0.85 + 0.05 boost for 2nd pattern = 0.90
955        assert!(result.confidence >= 0.85);
956        assert_eq!(result.matched_patterns.len(), 2);
957    }
958
959    #[test]
960    fn test_confidence_capped_at_95_percent() {
961        let classifier = RuleBasedClassifier::new();
962
963        // Message with many security patterns to exceed 0.95 cap
964        let result = classifier
965            .classify_from_message(
966                "security vulnerability exploit with sql injection and xss and cve-2024-1234",
967            )
968            .unwrap();
969
970        assert_eq!(result.category, DefectCategory::SecurityVulnerabilities);
971        assert!(result.confidence <= 0.95);
972    }
973
974    #[test]
975    fn test_highest_confidence_wins() {
976        let classifier = RuleBasedClassifier::new();
977
978        // "security" has higher confidence (0.90) than "performance" (0.65)
979        let result = classifier
980            .classify_from_message("fix security and performance issues")
981            .unwrap();
982
983        assert_eq!(result.category, DefectCategory::SecurityVulnerabilities);
984    }
985
986    #[test]
987    fn test_all_categories_classifiable() {
988        let classifier = RuleBasedClassifier::new();
989
990        let test_cases = vec![
991            // General categories
992            ("null pointer bug", DefectCategory::MemorySafety),
993            ("race condition fix", DefectCategory::ConcurrencyBugs),
994            ("off by one error", DefectCategory::LogicErrors),
995            ("api misuse fix", DefectCategory::ApiMisuse),
996            ("resource leak fix", DefectCategory::ResourceLeaks),
997            ("type error fix", DefectCategory::TypeErrors),
998            ("configuration bug", DefectCategory::ConfigurationErrors),
999            ("security fix", DefectCategory::SecurityVulnerabilities),
1000            ("performance fix", DefectCategory::PerformanceIssues),
1001            ("integration failure", DefectCategory::IntegrationFailures),
1002            // Transpiler categories
1003            (
1004                "fix operator precedence issue",
1005                DefectCategory::OperatorPrecedence,
1006            ),
1007            (
1008                "type annotation not supported",
1009                DefectCategory::TypeAnnotationGaps,
1010            ),
1011            ("stdlib mapping bug", DefectCategory::StdlibMapping),
1012            ("ast transform error", DefectCategory::ASTTransform),
1013            ("list comprehension bug", DefectCategory::ComprehensionBugs),
1014            ("iterator chain issue", DefectCategory::IteratorChain),
1015            ("ownership error", DefectCategory::OwnershipBorrow),
1016            ("trait bound issue", DefectCategory::TraitBounds),
1017        ];
1018
1019        for (message, expected_category) in test_cases {
1020            let result = classifier.classify_from_message(message);
1021            assert!(result.is_some(), "Should classify: {}", message);
1022            assert_eq!(
1023                result.unwrap().category,
1024                expected_category,
1025                "Failed for: {}",
1026                message
1027            );
1028        }
1029    }
1030
1031    #[test]
1032    fn test_classification_struct_fields() {
1033        let classifier = RuleBasedClassifier::new();
1034        let result = classifier
1035            .classify_from_message("fix: deadlock in mutex")
1036            .unwrap();
1037
1038        assert_eq!(result.category, DefectCategory::ConcurrencyBugs);
1039        assert!(result.confidence > 0.0 && result.confidence <= 1.0);
1040        assert!(!result.explanation.is_empty());
1041        assert!(!result.matched_patterns.is_empty());
1042    }
1043
1044    #[test]
1045    fn test_explanation_format() {
1046        let classifier = RuleBasedClassifier::new();
1047        let result = classifier
1048            .classify_from_message("fix: sql injection vulnerability")
1049            .unwrap();
1050
1051        assert!(result.explanation.contains("Security Vulnerabilities"));
1052        assert!(result.explanation.contains("sql injection"));
1053        assert!(result.explanation.contains("Confidence:"));
1054        assert!(result.explanation.contains("%"));
1055    }
1056
1057    #[test]
1058    fn test_matched_patterns_populated() {
1059        let classifier = RuleBasedClassifier::new();
1060        let result = classifier
1061            .classify_from_message("fix: double free and memory leak")
1062            .unwrap();
1063
1064        assert_eq!(result.matched_patterns.len(), 2);
1065        assert!(result.matched_patterns.contains(&"double free".to_string()));
1066        assert!(result.matched_patterns.contains(&"memory leak".to_string()));
1067    }
1068
1069    #[test]
1070    fn test_transpiler_operator_precedence_classification() {
1071        let classifier = RuleBasedClassifier::new();
1072
1073        let test_cases = vec![
1074            "fix: operator precedence bug in expression parser",
1075            "fix: incorrect parentheses handling",
1076            "fix: parse expression order of operations",
1077        ];
1078
1079        for message in test_cases {
1080            let result = classifier.classify_from_message(message);
1081            assert!(result.is_some(), "Should classify: {}", message);
1082            assert_eq!(
1083                result.unwrap().category,
1084                DefectCategory::OperatorPrecedence,
1085                "Failed for: {}",
1086                message
1087            );
1088        }
1089    }
1090
1091    #[test]
1092    fn test_transpiler_type_annotation_classification() {
1093        let classifier = RuleBasedClassifier::new();
1094
1095        let result = classifier
1096            .classify_from_message("fix: type annotation gap in generic type")
1097            .unwrap();
1098
1099        assert_eq!(result.category, DefectCategory::TypeAnnotationGaps);
1100        assert!(result.matched_patterns.len() >= 2);
1101    }
1102
1103    #[test]
1104    fn test_transpiler_ownership_classification() {
1105        let classifier = RuleBasedClassifier::new();
1106
1107        let test_cases = vec![
1108            "fix: borrow checker error in iterator",
1109            "fix: lifetime parameter issue",
1110            "fix: ownership move bug",
1111        ];
1112
1113        for message in test_cases {
1114            let result = classifier.classify_from_message(message);
1115            assert!(result.is_some(), "Should classify: {}", message);
1116            assert_eq!(
1117                result.unwrap().category,
1118                DefectCategory::OwnershipBorrow,
1119                "Failed for: {}",
1120                message
1121            );
1122        }
1123    }
1124
1125    #[test]
1126    fn test_transpiler_comprehension_classification() {
1127        let classifier = RuleBasedClassifier::new();
1128
1129        let result = classifier
1130            .classify_from_message("fix: dict comprehension generation bug")
1131            .unwrap();
1132
1133        assert_eq!(result.category, DefectCategory::ComprehensionBugs);
1134        assert!(result.confidence >= 0.80);
1135    }
1136
1137    #[test]
1138    fn test_transpiler_iterator_chain_classification() {
1139        let classifier = RuleBasedClassifier::new();
1140
1141        let result = classifier
1142            .classify_from_message("fix: .map( and .filter( iterator chain issue")
1143            .unwrap();
1144
1145        assert_eq!(result.category, DefectCategory::IteratorChain);
1146        assert!(result.matched_patterns.len() >= 2);
1147    }
1148
1149    #[test]
1150    fn test_transpiler_ast_transform_classification() {
1151        let classifier = RuleBasedClassifier::new();
1152
1153        let result = classifier
1154            .classify_from_message("fix: ast node transform in codegen")
1155            .unwrap();
1156
1157        assert_eq!(result.category, DefectCategory::ASTTransform);
1158        assert!(result.confidence >= 0.85);
1159    }
1160
1161    #[test]
1162    fn test_transpiler_stdlib_mapping_classification() {
1163        let classifier = RuleBasedClassifier::new();
1164
1165        let result = classifier
1166            .classify_from_message("fix: stdlib mapping from python to rust")
1167            .unwrap();
1168
1169        assert_eq!(result.category, DefectCategory::StdlibMapping);
1170    }
1171
1172    #[test]
1173    fn test_transpiler_trait_bounds_classification() {
1174        let classifier = RuleBasedClassifier::new();
1175
1176        let result = classifier
1177            .classify_from_message("fix: trait bound issue in where clause")
1178            .unwrap();
1179
1180        assert_eq!(result.category, DefectCategory::TraitBounds);
1181        assert!(result.matched_patterns.len() >= 2);
1182    }
1183
1184    // Multi-label classification tests
1185
1186    #[test]
1187    fn test_multi_label_basic() {
1188        let classifier = RuleBasedClassifier::new();
1189
1190        // Message that matches multiple categories
1191        let result = classifier
1192            .classify_multi_label("fix: null pointer in ast transform", 3, 0.60)
1193            .unwrap();
1194
1195        assert!(!result.categories.is_empty());
1196        assert!(result.categories.len() <= 3);
1197        assert_eq!(result.primary_category, result.categories[0].0);
1198        assert_eq!(result.primary_confidence, result.categories[0].1);
1199    }
1200
1201    #[test]
1202    fn test_multi_label_multiple_categories() {
1203        let classifier = RuleBasedClassifier::new();
1204
1205        // Message with patterns from multiple categories
1206        let result = classifier
1207            .classify_multi_label(
1208                "fix: memory leak and security vulnerability in ast transform",
1209                3,
1210                0.60,
1211            )
1212            .unwrap();
1213
1214        // Should detect at least 2 categories (MemorySafety, SecurityVulnerabilities)
1215        assert!(result.categories.len() >= 2);
1216
1217        // Verify categories are sorted by confidence
1218        for i in 0..result.categories.len() - 1 {
1219            assert!(result.categories[i].1 >= result.categories[i + 1].1);
1220        }
1221    }
1222
1223    #[test]
1224    fn test_multi_label_confidence_threshold() {
1225        let classifier = RuleBasedClassifier::new();
1226
1227        let message = "fix: memory leak";
1228
1229        // High threshold should return fewer results
1230        let result_high = classifier.classify_multi_label(message, 5, 0.90);
1231
1232        // Low threshold should return more results
1233        let result_low = classifier.classify_multi_label(message, 5, 0.60).unwrap();
1234
1235        if let Some(high) = result_high {
1236            assert!(high.categories.len() <= result_low.categories.len());
1237        }
1238
1239        // All returned categories should meet minimum confidence
1240        for (_, confidence) in &result_low.categories {
1241            assert!(*confidence >= 0.60);
1242        }
1243    }
1244
1245    #[test]
1246    fn test_multi_label_top_n_limiting() {
1247        let classifier = RuleBasedClassifier::new();
1248
1249        // Message that matches many categories
1250        let message = "fix: security memory performance integration";
1251
1252        let result_top_1 = classifier.classify_multi_label(message, 1, 0.60).unwrap();
1253        let result_top_3 = classifier.classify_multi_label(message, 3, 0.60).unwrap();
1254
1255        assert_eq!(result_top_1.categories.len(), 1);
1256        assert!(result_top_3.categories.len() <= 3);
1257        assert!(result_top_3.categories.len() >= result_top_1.categories.len());
1258    }
1259
1260    #[test]
1261    fn test_multi_label_single_category() {
1262        let classifier = RuleBasedClassifier::new();
1263
1264        // Message that only matches one category clearly
1265        let result = classifier
1266            .classify_multi_label("fix: deadlock in mutex", 3, 0.60)
1267            .unwrap();
1268
1269        assert_eq!(result.categories.len(), 1);
1270        assert_eq!(result.primary_category, DefectCategory::ConcurrencyBugs);
1271    }
1272
1273    #[test]
1274    fn test_multi_label_no_match() {
1275        let classifier = RuleBasedClassifier::new();
1276
1277        let result = classifier.classify_multi_label("docs: update README", 3, 0.60);
1278
1279        assert!(result.is_none());
1280    }
1281
1282    #[test]
1283    fn test_multi_label_all_patterns_collected() {
1284        let classifier = RuleBasedClassifier::new();
1285
1286        let result = classifier
1287            .classify_multi_label("fix: memory leak and buffer overflow", 3, 0.60)
1288            .unwrap();
1289
1290        // Should collect patterns from MemorySafety category
1291        assert!(result.matched_patterns.contains(&"memory leak".to_string()));
1292        assert!(result
1293            .matched_patterns
1294            .contains(&"buffer overflow".to_string()));
1295    }
1296
1297    #[test]
1298    fn test_multi_label_primary_is_highest_confidence() {
1299        let classifier = RuleBasedClassifier::new();
1300
1301        let result = classifier
1302            .classify_multi_label("fix: security and performance", 3, 0.60)
1303            .unwrap();
1304
1305        // Primary should be the first (highest confidence) category
1306        assert_eq!(result.primary_category, result.categories[0].0);
1307        assert_eq!(result.primary_confidence, result.categories[0].1);
1308
1309        // Security has higher confidence (0.90) than Performance (0.65)
1310        assert_eq!(
1311            result.primary_category,
1312            DefectCategory::SecurityVulnerabilities
1313        );
1314    }
1315
1316    #[test]
1317    fn test_multi_label_confidence_boost() {
1318        let classifier = RuleBasedClassifier::new();
1319
1320        // Multiple patterns should boost confidence
1321        let result = classifier
1322            .classify_multi_label("fix: null pointer and buffer overflow", 3, 0.60)
1323            .unwrap();
1324
1325        // Should detect MemorySafety with confidence boost (2 patterns)
1326        assert_eq!(result.primary_category, DefectCategory::MemorySafety);
1327        assert!(result.primary_confidence > 0.85); // Base confidence + boost
1328    }
1329
1330    #[test]
1331    fn test_multi_label_struct_serialization() {
1332        let classification = MultiLabelClassification {
1333            categories: vec![
1334                (DefectCategory::MemorySafety, 0.90),
1335                (DefectCategory::ConcurrencyBugs, 0.75),
1336            ],
1337            primary_category: DefectCategory::MemorySafety,
1338            primary_confidence: 0.90,
1339            matched_patterns: vec!["memory leak".to_string()],
1340        };
1341
1342        let json = serde_json::to_string(&classification).unwrap();
1343        let deserialized: MultiLabelClassification = serde_json::from_str(&json).unwrap();
1344
1345        assert_eq!(
1346            classification.categories.len(),
1347            deserialized.categories.len()
1348        );
1349        assert_eq!(
1350            classification.primary_category,
1351            deserialized.primary_category
1352        );
1353    }
1354
1355    #[test]
1356    fn test_multi_label_zero_top_n() {
1357        let classifier = RuleBasedClassifier::new();
1358
1359        // top_n=0 should return None (no results)
1360        let result = classifier.classify_multi_label("fix: memory leak", 0, 0.60);
1361
1362        assert!(result.is_none());
1363    }
1364
1365    #[test]
1366    fn test_multi_label_very_high_threshold() {
1367        let classifier = RuleBasedClassifier::new();
1368
1369        // Threshold above all confidences should return None
1370        let result = classifier.classify_multi_label("fix: memory leak", 3, 0.99);
1371
1372        assert!(result.is_none());
1373    }
1374
1375    // ===== HybridClassifier Tests =====
1376
1377    #[test]
1378    fn test_hybrid_classifier_rule_based_variant() {
1379        let classifier = HybridClassifier::new_rule_based();
1380
1381        // Should work exactly like RuleBasedClassifier
1382        let result = classifier.classify_from_message("fix: null pointer dereference");
1383        assert!(result.is_some());
1384
1385        let classification = result.unwrap();
1386        assert_eq!(classification.category, DefectCategory::MemorySafety);
1387    }
1388
1389    #[test]
1390    fn test_hybrid_classifier_default() {
1391        let classifier = HybridClassifier::default();
1392
1393        // Default should be rule-based
1394        let result = classifier.classify_from_message("fix: race condition");
1395        assert!(result.is_some());
1396    }
1397
1398    #[test]
1399    fn test_hybrid_classifier_multi_label_rule_based() {
1400        let classifier = HybridClassifier::new_rule_based();
1401
1402        let result = classifier
1403            .classify_multi_label("fix: memory leak and null pointer", 3, 0.60)
1404            .unwrap();
1405
1406        assert!(!result.categories.is_empty());
1407        assert_eq!(result.primary_category, result.categories[0].0);
1408    }
1409
1410    #[test]
1411    fn test_hybrid_classifier_no_match() {
1412        let classifier = HybridClassifier::new_rule_based();
1413
1414        // Non-defect message
1415        let result = classifier.classify_from_message("docs: update README");
1416        assert!(result.is_none());
1417    }
1418
1419    #[test]
1420    fn test_hybrid_classifier_multi_label_no_match() {
1421        let classifier = HybridClassifier::new_rule_based();
1422
1423        // Should return error when no classification found
1424        let result = classifier.classify_multi_label("docs: update README", 3, 0.60);
1425        assert!(result.is_err());
1426    }
1427
1428    #[test]
1429    fn test_hybrid_classifier_various_categories() {
1430        let classifier = HybridClassifier::new_rule_based();
1431
1432        // Test multiple categories
1433        let test_cases = vec![
1434            (
1435                "fix: operator precedence bug",
1436                DefectCategory::OperatorPrecedence,
1437            ),
1438            (
1439                "fix: type annotation missing",
1440                DefectCategory::TypeAnnotationGaps,
1441            ),
1442            ("fix: stdlib mapping error", DefectCategory::StdlibMapping),
1443            ("fix: ast transform issue", DefectCategory::ASTTransform),
1444            ("fix: comprehension bug", DefectCategory::ComprehensionBugs),
1445            ("fix: iterator chain error", DefectCategory::IteratorChain),
1446            ("fix: ownership violation", DefectCategory::OwnershipBorrow),
1447            ("fix: trait bound issue", DefectCategory::TraitBounds),
1448        ];
1449
1450        for (message, expected_category) in test_cases {
1451            let result = classifier.classify_from_message(message);
1452            assert!(result.is_some(), "Failed to classify: {}", message);
1453            assert_eq!(
1454                result.unwrap().category,
1455                expected_category,
1456                "Wrong category for: {}",
1457                message
1458            );
1459        }
1460    }
1461}