use super::grammar::GrammarCorrector;
use super::lexical::LexicalCorrector;
use super::semantic::SemanticCorrector;
use crate::code::ast::ParsedCode;
use crate::code::correction::{CodeCorrector, Correction, CorrectionSource};
use crate::code::cpg::CodePropertyGraph;
use crate::code::language::{CodeLanguage, TokenContext};
use crate::code::pcfg::WeightedCFG;
use crate::code::tokenizer::CodeToken;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct EnsembleCorrectorConfig {
pub lexical_weight: f64,
pub grammar_weight: f64,
pub semantic_weight: f64,
pub min_confidence: f64,
pub max_candidates: usize,
pub deduplicate: bool,
pub dedup_threshold: f64,
pub agreement_boost: bool,
pub agreement_boost_factor: f64,
}
impl Default for EnsembleCorrectorConfig {
fn default() -> Self {
Self {
lexical_weight: 0.4,
grammar_weight: 0.35,
semantic_weight: 0.25,
min_confidence: 0.3,
max_candidates: 10,
deduplicate: true,
dedup_threshold: 0.9,
agreement_boost: true,
agreement_boost_factor: 1.3,
}
}
}
pub struct EnsembleCorrector<L: CodeLanguage> {
language: Arc<L>,
config: EnsembleCorrectorConfig,
lexical: Option<LexicalCorrector<L>>,
grammar: Option<GrammarCorrector<L>>,
semantic: Option<SemanticCorrector<L>>,
}
impl<L: CodeLanguage + Clone> EnsembleCorrector<L> {
pub fn new(
language: Arc<L>,
grammar: Option<WeightedCFG>,
config: EnsembleCorrectorConfig,
) -> Self {
let lexical = Some(LexicalCorrector::with_defaults(Arc::clone(&language)));
let grammar_corrector =
grammar.map(|g| GrammarCorrector::with_defaults(Arc::clone(&language), g));
let semantic = Some(SemanticCorrector::with_defaults(Arc::clone(&language)));
Self {
language,
config,
lexical,
grammar: grammar_corrector,
semantic,
}
}
pub fn with_defaults(language: Arc<L>, grammar: Option<WeightedCFG>) -> Self {
Self::new(language, grammar, EnsembleCorrectorConfig::default())
}
pub fn lexical_only(language: Arc<L>) -> Self {
Self {
language: Arc::clone(&language),
config: EnsembleCorrectorConfig::default(),
lexical: Some(LexicalCorrector::with_defaults(language)),
grammar: None,
semantic: None,
}
}
pub fn lexical_mut(&mut self) -> Option<&mut LexicalCorrector<L>> {
self.lexical.as_mut()
}
pub fn semantic_mut(&mut self) -> Option<&mut SemanticCorrector<L>> {
self.semantic.as_mut()
}
pub fn add_identifiers(&mut self, identifiers: &[&str]) {
if let Some(ref mut lexical) = self.lexical {
for id in identifiers {
lexical.add_identifier(id);
}
}
}
pub fn register_variables(&mut self, variables: &[(String, Option<String>)]) {
if let Some(ref mut semantic) = self.semantic {
for (name, type_name) in variables {
semantic.register_variable(name.clone(), type_name.clone(), 0);
}
}
}
fn collect_corrections(
&self,
token: &CodeToken,
context: &TokenContext,
) -> Vec<(Correction, f64)> {
let mut corrections = Vec::new();
if let Some(ref lexical) = self.lexical {
for c in lexical.correct_token(token, context) {
corrections.push((c, self.config.lexical_weight));
}
}
if let Some(ref grammar) = self.grammar {
for c in grammar.correct_token(token, context) {
corrections.push((c, self.config.grammar_weight));
}
}
if let Some(ref semantic) = self.semantic {
for c in semantic.correct_token(token, context) {
corrections.push((c, self.config.semantic_weight));
}
}
corrections
}
fn apply_weight(correction: &Correction, weight: f64) -> Correction {
let mut c = correction.clone();
c.confidence *= weight;
c
}
fn merge_corrections(&self, corrections: Vec<(Correction, f64)>) -> Vec<Correction> {
if corrections.is_empty() {
return vec![];
}
if !self.config.deduplicate {
return corrections
.into_iter()
.map(|(c, weight)| Self::apply_weight(&c, weight))
.collect();
}
let n_in = corrections.len();
let mut groups: HashMap<(String, usize, usize), Vec<(Correction, f64)>> =
HashMap::with_capacity(n_in);
for (c, weight) in corrections {
let key = (c.replacement.clone(), c.start_byte, c.end_byte);
groups.entry(key).or_default().push((c, weight));
}
let mut merged = Vec::with_capacity(groups.len());
for ((_replacement, _start_byte, _end_byte), group) in groups {
if group.len() == 1 {
let (c, weight) = group.into_iter().next().unwrap();
merged.push(Self::apply_weight(&c, weight));
} else {
let sources: Vec<CorrectionSource> = group.iter().map(|(c, _)| c.source).collect();
let total_weight: f64 = group.iter().map(|(_, w)| w).sum();
let avg_confidence: f64 =
group.iter().map(|(c, w)| c.confidence * w).sum::<f64>() / total_weight;
let mut best = group
.into_iter()
.max_by(|a, b| {
(a.0.confidence * a.1)
.partial_cmp(&(b.0.confidence * b.1))
.unwrap_or(std::cmp::Ordering::Equal)
})
.expect("ensemble group is non-empty by construction")
.0;
let boost = if self.config.agreement_boost && sources.len() > 1 {
self.config.agreement_boost_factor
} else {
1.0
};
best.confidence = (avg_confidence * boost).min(1.0);
best.source = CorrectionSource::Combined;
best.context = Some(format!("Suggested by {} sources", sources.len()));
merged.push(best);
}
}
merged
}
fn finalize_corrections(&self, mut corrections: Vec<Correction>) -> Vec<Correction> {
corrections.retain(|c| c.confidence >= self.config.min_confidence);
corrections.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
corrections.truncate(self.config.max_candidates);
corrections
}
pub fn analyze_full(&self, parsed: &ParsedCode, cpg: &CodePropertyGraph) -> Vec<Correction> {
let mut all_corrections = Vec::new();
if let Some(ref semantic) = self.semantic {
let semantic_corrections = semantic.analyze_parsed(parsed, cpg);
for c in semantic_corrections {
all_corrections.push((c, self.config.semantic_weight));
}
}
let merged = self.merge_corrections(all_corrections);
self.finalize_corrections(merged)
}
pub fn config(&self) -> &EnsembleCorrectorConfig {
&self.config
}
pub fn language(&self) -> &L {
&self.language
}
}
impl<L: CodeLanguage + Clone + Send + Sync> CodeCorrector for EnsembleCorrector<L> {
fn correct_token(&self, token: &CodeToken, context: &TokenContext) -> Vec<Correction> {
let corrections = self.collect_corrections(token, context);
let merged = self.merge_corrections(corrections);
self.finalize_corrections(merged)
}
fn correct_range(&self, source: &str, start_byte: usize, end_byte: usize) -> Vec<Correction> {
let text = &source[start_byte..end_byte];
let token = CodeToken::new(
text,
start_byte,
0,
0,
crate::code::language::TokenType::Unknown,
"unknown",
);
let context = TokenContext::new(crate::code::language::TokenType::Unknown);
self.correct_token(&token, &context)
}
fn max_edit_distance(&self) -> usize {
let mut max = 2;
if let Some(ref lexical) = self.lexical {
max = max.max(lexical.max_edit_distance());
}
if let Some(ref grammar) = self.grammar {
max = max.max(grammar.max_edit_distance());
}
if let Some(ref semantic) = self.semantic {
max = max.max(semantic.max_edit_distance());
}
max
}
fn name(&self) -> &str {
"EnsembleCorrector"
}
}
pub struct EnsembleCorrectorBuilder<L: CodeLanguage> {
language: Arc<L>,
config: EnsembleCorrectorConfig,
grammar: Option<WeightedCFG>,
enable_lexical: bool,
enable_grammar: bool,
enable_semantic: bool,
}
impl<L: CodeLanguage + Clone> EnsembleCorrectorBuilder<L> {
pub fn new(language: Arc<L>) -> Self {
Self {
language,
config: EnsembleCorrectorConfig::default(),
grammar: None,
enable_lexical: true,
enable_grammar: true,
enable_semantic: true,
}
}
pub fn with_grammar(mut self, grammar: WeightedCFG) -> Self {
self.grammar = Some(grammar);
self
}
pub fn with_config(mut self, config: EnsembleCorrectorConfig) -> Self {
self.config = config;
self
}
pub fn without_lexical(mut self) -> Self {
self.enable_lexical = false;
self
}
pub fn without_grammar(mut self) -> Self {
self.enable_grammar = false;
self
}
pub fn without_semantic(mut self) -> Self {
self.enable_semantic = false;
self
}
pub fn lexical_weight(mut self, weight: f64) -> Self {
self.config.lexical_weight = weight;
self
}
pub fn grammar_weight(mut self, weight: f64) -> Self {
self.config.grammar_weight = weight;
self
}
pub fn semantic_weight(mut self, weight: f64) -> Self {
self.config.semantic_weight = weight;
self
}
pub fn build(self) -> EnsembleCorrector<L> {
let lexical = if self.enable_lexical {
Some(LexicalCorrector::with_defaults(Arc::clone(&self.language)))
} else {
None
};
let grammar = if self.enable_grammar {
self.grammar
.map(|g| GrammarCorrector::with_defaults(Arc::clone(&self.language), g))
} else {
None
};
let semantic = if self.enable_semantic {
Some(SemanticCorrector::with_defaults(Arc::clone(&self.language)))
} else {
None
};
EnsembleCorrector {
language: self.language,
config: self.config,
lexical,
grammar,
semantic,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Default)]
struct MockLanguage;
impl CodeLanguage for MockLanguage {
fn name(&self) -> &str {
"mock"
}
fn display_name(&self) -> &str {
"Mock"
}
fn tree_sitter_language(&self) -> tree_sitter::Language {
tree_sitter_rust::LANGUAGE.into()
}
fn keywords(&self) -> &[&str] {
&["if", "else", "while", "for", "return", "function"]
}
fn special_tokens(&self) -> &[&str] {
&[]
}
fn file_extensions(&self) -> &[&str] {
&["mock"]
}
fn classify_token(
&self,
_token: &str,
_node_kind: &str,
) -> crate::code::language::TokenType {
crate::code::language::TokenType::Unknown
}
fn is_valid_identifier(&self, s: &str) -> bool {
!s.is_empty() && s.chars().next().map(|c| c.is_alphabetic()).unwrap_or(false)
}
fn builtin_types(&self) -> &[&str] {
&["int", "string", "bool"]
}
fn stdlib_functions(&self) -> &[&str] {
&["print", "read"]
}
fn comment_syntax(&self) -> crate::code::language::CommentSyntax {
crate::code::language::CommentSyntax::default()
}
fn is_whitespace_significant(&self) -> bool {
false
}
}
#[test]
fn test_ensemble_corrector_creation() {
let lang = Arc::new(MockLanguage);
let corrector = EnsembleCorrector::with_defaults(lang, None);
assert!(corrector.lexical.is_some());
assert!(corrector.grammar.is_none()); assert!(corrector.semantic.is_some());
}
#[test]
fn test_ensemble_builder() {
let lang = Arc::new(MockLanguage);
let corrector = EnsembleCorrectorBuilder::new(lang)
.without_semantic()
.lexical_weight(0.6)
.build();
assert!(corrector.lexical.is_some());
assert!(corrector.semantic.is_none());
assert!((corrector.config.lexical_weight - 0.6).abs() < 0.01);
}
#[test]
fn test_ensemble_correction() {
let lang = Arc::new(MockLanguage);
let mut corrector = EnsembleCorrector::lexical_only(Arc::clone(&lang));
corrector.add_identifiers(&["calculateTotal", "processData"]);
let token = CodeToken::new(
"funtion", 0,
1,
0,
crate::code::language::TokenType::Keyword,
"keyword",
);
let context = TokenContext::new(crate::code::language::TokenType::Keyword);
let corrections = corrector.correct_token(&token, &context);
assert!(!corrections.is_empty());
}
#[test]
fn test_agreement_boost() {
let config = EnsembleCorrectorConfig {
agreement_boost: true,
agreement_boost_factor: 1.5,
..Default::default()
};
assert!(config.agreement_boost);
assert!((config.agreement_boost_factor - 1.5).abs() < 0.01);
}
}