use crate::latex::embedding::LaTeXEmbedder;
use crate::latex::tokenizer::{LaTeXToken, LaTeXTokenKind};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct RescorerConfig {
pub neural_weight: f64,
pub ngram_weight: f64,
pub embedding_weight: f64,
pub batch_size: usize,
pub max_length: usize,
pub use_gpu: bool,
pub model_name: String,
}
impl Default for RescorerConfig {
fn default() -> Self {
Self {
neural_weight: 0.4,
ngram_weight: 0.4,
embedding_weight: 0.2,
batch_size: 16,
max_length: 512,
use_gpu: false,
model_name: "modernbert-latex".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct RescoreResult {
pub sequence: String,
pub score: f64,
pub neural_score: Option<f64>,
pub ngram_score: Option<f64>,
pub embedding_score: Option<f64>,
pub confidence: f64,
}
impl RescoreResult {
pub fn new(sequence: String, score: f64) -> Self {
Self {
sequence,
score,
neural_score: None,
ngram_score: None,
embedding_score: None,
confidence: 1.0,
}
}
pub fn with_components(
mut self,
neural: Option<f64>,
ngram: Option<f64>,
embedding: Option<f64>,
) -> Self {
self.neural_score = neural;
self.ngram_score = ngram;
self.embedding_score = embedding;
self
}
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence;
self
}
}
pub struct LaTeXRescorer {
config: RescorerConfig,
embedder: Option<Arc<LaTeXEmbedder>>,
cache: HashMap<String, f64>,
model_loaded: bool,
#[cfg(feature = "neural-rescore")]
neural_model: Option<crate::neural::ModernBertRescorer>,
}
impl LaTeXRescorer {
pub fn new() -> Self {
Self::with_config(RescorerConfig::default())
}
pub fn with_config(config: RescorerConfig) -> Self {
Self {
config,
embedder: None,
cache: HashMap::new(),
model_loaded: false,
#[cfg(feature = "neural-rescore")]
neural_model: None,
}
}
pub fn with_embedder(mut self, embedder: Arc<LaTeXEmbedder>) -> Self {
self.embedder = Some(embedder);
self
}
#[cfg(feature = "neural-rescore")]
pub fn load_model(&mut self, config: crate::neural::RescoringConfig) -> crate::Result<()> {
use crate::neural::ModernBertRescorer;
let rescorer = ModernBertRescorer::new(config)?;
self.neural_model = Some(rescorer);
self.model_loaded = true;
Ok(())
}
pub fn is_model_loaded(&self) -> bool {
self.model_loaded
}
pub fn has_embedder(&self) -> bool {
self.embedder.is_some()
}
pub fn rescore(&mut self, tokens: &[LaTeXToken]) -> RescoreResult {
let sequence = tokens_to_string(tokens);
if let Some(&cached_score) = self.cache.get(&sequence) {
return RescoreResult::new(sequence, cached_score);
}
let neural_score = self.compute_neural_score(tokens);
let ngram_score = None; let embedding_score = self.compute_embedding_score(tokens);
let score = self.combine_scores(neural_score, ngram_score, embedding_score);
self.cache.insert(sequence.clone(), score);
RescoreResult::new(sequence, score).with_components(
neural_score,
ngram_score,
embedding_score,
)
}
pub fn rescore_batch(&mut self, sequences: &[&[LaTeXToken]]) -> Vec<RescoreResult> {
sequences
.iter()
.map(|tokens| self.rescore(tokens))
.collect()
}
fn compute_neural_score(&self, tokens: &[LaTeXToken]) -> Option<f64> {
if !self.model_loaded {
return None;
}
#[cfg(feature = "neural-rescore")]
if let Some(ref _model) = self.neural_model {
let _text = tokens_to_string(tokens);
return Some(self.heuristic_neural_score(tokens));
}
let score = self.heuristic_neural_score(tokens);
Some(score)
}
fn compute_embedding_score(&self, tokens: &[LaTeXToken]) -> Option<f64> {
match &self.embedder {
Some(embedder) => {
let commands: Vec<&str> = tokens
.iter()
.filter_map(|t| match &t.kind {
LaTeXTokenKind::Command(name) => Some(name.as_str()),
_ => None,
})
.collect();
if commands.len() < 2 {
return Some(self.compute_validity_score(tokens));
}
let seq_embedding = embedder.sequence_embedding(&commands);
let mut total_sim = 0.0f32;
let mut count = 0;
for cmd in &commands {
if embedder.contains_command(cmd) {
let cmd_vec = embedder.command_vector(cmd);
let sim = cosine_similarity_f32(&seq_embedding, cmd_vec);
total_sim += sim;
count += 1;
}
}
if count > 0 {
Some((total_sim / count as f32) as f64)
} else {
Some(self.compute_validity_score(tokens))
}
}
None => {
Some(self.compute_validity_score(tokens))
}
}
}
fn compute_validity_score(&self, tokens: &[LaTeXToken]) -> f64 {
let valid_tokens = tokens
.iter()
.filter(|t| !matches!(&t.kind, LaTeXTokenKind::Unknown(_)))
.count();
let total = tokens.len().max(1);
valid_tokens as f64 / total as f64
}
fn combine_scores(
&self,
neural: Option<f64>,
ngram: Option<f64>,
embedding: Option<f64>,
) -> f64 {
let mut total_weight = 0.0;
let mut weighted_sum = 0.0;
if let Some(s) = neural {
weighted_sum += s * self.config.neural_weight;
total_weight += self.config.neural_weight;
}
if let Some(s) = ngram {
weighted_sum += s * self.config.ngram_weight;
total_weight += self.config.ngram_weight;
}
if let Some(s) = embedding {
weighted_sum += s * self.config.embedding_weight;
total_weight += self.config.embedding_weight;
}
if total_weight > 0.0 {
weighted_sum / total_weight
} else {
0.0
}
}
fn heuristic_neural_score(&self, tokens: &[LaTeXToken]) -> f64 {
if tokens.is_empty() {
return 0.0;
}
let mut score = 0.0;
let mut brace_depth = 0i32;
let mut math_depth = 0i32;
for token in tokens {
match &token.kind {
LaTeXTokenKind::Command(_) => score += 0.1,
LaTeXTokenKind::OpenBrace(_) => {
brace_depth += 1;
score += 0.05;
}
LaTeXTokenKind::CloseBrace(_) => {
brace_depth -= 1;
if brace_depth < 0 {
score -= 0.5; } else {
score += 0.05;
}
}
LaTeXTokenKind::MathOpen(_) => {
math_depth += 1;
score += 0.1;
}
LaTeXTokenKind::MathClose(_) => {
math_depth -= 1;
if math_depth < 0 {
score -= 0.5; } else {
score += 0.1;
}
}
LaTeXTokenKind::Unknown(_) => score -= 0.2,
_ => score += 0.02,
}
}
score -= brace_depth.abs() as f64 * 0.3;
score -= math_depth.abs() as f64 * 0.3;
let normalized = (score / tokens.len() as f64).clamp(-1.0, 1.0);
(normalized + 1.0) / 2.0
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
pub fn cache_stats(&self) -> (usize, usize) {
(self.cache.len(), self.cache.capacity())
}
pub fn config(&self) -> &RescorerConfig {
&self.config
}
}
impl Default for LaTeXRescorer {
fn default() -> Self {
Self::new()
}
}
fn tokens_to_string(tokens: &[LaTeXToken]) -> String {
tokens.iter().map(|t| t.text()).collect::<Vec<_>>().join("")
}
fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a * norm_b).sqrt();
if denom > 0.0 {
dot / denom
} else {
0.0
}
}
#[derive(Debug, Clone)]
pub struct RescoreCandidate {
pub tokens: Vec<LaTeXToken>,
pub prior_score: f64,
pub source: String,
}
impl RescoreCandidate {
pub fn new(tokens: Vec<LaTeXToken>, prior_score: f64, source: &str) -> Self {
Self {
tokens,
prior_score,
source: source.to_string(),
}
}
pub fn text(&self) -> String {
tokens_to_string(&self.tokens)
}
}
pub struct BatchRescorer {
rescorer: LaTeXRescorer,
max_batch_size: usize,
}
impl BatchRescorer {
pub fn new(rescorer: LaTeXRescorer) -> Self {
let max_batch_size = rescorer.config.batch_size;
Self {
rescorer,
max_batch_size,
}
}
pub fn rescore_candidates(
&mut self,
candidates: &[RescoreCandidate],
) -> Vec<(RescoreCandidate, RescoreResult)> {
let mut results = Vec::with_capacity(candidates.len());
for chunk in candidates.chunks(self.max_batch_size.max(1)) {
for candidate in chunk {
let result = self.rescorer.rescore(&candidate.tokens);
results.push((candidate.clone(), result));
}
}
results.sort_by(|a, b| {
let score_a = a.0.prior_score + a.1.score;
let score_b = b.0.prior_score + b.1.score;
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
pub fn top_k(
&mut self,
candidates: &[RescoreCandidate],
k: usize,
) -> Vec<(RescoreCandidate, RescoreResult)> {
let mut results = self.rescore_candidates(candidates);
results.truncate(k);
results
}
pub fn best(
&mut self,
candidates: &[RescoreCandidate],
) -> Option<(RescoreCandidate, RescoreResult)> {
self.top_k(candidates, 1).into_iter().next()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::latex::tokenizer::LaTeXTokenizer;
#[test]
fn test_rescore_basic() {
let tokenizer = LaTeXTokenizer::new();
let mut rescorer = LaTeXRescorer::new();
let tokens = tokenizer.tokenize(r"\alpha + \beta");
let result = rescorer.rescore(&tokens);
assert!(!result.sequence.is_empty());
assert!(result.score >= 0.0 && result.score <= 1.0);
}
#[test]
fn test_rescore_unbalanced() {
let tokenizer = LaTeXTokenizer::new();
let mut rescorer = LaTeXRescorer::new();
let balanced = tokenizer.tokenize(r"\frac{a}{b}");
let unbalanced = tokenizer.tokenize(r"\frac{a}{b");
let balanced_result = rescorer.rescore(&balanced);
let unbalanced_result = rescorer.rescore(&unbalanced);
assert!(balanced_result.score >= unbalanced_result.score);
}
#[test]
fn test_cache() {
let tokenizer = LaTeXTokenizer::new();
let mut rescorer = LaTeXRescorer::new();
let tokens = tokenizer.tokenize(r"\alpha");
let result1 = rescorer.rescore(&tokens);
assert_eq!(rescorer.cache_stats().0, 1);
let result2 = rescorer.rescore(&tokens);
assert_eq!(result1.score, result2.score);
rescorer.clear_cache();
assert_eq!(rescorer.cache_stats().0, 0);
}
#[test]
fn test_batch_rescorer() {
let tokenizer = LaTeXTokenizer::new();
let rescorer = LaTeXRescorer::new();
let mut batch = BatchRescorer::new(rescorer);
let candidates = vec![
RescoreCandidate::new(tokenizer.tokenize(r"\alpha"), 0.5, "lexical"),
RescoreCandidate::new(tokenizer.tokenize(r"\beta"), 0.6, "lexical"),
RescoreCandidate::new(tokenizer.tokenize(r"\gamma"), 0.4, "lexical"),
];
let results = batch.top_k(&candidates, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_combine_scores() {
let rescorer = LaTeXRescorer::new();
let combined = rescorer.combine_scores(Some(0.8), Some(0.6), Some(0.7));
assert!(combined > 0.0 && combined < 1.0);
let neural_only = rescorer.combine_scores(Some(0.9), None, None);
assert!((neural_only - 0.9).abs() < 1e-6);
}
}