mod config;
mod pattern;
mod syntax;
pub use config::{CodeCorrectionConfig, CodeCorrectionLanguage};
pub use pattern::{PatternAwareConfig, PatternAwareLayer, PatternBoost};
pub use syntax::{RecoveryStrategy, SyntaxRecoveryConfig, SyntaxRecoveryLayer};
use std::marker::PhantomData;
use crate::backend::LatticeBackend;
use crate::lattice::Lattice;
use crate::semiring::{Semiring, TropicalWeight};
use super::{CorrectionLayer, LayerResult, LayerStats};
pub struct CodeCorrectionLayer<W: Semiring, B: LatticeBackend> {
config: CodeCorrectionConfig,
syntax_layer: SyntaxRecoveryLayer,
pattern_layer: Option<PatternAwareLayer>,
_phantom: PhantomData<(W, B)>,
}
impl<W: Semiring, B: LatticeBackend> CodeCorrectionLayer<W, B> {
pub fn new(config: CodeCorrectionConfig) -> Self {
let syntax_layer =
SyntaxRecoveryLayer::new(config.syntax_config.clone().unwrap_or_default());
let pattern_layer = config.pattern_config.clone().map(PatternAwareLayer::new);
Self {
config,
syntax_layer,
pattern_layer,
_phantom: PhantomData,
}
}
pub fn for_language(language: &str) -> Self {
Self::new(CodeCorrectionConfig::new(language))
}
pub fn config(&self) -> &CodeCorrectionConfig {
&self.config
}
pub fn syntax_layer(&self) -> &SyntaxRecoveryLayer {
&self.syntax_layer
}
pub fn pattern_layer(&self) -> Option<&PatternAwareLayer> {
self.pattern_layer.as_ref()
}
pub fn has_patterns(&self) -> bool {
self.pattern_layer.is_some()
}
}
impl<W, B> CorrectionLayer<W, B> for CodeCorrectionLayer<W, B>
where
W: Semiring + From<TropicalWeight>,
B: LatticeBackend + Clone,
{
fn name(&self) -> &str {
"code-correction"
}
fn apply(&self, lattice: &Lattice<W, B>) -> LayerResult<Lattice<W, B>> {
if lattice.is_empty() {
return Ok(lattice.clone());
}
let after_syntax = self.syntax_layer.apply(lattice)?;
let result = match &self.pattern_layer {
Some(pattern) => pattern.apply(&after_syntax)?,
None => after_syntax,
};
Ok(result)
}
fn can_apply(&self, _lattice: &Lattice<W, B>) -> bool {
true
}
fn estimated_reduction(&self) -> f64 {
let syntax_factor = self.syntax_layer.estimated_expansion();
let pattern_factor = self
.pattern_layer
.as_ref()
.map(|p| p.estimated_reduction())
.unwrap_or(1.0);
syntax_factor * pattern_factor
}
fn apply_with_stats(
&self,
lattice: &Lattice<W, B>,
) -> LayerResult<(Lattice<W, B>, LayerStats)> {
let start = std::time::Instant::now();
let input_edges = lattice.num_edges();
let result = self.apply(lattice)?;
let output_edges = result.num_edges();
let elapsed = start.elapsed();
let stats = LayerStats {
input_paths: 0, output_paths: 0,
input_edges,
output_edges,
time_us: elapsed.as_micros() as u64,
};
Ok((result, stats))
}
}
impl<W: Semiring, B: LatticeBackend> Clone for CodeCorrectionLayer<W, B> {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
syntax_layer: self.syntax_layer.clone(),
pattern_layer: self.pattern_layer.clone(),
_phantom: PhantomData,
}
}
}
impl<W: Semiring, B: LatticeBackend> std::fmt::Debug for CodeCorrectionLayer<W, B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CodeCorrectionLayer")
.field("config", &self.config)
.field("has_patterns", &self.has_patterns())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::lattice::{EdgeMetadata, LatticeBuilder};
use crate::semiring::TropicalWeight;
fn build_simple_lattice() -> Lattice<TropicalWeight, HashMapBackend> {
let mut backend = HashMapBackend::new();
let def = backend.intern("def");
let foo = backend.intern("foo");
let lparen = backend.intern("(");
let rparen = backend.intern(")");
let colon = backend.intern(":");
let mut builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
builder.add_correction_by_id(0, 1, def, TropicalWeight::one(), EdgeMetadata::default());
builder.add_correction_by_id(1, 2, foo, TropicalWeight::one(), EdgeMetadata::default());
builder.add_correction_by_id(2, 3, lparen, TropicalWeight::one(), EdgeMetadata::default());
builder.add_correction_by_id(3, 4, rparen, TropicalWeight::one(), EdgeMetadata::default());
builder.add_correction_by_id(4, 5, colon, TropicalWeight::one(), EdgeMetadata::default());
builder.build(5)
}
#[test]
fn test_code_correction_layer_creation() {
let config = CodeCorrectionConfig::new("python");
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::new(config);
assert_eq!(layer.config().language.as_str(), "python");
}
#[test]
fn test_code_correction_layer_for_language() {
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::for_language("rust");
assert_eq!(layer.config().language.as_str(), "rust");
}
#[test]
fn test_code_correction_layer_name() {
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::for_language("python");
assert_eq!(
<CodeCorrectionLayer<TropicalWeight, HashMapBackend> as CorrectionLayer<
TropicalWeight,
HashMapBackend,
>>::name(&layer),
"code-correction"
);
}
#[test]
fn test_code_correction_layer_apply() {
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::for_language("python");
let lattice = build_simple_lattice();
let result = layer.apply(&lattice);
assert!(result.is_ok());
let corrected = result.expect("should apply");
assert!(corrected.num_edges() >= lattice.num_edges());
}
#[test]
fn test_code_correction_layer_empty_lattice() {
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::for_language("python");
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let empty_lattice = builder.build(0);
let result = layer.apply(&empty_lattice);
assert!(result.is_ok());
}
#[test]
fn test_code_correction_layer_with_stats() {
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::for_language("python");
let lattice = build_simple_lattice();
let result = layer.apply_with_stats(&lattice);
assert!(result.is_ok());
let (corrected, stats) = result.expect("should apply");
assert_eq!(stats.input_edges, 5);
assert!(stats.output_edges >= 5);
assert!(corrected.num_edges() >= 5);
}
#[test]
fn test_code_correction_layer_clone() {
let config = CodeCorrectionConfig::new("python").with_max_corrections(10);
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::new(config);
let cloned = layer.clone();
assert_eq!(cloned.config().max_corrections_per_token, 10);
}
#[test]
fn test_code_correction_layer_debug() {
let layer: CodeCorrectionLayer<TropicalWeight, HashMapBackend> =
CodeCorrectionLayer::for_language("python");
let debug_str = format!("{:?}", layer);
assert!(debug_str.contains("CodeCorrectionLayer"));
assert!(debug_str.contains("has_patterns"));
}
}