use std::collections::HashMap;
const DEFAULT_ORDER: usize = 3;
const MIN_TOKENS_FOR_CONFIDENCE: usize = 5000;
#[derive(Debug, Clone)]
pub struct NgramModel {
order: usize,
counts: HashMap<String, u32>,
context_counts: HashMap<String, u32>,
unigram_counts: HashMap<String, u32>,
total_tokens: usize,
vocab_size: usize,
confident: bool,
}
impl NgramModel {
pub fn new() -> Self {
Self {
order: DEFAULT_ORDER,
counts: HashMap::new(),
context_counts: HashMap::new(),
unigram_counts: HashMap::new(),
total_tokens: 0,
vocab_size: 0,
confident: false,
}
}
pub fn train_on_tokens(&mut self, tokens: &[String]) {
if tokens.len() < self.order {
return;
}
for token in tokens {
*self.unigram_counts.entry(token.clone()).or_insert(0) += 1;
}
for window in tokens.windows(self.order) {
let ngram = window.join(" ");
let context = window[..self.order - 1].join(" ");
*self.counts.entry(ngram).or_insert(0) += 1;
*self.context_counts.entry(context).or_insert(0) += 1;
}
self.total_tokens += tokens.len();
self.vocab_size = self.unigram_counts.len();
self.confident = self.total_tokens >= MIN_TOKENS_FOR_CONFIDENCE;
}
pub fn tokenize_line(line: &str) -> Vec<String> {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("//") || trimmed.starts_with('#') {
return vec![];
}
let mut tokens = Vec::new();
let mut chars = trimmed.chars().peekable();
while let Some(&ch) = chars.peek() {
tokenize_next_char(&mut chars, ch, &mut tokens);
}
tokens
}
pub fn tokenize_file(content: &str) -> Vec<String> {
tokenize_with_line_attribution(content)
.into_iter()
.map(|(tok, _)| tok)
.collect()
}
pub fn surprisal(&self, tokens: &[String]) -> f64 {
if !self.confident || tokens.len() < self.order {
return 0.0; }
let smoothing = 0.1;
let vocab = self.vocab_size.max(1) as f64;
let mut total_surprisal = 0.0;
let mut count = 0;
for window in tokens.windows(self.order) {
let ngram = window.join(" ");
let context = window[..self.order - 1].join(" ");
let ngram_count = *self.counts.get(&ngram).unwrap_or(&0) as f64;
let context_count = *self.context_counts.get(&context).unwrap_or(&0) as f64;
let prob = if context_count > 0.0 {
(ngram_count + smoothing) / (context_count + smoothing * vocab)
} else {
let target = &window[self.order - 1];
let uni_count = *self.unigram_counts.get(target).unwrap_or(&0) as f64;
(uni_count + smoothing) / (self.total_tokens as f64 + smoothing * vocab)
};
total_surprisal += -prob.log2();
count += 1;
}
if count > 0 {
total_surprisal / count as f64 } else {
0.0
}
}
pub fn line_surprisal(&self, line: &str) -> f64 {
let tokens = Self::tokenize_line(line);
if tokens.len() < self.order {
return 0.0;
}
self.surprisal(&tokens)
}
pub fn function_surprisal(&self, lines: &[&str]) -> (f64, f64, usize) {
if !self.confident || lines.is_empty() {
return (0.0, 0.0, 0);
}
let source = lines.join("\n");
let attributed = tokenize_with_line_attribution(&source);
if attributed.len() < self.order {
return (0.0, 0.0, 0);
}
let mut per_line_total = vec![0.0f64; lines.len()];
let mut per_line_count = vec![0usize; lines.len()];
let smoothing = 0.1;
let vocab = self.vocab_size.max(1) as f64;
for window in attributed.windows(self.order) {
let target_line = window[self.order - 1].1.min(lines.len().saturating_sub(1));
let ngram = window
.iter()
.map(|(t, _)| t.as_str())
.collect::<Vec<_>>()
.join(" ");
let context = window[..self.order - 1]
.iter()
.map(|(t, _)| t.as_str())
.collect::<Vec<_>>()
.join(" ");
let ngram_count = *self.counts.get(&ngram).unwrap_or(&0) as f64;
let context_count = *self.context_counts.get(&context).unwrap_or(&0) as f64;
let prob = if context_count > 0.0 {
(ngram_count + smoothing) / (context_count + smoothing * vocab)
} else {
let target = &window[self.order - 1].0;
let uni_count = *self.unigram_counts.get(target).unwrap_or(&0) as f64;
(uni_count + smoothing) / (self.total_tokens as f64 + smoothing * vocab)
};
let bits = -prob.log2();
per_line_total[target_line] += bits;
per_line_count[target_line] += 1;
}
let mut total = 0.0;
let mut max_surprisal = 0.0f64;
let mut max_line = 0;
let mut scored_lines = 0;
for (i, (&sum, &count)) in per_line_total.iter().zip(per_line_count.iter()).enumerate() {
if count == 0 {
continue;
}
let line_avg = sum / count as f64;
if line_avg <= 0.0 {
continue;
}
total += line_avg;
scored_lines += 1;
if line_avg > max_surprisal {
max_surprisal = line_avg;
max_line = i;
}
}
let avg = if scored_lines > 0 {
total / scored_lines as f64
} else {
0.0
};
(avg, max_surprisal, max_line)
}
pub fn baseline_stats(&self) -> (f64, f64) {
(0.0, 0.0) }
pub fn is_confident(&self) -> bool {
self.confident
}
pub fn total_tokens(&self) -> usize {
self.total_tokens
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn stats_json(&self) -> serde_json::Value {
serde_json::json!({
"order": self.order,
"total_tokens": self.total_tokens,
"vocab_size": self.vocab_size,
"ngram_count": self.counts.len(),
"confident": self.confident,
})
}
}
impl Default for NgramModel {
fn default() -> Self {
Self::new()
}
}
fn consume_number(chars: &mut std::iter::Peekable<std::str::Chars>) {
while chars
.peek()
.is_some_and(|c| c.is_ascii_alphanumeric() || *c == '.' || *c == 'x' || *c == '_')
{
chars.next();
}
}
fn tokenize_next_char(
chars: &mut std::iter::Peekable<std::str::Chars>,
ch: char,
tokens: &mut Vec<String>,
) {
match ch {
' ' | '\t' => drop(chars.next()),
'"' | '\'' | '`' => {
chars.next();
consume_string_literal(chars, ch);
tokens.push("<STR>".to_string());
}
'0'..='9' => {
consume_number(chars);
tokens.push("<NUM>".to_string());
}
'a'..='z' | 'A'..='Z' | '_' => {
let word = consume_identifier(chars);
tokens.push(classify_identifier(word));
}
_ => tokens.push(consume_operator(chars)),
}
}
fn consume_string_literal(chars: &mut std::iter::Peekable<std::str::Chars>, quote: char) {
while let Some(&c) = chars.peek() {
chars.next();
if c == quote {
break;
}
if c == '\\' {
chars.next();
} }
}
fn consume_identifier(chars: &mut std::iter::Peekable<std::str::Chars>) -> String {
let mut word = String::new();
while chars
.peek()
.is_some_and(|c| c.is_ascii_alphanumeric() || *c == '_')
{
if let Some(c) = chars.next() {
word.push(c);
}
}
word
}
fn classify_identifier(word: String) -> String {
if is_keyword(&word) {
word
} else if word.chars().all(|c| c.is_uppercase() || c == '_') {
"<CONST>".to_string()
} else if word.starts_with(|c: char| c.is_uppercase()) {
"<TYPE>".to_string()
} else {
"<ID>".to_string()
}
}
fn consume_operator(chars: &mut std::iter::Peekable<std::str::Chars>) -> String {
let mut op = String::new();
let Some(first) = chars.next() else {
return op;
};
op.push(first);
let Some(&next) = chars.peek() else { return op };
let two = format!("{}{}", op, next);
if !matches!(
two.as_str(),
"==" | "!="
| ">="
| "<="
| "&&"
| "||"
| "->"
| "=>"
| "::"
| "+="
| "-="
| "*="
| "/="
| ".."
| "<<"
| ">>"
) {
return op;
}
chars.next();
op = two;
let Some(&third) = chars.peek() else {
return op;
};
let three = format!("{}{}", op, third);
if matches!(
three.as_str(),
"===" | "!==" | "..." | ">>>" | "<<=" | ">>="
) {
chars.next();
op = three;
}
op
}
fn tokenize_with_line_attribution(content: &str) -> Vec<(String, usize)> {
let bytes = content.as_bytes();
let mut tokens: Vec<(String, usize)> = Vec::new();
let mut i = 0;
let mut line = 0usize;
let mut line_had_code_token = false;
while i < bytes.len() {
let b = bytes[i];
if b == b'\n' {
if line_had_code_token {
tokens.push(("<EOL>".to_string(), line));
}
line += 1;
line_had_code_token = false;
i += 1;
continue;
}
if b == b' ' || b == b'\t' || b == b'\r' {
i += 1;
continue;
}
if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
continue;
}
if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
i += 2;
let mut depth = 1usize;
while i < bytes.len() && depth > 0 {
if bytes[i] == b'\n' {
line += 1;
i += 1;
} else if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
depth += 1;
i += 2;
} else if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
depth -= 1;
i += 2;
} else {
i += 1;
}
}
continue;
}
if b == b'r' && i + 1 < bytes.len() {
let mut k = i + 1;
let mut hashes = 0usize;
while k < bytes.len() && bytes[k] == b'#' {
hashes += 1;
k += 1;
}
if k < bytes.len() && bytes[k] == b'"' {
let start_line = line;
i = k + 1;
loop {
if i >= bytes.len() {
break;
}
if bytes[i] == b'\n' {
line += 1;
i += 1;
continue;
}
if bytes[i] == b'"' {
let mut closing = 0usize;
let mut j = i + 1;
while closing < hashes && j < bytes.len() && bytes[j] == b'#' {
closing += 1;
j += 1;
}
if closing == hashes {
i = j;
break;
}
}
i += 1;
}
tokens.push(("<STR>".to_string(), start_line));
line_had_code_token = true;
continue;
}
}
if (b == b'"' || b == b'\'')
&& i + 2 < bytes.len()
&& bytes[i + 1] == b
&& bytes[i + 2] == b
{
let quote = b;
let start_line = line;
i += 3;
loop {
if i >= bytes.len() {
break;
}
if bytes[i] == b'\n' {
line += 1;
i += 1;
continue;
}
if bytes[i] == b'\\' && i + 1 < bytes.len() {
if bytes[i + 1] == b'\n' {
line += 1;
}
i += 2;
continue;
}
if bytes[i] == quote
&& i + 2 < bytes.len()
&& bytes[i + 1] == quote
&& bytes[i + 2] == quote
{
i += 3;
break;
}
i += 1;
}
tokens.push(("<STR>".to_string(), start_line));
line_had_code_token = true;
continue;
}
if b == b'"' || b == b'\'' || b == b'`' {
let quote = b;
let start_line = line;
i += 1;
while i < bytes.len() && bytes[i] != quote && bytes[i] != b'\n' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
i += 2;
} else {
i += 1;
}
}
if i < bytes.len() && bytes[i] == quote {
i += 1;
}
tokens.push(("<STR>".to_string(), start_line));
line_had_code_token = true;
continue;
}
if b.is_ascii_digit() {
let start_line = line;
while i < bytes.len()
&& (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'.' || bytes[i] == b'_')
{
i += 1;
}
tokens.push(("<NUM>".to_string(), start_line));
line_had_code_token = true;
continue;
}
if b.is_ascii_alphabetic() || b == b'_' {
let start_line = line;
let start = i;
while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
let word = std::str::from_utf8(&bytes[start..i])
.unwrap_or("")
.to_string();
tokens.push((classify_identifier(word), start_line));
line_had_code_token = true;
continue;
}
let start_line = line;
let remaining = &content[i..];
let mut chars = remaining.chars().peekable();
let op = consume_operator(&mut chars);
if op.is_empty() {
i += 1;
} else {
i += op.len();
tokens.push((op, start_line));
line_had_code_token = true;
}
}
if line_had_code_token {
tokens.push(("<EOL>".to_string(), line));
}
tokens
}
fn is_keyword(word: &str) -> bool {
matches!(
word,
"if" | "else" | "elif" | "for" | "while" | "do" | "loop"
| "break" | "continue" | "return" | "yield" | "switch" | "case" | "default"
| "match" | "when" | "select" | "range"
| "try" | "catch" | "except" | "finally" | "throw" | "throws" | "raise"
| "fn" | "func" | "def" | "function" | "let" | "var" | "val" | "const"
| "static" | "auto" | "type" | "typedef"
| "class" | "struct" | "enum" | "trait" | "interface" | "impl"
| "extends" | "implements" | "abstract" | "sealed" | "final"
| "override" | "virtual" | "explicit" | "friend" | "operator"
| "object" | "companion" | "data"
| "pub" | "private" | "protected" | "public" | "readonly"
| "use" | "mod" | "import" | "export" | "from" | "package"
| "as" | "crate" | "super" | "namespace" | "include"
| "mut" | "ref" | "move" | "dyn" | "unsafe" | "extern"
| "async" | "await" | "defer" | "go"
| "true" | "false" | "True" | "False" | "null" | "nil" | "None"
| "undefined" | "NaN" | "Infinity"
| "self" | "Self" | "this" | "new" | "delete" | "del"
| "Box" | "Vec" | "Option" | "Result" | "Some" | "Ok" | "Err"
| "and" | "or" | "not" | "is" | "in"
| "lambda" | "pass" | "assert" | "global" | "nonlocal" | "with"
| "typeof" | "instanceof" | "void"
| "chan" | "map" | "make" | "append" | "len" | "cap"
| "synchronized" | "volatile" | "transient" | "native"
| "register" | "sizeof" | "union" | "goto" | "inline" | "restrict"
| "template" | "noexcept" | "constexpr"
| "define" | "ifdef" | "ifndef" | "endif" | "pragma"
| "where"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize_line() {
let tokens = NgramModel::tokenize_line("let mut count = 0;");
assert_eq!(tokens, vec!["let", "mut", "<ID>", "=", "<NUM>", ";"]);
}
#[test]
fn test_tokenize_string_literal() {
let tokens = NgramModel::tokenize_line(r#"println!("hello world");"#);
assert!(tokens.contains(&"<STR>".to_string()));
}
#[test]
fn test_tokenize_type() {
let tokens = NgramModel::tokenize_line("let x: HashMap<String, u32> = HashMap::new();");
assert!(tokens.contains(&"<TYPE>".to_string()));
}
#[test]
fn test_model_training() {
let mut model = NgramModel::new();
for _ in 0..800 {
model.train_on_tokens(&[
"let".to_string(),
"mut".to_string(),
"<ID>".to_string(),
"=".to_string(),
"<NUM>".to_string(),
";".to_string(),
"<EOL>".to_string(),
]);
}
assert!(model.total_tokens() > 1000);
assert!(model.is_confident());
}
#[test]
fn test_surprisal_familiar_vs_unusual() {
let mut model = NgramModel::new();
for _ in 0..500 {
model.train_on_tokens(&[
"let".to_string(),
"<ID>".to_string(),
"=".to_string(),
"<ID>".to_string(),
".".to_string(),
"<ID>".to_string(),
"(".to_string(),
")".to_string(),
";".to_string(),
"<EOL>".to_string(),
]);
}
let familiar = vec![
"let".to_string(),
"<ID>".to_string(),
"=".to_string(),
"<ID>".to_string(),
".".to_string(),
"<ID>".to_string(),
"(".to_string(),
")".to_string(),
";".to_string(),
];
let unusual = vec![
"unsafe".to_string(),
"{".to_string(),
"<ID>".to_string(),
"::".to_string(),
"<ID>".to_string(),
"(".to_string(),
"&".to_string(),
"mut".to_string(),
"<ID>".to_string(),
];
let s_familiar = model.surprisal(&familiar);
let s_unusual = model.surprisal(&unusual);
assert!(
s_unusual > s_familiar,
"Unusual code ({:.2}) should be more surprising than familiar code ({:.2})",
s_unusual,
s_familiar
);
}
#[test]
fn test_not_confident_returns_zero() {
let model = NgramModel::new(); let tokens = vec!["let".to_string(), "<ID>".to_string(), "=".to_string()];
assert_eq!(model.surprisal(&tokens), 0.0);
}
#[test]
fn test_tokenize_file_collapses_rust_raw_string_across_newlines() {
let source = "let template = r#\"\ndef foo(x):\n return x + 1\n\"#;\n";
let tokens = NgramModel::tokenize_file(source);
assert!(
!tokens.iter().any(|t| t == "def" || t == "return"),
"Multi-line raw string contents leaked as tokens: {:?}",
tokens
);
let str_count = tokens.iter().filter(|t| *t == "<STR>").count();
assert_eq!(
str_count, 1,
"Expected one <STR> token for the raw string, got {} in {:?}",
str_count, tokens
);
}
#[test]
fn test_tokenize_file_collapses_python_triple_string_across_newlines() {
let source = "msg = \"\"\"\nhello\nworld\n\"\"\"\n";
let tokens = NgramModel::tokenize_file(source);
let str_count = tokens.iter().filter(|t| *t == "<STR>").count();
assert_eq!(
str_count, 1,
"Expected one <STR> for the triple-quoted string, got {} in {:?}",
str_count, tokens
);
}
#[test]
fn test_tokenize_file_skips_multiline_block_comment() {
let source = "let x = 1;\n/* multi\n * line\n * comment */\nlet y = 2;\n";
let tokens = NgramModel::tokenize_file(source);
let lets = tokens.iter().filter(|t| *t == "let").count();
assert_eq!(lets, 2, "Expected 2 `let` tokens, got {:?}", tokens);
let stars = tokens.iter().filter(|t| *t == "*").count();
assert_eq!(
stars, 0,
"Block-comment `*` leaked as operator: {:?}",
tokens
);
}
#[test]
fn test_function_surprisal_treats_multiline_raw_string_as_one_token() {
let mut model = NgramModel::new();
for _ in 0..1000 {
model.train_on_tokens(&[
"fn".to_string(),
"<ID>".to_string(),
"(".to_string(),
")".to_string(),
"->".to_string(),
"<TYPE>".to_string(),
"{".to_string(),
"<EOL>".to_string(),
"<STR>".to_string(),
"<EOL>".to_string(),
"}".to_string(),
"<EOL>".to_string(),
]);
}
assert!(model.is_confident());
let lines = vec![
"fn template() -> String {",
" r#\"",
"def handler(req):",
" if req.error:",
" return 500",
" return 200",
"\"#",
"}",
];
let (avg, _max, _peak) = model.function_surprisal(&lines);
assert!(
avg < 5.0,
"Function with raw-string Python body scored {:.2} bits — \
stateful tokenizer should collapse the string to <STR>",
avg
);
}
}