use std::collections::HashSet;
use std::io::Write;
use flatbuffers;
use rand::SeedableRng;
use rand::distr::Distribution;
use rand::distr::weighted::WeightedIndex;
use rand::rngs::StdRng;
use crate::trie::CountTrie;
#[cfg(feature = "pyo3")]
mod py;
#[cfg(feature = "pyo3")]
pub(crate) use py::register_module;
#[cfg(feature = "pyo3")]
pub use py::{PyLaplace, PyLidstone, PyMLE};
#[allow(dead_code, unused_imports, clippy::all)]
mod generated {
include!(concat!(env!("OUT_DIR"), "/lm/model_generated.rs"));
}
#[derive(Clone, Debug)]
pub enum Smoothing {
Mle,
Lidstone { gamma: f64 },
}
impl Smoothing {
pub fn gamma(&self) -> Option<f64> {
match self {
Smoothing::Mle => None,
Smoothing::Lidstone { gamma } => Some(*gamma),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Vocabulary {
words: HashSet<String>,
}
pub const UNK_LABEL: &str = "<UNK>";
pub const BOS_LABEL: &str = "<s>";
pub const EOS_LABEL: &str = "</s>";
impl Vocabulary {
pub fn new() -> Self {
Self::default()
}
pub fn build(sents: &[Vec<String>]) -> Self {
let mut words = HashSet::new();
for sent in sents {
for word in sent {
words.insert(word.clone());
}
}
words.insert(UNK_LABEL.to_string());
words.insert(BOS_LABEL.to_string());
words.insert(EOS_LABEL.to_string());
Self { words }
}
pub fn lookup(&self, word: &str) -> String {
if self.words.contains(word) {
word.to_string()
} else {
UNK_LABEL.to_string()
}
}
pub fn len(&self) -> usize {
self.words.len()
}
pub fn is_empty(&self) -> bool {
self.words.is_empty()
}
pub fn words(&self) -> &HashSet<String> {
&self.words
}
pub fn from_words(words: HashSet<String>) -> Self {
Self { words }
}
}
pub trait BaseLanguageModel: Sized + Clone {
fn order(&self) -> usize;
fn smoothing(&self) -> &Smoothing;
fn vocabulary(&self) -> &Vocabulary;
fn vocabulary_mut(&mut self) -> &mut Vocabulary;
fn counts(&self) -> &CountTrie<String>;
fn counts_mut(&mut self) -> &mut CountTrie<String>;
fn fitted(&self) -> bool;
fn set_fitted(&mut self, fitted: bool);
fn smoothing_name(&self) -> &str;
fn fit(&mut self, sents: Vec<Vec<String>>) {
*self.vocabulary_mut() = Vocabulary::build(&sents);
*self.counts_mut() = CountTrie::new();
for sent in &sents {
let mut padded: Vec<String> = Vec::with_capacity(self.order() - 1 + sent.len() + 1);
for _ in 0..self.order().saturating_sub(1) {
padded.push(BOS_LABEL.to_string());
}
for word in sent {
padded.push(word.clone());
}
padded.push(EOS_LABEL.to_string());
for n in 1..=self.order() {
for window in padded.windows(n) {
self.counts_mut().increment(window.iter().cloned());
}
}
}
self.set_fitted(true);
}
fn compute_score(&self, word: &str, context: &[String]) -> f64 {
let ctx = if context.len() >= self.order() {
&context[context.len() - (self.order() - 1)..]
} else {
context
};
let mut ngram: Vec<String> = ctx.to_vec();
ngram.push(word.to_string());
let word_count = self.counts().get_count(ngram.iter().cloned()) as f64;
let context_count = self.counts().children_count_sum(ctx.iter().cloned()) as f64;
match self.smoothing() {
Smoothing::Mle => {
if context_count == 0.0 {
0.0
} else {
word_count / context_count
}
}
Smoothing::Lidstone { gamma } => {
let vocab_size = self.vocabulary().len() as f64;
let numerator = word_count + gamma;
let denominator = context_count + vocab_size * gamma;
if denominator == 0.0 {
0.0
} else {
numerator / denominator
}
}
}
}
fn score(&self, word: String, context: Option<Vec<String>>) -> Result<f64, ModelError> {
if !self.fitted() {
return Err(ModelError::ValidationError(
"Model has not been fitted yet.".to_string(),
));
}
let word = self.vocabulary().lookup(&word);
let context: Vec<String> = context
.unwrap_or_default()
.iter()
.map(|w| self.vocabulary().lookup(w))
.collect();
Ok(self.compute_score(&word, &context))
}
fn unmasked_score(
&self,
word: String,
context: Option<Vec<String>>,
) -> Result<f64, ModelError> {
if !self.fitted() {
return Err(ModelError::ValidationError(
"Model has not been fitted yet.".to_string(),
));
}
let context = context.unwrap_or_default();
Ok(self.compute_score(&word, &context))
}
fn logscore(&self, word: String, context: Option<Vec<String>>) -> Result<f64, ModelError> {
let s = self.score(word, context)?;
if s == 0.0 {
Ok(f64::NEG_INFINITY)
} else {
Ok(s.log2())
}
}
fn generate(
&self,
num_words: usize,
text_seed: Option<Vec<String>>,
random_seed: Option<u64>,
) -> Result<Vec<String>, ModelError> {
if !self.fitted() {
return Err(ModelError::ValidationError(
"Model has not been fitted yet.".to_string(),
));
}
let mut rng: Box<dyn rand::Rng> = match random_seed {
Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
None => Box::new(rand::rng()),
};
let mut context: Vec<String> = text_seed.unwrap_or_else(|| {
(0..self.order().saturating_sub(1))
.map(|_| BOS_LABEL.to_string())
.collect()
});
let mut generated = Vec::with_capacity(num_words);
for _ in 0..num_words {
let ctx_start = if context.len() >= self.order().saturating_sub(1) {
context.len() - self.order().saturating_sub(1)
} else {
0
};
let ctx = &context[ctx_start..];
let children = self.counts().children_with_counts(ctx.iter().cloned());
if children.is_empty() {
break;
}
let words: Vec<String> = children.iter().map(|(w, _)| w.clone()).collect();
let weights: Vec<f64> = children.iter().map(|(_, c)| *c as f64).collect();
let dist = WeightedIndex::new(&weights)
.map_err(|e| ModelError::ValidationError(format!("Sampling error: {}", e)))?;
let idx = dist.sample(&mut *rng);
let word = words[idx].clone();
if word == EOS_LABEL {
break;
}
context.push(word.clone());
generated.push(word);
}
Ok(generated)
}
fn vocab_size(&self) -> usize {
self.vocabulary().len()
}
fn save_to_path(&self, path: &str) -> Result<(), ModelError> {
let mut buf = Vec::new();
save_lm_flatbuffers(self, &mut buf)?;
crate::persistence::save_zstd(path, &buf)
}
fn load_from_path(&mut self, path: &str) -> Result<(), ModelError> {
let bytes = crate::persistence::load_zstd(path, "language model")?;
load_lm_flatbuffers(self, &bytes)
}
}
fn save_lm_flatbuffers<T: BaseLanguageModel, W: Write>(
model: &T,
writer: &mut W,
) -> Result<(), ModelError> {
use generated::rustling::lm_fbs as fbs;
let mut builder = flatbuffers::FlatBufferBuilder::with_capacity(1024 * 1024);
let mut vocab_words: Vec<&String> = model.vocabulary().words().iter().collect();
vocab_words.sort();
let vocab_strs: Vec<_> = vocab_words
.iter()
.map(|w| builder.create_string(w))
.collect();
let vocab_fb = builder.create_vector(&vocab_strs);
let mut all_counts = model.counts().all_counts();
all_counts.sort_by(|a, b| a.0.cmp(&b.0));
let fb_ngrams: Vec<_> = all_counts
.iter()
.map(|(ngram, count)| {
let ngram_strs: Vec<_> = ngram.iter().map(|w| builder.create_string(w)).collect();
let ngram_fb = builder.create_vector(&ngram_strs);
fbs::NgramEntry::create(
&mut builder,
&fbs::NgramEntryArgs {
ngram: Some(ngram_fb),
count: *count,
},
)
})
.collect();
let ngrams_fb = builder.create_vector(&fb_ngrams);
let smoothing_name = builder.create_string(model.smoothing_name());
let gamma = match model.smoothing() {
Smoothing::Mle => 0.0,
Smoothing::Lidstone { gamma } => *gamma,
};
let lm = fbs::LmModel::create(
&mut builder,
&fbs::LmModelArgs {
order: model.order() as u32,
smoothing: Some(smoothing_name),
gamma,
vocabulary: Some(vocab_fb),
ngrams: Some(ngrams_fb),
},
);
builder.finish(lm, None);
writer
.write_all(builder.finished_data())
.map_err(|e| ModelError::Io(format!("Failed to write FlatBuffers data: {e}")))
}
fn load_lm_flatbuffers<T: BaseLanguageModel>(lm: &mut T, bytes: &[u8]) -> Result<(), ModelError> {
use generated::rustling::lm_fbs as fbs;
let opts = crate::persistence::flatbuffers_verifier_opts();
let model = flatbuffers::root_with_opts::<fbs::LmModel>(&opts, bytes)
.map_err(|e| ModelError::ParseError(format!("Invalid FlatBuffers LM data: {e}")))?;
let file_smoothing = model.smoothing();
if file_smoothing != lm.smoothing_name() {
return Err(ModelError::ParseError(format!(
"Smoothing type mismatch: file has '{file_smoothing}' but this model is '{}'",
lm.smoothing_name()
)));
}
let file_order = model.order() as usize;
if file_order != lm.order() {
return Err(ModelError::ParseError(format!(
"Order mismatch: file has {file_order} but this model has {}",
lm.order()
)));
}
if let Smoothing::Lidstone { gamma } = lm.smoothing() {
let file_gamma = model.gamma();
if (file_gamma - gamma).abs() > 1e-15 {
return Err(ModelError::ParseError(format!(
"Gamma mismatch: file has {file_gamma} but this model has {gamma}"
)));
}
}
let vocab_words: HashSet<String> = model.vocabulary().iter().map(|s| s.to_owned()).collect();
*lm.vocabulary_mut() = Vocabulary::from_words(vocab_words);
*lm.counts_mut() = CountTrie::new();
for entry in model.ngrams().iter() {
let ngram: Vec<String> = entry.ngram().iter().map(|s| s.to_owned()).collect();
lm.counts_mut()
.insert_count(ngram.into_iter(), entry.count());
}
lm.set_fitted(true);
Ok(())
}
use crate::persistence::ModelError;
#[derive(Clone)]
pub struct MLE {
order: usize,
smoothing: Smoothing,
vocabulary: Vocabulary,
counts: CountTrie<String>,
fitted: bool,
}
impl BaseLanguageModel for MLE {
fn order(&self) -> usize {
self.order
}
fn smoothing(&self) -> &Smoothing {
&self.smoothing
}
fn smoothing_name(&self) -> &str {
"mle"
}
fn vocabulary(&self) -> &Vocabulary {
&self.vocabulary
}
fn vocabulary_mut(&mut self) -> &mut Vocabulary {
&mut self.vocabulary
}
fn counts(&self) -> &CountTrie<String> {
&self.counts
}
fn counts_mut(&mut self) -> &mut CountTrie<String> {
&mut self.counts
}
fn fitted(&self) -> bool {
self.fitted
}
fn set_fitted(&mut self, fitted: bool) {
self.fitted = fitted;
}
}
impl MLE {
pub fn new(order: usize) -> Result<Self, ModelError> {
if order < 1 {
return Err(ModelError::ValidationError(
"order must be >= 1".to_string(),
));
}
Ok(Self {
order,
smoothing: Smoothing::Mle,
vocabulary: Vocabulary::new(),
counts: CountTrie::new(),
fitted: false,
})
}
}
#[derive(Clone)]
pub struct Lidstone {
order: usize,
gamma: f64,
smoothing: Smoothing,
vocabulary: Vocabulary,
counts: CountTrie<String>,
fitted: bool,
}
impl BaseLanguageModel for Lidstone {
fn order(&self) -> usize {
self.order
}
fn smoothing(&self) -> &Smoothing {
&self.smoothing
}
fn smoothing_name(&self) -> &str {
"lidstone"
}
fn vocabulary(&self) -> &Vocabulary {
&self.vocabulary
}
fn vocabulary_mut(&mut self) -> &mut Vocabulary {
&mut self.vocabulary
}
fn counts(&self) -> &CountTrie<String> {
&self.counts
}
fn counts_mut(&mut self) -> &mut CountTrie<String> {
&mut self.counts
}
fn fitted(&self) -> bool {
self.fitted
}
fn set_fitted(&mut self, fitted: bool) {
self.fitted = fitted;
}
}
impl Lidstone {
pub fn new(order: usize, gamma: f64) -> Result<Self, ModelError> {
if order < 1 {
return Err(ModelError::ValidationError(
"order must be >= 1".to_string(),
));
}
if gamma <= 0.0 {
return Err(ModelError::ValidationError("gamma must be > 0".to_string()));
}
Ok(Self {
order,
gamma,
smoothing: Smoothing::Lidstone { gamma },
vocabulary: Vocabulary::new(),
counts: CountTrie::new(),
fitted: false,
})
}
pub fn gamma(&self) -> f64 {
self.gamma
}
}
#[derive(Clone)]
pub struct Laplace {
order: usize,
smoothing: Smoothing,
vocabulary: Vocabulary,
counts: CountTrie<String>,
fitted: bool,
}
impl BaseLanguageModel for Laplace {
fn order(&self) -> usize {
self.order
}
fn smoothing(&self) -> &Smoothing {
&self.smoothing
}
fn smoothing_name(&self) -> &str {
"laplace"
}
fn vocabulary(&self) -> &Vocabulary {
&self.vocabulary
}
fn vocabulary_mut(&mut self) -> &mut Vocabulary {
&mut self.vocabulary
}
fn counts(&self) -> &CountTrie<String> {
&self.counts
}
fn counts_mut(&mut self) -> &mut CountTrie<String> {
&mut self.counts
}
fn fitted(&self) -> bool {
self.fitted
}
fn set_fitted(&mut self, fitted: bool) {
self.fitted = fitted;
}
}
impl Laplace {
pub fn new(order: usize) -> Result<Self, ModelError> {
if order < 1 {
return Err(ModelError::ValidationError(
"order must be >= 1".to_string(),
));
}
Ok(Self {
order,
smoothing: Smoothing::Lidstone { gamma: 1.0 },
vocabulary: Vocabulary::new(),
counts: CountTrie::new(),
fitted: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn training_data() -> Vec<Vec<String>> {
vec![
vec!["the".into(), "cat".into(), "sat".into()],
vec!["the".into(), "dog".into(), "ran".into()],
vec!["the".into(), "cat".into(), "ran".into()],
]
}
#[test]
fn test_new_mle() {
let model = MLE::new(2).unwrap();
assert_eq!(model.order, 2);
assert!(!model.fitted);
}
#[test]
fn test_new_invalid_order() {
let result = MLE::new(0);
assert!(result.is_err());
}
#[test]
fn test_new_lidstone_invalid_gamma() {
let result = Lidstone::new(2, 0.0);
assert!(result.is_err());
let result = Lidstone::new(2, -1.0);
assert!(result.is_err());
}
#[test]
fn test_fit_builds_vocabulary() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
assert!(model.fitted);
assert_eq!(model.vocabulary.len(), 8);
}
#[test]
fn test_score_before_fit() {
let model = MLE::new(2).unwrap();
let result = model.score("cat".into(), Some(vec!["the".into()]));
assert!(result.is_err());
}
#[test]
fn test_generate_before_fit() {
let model = MLE::new(2).unwrap();
let result = model.generate(5, None, Some(42));
assert!(result.is_err());
}
#[test]
fn test_mle_bigram_score() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let score = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
assert!((score - 2.0 / 3.0).abs() < 1e-9);
let score = model.score("dog".into(), Some(vec!["the".into()])).unwrap();
assert!((score - 1.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_mle_unseen_is_zero() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let score = model
.score("fish".into(), Some(vec!["the".into()]))
.unwrap();
assert_eq!(score, 0.0);
}
#[test]
fn test_mle_unigram() {
let mut model = MLE::new(1).unwrap();
model.fit(training_data());
let score = model.score("the".into(), None).unwrap();
assert!((score - 3.0 / 12.0).abs() < 1e-9);
}
#[test]
fn test_score_vs_unmasked_score() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let s1 = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
let s2 = model
.unmasked_score("cat".into(), Some(vec!["the".into()]))
.unwrap();
assert!((s1 - s2).abs() < 1e-9);
let s1 = model
.score("fish".into(), Some(vec!["the".into()]))
.unwrap();
let s2 = model
.unmasked_score("fish".into(), Some(vec!["the".into()]))
.unwrap();
assert_eq!(s1, 0.0);
assert_eq!(s2, 0.0);
}
#[test]
fn test_lidstone_unseen_nonzero() {
let mut model = Lidstone::new(2, 0.5).unwrap();
model.fit(training_data());
let score = model
.score("fish".into(), Some(vec!["the".into()]))
.unwrap();
assert!(score > 0.0);
}
#[test]
fn test_lidstone_score_formula() {
let mut model = Lidstone::new(2, 0.5).unwrap();
model.fit(training_data());
let score = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
assert!((score - 2.5 / 7.0).abs() < 1e-9);
}
#[test]
fn test_laplace_is_lidstone_gamma_one() {
let mut laplace = Laplace::new(2).unwrap();
let mut lidstone = Lidstone::new(2, 1.0).unwrap();
let data = training_data();
laplace.fit(data.clone());
lidstone.fit(data);
for word in &["cat", "dog", "sat", "ran", "fish"] {
for ctx in &[vec!["the".into()], vec!["cat".into()]] {
let s1 = laplace.score(word.to_string(), Some(ctx.clone())).unwrap();
let s2 = lidstone.score(word.to_string(), Some(ctx.clone())).unwrap();
assert!(
(s1 - s2).abs() < 1e-9,
"Mismatch for word={} ctx={:?}: {} vs {}",
word,
ctx,
s1,
s2
);
}
}
}
#[test]
fn test_logscore() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let score = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
let logscore = model
.logscore("cat".into(), Some(vec!["the".into()]))
.unwrap();
assert!((logscore - score.log2()).abs() < 1e-9);
}
#[test]
fn test_logscore_zero_is_neg_inf() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let logscore = model
.logscore("fish".into(), Some(vec!["the".into()]))
.unwrap();
assert!(logscore.is_infinite() && logscore.is_sign_negative());
}
#[test]
fn test_generate_deterministic_with_seed() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let result1 = model.generate(5, None, Some(42)).unwrap();
let result2 = model.generate(5, None, Some(42)).unwrap();
assert_eq!(result1, result2);
}
#[test]
fn test_generate_returns_words() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let result = model.generate(3, None, Some(42)).unwrap();
assert!(!result.is_empty());
assert!(result.len() <= 3);
for word in &result {
assert_ne!(word, BOS_LABEL);
assert_ne!(word, EOS_LABEL);
}
}
#[test]
fn test_generate_with_text_seed() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let result = model
.generate(2, Some(vec!["the".into()]), Some(42))
.unwrap();
assert!(!result.is_empty());
}
#[test]
fn test_vocabulary_lookup() {
let vocab = Vocabulary::build(&[vec!["hello".into(), "world".into()]]);
assert_eq!(vocab.lookup("hello"), "hello");
assert_eq!(vocab.lookup("unknown"), UNK_LABEL);
}
#[test]
fn test_context_trimming() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let s1 = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
let s2 = model
.score("cat".into(), Some(vec!["blah".into(), "the".into()]))
.unwrap();
assert!((s1 - s2).abs() < 1e-9);
}
#[test]
fn test_save_and_load_mle() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mle_model.bin");
let path_str = path.to_str().unwrap();
model.save_to_path(path_str).unwrap();
let mut loaded = MLE::new(2).unwrap();
loaded.load_from_path(path_str).unwrap();
assert!(loaded.fitted());
assert_eq!(loaded.vocabulary().len(), model.vocabulary().len());
let s1 = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
let s2 = loaded
.score("cat".into(), Some(vec!["the".into()]))
.unwrap();
assert!((s1 - s2).abs() < 1e-9);
}
#[test]
fn test_save_and_load_lidstone() {
let mut model = Lidstone::new(2, 0.5).unwrap();
model.fit(training_data());
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("lidstone_model.bin");
let path_str = path.to_str().unwrap();
model.save_to_path(path_str).unwrap();
let mut loaded = Lidstone::new(2, 0.5).unwrap();
loaded.load_from_path(path_str).unwrap();
let s1 = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
let s2 = loaded
.score("cat".into(), Some(vec!["the".into()]))
.unwrap();
assert!((s1 - s2).abs() < 1e-9);
}
#[test]
fn test_save_and_load_laplace() {
let mut model = Laplace::new(2).unwrap();
model.fit(training_data());
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("laplace_model.bin");
let path_str = path.to_str().unwrap();
model.save_to_path(path_str).unwrap();
let mut loaded = Laplace::new(2).unwrap();
loaded.load_from_path(path_str).unwrap();
let s1 = model.score("cat".into(), Some(vec!["the".into()])).unwrap();
let s2 = loaded
.score("cat".into(), Some(vec!["the".into()]))
.unwrap();
assert!((s1 - s2).abs() < 1e-9);
}
#[test]
fn test_load_smoothing_mismatch() {
let mut model = MLE::new(2).unwrap();
model.fit(training_data());
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("model.bin");
let path_str = path.to_str().unwrap();
model.save_to_path(path_str).unwrap();
let mut wrong = Lidstone::new(2, 0.5).unwrap();
let result = wrong.load_from_path(path_str);
assert!(result.is_err());
}
#[test]
fn test_load_nonexistent_file() {
let mut model = MLE::new(2).unwrap();
let result = model.load_from_path("/nonexistent/path/model.bin");
assert!(result.is_err());
}
}