use super::encoder::ErrorEmbedding;
use super::error::{CITLError, CITLResult};
use super::{Difficulty, ErrorCategory, ErrorCode};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read as IoRead, Write as IoWrite};
use std::path::Path;
const MAGIC: &[u8; 4] = b"CITL";
const FORMAT_VERSION: u8 = 1;
#[derive(Debug)]
pub struct PatternLibrary {
patterns: Vec<ErrorFixPattern>,
embeddings: Vec<Vec<f32>>,
stats: PatternStats,
}
impl PatternLibrary {
#[must_use]
pub fn new() -> Self {
Self {
patterns: Vec::new(),
embeddings: Vec::new(),
stats: PatternStats::new(),
}
}
pub fn load(path: &str) -> CITLResult<Self> {
let path = Path::new(path);
if !path.exists() {
return Err(CITLError::PatternLibraryError {
message: format!("Pattern library not found: {}", path.display()),
});
}
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut magic = [0u8; 4];
reader.read_exact(&mut magic)?;
if &magic != MAGIC {
return Err(CITLError::PatternLibraryError {
message: "Invalid pattern library file: bad magic header".to_string(),
});
}
let mut version = [0u8; 1];
reader.read_exact(&mut version)?;
if version[0] != FORMAT_VERSION {
return Err(CITLError::PatternLibraryError {
message: format!("Unsupported format version: {}", version[0]),
});
}
let mut count_bytes = [0u8; 4];
reader.read_exact(&mut count_bytes)?;
let count = u32::from_le_bytes(count_bytes) as usize;
let mut patterns = Vec::with_capacity(count);
let mut embeddings = Vec::with_capacity(count);
for _ in 0..count {
let (pattern, embedding) = read_pattern(&mut reader)?;
patterns.push(pattern);
embeddings.push(embedding);
}
Ok(Self {
patterns,
embeddings,
stats: PatternStats::new(),
})
}
pub fn save(&self, path: &str) -> CITLResult<()> {
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(MAGIC)?;
writer.write_all(&[FORMAT_VERSION])?;
let count = self.patterns.len() as u32;
writer.write_all(&count.to_le_bytes())?;
for (pattern, embedding) in self.patterns.iter().zip(self.embeddings.iter()) {
write_pattern(&mut writer, pattern, embedding)?;
}
writer.flush()?;
Ok(())
}
#[must_use]
pub fn search(&self, query: &ErrorEmbedding, k: usize) -> Vec<PatternMatch> {
if self.patterns.is_empty() {
return Vec::new();
}
let n = self.embeddings.len();
let k = k.min(n);
let mut scored: Vec<(usize, f32)> = Vec::with_capacity(n);
for (idx, embedding) in self.embeddings.iter().enumerate() {
let similarity = cosine_similarity(&query.vector, embedding);
scored.push((idx, similarity));
}
if k < n {
scored.select_nth_unstable_by(k - 1, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(k);
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
scored
.into_iter()
.map(|(idx, similarity)| PatternMatch {
pattern: self.patterns[idx].clone(),
similarity,
success_rate: self.stats.success_rate(idx),
})
.collect()
}
pub fn add_pattern(&mut self, error: ErrorEmbedding, fix: FixTemplate) {
let pattern = ErrorFixPattern {
error_code: error.error_code.clone(),
context_hash: error.context_hash,
fix_template: fix,
success_count: 1,
failure_count: 0,
};
self.embeddings.push(error.vector);
self.patterns.push(pattern);
}
pub fn record_outcome(&mut self, pattern_idx: usize, success: bool) {
if pattern_idx < self.patterns.len() {
if success {
self.patterns[pattern_idx].success_count += 1;
} else {
self.patterns[pattern_idx].failure_count += 1;
}
self.stats.record(pattern_idx, success);
}
}
#[must_use]
pub fn len(&self) -> usize {
self.patterns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
#[must_use]
pub fn get(&self, idx: usize) -> Option<&ErrorFixPattern> {
self.patterns.get(idx)
}
#[must_use]
pub fn get_by_code(&self, code: &str) -> Vec<&ErrorFixPattern> {
self.patterns
.iter()
.filter(|p| p.error_code.code == code)
.collect()
}
}
impl Default for PatternLibrary {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ErrorFixPattern {
pub error_code: ErrorCode,
pub context_hash: u64,
pub fix_template: FixTemplate,
pub success_count: u64,
pub failure_count: u64,
}
impl ErrorFixPattern {
#[must_use]
pub fn success_rate(&self) -> f64 {
let total = self.success_count + self.failure_count;
if total == 0 {
0.0
} else {
self.success_count as f64 / total as f64
}
}
#[must_use]
pub fn total_applications(&self) -> u64 {
self.success_count + self.failure_count
}
}
#[derive(Debug, Clone)]
pub struct FixTemplate {
pub pattern: String,
pub placeholders: Vec<Placeholder>,
pub applicable_codes: Vec<String>,
pub confidence: f32,
pub description: String,
}
impl FixTemplate {
#[must_use]
pub fn new(pattern: &str, description: &str) -> Self {
Self {
pattern: pattern.to_string(),
placeholders: Vec::new(),
applicable_codes: Vec::new(),
confidence: 0.5,
description: description.to_string(),
}
}
#[must_use]
pub fn with_placeholder(mut self, placeholder: Placeholder) -> Self {
self.placeholders.push(placeholder);
self
}
#[must_use]
pub fn with_code(mut self, code: &str) -> Self {
self.applicable_codes.push(code.to_string());
self
}
#[must_use]
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = confidence;
self
}
#[must_use]
pub fn apply(&self, bindings: &HashMap<String, String>) -> String {
let mut result = self.pattern.clone();
for (name, value) in bindings {
result = result.replace(&format!("${name}"), value);
}
result
}
#[must_use]
pub fn applies_to(&self, code: &str) -> bool {
self.applicable_codes.is_empty() || self.applicable_codes.iter().any(|c| c == code)
}
}
#[derive(Debug, Clone)]
pub struct Placeholder {
pub name: String,
pub description: String,
pub constraint: PlaceholderConstraint,
}
impl Placeholder {
#[must_use]
pub fn new(name: &str, description: &str, constraint: PlaceholderConstraint) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
constraint,
}
}
#[must_use]
pub fn expression(name: &str) -> Self {
Self::new(name, "An expression", PlaceholderConstraint::Expression)
}
#[must_use]
pub fn type_name(name: &str) -> Self {
Self::new(name, "A type name", PlaceholderConstraint::Type)
}
#[must_use]
pub fn identifier(name: &str) -> Self {
Self::new(name, "An identifier", PlaceholderConstraint::Identifier)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlaceholderConstraint {
Expression,
Type,
Identifier,
Literal,
Any,
}
#[derive(Debug, Clone)]
pub struct PatternMatch {
pub pattern: ErrorFixPattern,
pub similarity: f32,
pub success_rate: f64,
}
impl PatternMatch {
#[must_use]
pub fn combined_score(&self) -> f64 {
f64::from(self.similarity) * self.success_rate
}
}
#[derive(Debug, Clone)]
pub(super) struct PatternStats {
successes: HashMap<usize, u64>,
failures: HashMap<usize, u64>,
}
impl PatternStats {
#[must_use]
pub(super) fn new() -> Self {
Self {
successes: HashMap::new(),
failures: HashMap::new(),
}
}
pub(super) fn record(&mut self, pattern_idx: usize, success: bool) {
if success {
*self.successes.entry(pattern_idx).or_insert(0) += 1;
} else {
*self.failures.entry(pattern_idx).or_insert(0) += 1;
}
}
#[must_use]
pub(super) fn success_rate(&self, pattern_idx: usize) -> f64 {
let successes = *self.successes.get(&pattern_idx).unwrap_or(&0);
let failures = *self.failures.get(&pattern_idx).unwrap_or(&0);
let total = successes + failures;
if total == 0 {
0.5 } else {
successes as f64 / total as f64
}
}
}
include!("read.rs");
include!("template.rs");