use crate::classifier::ErrorCategory;
use crate::moe_oracle::ExpertDomain;
use crate::training::{TrainingDataset, TrainingSample};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum OipDefectCategory {
MemorySafety,
ConcurrencyBugs,
LogicErrors,
ApiMisuse,
ResourceLeaks,
TypeErrors,
ConfigurationErrors,
SecurityVulnerabilities,
PerformanceIssues,
IntegrationFailures,
OperatorPrecedence,
TypeAnnotationGaps,
StdlibMapping,
ASTTransform,
ComprehensionBugs,
IteratorChain,
OwnershipBorrow,
TraitBounds,
}
impl OipDefectCategory {
#[must_use]
pub fn to_error_category(self) -> ErrorCategory {
match self {
Self::OwnershipBorrow | Self::MemorySafety => ErrorCategory::BorrowChecker,
Self::TraitBounds => ErrorCategory::TraitBound,
Self::TypeErrors | Self::TypeAnnotationGaps => ErrorCategory::TypeMismatch,
Self::StdlibMapping | Self::ConfigurationErrors | Self::ASTTransform => {
ErrorCategory::MissingImport
}
Self::ResourceLeaks => ErrorCategory::LifetimeError,
Self::ConcurrencyBugs
| Self::LogicErrors
| Self::ApiMisuse
| Self::SecurityVulnerabilities
| Self::PerformanceIssues
| Self::IntegrationFailures
| Self::OperatorPrecedence
| Self::ComprehensionBugs
| Self::IteratorChain => ErrorCategory::Other,
}
}
#[must_use]
pub fn to_expert_domain(self) -> ExpertDomain {
match self {
Self::TypeErrors | Self::TypeAnnotationGaps | Self::TraitBounds => {
ExpertDomain::TypeSystem
}
Self::StdlibMapping
| Self::ConfigurationErrors
| Self::IntegrationFailures
| Self::ASTTransform => ExpertDomain::ScopeResolution,
Self::ApiMisuse | Self::IteratorChain | Self::ComprehensionBugs => {
ExpertDomain::MethodField
}
Self::OwnershipBorrow
| Self::MemorySafety
| Self::ResourceLeaks
| Self::ConcurrencyBugs
| Self::LogicErrors
| Self::SecurityVulnerabilities
| Self::PerformanceIssues
| Self::OperatorPrecedence => ExpertDomain::SyntaxBorrowing,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OipTrainingExample {
pub message: String,
pub label: OipDefectCategory,
pub confidence: f32,
pub commit_hash: String,
pub author: String,
pub timestamp: i64,
pub lines_added: usize,
pub lines_removed: usize,
pub files_changed: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OipTrainingDataset {
pub train: Vec<OipTrainingExample>,
pub validation: Vec<OipTrainingExample>,
pub test: Vec<OipTrainingExample>,
}
pub fn load_oip_training_data(path: &Path) -> Result<OipTrainingDataset, std::io::Error> {
let content = fs::read_to_string(path)?;
serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[must_use]
pub fn convert_oip_to_depyler(oip_data: &OipTrainingDataset) -> TrainingDataset {
let mut dataset = TrainingDataset::new();
let all_examples: Vec<_> = oip_data
.train
.iter()
.chain(oip_data.validation.iter())
.chain(oip_data.test.iter())
.collect();
for example in all_examples {
let category = example.label.to_error_category();
let error_pattern = extract_error_pattern(&example.message);
let fix = extract_fix_from_commit(&example.message);
dataset.add(TrainingSample::with_fix(&error_pattern, category, &fix));
}
dataset
}
fn extract_error_pattern(message: &str) -> String {
if let Some(start) = message.find("error[E") {
if let Some(end) = message[start..].find(']') {
let error_code = &message[start..start + end + 1];
let rest = &message[start + end + 1..];
if let Some(desc_end) = rest.find('\n') {
return format!("{}: {}", error_code, rest[..desc_end].trim());
}
return error_code.to_string();
}
}
if let Some(fix_start) = message.to_lowercase().find("fix:") {
let rest = &message[fix_start + 4..];
if let Some(end) = rest.find('\n') {
return rest[..end].trim().to_string();
}
return rest.trim().to_string();
}
message.lines().next().unwrap_or(message).to_string()
}
fn extract_fix_from_commit(message: &str) -> String {
let lower = message.to_lowercase();
for pattern in &["solution:", "fixed by:", "fix:", "resolved:"] {
if let Some(idx) = lower.find(pattern) {
let rest = &message[idx + pattern.len()..];
if let Some(end) = rest.find('\n') {
return rest[..end].trim().to_string();
}
return rest.trim().to_string();
}
}
message
.lines()
.next()
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "See commit for fix details".to_string())
}
pub fn build_github_corpus(oip_json_path: &Path) -> Result<TrainingDataset, std::io::Error> {
let oip_data = load_oip_training_data(oip_json_path)?;
Ok(convert_oip_to_depyler(&oip_data))
}
#[must_use]
pub fn get_moe_samples_from_oip(
oip_data: &OipTrainingDataset,
) -> Vec<(String, String, ExpertDomain)> {
let mut samples = Vec::new();
let all_examples: Vec<_> = oip_data
.train
.iter()
.chain(oip_data.validation.iter())
.chain(oip_data.test.iter())
.collect();
for example in all_examples {
let domain = example.label.to_expert_domain();
let error_code = infer_error_code_from_category(example.label);
let context = example.message.clone();
samples.push((error_code, context, domain));
}
samples
}
fn infer_error_code_from_category(category: OipDefectCategory) -> String {
match category {
OipDefectCategory::TypeErrors | OipDefectCategory::TypeAnnotationGaps => {
"E0308".to_string()
}
OipDefectCategory::TraitBounds => "E0277".to_string(),
OipDefectCategory::OwnershipBorrow | OipDefectCategory::MemorySafety => "E0382".to_string(),
OipDefectCategory::StdlibMapping | OipDefectCategory::ASTTransform => "E0433".to_string(),
OipDefectCategory::ApiMisuse | OipDefectCategory::IteratorChain => "E0599".to_string(),
OipDefectCategory::ConfigurationErrors | OipDefectCategory::IntegrationFailures => {
"E0425".to_string()
}
OipDefectCategory::ResourceLeaks => "E0106".to_string(),
OipDefectCategory::ComprehensionBugs => "E0609".to_string(),
_ => "E0000".to_string(), }
}
#[derive(Debug, Default)]
pub struct CorpusStats {
pub total_examples: usize,
pub by_category: HashMap<String, usize>,
pub by_expert: HashMap<ExpertDomain, usize>,
pub avg_confidence: f32,
}
#[must_use]
pub fn analyze_corpus(oip_data: &OipTrainingDataset) -> CorpusStats {
let mut stats = CorpusStats::default();
let all_examples: Vec<_> = oip_data
.train
.iter()
.chain(oip_data.validation.iter())
.chain(oip_data.test.iter())
.collect();
stats.total_examples = all_examples.len();
let mut total_confidence = 0.0f32;
for example in &all_examples {
let cat_name = format!("{:?}", example.label);
*stats.by_category.entry(cat_name).or_default() += 1;
let domain = example.label.to_expert_domain();
*stats.by_expert.entry(domain).or_default() += 1;
total_confidence += example.confidence;
}
if !all_examples.is_empty() {
stats.avg_confidence = total_confidence / all_examples.len() as f32;
}
stats
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oip_to_error_category_mapping() {
assert_eq!(
OipDefectCategory::OwnershipBorrow.to_error_category(),
ErrorCategory::BorrowChecker
);
assert_eq!(
OipDefectCategory::TypeErrors.to_error_category(),
ErrorCategory::TypeMismatch
);
assert_eq!(
OipDefectCategory::TraitBounds.to_error_category(),
ErrorCategory::TraitBound
);
assert_eq!(
OipDefectCategory::StdlibMapping.to_error_category(),
ErrorCategory::MissingImport
);
}
#[test]
fn test_oip_to_expert_domain_mapping() {
assert_eq!(
OipDefectCategory::TypeErrors.to_expert_domain(),
ExpertDomain::TypeSystem
);
assert_eq!(
OipDefectCategory::StdlibMapping.to_expert_domain(),
ExpertDomain::ScopeResolution
);
assert_eq!(
OipDefectCategory::OwnershipBorrow.to_expert_domain(),
ExpertDomain::SyntaxBorrowing
);
assert_eq!(
OipDefectCategory::ApiMisuse.to_expert_domain(),
ExpertDomain::MethodField
);
}
#[test]
fn test_extract_error_pattern() {
let msg = "fix: error[E0308]: mismatched types\n\ndetails here";
let pattern = extract_error_pattern(msg);
assert!(pattern.contains("E0308"));
}
#[test]
fn test_extract_error_pattern_conventional() {
let msg = "fix: resolve borrow checker issue with lifetime";
let pattern = extract_error_pattern(msg);
assert_eq!(pattern, "resolve borrow checker issue with lifetime");
}
#[test]
fn test_extract_fix_from_commit() {
let msg = "fix: type mismatch\n\nSolution: Use .into() for conversion";
let fix = extract_fix_from_commit(msg);
assert_eq!(fix, "Use .into() for conversion");
}
#[test]
fn test_infer_error_code() {
assert_eq!(
infer_error_code_from_category(OipDefectCategory::TypeErrors),
"E0308"
);
assert_eq!(
infer_error_code_from_category(OipDefectCategory::TraitBounds),
"E0277"
);
assert_eq!(
infer_error_code_from_category(OipDefectCategory::OwnershipBorrow),
"E0382"
);
}
#[test]
fn test_convert_empty_dataset() {
let oip = OipTrainingDataset {
train: vec![],
validation: vec![],
test: vec![],
};
let dataset = convert_oip_to_depyler(&oip);
assert!(dataset.samples().is_empty());
}
#[test]
fn test_analyze_corpus_empty() {
let oip = OipTrainingDataset {
train: vec![],
validation: vec![],
test: vec![],
};
let stats = analyze_corpus(&oip);
assert_eq!(stats.total_examples, 0);
}
#[test]
fn test_load_real_oip_data_if_exists() {
let oip_path = std::path::Path::new(
"/home/noah/src/organizational-intelligence-plugin/training-data.json",
);
if oip_path.exists() {
let oip_data = load_oip_training_data(oip_path).expect("Should load OIP data");
let stats = analyze_corpus(&oip_data);
println!("OIP Corpus Statistics:");
println!(" Total examples: {}", stats.total_examples);
println!(" Avg confidence: {:.2}", stats.avg_confidence);
println!(" By category:");
for (cat, count) in &stats.by_category {
println!(" {}: {}", cat, count);
}
println!(" By expert domain:");
for (domain, count) in &stats.by_expert {
println!(" {:?}: {}", domain, count);
}
let depyler_dataset = convert_oip_to_depyler(&oip_data);
println!(
" Converted to {} depyler samples",
depyler_dataset.samples().len()
);
assert!(stats.total_examples > 0, "Should have training examples");
} else {
println!("OIP training data not found at {:?}, skipping", oip_path);
}
}
#[test]
fn test_convert_with_sample_data() {
let oip = OipTrainingDataset {
train: vec![OipTrainingExample {
message: "fix: error[E0308]: mismatched types\n\nUse .into()".to_string(),
label: OipDefectCategory::TypeErrors,
confidence: 0.85,
commit_hash: "abc123".to_string(),
author: "test@example.com".to_string(),
timestamp: 1234567890,
lines_added: 10,
lines_removed: 5,
files_changed: 2,
}],
validation: vec![],
test: vec![],
};
let dataset = convert_oip_to_depyler(&oip);
assert_eq!(dataset.samples().len(), 1);
let moe_samples = get_moe_samples_from_oip(&oip);
assert_eq!(moe_samples.len(), 1);
assert_eq!(moe_samples[0].0, "E0308"); assert_eq!(moe_samples[0].2, ExpertDomain::TypeSystem); }
}