use crate::code::correction::{CodeCorrector, Correction, CorrectionKind, CorrectionSource};
use crate::code::language::{CodeLanguage, TokenContext, TokenType};
use crate::code::tokenizer::CodeToken;
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct LexicalCorrectorConfig {
pub max_edit_distance: usize,
pub min_token_length: usize,
pub max_candidates: usize,
pub edit_penalty: f64,
}
impl Default for LexicalCorrectorConfig {
fn default() -> Self {
Self {
max_edit_distance: 2,
min_token_length: 2,
max_candidates: 5,
edit_penalty: 0.15,
}
}
}
#[derive(Debug, Clone)]
struct FuzzyCandidate {
term: String,
distance: usize,
}
pub struct LexicalCorrector<L: CodeLanguage> {
language: Arc<L>,
config: LexicalCorrectorConfig,
keywords: HashSet<String>,
identifiers: HashSet<String>,
types: HashSet<String>,
stdlib: HashSet<String>,
}
impl<L: CodeLanguage> LexicalCorrector<L> {
pub fn new(language: Arc<L>, config: LexicalCorrectorConfig) -> Self {
let mut keywords = HashSet::new();
let mut types = HashSet::new();
let mut stdlib = HashSet::new();
for keyword in language.keywords() {
keywords.insert(keyword.to_string());
}
for typ in language.builtin_types() {
types.insert(typ.to_string());
}
for func in language.stdlib_functions() {
stdlib.insert(func.to_string());
}
Self {
language,
config,
keywords,
identifiers: HashSet::new(),
types,
stdlib,
}
}
pub fn with_defaults(language: Arc<L>) -> Self {
Self::new(language, LexicalCorrectorConfig::default())
}
pub fn add_identifier(&mut self, identifier: &str) {
if self.language.is_valid_identifier(identifier) {
self.identifiers.insert(identifier.to_string());
}
}
pub fn add_identifiers_from_source(&mut self, source: &str) {
for word in source.split(|c: char| !c.is_alphanumeric() && c != '_') {
if !word.is_empty() && self.language.is_valid_identifier(word) {
self.identifiers.insert(word.to_string());
}
}
}
pub fn add_identifiers_from_tokens(&mut self, tokens: &[CodeToken]) {
for token in tokens {
if token.token_type == TokenType::Identifier {
self.add_identifier(&token.text);
}
}
}
fn levenshtein_distance(a: &str, b: &str) -> usize {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let m = a_chars.len();
let n = b_chars.len();
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let mut dp = vec![vec![0usize; n + 1]; m + 1];
for i in 0..=m {
dp[i][0] = i;
}
for j in 0..=n {
dp[0][j] = j;
}
for i in 1..=m {
for j in 1..=n {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
dp[i][j] = (dp[i - 1][j] + 1)
.min(dp[i][j - 1] + 1)
.min(dp[i - 1][j - 1] + cost);
}
}
dp[m][n]
}
fn fuzzy_search(&self, query: &str, dictionary: &HashSet<String>) -> Vec<FuzzyCandidate> {
let max_dist = self.config.max_edit_distance;
let mut candidates = Vec::new();
for term in dictionary {
let len_diff = (query.len() as isize - term.len() as isize).unsigned_abs();
if len_diff > max_dist {
continue;
}
let distance = Self::levenshtein_distance(query, term);
if distance > 0 && distance <= max_dist {
candidates.push(FuzzyCandidate {
term: term.clone(),
distance,
});
}
}
candidates.sort_by_key(|c| c.distance);
candidates
}
fn get_candidates(&self, token: &str, token_type: TokenType) -> Vec<FuzzyCandidate> {
match token_type {
TokenType::Keyword => self.fuzzy_search(token, &self.keywords),
TokenType::TypeName => self.fuzzy_search(token, &self.types),
TokenType::Identifier => {
let mut candidates = self.fuzzy_search(token, &self.identifiers);
candidates.extend(self.fuzzy_search(token, &self.stdlib));
let mut seen = HashSet::new();
candidates.retain(|c| seen.insert(c.term.clone()));
candidates.sort_by_key(|c| c.distance);
candidates
}
_ => {
let mut candidates = Vec::new();
candidates.extend(self.fuzzy_search(token, &self.keywords));
candidates.extend(self.fuzzy_search(token, &self.identifiers));
candidates.extend(self.fuzzy_search(token, &self.types));
candidates.extend(self.fuzzy_search(token, &self.stdlib));
let mut seen = HashSet::new();
candidates.retain(|c| seen.insert(c.term.clone()));
candidates.sort_by_key(|c| c.distance);
candidates
}
}
}
fn candidate_to_correction(&self, candidate: &FuzzyCandidate, token: &CodeToken) -> Correction {
let distance = candidate.distance as f64;
let confidence = 1.0 - (distance * self.config.edit_penalty).min(0.9);
let end_byte = token.byte_offset + token.text.len();
Correction::new(
CorrectionKind::Spelling,
token.byte_offset,
end_byte,
&token.text,
&candidate.term,
)
.with_confidence(confidence)
.with_source(CorrectionSource::Lexical)
.with_context(format!("Edit distance: {}", candidate.distance))
}
pub fn language(&self) -> &L {
&self.language
}
pub fn config(&self) -> &LexicalCorrectorConfig {
&self.config
}
pub fn keyword_count(&self) -> usize {
self.keywords.len()
}
pub fn identifier_count(&self) -> usize {
self.identifiers.len()
}
}
impl<L: CodeLanguage + Send + Sync> CodeCorrector for LexicalCorrector<L> {
fn correct_token(&self, token: &CodeToken, _context: &TokenContext) -> Vec<Correction> {
if token.text.len() < self.config.min_token_length {
return vec![];
}
let token_type = token.token_type;
let candidates = self.get_candidates(&token.text, token_type);
candidates
.into_iter()
.take(self.config.max_candidates)
.map(|c| self.candidate_to_correction(&c, token))
.collect()
}
fn correct_range(&self, source: &str, start_byte: usize, end_byte: usize) -> Vec<Correction> {
let text = &source[start_byte..end_byte];
let token = CodeToken::new(text, start_byte, 0, 0, TokenType::Unknown, "unknown");
let context = TokenContext::new(TokenType::Unknown);
self.correct_token(&token, &context)
}
fn max_edit_distance(&self) -> usize {
self.config.max_edit_distance
}
fn name(&self) -> &str {
"LexicalCorrector"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Default)]
struct MockLanguage;
impl CodeLanguage for MockLanguage {
fn name(&self) -> &str {
"mock"
}
fn display_name(&self) -> &str {
"Mock"
}
fn tree_sitter_language(&self) -> tree_sitter::Language {
tree_sitter_rust::LANGUAGE.into()
}
fn keywords(&self) -> &[&str] {
&[
"if", "else", "while", "for", "return", "function", "let", "const", "var",
]
}
fn special_tokens(&self) -> &[&str] {
&[]
}
fn file_extensions(&self) -> &[&str] {
&["mock"]
}
fn classify_token(&self, _token: &str, _node_kind: &str) -> TokenType {
TokenType::Unknown
}
fn is_valid_identifier(&self, s: &str) -> bool {
!s.is_empty() && s.chars().next().map(|c| c.is_alphabetic()).unwrap_or(false)
}
fn builtin_types(&self) -> &[&str] {
&["int", "string", "bool", "float"]
}
fn stdlib_functions(&self) -> &[&str] {
&["print", "println", "read", "write"]
}
fn comment_syntax(&self) -> crate::code::language::CommentSyntax {
crate::code::language::CommentSyntax::default()
}
fn is_whitespace_significant(&self) -> bool {
false
}
}
#[test]
fn test_levenshtein_distance() {
assert_eq!(
LexicalCorrector::<MockLanguage>::levenshtein_distance("", ""),
0
);
assert_eq!(
LexicalCorrector::<MockLanguage>::levenshtein_distance("abc", ""),
3
);
assert_eq!(
LexicalCorrector::<MockLanguage>::levenshtein_distance("", "abc"),
3
);
assert_eq!(
LexicalCorrector::<MockLanguage>::levenshtein_distance("abc", "abc"),
0
);
assert_eq!(
LexicalCorrector::<MockLanguage>::levenshtein_distance("abc", "abd"),
1
);
assert_eq!(
LexicalCorrector::<MockLanguage>::levenshtein_distance("function", "funtion"),
1
);
}
#[test]
fn test_lexical_corrector_keywords() {
let lang = Arc::new(MockLanguage);
let corrector = LexicalCorrector::with_defaults(lang);
let token = CodeToken::new(
"funtion", 0,
1,
0,
TokenType::Keyword,
"keyword",
);
let context = TokenContext::new(TokenType::Keyword);
let corrections = corrector.correct_token(&token, &context);
assert!(!corrections.is_empty());
assert!(corrections.iter().any(|c| c.replacement == "function"));
}
#[test]
fn test_lexical_corrector_identifiers() {
let lang = Arc::new(MockLanguage);
let mut corrector = LexicalCorrector::with_defaults(lang);
corrector.add_identifier("calculateTotal");
corrector.add_identifier("processData");
corrector.add_identifier("handleError");
let token = CodeToken::new(
"calulateTotal", 0,
1,
0,
TokenType::Identifier,
"identifier",
);
let context = TokenContext::new(TokenType::Identifier);
let corrections = corrector.correct_token(&token, &context);
assert!(!corrections.is_empty());
assert!(corrections
.iter()
.any(|c| c.replacement == "calculateTotal"));
}
#[test]
fn test_lexical_corrector_exact_match() {
let lang = Arc::new(MockLanguage);
let corrector = LexicalCorrector::with_defaults(lang);
let token = CodeToken::new(
"function", 0,
1,
0,
TokenType::Keyword,
"keyword",
);
let context = TokenContext::new(TokenType::Keyword);
let corrections = corrector.correct_token(&token, &context);
assert!(corrections.is_empty() || corrections.iter().all(|c| c.replacement != "function"));
}
}