use std::sync::Arc;
use crate::backend::LatticeBackend;
use crate::cfg::EarleyParser;
use crate::lattice::{Lattice, LatticeBuilder};
use crate::semiring::Semiring;
use super::grammar::LatexGrammar;
use super::repair::{CompositeRepairStrategy, RepairStrategy, RepairSuggestion};
use super::validator::{LatexValidator, ValidationResult};
use crate::layers::traits::{CorrectionLayer, LayerError, LayerResult};
#[derive(Clone)]
pub struct LatexSyntaxConfig {
pub prune_ungrammatical: bool,
pub validate_structure: bool,
pub generate_repairs: bool,
pub max_repairs_per_issue: usize,
pub auto_repair: bool,
pub auto_repair_threshold: f32,
}
impl Default for LatexSyntaxConfig {
fn default() -> Self {
Self {
prune_ungrammatical: true,
validate_structure: true,
generate_repairs: true,
max_repairs_per_issue: 3,
auto_repair: false,
auto_repair_threshold: 0.9,
}
}
}
impl LatexSyntaxConfig {
pub fn strict() -> Self {
Self {
prune_ungrammatical: true,
validate_structure: true,
generate_repairs: true,
max_repairs_per_issue: 5,
auto_repair: false,
auto_repair_threshold: 0.95,
}
}
pub fn lenient() -> Self {
Self {
prune_ungrammatical: false,
validate_structure: true,
generate_repairs: true,
max_repairs_per_issue: 3,
auto_repair: true,
auto_repair_threshold: 0.85,
}
}
pub fn minimal() -> Self {
Self {
prune_ungrammatical: true,
validate_structure: false,
generate_repairs: false,
max_repairs_per_issue: 0,
auto_repair: false,
auto_repair_threshold: 1.0,
}
}
}
pub struct LatexSyntaxLayer {
grammar: LatexGrammar,
validator: LatexValidator,
repair_strategy: Option<Arc<dyn RepairStrategy>>,
config: LatexSyntaxConfig,
last_repairs: std::sync::Mutex<Vec<RepairSuggestion>>,
}
impl LatexSyntaxLayer {
pub fn new(grammar: LatexGrammar) -> Self {
Self {
grammar,
validator: LatexValidator::new(),
repair_strategy: Some(Arc::new(CompositeRepairStrategy::all())),
config: LatexSyntaxConfig::default(),
last_repairs: std::sync::Mutex::new(Vec::new()),
}
}
pub fn with_config(grammar: LatexGrammar, config: LatexSyntaxConfig) -> Self {
let repair_strategy = if config.generate_repairs {
Some(Arc::new(CompositeRepairStrategy::all()) as Arc<dyn RepairStrategy>)
} else {
None
};
Self {
grammar,
validator: LatexValidator::new(),
repair_strategy,
config,
last_repairs: std::sync::Mutex::new(Vec::new()),
}
}
pub fn with_validator(mut self, validator: LatexValidator) -> Self {
self.validator = validator;
self
}
pub fn with_repair_strategy<S: RepairStrategy + 'static>(mut self, strategy: S) -> Self {
self.repair_strategy = Some(Arc::new(strategy));
self
}
pub fn without_repairs(mut self) -> Self {
self.repair_strategy = None;
self.config.generate_repairs = false;
self
}
pub fn grammar(&self) -> &LatexGrammar {
&self.grammar
}
pub fn config(&self) -> &LatexSyntaxConfig {
&self.config
}
pub fn last_repairs(&self) -> Vec<RepairSuggestion> {
self.last_repairs
.lock()
.expect("layers/latex/syntax.rs: required value was None/Err")
.clone()
}
pub fn validate_tokens(&self, tokens: &[&str]) -> ValidationResult {
self.validator.validate(tokens)
}
fn generate_repairs(
&self,
validation: &ValidationResult,
context: &[&str],
) -> Vec<RepairSuggestion> {
let Some(strategy) = &self.repair_strategy else {
return Vec::new();
};
let mut all_repairs = Vec::new();
for issue in &validation.issues {
let mut repairs = strategy.suggest(issue, context);
repairs.truncate(self.config.max_repairs_per_issue);
all_repairs.extend(repairs);
}
all_repairs.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_repairs
}
}
unsafe impl Send for LatexSyntaxLayer {}
unsafe impl Sync for LatexSyntaxLayer {}
impl<W: Semiring, B: LatticeBackend> CorrectionLayer<W, B> for LatexSyntaxLayer {
fn name(&self) -> &str {
"latex-syntax"
}
fn apply(&self, lattice: &Lattice<W, B>) -> LayerResult<Lattice<W, B>> {
self.last_repairs
.lock()
.expect("layers/latex/syntax.rs: required value was None/Err")
.clear();
if lattice.is_empty() {
return Ok(lattice.clone());
}
let parser = EarleyParser::new(self.grammar.grammar());
let parse_result = parser.parse_lattice(lattice);
let (filtered_lattice, _used_edges) = match parse_result {
Ok(forest) => {
let used_edges = forest.collect_used_edges();
if !self.config.prune_ungrammatical {
(lattice.clone(), None)
} else {
let mut new_builder = LatticeBuilder::new(lattice.backend().clone());
for edge in lattice.edges() {
if used_edges.contains(&edge.id) {
new_builder.add_correction_by_id(
edge.source.0 as usize,
edge.target.0 as usize,
edge.label,
edge.weight,
edge.metadata.clone(),
);
}
}
let end_pos = lattice.end().0 as usize;
(new_builder.build(end_pos), Some(used_edges))
}
}
Err(e) => {
if self.config.prune_ungrammatical {
return Err(LayerError::ParseError(format!(
"LaTeX parse failed: {:?}",
e
)));
}
(lattice.clone(), None)
}
};
if self.config.validate_structure {
let tokens: Vec<String> = filtered_lattice
.edges()
.iter()
.filter_map(|e| {
filtered_lattice
.backend()
.lookup(e.label)
.map(|s| s.to_string())
})
.collect();
let token_refs: Vec<&str> = tokens.iter().map(|s| s.as_str()).collect();
let validation = self.validator.validate(&token_refs);
if !validation.is_valid && self.config.generate_repairs {
let repairs = self.generate_repairs(&validation, &token_refs);
*self
.last_repairs
.lock()
.expect("layers/latex/syntax.rs: required value was None/Err") = repairs;
}
if !validation.is_valid && self.config.prune_ungrammatical && validation.has_errors() {
let error_msg = validation
.errors()
.map(|e| e.message.as_str())
.collect::<Vec<_>>()
.join("; ");
return Err(LayerError::ParseError(format!(
"LaTeX validation failed: {}",
error_msg
)));
}
}
Ok(filtered_lattice)
}
fn can_apply(&self, lattice: &Lattice<W, B>) -> bool {
!lattice.is_empty() || lattice.start() == lattice.end()
}
fn estimated_reduction(&self) -> f64 {
if self.config.prune_ungrammatical {
0.15
} else {
1.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::semiring::TropicalWeight;
#[test]
fn test_layer_name() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
type L = LatexSyntaxLayer;
type W = TropicalWeight;
type B = HashMapBackend;
assert_eq!(<L as CorrectionLayer<W, B>>::name(&layer), "latex-syntax");
}
#[test]
fn test_layer_creation() {
let grammar = LatexGrammar::standard().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
assert!(layer.config.prune_ungrammatical);
assert!(layer.config.validate_structure);
assert!(layer.config.generate_repairs);
}
#[test]
fn test_config_presets() {
let strict = LatexSyntaxConfig::strict();
assert!(strict.prune_ungrammatical);
assert!(!strict.auto_repair);
let lenient = LatexSyntaxConfig::lenient();
assert!(!lenient.prune_ungrammatical);
assert!(lenient.auto_repair);
let minimal = LatexSyntaxConfig::minimal();
assert!(minimal.prune_ungrammatical);
assert!(!minimal.validate_structure);
assert!(!minimal.generate_repairs);
}
#[test]
fn test_with_custom_validator() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let validator = LatexValidator::new()
.with_environment_validation(false)
.with_nested_math(true);
let layer = LatexSyntaxLayer::new(grammar).with_validator(validator);
type L = LatexSyntaxLayer;
type W = TropicalWeight;
type B = HashMapBackend;
assert_eq!(<L as CorrectionLayer<W, B>>::name(&layer), "latex-syntax");
}
#[test]
fn test_without_repairs() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar).without_repairs();
assert!(layer.repair_strategy.is_none());
assert!(!layer.config.generate_repairs);
}
#[test]
fn test_estimated_reduction_prune_mode() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
type L = LatexSyntaxLayer;
type W = TropicalWeight;
type B = HashMapBackend;
let reduction = <L as CorrectionLayer<W, B>>::estimated_reduction(&layer);
assert!((reduction - 0.15).abs() < 0.01);
}
#[test]
fn test_estimated_reduction_no_prune_mode() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let config = LatexSyntaxConfig::lenient();
let layer = LatexSyntaxLayer::with_config(grammar, config);
type L = LatexSyntaxLayer;
type W = TropicalWeight;
type B = HashMapBackend;
let reduction = <L as CorrectionLayer<W, B>>::estimated_reduction(&layer);
assert!((reduction - 1.0).abs() < 0.01);
}
#[test]
fn test_can_apply_empty_lattice() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let empty_lattice = builder.build(0);
assert!(layer.can_apply(&empty_lattice));
}
#[test]
fn test_apply_empty_lattice() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let empty_lattice = builder.build(0);
let result = layer.apply(&empty_lattice);
assert!(result.is_ok());
}
#[test]
fn test_validate_tokens() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
let valid = layer.validate_tokens(&["{", "content", "}"]);
assert!(valid.is_valid);
let invalid = layer.validate_tokens(&["{", "content"]);
assert!(!invalid.is_valid);
}
#[test]
fn test_last_repairs_initially_empty() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
assert!(layer.last_repairs().is_empty());
}
#[test]
fn test_config_access() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let config = LatexSyntaxConfig::strict();
let layer = LatexSyntaxLayer::with_config(grammar, config);
assert!(layer.config().prune_ungrammatical);
assert!(layer.config().validate_structure);
}
#[test]
fn test_grammar_access() {
let grammar = LatexGrammar::minimal().expect("grammar should build");
let layer = LatexSyntaxLayer::new(grammar);
assert!(layer.grammar().grammar().num_productions() > 0);
}
}