use crate::code::constrained_decoding::{ConstrainedDecodingConfig, GrammarConstraint};
use crate::code::correction::{CodeCorrector, Correction, CorrectionKind, CorrectionSource};
use crate::code::language::{CodeLanguage, TokenContext, TokenType};
use crate::code::pcfg::{Symbol, WeightedCFG};
use crate::code::tokenizer::CodeToken;
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct GrammarCorrectorConfig {
pub max_candidates: usize,
pub min_rule_probability: f64,
pub suggest_insertions: bool,
pub suggest_deletions: bool,
pub max_lookahead: usize,
pub base_confidence: f64,
}
impl Default for GrammarCorrectorConfig {
fn default() -> Self {
Self {
max_candidates: 5,
min_rule_probability: 0.01,
suggest_insertions: true,
suggest_deletions: true,
max_lookahead: 3,
base_confidence: 0.8,
}
}
}
pub struct GrammarCorrector<L: CodeLanguage> {
language: Arc<L>,
config: GrammarCorrectorConfig,
grammar: WeightedCFG,
}
impl<L: CodeLanguage> GrammarCorrector<L> {
pub fn new(language: Arc<L>, grammar: WeightedCFG, config: GrammarCorrectorConfig) -> Self {
Self {
language,
config,
grammar,
}
}
pub fn with_defaults(language: Arc<L>, grammar: WeightedCFG) -> Self {
Self::new(language, grammar, GrammarCorrectorConfig::default())
}
pub fn create_constraint(&self) -> GrammarConstraint {
GrammarConstraint::new(self.grammar.clone(), ConstrainedDecodingConfig::default())
}
pub fn valid_next_tokens(&self, token_history: &[&str]) -> HashSet<String> {
let mut constraint = self.create_constraint();
for token in token_history {
if !constraint.advance(token) {
return HashSet::new();
}
}
constraint.valid_tokens()
}
pub fn suggest_completions(
&self,
context: &[&str],
max_suggestions: usize,
) -> Vec<(String, f64)> {
let valid = self.valid_next_tokens(context);
let mut suggestions: Vec<(String, f64)> = valid
.into_iter()
.map(|token| {
let prob = self.token_probability(&token, context);
(token, prob)
})
.collect();
suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
suggestions.truncate(max_suggestions);
suggestions
}
fn token_probability(&self, token: &str, _context: &[&str]) -> f64 {
let mut max_prob: f64 = 0.0;
for (production, _) in self.grammar.iter_rules() {
for symbol in &production.rhs {
if let Symbol::Terminal(t) = symbol {
if t == token {
let prob = self.grammar.probability(production);
max_prob = max_prob.max(prob);
}
}
}
}
max_prob
}
pub fn find_syntax_errors(&self, tokens: &[&str]) -> Vec<SyntaxError> {
let mut errors = Vec::new();
let mut constraint = self.create_constraint();
for (i, token) in tokens.iter().enumerate() {
if !constraint.is_valid_token(token) {
let valid_tokens = constraint.valid_tokens();
let is_empty = valid_tokens.is_empty();
errors.push(SyntaxError {
position: i,
token: token.to_string(),
expected: valid_tokens,
error_type: if is_empty {
SyntaxErrorType::UnexpectedToken
} else {
SyntaxErrorType::InvalidToken
},
});
}
if !constraint.advance(token) {
break;
}
}
errors
}
fn suggest_insertions(
&self,
position: usize,
context: &[&str],
_source: &str,
byte_position: usize,
) -> Vec<Correction> {
if !self.config.suggest_insertions {
return vec![];
}
let suggestions = self.suggest_completions(context, self.config.max_candidates);
suggestions
.into_iter()
.map(|(token, prob)| {
let confidence = self.config.base_confidence * prob;
Correction::new(
CorrectionKind::Insertion,
byte_position,
byte_position,
"",
&token,
)
.with_confidence(confidence)
.with_source(CorrectionSource::Grammar)
.with_context(format!("Expected token at position {}", position))
})
.collect()
}
fn suggest_deletions(
&self,
token: &CodeToken,
valid_tokens: &HashSet<String>,
) -> Vec<Correction> {
if !self.config.suggest_deletions {
return vec![];
}
let end_byte = token.byte_offset + token.text.len();
if valid_tokens.is_empty() || !valid_tokens.contains(&token.text) {
vec![Correction::new(
CorrectionKind::Deletion,
token.byte_offset,
end_byte,
&token.text,
"",
)
.with_confidence(self.config.base_confidence * 0.7)
.with_source(CorrectionSource::Grammar)
.with_context("Unexpected token")]
} else {
vec![]
}
}
fn suggest_replacements(
&self,
token: &CodeToken,
valid_tokens: &HashSet<String>,
) -> Vec<Correction> {
let mut corrections = Vec::new();
let end_byte = token.byte_offset + token.text.len();
for valid in valid_tokens {
if valid == &token.text {
continue;
}
let similarity = self.string_similarity(&token.text, valid);
if similarity < 0.3 {
continue; }
let prob = self.token_probability(valid, &[]);
let confidence = self.config.base_confidence * similarity * (0.5 + 0.5 * prob);
corrections.push(
Correction::new(
CorrectionKind::Replacement,
token.byte_offset,
end_byte,
&token.text,
valid,
)
.with_confidence(confidence)
.with_source(CorrectionSource::Grammar)
.with_context(format!("Grammar suggests '{}'", valid)),
);
}
corrections.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
corrections.truncate(self.config.max_candidates);
corrections
}
fn string_similarity(&self, a: &str, b: &str) -> f64 {
if a == b {
return 1.0;
}
if a.is_empty() || b.is_empty() {
return 0.0;
}
let bigrams_a: HashSet<(char, char)> = {
let mut set = HashSet::with_capacity(a.len().saturating_sub(1).max(1));
let mut chars = a.chars().peekable();
while let Some(c1) = chars.next() {
if let Some(&c2) = chars.peek() {
set.insert((c1, c2));
}
}
set
};
let bigrams_b: HashSet<(char, char)> = {
let mut set = HashSet::with_capacity(b.len().saturating_sub(1).max(1));
let mut chars = b.chars().peekable();
while let Some(c1) = chars.next() {
if let Some(&c2) = chars.peek() {
set.insert((c1, c2));
}
}
set
};
if bigrams_a.is_empty() || bigrams_b.is_empty() {
return if a == b { 1.0 } else { 0.0 };
}
let intersection = bigrams_a.intersection(&bigrams_b).count();
let union = bigrams_a.union(&bigrams_b).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
pub fn grammar(&self) -> &WeightedCFG {
&self.grammar
}
pub fn language(&self) -> &L {
&self.language
}
}
impl<L: CodeLanguage + Send + Sync> CodeCorrector for GrammarCorrector<L> {
fn correct_token(&self, token: &CodeToken, _context: &TokenContext) -> Vec<Correction> {
let mut corrections = Vec::new();
let valid_tokens = self.valid_next_tokens(&[]);
if !valid_tokens.contains(&token.text) {
corrections.extend(self.suggest_replacements(token, &valid_tokens));
corrections.extend(self.suggest_deletions(token, &valid_tokens));
corrections.extend(self.suggest_insertions(0, &[], "", token.byte_offset));
}
corrections
}
fn correct_range(&self, source: &str, start_byte: usize, end_byte: usize) -> Vec<Correction> {
let text = &source[start_byte..end_byte];
let token = CodeToken::new(text, start_byte, 0, 0, TokenType::Unknown, "unknown");
let context = TokenContext::new(TokenType::Unknown);
self.correct_token(&token, &context)
}
fn max_edit_distance(&self) -> usize {
2 }
fn name(&self) -> &str {
"GrammarCorrector"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyntaxErrorType {
InvalidToken,
UnexpectedToken,
MissingToken,
UnclosedDelimiter,
}
#[derive(Debug, Clone)]
pub struct SyntaxError {
pub position: usize,
pub token: String,
pub expected: HashSet<String>,
pub error_type: SyntaxErrorType,
}
impl SyntaxError {
pub fn message(&self) -> String {
match self.error_type {
SyntaxErrorType::InvalidToken => {
if self.expected.len() <= 3 {
let expected: Vec<_> = self.expected.iter().take(3).collect();
format!(
"Invalid token '{}', expected one of: {}",
self.token,
expected
.iter()
.map(|s| format!("'{}'", s))
.collect::<Vec<_>>()
.join(", ")
)
} else {
format!(
"Invalid token '{}' (expected {} possible tokens)",
self.token,
self.expected.len()
)
}
}
SyntaxErrorType::UnexpectedToken => {
format!("Unexpected token '{}'", self.token)
}
SyntaxErrorType::MissingToken => {
format!("Missing token before '{}'", self.token)
}
SyntaxErrorType::UnclosedDelimiter => {
format!("Unclosed delimiter before '{}'", self.token)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::code::pcfg::Production;
fn create_test_grammar() -> WeightedCFG {
let mut cfg = WeightedCFG::new("stmt");
cfg.add_rule(
Production::new(
"stmt",
vec![
Symbol::Terminal("if".to_string()),
Symbol::Terminal("(".to_string()),
Symbol::NonTerminal("expr".to_string()),
Symbol::Terminal(")".to_string()),
Symbol::NonTerminal("stmt".to_string()),
],
),
0.3,
);
cfg.add_rule(
Production::new(
"stmt",
vec![
Symbol::Terminal("while".to_string()),
Symbol::Terminal("(".to_string()),
Symbol::NonTerminal("expr".to_string()),
Symbol::Terminal(")".to_string()),
Symbol::NonTerminal("stmt".to_string()),
],
),
0.2,
);
cfg.add_rule(
Production::new(
"stmt",
vec![
Symbol::Terminal("return".to_string()),
Symbol::NonTerminal("expr".to_string()),
Symbol::Terminal(";".to_string()),
],
),
0.3,
);
cfg.add_rule(
Production::new(
"stmt",
vec![
Symbol::NonTerminal("expr".to_string()),
Symbol::Terminal(";".to_string()),
],
),
0.2,
);
cfg.add_rule(
Production::new("expr", vec![Symbol::Terminal("x".to_string())]),
0.5,
);
cfg.add_rule(
Production::new("expr", vec![Symbol::Terminal("y".to_string())]),
0.5,
);
cfg
}
#[derive(Debug, Clone, Default)]
struct MockLanguage;
impl CodeLanguage for MockLanguage {
fn name(&self) -> &str {
"mock"
}
fn display_name(&self) -> &str {
"Mock"
}
fn tree_sitter_language(&self) -> tree_sitter::Language {
tree_sitter_rust::LANGUAGE.into()
}
fn keywords(&self) -> &[&str] {
&["if", "else", "while", "return"]
}
fn special_tokens(&self) -> &[&str] {
&[]
}
fn file_extensions(&self) -> &[&str] {
&["mock"]
}
fn classify_token(&self, _token: &str, _node_kind: &str) -> TokenType {
TokenType::Unknown
}
fn is_valid_identifier(&self, s: &str) -> bool {
!s.is_empty()
}
fn builtin_types(&self) -> &[&str] {
&[]
}
fn stdlib_functions(&self) -> &[&str] {
&[]
}
fn comment_syntax(&self) -> crate::code::language::CommentSyntax {
crate::code::language::CommentSyntax::default()
}
fn is_whitespace_significant(&self) -> bool {
false
}
}
#[test]
fn test_grammar_corrector_valid_tokens() {
let lang = Arc::new(MockLanguage);
let grammar = create_test_grammar();
let corrector = GrammarCorrector::with_defaults(lang, grammar);
let valid = corrector.valid_next_tokens(&[]);
assert!(valid.contains("if") || valid.contains("while") || valid.contains("return"));
}
#[test]
fn test_grammar_corrector_completions() {
let lang = Arc::new(MockLanguage);
let grammar = create_test_grammar();
let corrector = GrammarCorrector::with_defaults(lang, grammar);
let completions = corrector.suggest_completions(&["if", "("], 5);
assert!(!completions.is_empty());
}
#[test]
fn test_syntax_error_message() {
let mut expected = HashSet::new();
expected.insert("if".to_string());
expected.insert("while".to_string());
let error = SyntaxError {
position: 0,
token: "fi".to_string(),
expected,
error_type: SyntaxErrorType::InvalidToken,
};
let msg = error.message();
assert!(msg.contains("Invalid token"));
assert!(msg.contains("fi"));
}
}