use super::language::TokenContext;
use super::tokenizer::CodeToken;
#[derive(Debug, Clone)]
pub struct Correction {
pub kind: CorrectionKind,
pub start_byte: usize,
pub end_byte: usize,
pub original: String,
pub replacement: String,
pub confidence: f64,
pub source: CorrectionSource,
pub context: Option<String>,
}
impl Correction {
pub fn new(
kind: CorrectionKind,
start_byte: usize,
end_byte: usize,
original: impl Into<String>,
replacement: impl Into<String>,
) -> Self {
Self {
kind,
start_byte,
end_byte,
original: original.into(),
replacement: replacement.into(),
confidence: 1.0,
source: CorrectionSource::Unknown,
context: None,
}
}
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn with_source(mut self, source: CorrectionSource) -> Self {
self.source = source;
self
}
pub fn with_context(mut self, context: impl Into<String>) -> Self {
self.context = Some(context.into());
self
}
pub fn edit_distance(&self) -> usize {
let len_diff =
(self.original.len() as isize - self.replacement.len() as isize).unsigned_abs();
if self.original == self.replacement {
0
} else if self.original.is_empty() || self.replacement.is_empty() {
self.original.len().max(self.replacement.len())
} else {
len_diff + 1
}
}
pub fn is_noop(&self) -> bool {
self.original == self.replacement
}
pub fn apply(&self, source: &str) -> String {
let mut result = String::with_capacity(source.len() + self.replacement.len());
result.push_str(&source[..self.start_byte]);
result.push_str(&self.replacement);
result.push_str(&source[self.end_byte..]);
result
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CorrectionKind {
Spelling,
Insertion,
Deletion,
Replacement,
VariableMisuse,
TypeError,
MissingImport,
SyntaxError,
Formatting,
Other,
}
impl CorrectionKind {
pub fn description(&self) -> &'static str {
match self {
CorrectionKind::Spelling => "Spelling correction",
CorrectionKind::Insertion => "Insert missing token",
CorrectionKind::Deletion => "Remove extra token",
CorrectionKind::Replacement => "Replace token",
CorrectionKind::VariableMisuse => "Wrong variable name",
CorrectionKind::TypeError => "Type error",
CorrectionKind::MissingImport => "Missing import",
CorrectionKind::SyntaxError => "Syntax error",
CorrectionKind::Formatting => "Formatting",
CorrectionKind::Other => "Other correction",
}
}
pub fn is_semantic(&self) -> bool {
matches!(
self,
CorrectionKind::VariableMisuse
| CorrectionKind::TypeError
| CorrectionKind::MissingImport
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CorrectionSource {
Lexical,
Grammar,
Neural,
TypeInference,
ControlFlow,
DataFlow,
Combined,
Unknown,
}
pub trait CodeCorrector: Send + Sync {
fn correct_token(&self, token: &CodeToken, context: &TokenContext) -> Vec<Correction>;
fn correct_range(&self, source: &str, start_byte: usize, end_byte: usize) -> Vec<Correction>;
fn max_edit_distance(&self) -> usize {
2
}
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct CorrectionCandidates {
corrections: Vec<Correction>,
max_candidates: usize,
}
impl CorrectionCandidates {
pub fn new(max_candidates: usize) -> Self {
Self {
corrections: Vec::new(),
max_candidates,
}
}
pub fn add(&mut self, correction: Correction) {
self.corrections.push(correction);
self.corrections.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
if self.corrections.len() > self.max_candidates {
self.corrections.truncate(self.max_candidates);
}
}
pub fn add_all(&mut self, corrections: impl IntoIterator<Item = Correction>) {
for c in corrections {
self.add(c);
}
}
pub fn best(&self) -> Option<&Correction> {
self.corrections.first()
}
pub fn ranked(&self) -> &[Correction] {
&self.corrections
}
pub fn len(&self) -> usize {
self.corrections.len()
}
pub fn is_empty(&self) -> bool {
self.corrections.is_empty()
}
pub fn filter_by_confidence(&mut self, min_confidence: f64) {
self.corrections.retain(|c| c.confidence >= min_confidence);
}
pub fn filter_by_kind(&mut self, kind: CorrectionKind) {
self.corrections.retain(|c| c.kind == kind);
}
pub fn filter_by_source(&mut self, source: CorrectionSource) {
self.corrections.retain(|c| c.source == source);
}
}
impl Default for CorrectionCandidates {
fn default() -> Self {
Self::new(10)
}
}
impl IntoIterator for CorrectionCandidates {
type Item = Correction;
type IntoIter = std::vec::IntoIter<Correction>;
fn into_iter(self) -> Self::IntoIter {
self.corrections.into_iter()
}
}
impl<'a> IntoIterator for &'a CorrectionCandidates {
type Item = &'a Correction;
type IntoIter = std::slice::Iter<'a, Correction>;
fn into_iter(self) -> Self::IntoIter {
self.corrections.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_correction_apply() {
let source = "pritn(\"hello\")";
let correction = Correction::new(CorrectionKind::Spelling, 0, 5, "pritn", "print");
let result = correction.apply(source);
assert_eq!(result, "print(\"hello\")");
}
#[test]
fn test_correction_candidates_ranking() {
let mut candidates = CorrectionCandidates::new(3);
candidates.add(
Correction::new(CorrectionKind::Spelling, 0, 5, "pritn", "print").with_confidence(0.9),
);
candidates.add(
Correction::new(CorrectionKind::Spelling, 0, 5, "pritn", "prion").with_confidence(0.5),
);
candidates.add(
Correction::new(CorrectionKind::Spelling, 0, 5, "pritn", "pint").with_confidence(0.7),
);
let ranked = candidates.ranked();
assert_eq!(ranked[0].replacement, "print");
assert_eq!(ranked[1].replacement, "pint");
assert_eq!(ranked[2].replacement, "prion");
}
#[test]
fn test_correction_candidates_max() {
let mut candidates = CorrectionCandidates::new(2);
for i in 0..5 {
candidates.add(
Correction::new(CorrectionKind::Spelling, 0, 5, "test", format!("test{}", i))
.with_confidence(i as f64 / 10.0),
);
}
assert_eq!(candidates.len(), 2);
}
#[test]
fn test_correction_kind_semantic() {
assert!(!CorrectionKind::Spelling.is_semantic());
assert!(!CorrectionKind::SyntaxError.is_semantic());
assert!(CorrectionKind::VariableMisuse.is_semantic());
assert!(CorrectionKind::TypeError.is_semantic());
}
}