use serde::{Deserialize, Serialize};
use std::fmt;
use tracing::debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DefectCategory {
MemorySafety,
ConcurrencyBugs,
LogicErrors,
ApiMisuse,
ResourceLeaks,
TypeErrors,
ConfigurationErrors,
SecurityVulnerabilities,
PerformanceIssues,
IntegrationFailures,
OperatorPrecedence,
TypeAnnotationGaps,
StdlibMapping,
ASTTransform,
ComprehensionBugs,
IteratorChain,
OwnershipBorrow,
TraitBounds,
}
impl DefectCategory {
pub fn as_str(&self) -> &'static str {
match self {
Self::MemorySafety => "Memory Safety",
Self::ConcurrencyBugs => "Concurrency Bugs",
Self::LogicErrors => "Logic Errors",
Self::ApiMisuse => "API Misuse",
Self::ResourceLeaks => "Resource Leaks",
Self::TypeErrors => "Type Errors",
Self::ConfigurationErrors => "Configuration Errors",
Self::SecurityVulnerabilities => "Security Vulnerabilities",
Self::PerformanceIssues => "Performance Issues",
Self::IntegrationFailures => "Integration Failures",
Self::OperatorPrecedence => "Operator Precedence",
Self::TypeAnnotationGaps => "Type Annotation Gaps",
Self::StdlibMapping => "Stdlib Mapping",
Self::ASTTransform => "AST Transform",
Self::ComprehensionBugs => "Comprehension Bugs",
Self::IteratorChain => "Iterator Chain",
Self::OwnershipBorrow => "Ownership/Borrow",
Self::TraitBounds => "Trait Bounds",
}
}
}
impl fmt::Display for DefectCategory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MemorySafety => write!(f, "MemorySafety"),
Self::ConcurrencyBugs => write!(f, "ConcurrencyBugs"),
Self::LogicErrors => write!(f, "LogicErrors"),
Self::ApiMisuse => write!(f, "ApiMisuse"),
Self::ResourceLeaks => write!(f, "ResourceLeaks"),
Self::TypeErrors => write!(f, "TypeErrors"),
Self::ConfigurationErrors => write!(f, "ConfigurationErrors"),
Self::SecurityVulnerabilities => write!(f, "SecurityVulnerabilities"),
Self::PerformanceIssues => write!(f, "PerformanceIssues"),
Self::IntegrationFailures => write!(f, "IntegrationFailures"),
Self::OperatorPrecedence => write!(f, "OperatorPrecedence"),
Self::TypeAnnotationGaps => write!(f, "TypeAnnotationGaps"),
Self::StdlibMapping => write!(f, "StdlibMapping"),
Self::ASTTransform => write!(f, "ASTTransform"),
Self::ComprehensionBugs => write!(f, "ComprehensionBugs"),
Self::IteratorChain => write!(f, "IteratorChain"),
Self::OwnershipBorrow => write!(f, "OwnershipBorrow"),
Self::TraitBounds => write!(f, "TraitBounds"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Classification {
pub category: DefectCategory,
pub confidence: f32, pub explanation: String,
pub matched_patterns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiLabelClassification {
pub categories: Vec<(DefectCategory, f32)>, pub primary_category: DefectCategory,
pub primary_confidence: f32,
pub matched_patterns: Vec<String>,
}
#[derive(Debug, Clone)]
struct Rule {
category: DefectCategory,
patterns: Vec<&'static str>,
confidence: f32,
}
pub struct RuleBasedClassifier {
rules: Vec<Rule>,
}
impl RuleBasedClassifier {
pub fn new() -> Self {
let rules = vec![
Rule {
category: DefectCategory::MemorySafety,
patterns: vec![
"use after free",
"use-after-free",
"null pointer",
"nullptr",
"buffer overflow",
"memory leak",
"dangling pointer",
"double free",
"heap corruption",
],
confidence: 0.85,
},
Rule {
category: DefectCategory::ConcurrencyBugs,
patterns: vec![
"race condition",
"data race",
"deadlock",
"atomicity",
"thread safety",
"concurrent",
"synchronization",
"mutex",
"lock contention",
],
confidence: 0.80,
},
Rule {
category: DefectCategory::SecurityVulnerabilities,
patterns: vec![
"sql injection",
"xss",
"cross-site scripting",
"authentication",
"authorization",
"security",
"vulnerability",
"exploit",
"cve-",
],
confidence: 0.90,
},
Rule {
category: DefectCategory::LogicErrors,
patterns: vec![
"off by one",
"off-by-one",
"boundary",
"incorrect logic",
"wrong condition",
"infinite loop",
],
confidence: 0.70,
},
Rule {
category: DefectCategory::ApiMisuse,
patterns: vec![
"api misuse",
"wrong parameter",
"incorrect usage",
"missing error handling",
"unchecked error",
],
confidence: 0.75,
},
Rule {
category: DefectCategory::ResourceLeaks,
patterns: vec![
"resource leak",
"file handle leak",
"connection leak",
"not closed",
"forgot to close",
],
confidence: 0.80,
},
Rule {
category: DefectCategory::TypeErrors,
patterns: vec![
"type error",
"type mismatch",
"casting error",
"serialization",
"deserialization",
],
confidence: 0.75,
},
Rule {
category: DefectCategory::ConfigurationErrors,
patterns: vec![
"configuration",
"config",
"environment variable",
"missing env",
"settings",
],
confidence: 0.70,
},
Rule {
category: DefectCategory::PerformanceIssues,
patterns: vec![
"performance",
"slow",
"inefficient",
"n+1 query",
"optimization",
],
confidence: 0.65,
},
Rule {
category: DefectCategory::IntegrationFailures,
patterns: vec![
"integration",
"compatibility",
"version mismatch",
"breaking change",
"api change",
],
confidence: 0.70,
},
Rule {
category: DefectCategory::OperatorPrecedence,
patterns: vec![
"operator precedence",
"parentheses",
"parse expression",
"order of operations",
"precedence",
"expression parsing",
"operator order",
],
confidence: 0.80,
},
Rule {
category: DefectCategory::TypeAnnotationGaps,
patterns: vec![
"type annotation",
"type hint",
"unsupported type",
"generic type",
"type parameter",
"annotation",
"typing",
],
confidence: 0.75,
},
Rule {
category: DefectCategory::StdlibMapping,
patterns: vec![
"stdlib",
"standard library",
"python to rust",
"library mapping",
"std::",
"builtin",
"library conversion",
],
confidence: 0.80,
},
Rule {
category: DefectCategory::ASTTransform,
patterns: vec![
"ast",
"hir",
"codegen",
"transform",
"syntax tree",
"ast node",
"tree traversal",
],
confidence: 0.85,
},
Rule {
category: DefectCategory::ComprehensionBugs,
patterns: vec![
"comprehension",
"list comprehension",
"dict comprehension",
"set comprehension",
"generator",
"generator expression",
],
confidence: 0.80,
},
Rule {
category: DefectCategory::IteratorChain,
patterns: vec![
"iterator",
"into_iter",
".map(",
".filter(",
".chain(",
"iterator chain",
"iter method",
],
confidence: 0.80,
},
Rule {
category: DefectCategory::OwnershipBorrow,
patterns: vec![
"ownership",
"borrow",
"lifetime",
"borrow checker",
"move",
"borrowed value",
"lifetime parameter",
],
confidence: 0.85,
},
Rule {
category: DefectCategory::TraitBounds,
patterns: vec![
"trait bound",
"generic constraint",
"where clause",
"impl trait",
"trait constraint",
"bound",
],
confidence: 0.80,
},
];
Self { rules }
}
pub fn classify_from_message(&self, message: &str) -> Option<Classification> {
let message_lower = message.to_lowercase();
debug!("Classifying message: {}", message);
let mut matches: Vec<(DefectCategory, f32, Vec<String>)> = Vec::new();
for rule in &self.rules {
let mut matched_patterns = Vec::new();
for pattern in &rule.patterns {
if message_lower.contains(pattern) {
matched_patterns.push(pattern.to_string());
}
}
if !matched_patterns.is_empty() {
let confidence_boost = (matched_patterns.len() - 1) as f32 * 0.05;
let adjusted_confidence = (rule.confidence + confidence_boost).min(0.95);
matches.push((rule.category, adjusted_confidence, matched_patterns));
}
}
if matches.is_empty() {
debug!("No patterns matched for message");
return None;
}
matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let (category, confidence, matched_patterns) = matches.into_iter().next().unwrap();
let explanation = format!(
"Classified as '{}' based on patterns: {}. Confidence: {:.0}%",
category.as_str(),
matched_patterns.join(", "),
confidence * 100.0
);
debug!(
"Classification: {:?} with confidence {}",
category, confidence
);
Some(Classification {
category,
confidence,
explanation,
matched_patterns,
})
}
pub fn classify_multi_label(
&self,
message: &str,
top_n: usize,
min_confidence: f32,
) -> Option<MultiLabelClassification> {
let message_lower = message.to_lowercase();
debug!(
"Multi-label classifying message: {} (top_n={}, min_confidence={})",
message, top_n, min_confidence
);
let mut matches: Vec<(DefectCategory, f32, Vec<String>)> = Vec::new();
for rule in &self.rules {
let mut matched_patterns = Vec::new();
for pattern in &rule.patterns {
if message_lower.contains(pattern) {
matched_patterns.push(pattern.to_string());
}
}
if !matched_patterns.is_empty() {
let confidence_boost = (matched_patterns.len() - 1) as f32 * 0.05;
let adjusted_confidence = (rule.confidence + confidence_boost).min(0.95);
matches.push((rule.category, adjusted_confidence, matched_patterns));
}
}
if matches.is_empty() {
debug!("No patterns matched for multi-label classification");
return None;
}
matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let filtered_matches: Vec<(DefectCategory, f32, Vec<String>)> = matches
.into_iter()
.filter(|(_, confidence, _)| *confidence >= min_confidence)
.take(top_n)
.collect();
if filtered_matches.is_empty() {
debug!("No matches above confidence threshold {}", min_confidence);
return None;
}
let categories: Vec<(DefectCategory, f32)> = filtered_matches
.iter()
.map(|(cat, conf, _)| (*cat, *conf))
.collect();
let (primary_category, primary_confidence) = categories[0];
let mut all_matched_patterns: Vec<String> = Vec::new();
for (_, _, patterns) in &filtered_matches {
for pattern in patterns {
if !all_matched_patterns.contains(pattern) {
all_matched_patterns.push(pattern.clone());
}
}
}
debug!(
"Multi-label classification: {} categories, primary: {:?} ({})",
categories.len(),
primary_category,
primary_confidence
);
Some(MultiLabelClassification {
categories,
primary_category,
primary_confidence,
matched_patterns: all_matched_patterns,
})
}
}
impl Default for RuleBasedClassifier {
fn default() -> Self {
Self::new()
}
}
pub enum HybridClassifier {
RuleBased(RuleBasedClassifier),
Hybrid {
ml_model: Box<crate::ml_trainer::TrainedModel>,
fallback: RuleBasedClassifier,
confidence_threshold: f32,
},
}
impl HybridClassifier {
pub fn new_rule_based() -> Self {
Self::RuleBased(RuleBasedClassifier::new())
}
pub fn new_hybrid(
ml_model: crate::ml_trainer::TrainedModel,
confidence_threshold: f32,
) -> Self {
Self::Hybrid {
ml_model: Box::new(ml_model),
fallback: RuleBasedClassifier::new(),
confidence_threshold,
}
}
pub fn classify_from_message(&self, message: &str) -> Option<Classification> {
match self {
Self::RuleBased(classifier) => classifier.classify_from_message(message),
Self::Hybrid {
ml_model,
fallback,
confidence_threshold,
} => {
if let Ok(Some((category, confidence))) = ml_model.predict(message) {
if confidence >= *confidence_threshold {
return Some(Classification {
category,
confidence,
explanation: format!("ML prediction (confidence: {:.2})", confidence),
matched_patterns: vec!["ML-based classification".to_string()],
});
}
}
fallback.classify_from_message(message)
}
}
}
pub fn classify_multi_label(
&self,
message: &str,
top_n: usize,
min_confidence: f32,
) -> anyhow::Result<MultiLabelClassification> {
match self {
Self::RuleBased(classifier) => classifier
.classify_multi_label(message, top_n, min_confidence)
.ok_or_else(|| anyhow::anyhow!("No classification found")),
Self::Hybrid {
ml_model, fallback, ..
} => {
if let Ok(predictions) = ml_model.predict_top_n(message, top_n) {
if !predictions.is_empty() {
let filtered: Vec<(DefectCategory, f32)> = predictions
.into_iter()
.filter(|(_, conf)| *conf >= min_confidence)
.collect();
if !filtered.is_empty() {
return Ok(MultiLabelClassification {
primary_category: filtered[0].0,
primary_confidence: filtered[0].1,
categories: filtered.clone(),
matched_patterns: vec!["ML-based classification".to_string()],
});
}
}
}
fallback
.classify_multi_label(message, top_n, min_confidence)
.ok_or_else(|| anyhow::anyhow!("No classification found"))
}
}
}
}
impl Default for HybridClassifier {
fn default() -> Self {
Self::new_rule_based()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classifier_creation() {
let _classifier = RuleBasedClassifier::new();
}
#[test]
fn test_all_categories_covered() {
let classifier = RuleBasedClassifier::new();
let mut categories_covered = std::collections::HashSet::new();
for rule in &classifier.rules {
categories_covered.insert(rule.category);
}
assert_eq!(
categories_covered.len(),
18,
"Should have rules for all 18 categories (10 general + 8 transpiler)"
);
}
#[test]
fn test_pattern_matching() {
let classifier = RuleBasedClassifier::new();
let test_cases = vec![
("fix: use-after-free bug", DefectCategory::MemorySafety),
("fix: race condition", DefectCategory::ConcurrencyBugs),
(
"security: prevent SQL injection",
DefectCategory::SecurityVulnerabilities,
),
];
for (message, expected_category) in test_cases {
let result = classifier.classify_from_message(message);
assert!(result.is_some(), "Should classify: {}", message);
assert_eq!(result.unwrap().category, expected_category);
}
}
#[test]
fn test_non_defect_returns_none() {
let classifier = RuleBasedClassifier::new();
let non_defect_messages = vec![
"docs: update README",
"chore: bump version",
"feat: add new feature",
"refactor: simplify code",
];
for message in non_defect_messages {
let result = classifier.classify_from_message(message);
assert!(
result.is_none(),
"Should not classify as defect: {}",
message
);
}
}
#[test]
fn test_defect_category_as_str() {
assert_eq!(DefectCategory::MemorySafety.as_str(), "Memory Safety");
assert_eq!(DefectCategory::ConcurrencyBugs.as_str(), "Concurrency Bugs");
assert_eq!(DefectCategory::LogicErrors.as_str(), "Logic Errors");
assert_eq!(DefectCategory::ApiMisuse.as_str(), "API Misuse");
assert_eq!(DefectCategory::ResourceLeaks.as_str(), "Resource Leaks");
assert_eq!(DefectCategory::TypeErrors.as_str(), "Type Errors");
assert_eq!(
DefectCategory::ConfigurationErrors.as_str(),
"Configuration Errors"
);
assert_eq!(
DefectCategory::SecurityVulnerabilities.as_str(),
"Security Vulnerabilities"
);
assert_eq!(
DefectCategory::PerformanceIssues.as_str(),
"Performance Issues"
);
assert_eq!(
DefectCategory::IntegrationFailures.as_str(),
"Integration Failures"
);
assert_eq!(
DefectCategory::OperatorPrecedence.as_str(),
"Operator Precedence"
);
assert_eq!(
DefectCategory::TypeAnnotationGaps.as_str(),
"Type Annotation Gaps"
);
assert_eq!(DefectCategory::StdlibMapping.as_str(), "Stdlib Mapping");
assert_eq!(DefectCategory::ASTTransform.as_str(), "AST Transform");
assert_eq!(
DefectCategory::ComprehensionBugs.as_str(),
"Comprehension Bugs"
);
assert_eq!(DefectCategory::IteratorChain.as_str(), "Iterator Chain");
assert_eq!(DefectCategory::OwnershipBorrow.as_str(), "Ownership/Borrow");
assert_eq!(DefectCategory::TraitBounds.as_str(), "Trait Bounds");
}
#[test]
fn test_defect_category_display() {
assert_eq!(format!("{}", DefectCategory::MemorySafety), "MemorySafety");
assert_eq!(
format!("{}", DefectCategory::ConcurrencyBugs),
"ConcurrencyBugs"
);
assert_eq!(format!("{}", DefectCategory::LogicErrors), "LogicErrors");
assert_eq!(format!("{}", DefectCategory::ApiMisuse), "ApiMisuse");
assert_eq!(
format!("{}", DefectCategory::ResourceLeaks),
"ResourceLeaks"
);
assert_eq!(format!("{}", DefectCategory::TypeErrors), "TypeErrors");
assert_eq!(
format!("{}", DefectCategory::ConfigurationErrors),
"ConfigurationErrors"
);
assert_eq!(
format!("{}", DefectCategory::SecurityVulnerabilities),
"SecurityVulnerabilities"
);
assert_eq!(
format!("{}", DefectCategory::PerformanceIssues),
"PerformanceIssues"
);
assert_eq!(
format!("{}", DefectCategory::IntegrationFailures),
"IntegrationFailures"
);
assert_eq!(
format!("{}", DefectCategory::OperatorPrecedence),
"OperatorPrecedence"
);
assert_eq!(
format!("{}", DefectCategory::TypeAnnotationGaps),
"TypeAnnotationGaps"
);
assert_eq!(
format!("{}", DefectCategory::StdlibMapping),
"StdlibMapping"
);
assert_eq!(format!("{}", DefectCategory::ASTTransform), "ASTTransform");
assert_eq!(
format!("{}", DefectCategory::ComprehensionBugs),
"ComprehensionBugs"
);
assert_eq!(
format!("{}", DefectCategory::IteratorChain),
"IteratorChain"
);
assert_eq!(
format!("{}", DefectCategory::OwnershipBorrow),
"OwnershipBorrow"
);
assert_eq!(format!("{}", DefectCategory::TraitBounds), "TraitBounds");
}
#[test]
fn test_default_constructor() {
let classifier = RuleBasedClassifier::default();
assert_eq!(classifier.rules.len(), 18);
}
#[test]
fn test_empty_message() {
let classifier = RuleBasedClassifier::new();
let result = classifier.classify_from_message("");
assert!(result.is_none());
}
#[test]
fn test_case_insensitive_matching() {
let classifier = RuleBasedClassifier::new();
let result = classifier.classify_from_message("Fix: NULL POINTER dereference");
assert!(result.is_some());
assert_eq!(result.unwrap().category, DefectCategory::MemorySafety);
}
#[test]
fn test_multiple_patterns_boost_confidence() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: null pointer and buffer overflow")
.unwrap();
assert_eq!(result.category, DefectCategory::MemorySafety);
assert!(result.confidence >= 0.85);
assert_eq!(result.matched_patterns.len(), 2);
}
#[test]
fn test_confidence_capped_at_95_percent() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message(
"security vulnerability exploit with sql injection and xss and cve-2024-1234",
)
.unwrap();
assert_eq!(result.category, DefectCategory::SecurityVulnerabilities);
assert!(result.confidence <= 0.95);
}
#[test]
fn test_highest_confidence_wins() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix security and performance issues")
.unwrap();
assert_eq!(result.category, DefectCategory::SecurityVulnerabilities);
}
#[test]
fn test_all_categories_classifiable() {
let classifier = RuleBasedClassifier::new();
let test_cases = vec![
("null pointer bug", DefectCategory::MemorySafety),
("race condition fix", DefectCategory::ConcurrencyBugs),
("off by one error", DefectCategory::LogicErrors),
("api misuse fix", DefectCategory::ApiMisuse),
("resource leak fix", DefectCategory::ResourceLeaks),
("type error fix", DefectCategory::TypeErrors),
("configuration bug", DefectCategory::ConfigurationErrors),
("security fix", DefectCategory::SecurityVulnerabilities),
("performance fix", DefectCategory::PerformanceIssues),
("integration failure", DefectCategory::IntegrationFailures),
(
"fix operator precedence issue",
DefectCategory::OperatorPrecedence,
),
(
"type annotation not supported",
DefectCategory::TypeAnnotationGaps,
),
("stdlib mapping bug", DefectCategory::StdlibMapping),
("ast transform error", DefectCategory::ASTTransform),
("list comprehension bug", DefectCategory::ComprehensionBugs),
("iterator chain issue", DefectCategory::IteratorChain),
("ownership error", DefectCategory::OwnershipBorrow),
("trait bound issue", DefectCategory::TraitBounds),
];
for (message, expected_category) in test_cases {
let result = classifier.classify_from_message(message);
assert!(result.is_some(), "Should classify: {}", message);
assert_eq!(
result.unwrap().category,
expected_category,
"Failed for: {}",
message
);
}
}
#[test]
fn test_classification_struct_fields() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: deadlock in mutex")
.unwrap();
assert_eq!(result.category, DefectCategory::ConcurrencyBugs);
assert!(result.confidence > 0.0 && result.confidence <= 1.0);
assert!(!result.explanation.is_empty());
assert!(!result.matched_patterns.is_empty());
}
#[test]
fn test_explanation_format() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: sql injection vulnerability")
.unwrap();
assert!(result.explanation.contains("Security Vulnerabilities"));
assert!(result.explanation.contains("sql injection"));
assert!(result.explanation.contains("Confidence:"));
assert!(result.explanation.contains("%"));
}
#[test]
fn test_matched_patterns_populated() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: double free and memory leak")
.unwrap();
assert_eq!(result.matched_patterns.len(), 2);
assert!(result.matched_patterns.contains(&"double free".to_string()));
assert!(result.matched_patterns.contains(&"memory leak".to_string()));
}
#[test]
fn test_transpiler_operator_precedence_classification() {
let classifier = RuleBasedClassifier::new();
let test_cases = vec![
"fix: operator precedence bug in expression parser",
"fix: incorrect parentheses handling",
"fix: parse expression order of operations",
];
for message in test_cases {
let result = classifier.classify_from_message(message);
assert!(result.is_some(), "Should classify: {}", message);
assert_eq!(
result.unwrap().category,
DefectCategory::OperatorPrecedence,
"Failed for: {}",
message
);
}
}
#[test]
fn test_transpiler_type_annotation_classification() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: type annotation gap in generic type")
.unwrap();
assert_eq!(result.category, DefectCategory::TypeAnnotationGaps);
assert!(result.matched_patterns.len() >= 2);
}
#[test]
fn test_transpiler_ownership_classification() {
let classifier = RuleBasedClassifier::new();
let test_cases = vec![
"fix: borrow checker error in iterator",
"fix: lifetime parameter issue",
"fix: ownership move bug",
];
for message in test_cases {
let result = classifier.classify_from_message(message);
assert!(result.is_some(), "Should classify: {}", message);
assert_eq!(
result.unwrap().category,
DefectCategory::OwnershipBorrow,
"Failed for: {}",
message
);
}
}
#[test]
fn test_transpiler_comprehension_classification() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: dict comprehension generation bug")
.unwrap();
assert_eq!(result.category, DefectCategory::ComprehensionBugs);
assert!(result.confidence >= 0.80);
}
#[test]
fn test_transpiler_iterator_chain_classification() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: .map( and .filter( iterator chain issue")
.unwrap();
assert_eq!(result.category, DefectCategory::IteratorChain);
assert!(result.matched_patterns.len() >= 2);
}
#[test]
fn test_transpiler_ast_transform_classification() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: ast node transform in codegen")
.unwrap();
assert_eq!(result.category, DefectCategory::ASTTransform);
assert!(result.confidence >= 0.85);
}
#[test]
fn test_transpiler_stdlib_mapping_classification() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: stdlib mapping from python to rust")
.unwrap();
assert_eq!(result.category, DefectCategory::StdlibMapping);
}
#[test]
fn test_transpiler_trait_bounds_classification() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_from_message("fix: trait bound issue in where clause")
.unwrap();
assert_eq!(result.category, DefectCategory::TraitBounds);
assert!(result.matched_patterns.len() >= 2);
}
#[test]
fn test_multi_label_basic() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_multi_label("fix: null pointer in ast transform", 3, 0.60)
.unwrap();
assert!(!result.categories.is_empty());
assert!(result.categories.len() <= 3);
assert_eq!(result.primary_category, result.categories[0].0);
assert_eq!(result.primary_confidence, result.categories[0].1);
}
#[test]
fn test_multi_label_multiple_categories() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_multi_label(
"fix: memory leak and security vulnerability in ast transform",
3,
0.60,
)
.unwrap();
assert!(result.categories.len() >= 2);
for i in 0..result.categories.len() - 1 {
assert!(result.categories[i].1 >= result.categories[i + 1].1);
}
}
#[test]
fn test_multi_label_confidence_threshold() {
let classifier = RuleBasedClassifier::new();
let message = "fix: memory leak";
let result_high = classifier.classify_multi_label(message, 5, 0.90);
let result_low = classifier.classify_multi_label(message, 5, 0.60).unwrap();
if let Some(high) = result_high {
assert!(high.categories.len() <= result_low.categories.len());
}
for (_, confidence) in &result_low.categories {
assert!(*confidence >= 0.60);
}
}
#[test]
fn test_multi_label_top_n_limiting() {
let classifier = RuleBasedClassifier::new();
let message = "fix: security memory performance integration";
let result_top_1 = classifier.classify_multi_label(message, 1, 0.60).unwrap();
let result_top_3 = classifier.classify_multi_label(message, 3, 0.60).unwrap();
assert_eq!(result_top_1.categories.len(), 1);
assert!(result_top_3.categories.len() <= 3);
assert!(result_top_3.categories.len() >= result_top_1.categories.len());
}
#[test]
fn test_multi_label_single_category() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_multi_label("fix: deadlock in mutex", 3, 0.60)
.unwrap();
assert_eq!(result.categories.len(), 1);
assert_eq!(result.primary_category, DefectCategory::ConcurrencyBugs);
}
#[test]
fn test_multi_label_no_match() {
let classifier = RuleBasedClassifier::new();
let result = classifier.classify_multi_label("docs: update README", 3, 0.60);
assert!(result.is_none());
}
#[test]
fn test_multi_label_all_patterns_collected() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_multi_label("fix: memory leak and buffer overflow", 3, 0.60)
.unwrap();
assert!(result.matched_patterns.contains(&"memory leak".to_string()));
assert!(result
.matched_patterns
.contains(&"buffer overflow".to_string()));
}
#[test]
fn test_multi_label_primary_is_highest_confidence() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_multi_label("fix: security and performance", 3, 0.60)
.unwrap();
assert_eq!(result.primary_category, result.categories[0].0);
assert_eq!(result.primary_confidence, result.categories[0].1);
assert_eq!(
result.primary_category,
DefectCategory::SecurityVulnerabilities
);
}
#[test]
fn test_multi_label_confidence_boost() {
let classifier = RuleBasedClassifier::new();
let result = classifier
.classify_multi_label("fix: null pointer and buffer overflow", 3, 0.60)
.unwrap();
assert_eq!(result.primary_category, DefectCategory::MemorySafety);
assert!(result.primary_confidence > 0.85); }
#[test]
fn test_multi_label_struct_serialization() {
let classification = MultiLabelClassification {
categories: vec![
(DefectCategory::MemorySafety, 0.90),
(DefectCategory::ConcurrencyBugs, 0.75),
],
primary_category: DefectCategory::MemorySafety,
primary_confidence: 0.90,
matched_patterns: vec!["memory leak".to_string()],
};
let json = serde_json::to_string(&classification).unwrap();
let deserialized: MultiLabelClassification = serde_json::from_str(&json).unwrap();
assert_eq!(
classification.categories.len(),
deserialized.categories.len()
);
assert_eq!(
classification.primary_category,
deserialized.primary_category
);
}
#[test]
fn test_multi_label_zero_top_n() {
let classifier = RuleBasedClassifier::new();
let result = classifier.classify_multi_label("fix: memory leak", 0, 0.60);
assert!(result.is_none());
}
#[test]
fn test_multi_label_very_high_threshold() {
let classifier = RuleBasedClassifier::new();
let result = classifier.classify_multi_label("fix: memory leak", 3, 0.99);
assert!(result.is_none());
}
#[test]
fn test_hybrid_classifier_rule_based_variant() {
let classifier = HybridClassifier::new_rule_based();
let result = classifier.classify_from_message("fix: null pointer dereference");
assert!(result.is_some());
let classification = result.unwrap();
assert_eq!(classification.category, DefectCategory::MemorySafety);
}
#[test]
fn test_hybrid_classifier_default() {
let classifier = HybridClassifier::default();
let result = classifier.classify_from_message("fix: race condition");
assert!(result.is_some());
}
#[test]
fn test_hybrid_classifier_multi_label_rule_based() {
let classifier = HybridClassifier::new_rule_based();
let result = classifier
.classify_multi_label("fix: memory leak and null pointer", 3, 0.60)
.unwrap();
assert!(!result.categories.is_empty());
assert_eq!(result.primary_category, result.categories[0].0);
}
#[test]
fn test_hybrid_classifier_no_match() {
let classifier = HybridClassifier::new_rule_based();
let result = classifier.classify_from_message("docs: update README");
assert!(result.is_none());
}
#[test]
fn test_hybrid_classifier_multi_label_no_match() {
let classifier = HybridClassifier::new_rule_based();
let result = classifier.classify_multi_label("docs: update README", 3, 0.60);
assert!(result.is_err());
}
#[test]
fn test_hybrid_classifier_various_categories() {
let classifier = HybridClassifier::new_rule_based();
let test_cases = vec![
(
"fix: operator precedence bug",
DefectCategory::OperatorPrecedence,
),
(
"fix: type annotation missing",
DefectCategory::TypeAnnotationGaps,
),
("fix: stdlib mapping error", DefectCategory::StdlibMapping),
("fix: ast transform issue", DefectCategory::ASTTransform),
("fix: comprehension bug", DefectCategory::ComprehensionBugs),
("fix: iterator chain error", DefectCategory::IteratorChain),
("fix: ownership violation", DefectCategory::OwnershipBorrow),
("fix: trait bound issue", DefectCategory::TraitBounds),
];
for (message, expected_category) in test_cases {
let result = classifier.classify_from_message(message);
assert!(result.is_some(), "Failed to classify: {}", message);
assert_eq!(
result.unwrap().category,
expected_category,
"Wrong category for: {}",
message
);
}
}
}