use crate::code::ast::{byte_offset_to_position, CodeParser, ParsedCode};
use crate::code::correction::{CodeCorrector, Correction, CorrectionCandidates};
use crate::code::correctors::EnsembleCorrector;
use crate::code::cpg::CodePropertyGraph;
use crate::code::language::{CodeLanguage, TokenContext};
use crate::code::pcfg::WeightedCFG;
use crate::code::tokenizer::{CodeToken, CodeTokenizer};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::sync::Arc;
struct CorrectionEntry {
correction: Correction,
}
impl PartialEq for CorrectionEntry {
fn eq(&self, other: &Self) -> bool {
self.correction.confidence == other.correction.confidence
}
}
impl Eq for CorrectionEntry {}
impl PartialOrd for CorrectionEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for CorrectionEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.correction
.confidence
.partial_cmp(&self.correction.confidence)
.unwrap_or(Ordering::Equal)
}
}
struct StreamingCorrectionCollector {
heap: BinaryHeap<CorrectionEntry>,
max_size: usize,
min_confidence: f64,
seen: HashSet<(usize, usize, String)>,
}
impl StreamingCorrectionCollector {
fn new(max_size: usize, min_confidence: f64) -> Self {
Self {
heap: BinaryHeap::with_capacity(max_size + 1),
max_size,
min_confidence,
seen: HashSet::new(),
}
}
fn add(&mut self, correction: Correction) -> bool {
if correction.confidence < self.min_confidence {
return false;
}
let key = (
correction.start_byte,
correction.end_byte,
correction.replacement.clone(),
);
if self.seen.contains(&key) {
return false;
}
if self.heap.len() < self.max_size {
self.seen.insert(key);
self.heap.push(CorrectionEntry { correction });
return true;
}
if let Some(min_entry) = self.heap.peek() {
if correction.confidence > min_entry.correction.confidence {
let removed = self.heap.pop().expect("heap should have an element");
let removed_key = (
removed.correction.start_byte,
removed.correction.end_byte,
removed.correction.replacement.clone(),
);
self.seen.remove(&removed_key);
self.seen.insert(key);
self.heap.push(CorrectionEntry { correction });
return true;
}
}
false
}
fn add_all<I: IntoIterator<Item = Correction>>(&mut self, corrections: I) {
for correction in corrections {
self.add(correction);
}
}
fn finalize(self) -> CorrectionCandidates {
let mut corrections: Vec<Correction> = self
.heap
.into_iter()
.map(|entry| entry.correction)
.collect();
corrections.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(Ordering::Equal)
});
let mut candidates = CorrectionCandidates::new(self.max_size);
candidates.add_all(corrections);
candidates
}
}
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub max_corrections: usize,
pub min_confidence: f64,
pub include_diagnostics: bool,
pub auto_apply_threshold: Option<f64>,
pub full_semantic_analysis: bool,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
max_corrections: 50,
min_confidence: 0.3,
include_diagnostics: true,
auto_apply_threshold: None,
full_semantic_analysis: true,
}
}
}
#[derive(Debug, Clone)]
pub struct AnalysisResult {
pub source: String,
pub has_parse_errors: bool,
pub error_count: usize,
pub tokens: Vec<CodeToken>,
pub corrections: CorrectionCandidates,
pub diagnostics: Vec<Diagnostic>,
}
#[derive(Debug, Clone)]
pub struct Diagnostic {
pub severity: DiagnosticSeverity,
pub message: String,
pub start_byte: usize,
pub end_byte: usize,
pub line: usize,
pub column: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagnosticSeverity {
Error,
Warning,
Info,
Hint,
}
pub struct CorrectionPipeline<L: CodeLanguage> {
language: Arc<L>,
corrector: EnsembleCorrector<L>,
config: PipelineConfig,
parser: CodeParser<L>,
}
impl<L: CodeLanguage + Clone + Send + Sync> CorrectionPipeline<L> {
pub fn new(
language: Arc<L>,
grammar: Option<WeightedCFG>,
config: PipelineConfig,
) -> Result<Self, PipelineError> {
let corrector = EnsembleCorrector::with_defaults(Arc::clone(&language), grammar);
let parser = CodeParser::new(Arc::clone(&language))
.map_err(|e| PipelineError::ParseError(format!("{}", e)))?;
Ok(Self {
language,
corrector,
config,
parser,
})
}
pub fn with_defaults(
language: Arc<L>,
grammar: Option<WeightedCFG>,
) -> Result<Self, PipelineError> {
Self::new(language, grammar, PipelineConfig::default())
}
pub fn minimal(language: Arc<L>) -> Result<Self, PipelineError> {
let corrector = EnsembleCorrector::lexical_only(Arc::clone(&language));
let parser = CodeParser::new(Arc::clone(&language))
.map_err(|e| PipelineError::ParseError(format!("{}", e)))?;
Ok(Self {
language,
corrector,
config: PipelineConfig {
full_semantic_analysis: false,
..Default::default()
},
parser,
})
}
pub fn analyze(&mut self, source: &str) -> Result<AnalysisResult, PipelineError> {
let parsed = self
.parser
.parse(source)
.map_err(|e| PipelineError::ParseError(format!("{}", e)))?;
let tokens = self.tokenize(&parsed);
let cpg = if self.config.full_semantic_analysis {
Some(CodePropertyGraph::from_parsed_code(&parsed))
} else {
None
};
let mut diagnostics = self.collect_parse_diagnostics(&parsed);
let mut collector = StreamingCorrectionCollector::new(
self.config.max_corrections,
self.config.min_confidence,
);
for token in &tokens {
let context = TokenContext::new(token.token_type);
let token_corrections = self.corrector.correct_token(token, &context);
collector.add_all(token_corrections);
}
if let Some(ref cpg) = cpg {
let semantic_corrections = self.corrector.analyze_full(&parsed, cpg);
collector.add_all(semantic_corrections);
}
let corrections = collector.finalize();
for correction in corrections.ranked() {
if self.config.include_diagnostics {
let (line, column) = byte_offset_to_position(source, correction.start_byte);
diagnostics.push(Diagnostic {
severity: DiagnosticSeverity::Hint,
message: correction.context.clone().unwrap_or_else(|| {
format!(
"Consider: {} -> {}",
correction.original, correction.replacement
)
}),
start_byte: correction.start_byte,
end_byte: correction.end_byte,
line,
column,
});
}
}
let has_parse_errors = parsed.has_errors;
let error_count = parsed.error_count();
Ok(AnalysisResult {
source: source.to_string(),
has_parse_errors,
error_count,
tokens,
corrections,
diagnostics,
})
}
fn tokenize(&self, parsed: &ParsedCode) -> Vec<CodeToken> {
let tokenizer = CodeTokenizer::new(&*self.language);
tokenizer.tokenize(&parsed.tree, &parsed.source)
}
fn collect_parse_diagnostics(&self, parsed: &ParsedCode) -> Vec<Diagnostic> {
let mut diagnostics = Vec::new();
for error in parsed.errors() {
diagnostics.push(Diagnostic {
severity: DiagnosticSeverity::Error,
message: format!("Syntax error: {} '{}'", error.kind, error.text),
start_byte: error.start_byte,
end_byte: error.end_byte,
line: error.start_position.0,
column: error.start_position.1,
});
}
diagnostics
}
pub fn apply_corrections(&self, source: &str, corrections: &[Correction]) -> String {
if corrections.is_empty() {
return source.to_string();
}
let mut sorted: Vec<_> = corrections.iter().collect();
sorted.sort_by(|a, b| b.start_byte.cmp(&a.start_byte));
let mut result = source.to_string();
for correction in sorted {
if correction.start_byte < result.len() && correction.end_byte <= result.len() {
result.replace_range(
correction.start_byte..correction.end_byte,
&correction.replacement,
);
}
}
result
}
pub fn corrector_mut(&mut self) -> &mut EnsembleCorrector<L> {
&mut self.corrector
}
pub fn language(&self) -> &L {
&self.language
}
pub fn config(&self) -> &PipelineConfig {
&self.config
}
}
#[derive(Debug)]
pub enum PipelineError {
ParseError(String),
TokenizeError(String),
CpgError(String),
CorrectionError(String),
IoError(std::io::Error),
}
impl std::fmt::Display for PipelineError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PipelineError::ParseError(msg) => write!(f, "Parse error: {}", msg),
PipelineError::TokenizeError(msg) => write!(f, "Tokenize error: {}", msg),
PipelineError::CpgError(msg) => write!(f, "CPG error: {}", msg),
PipelineError::CorrectionError(msg) => write!(f, "Correction error: {}", msg),
PipelineError::IoError(e) => write!(f, "I/O error: {}", e),
}
}
}
impl std::error::Error for PipelineError {}
impl From<std::io::Error> for PipelineError {
fn from(e: std::io::Error) -> Self {
PipelineError::IoError(e)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config_default() {
let config = PipelineConfig::default();
assert_eq!(config.max_corrections, 50);
assert!((config.min_confidence - 0.3).abs() < 0.01);
assert!(config.include_diagnostics);
assert!(config.auto_apply_threshold.is_none());
assert!(config.full_semantic_analysis);
}
#[test]
fn test_apply_corrections() {
let source = "funtion foo() { return 42; }";
let mut result = source.to_string();
result.replace_range(0..7, "function");
assert_eq!(result, "function foo() { return 42; }");
}
#[test]
fn test_apply_multiple_corrections() {
let source = "funtion foo() { retrun 42; }";
let mut result = source.to_string();
result.replace_range(16..22, "return");
result.replace_range(0..7, "function");
assert_eq!(result, "function foo() { return 42; }");
}
}