use crate::calibrate::NgramModel;
use std::collections::HashMap;
pub struct TokenLevelScorer {
pub models: HashMap<String, NgramModel>,
}
impl TokenLevelScorer {
pub fn new() -> Self {
Self {
models: HashMap::new(),
}
}
pub fn train_file(&mut self, content: &str, extension: &str) {
let lang = normalize_extension(extension);
let model = self.models.entry(lang).or_default();
let tokens = NgramModel::tokenize_file(content);
model.train_on_tokens(&tokens);
}
pub fn score_function(&self, lines: &[&str], extension: &str) -> f64 {
let lang = normalize_extension(extension);
let Some(model) = self.models.get(&lang) else {
return 0.0;
};
if !model.is_confident() {
return 0.0;
}
let (avg, _, _) = model.function_surprisal(lines);
avg
}
pub fn is_confident(&self, extension: &str) -> bool {
let lang = normalize_extension(extension);
self.models.get(&lang).is_some_and(|m| m.is_confident())
}
}
impl Default for TokenLevelScorer {
fn default() -> Self {
Self::new()
}
}
fn normalize_extension(ext: &str) -> String {
match ext {
"ts" | "tsx" => "ts".to_string(),
"js" | "jsx" => "js".to_string(),
"cc" | "cpp" | "cxx" | "hpp" => "cpp".to_string(),
"h" => "c".to_string(),
other => other.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rust_training_source() -> String {
(0..800)
.map(|i| format!("let mut count_{i} = {i};"))
.collect::<Vec<_>>()
.join("\n")
}
fn python_training_source() -> String {
(0..800)
.map(|i| format!("result_{i} = process(value_{i})"))
.collect::<Vec<_>>()
.join("\n")
}
#[test]
fn test_per_language_models_trained_separately() {
let mut scorer = TokenLevelScorer::new();
scorer.train_file(&rust_training_source(), "rs");
scorer.train_file(&python_training_source(), "py");
assert_eq!(
scorer.models.len(),
2,
"Expected 2 separate models (rs + py)"
);
assert!(scorer.models.contains_key("rs"), "Missing Rust model");
assert!(scorer.models.contains_key("py"), "Missing Python model");
assert!(scorer.models["rs"].total_tokens() > 0);
assert!(scorer.models["py"].total_tokens() > 0);
}
#[test]
fn test_score_function_returns_zero_without_confidence() {
let scorer = TokenLevelScorer::new(); let lines = vec!["let x = 42;", "println!(x);"];
let score = scorer.score_function(&lines, "rs");
assert_eq!(score, 0.0, "Empty model should return 0.0");
}
#[test]
fn test_score_function_returns_zero_for_low_confidence() {
let mut scorer = TokenLevelScorer::new();
scorer.train_file("let x = 1;\nlet y = 2;\n", "rs");
assert!(
!scorer.is_confident("rs"),
"Model should not be confident with so few tokens"
);
let lines = vec!["let x = 42;"];
let score = scorer.score_function(&lines, "rs");
assert_eq!(score, 0.0, "Under-trained model should return 0.0");
}
#[test]
fn test_normalize_extensions() {
assert_eq!(normalize_extension("tsx"), "ts");
assert_eq!(normalize_extension("ts"), "ts");
assert_eq!(normalize_extension("jsx"), "js");
assert_eq!(normalize_extension("js"), "js");
assert_eq!(normalize_extension("cc"), "cpp");
assert_eq!(normalize_extension("cpp"), "cpp");
assert_eq!(normalize_extension("cxx"), "cpp");
assert_eq!(normalize_extension("hpp"), "cpp");
assert_eq!(normalize_extension("h"), "c");
assert_eq!(normalize_extension("rs"), "rs");
assert_eq!(normalize_extension("py"), "py");
assert_eq!(normalize_extension("go"), "go");
}
#[test]
fn test_is_confident_after_sufficient_training() {
let mut scorer = TokenLevelScorer::new();
assert!(
!scorer.is_confident("rs"),
"Should not be confident before training"
);
scorer.train_file(&rust_training_source(), "rs");
assert!(
scorer.is_confident("rs"),
"Should be confident after training with {} tokens",
scorer.models["rs"].total_tokens()
);
}
#[test]
fn test_extension_variants_share_model() {
let mut scorer = TokenLevelScorer::new();
scorer.train_file("const x: number = 1;\n", "ts");
scorer.train_file("const y: string = 'hi';\n", "tsx");
assert_eq!(
scorer.models.len(),
1,
"ts and tsx should share a single model"
);
assert!(scorer.models.contains_key("ts"));
}
#[test]
fn test_score_function_produces_nonzero_for_confident_model() {
let mut scorer = TokenLevelScorer::new();
scorer.train_file(&rust_training_source(), "rs");
assert!(scorer.is_confident("rs"));
let lines = vec![
"unsafe { std::ptr::write(addr, value) }",
"extern \"C\" fn callback(ptr: *mut u8) -> i32 {",
];
let score = scorer.score_function(&lines, "rs");
assert!(
score > 0.0,
"Confident model should produce non-zero surprisal for unusual code, got {}",
score
);
}
#[test]
fn test_unknown_language_returns_zero() {
let scorer = TokenLevelScorer::new();
let lines = vec!["some code here"];
assert_eq!(
scorer.score_function(&lines, "zig"),
0.0,
"Unknown language should return 0.0"
);
assert!(!scorer.is_confident("zig"));
}
}