use super::entry::NgramEntry;
use super::smoothing::KneserNeySmoothing;
use super::trie::NgramTrie;
use liblevenshtein::dictionary::MutableMappedDictionary;
#[cfg(feature = "serde-extras")]
use std::path::Path;
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(bound = "D: serde::Serialize + serde::de::DeserializeOwned")]
pub struct NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
trie: NgramTrie<D>,
smoothing: KneserNeySmoothing,
vocab_size: usize,
total_count: u64,
}
impl<D> NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
pub fn new(
trie: NgramTrie<D>,
smoothing: KneserNeySmoothing,
vocab_size: usize,
total_count: u64,
) -> Self {
Self {
trie,
smoothing,
vocab_size,
total_count,
}
}
#[inline]
pub fn order(&self) -> usize {
self.trie.max_order()
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
#[inline]
pub fn total_count(&self) -> u64 {
self.total_count
}
#[inline]
pub fn trie(&self) -> &NgramTrie<D> {
&self.trie
}
#[inline]
pub fn count(&self, tokens: &[&str]) -> u64 {
self.trie.count(tokens)
}
pub fn log_prob(&self, word: &str, context: &[&str]) -> f64 {
self.smoothing
.log_prob(word, context, &self.trie, self.vocab_size, self.total_count)
}
pub fn sentence_log_prob(&self, tokens: &[&str]) -> f64 {
if tokens.is_empty() {
return 0.0;
}
let order = self.order();
let mut total_log_prob = 0.0;
for i in 0..tokens.len() {
let word = tokens[i];
let context_start = i.saturating_sub(order - 1);
let context = &tokens[context_start..i];
total_log_prob += self.log_prob(word, context);
}
total_log_prob
}
#[inline]
pub fn in_vocabulary(&self, word: &str) -> bool {
self.trie.contains(&[word])
}
#[inline]
pub fn ngram_count(&self) -> usize {
self.trie.len()
}
#[inline]
pub fn oov_log_prob(&self) -> f64 {
if self.vocab_size == 0 {
f64::NEG_INFINITY
} else {
-(self.vocab_size as f64).ln()
}
}
}
impl<D> Clone for NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
fn clone(&self) -> Self {
Self {
trie: self.trie.clone(),
smoothing: self.smoothing.clone(),
vocab_size: self.vocab_size,
total_count: self.total_count,
}
}
}
#[cfg(feature = "serde-extras")]
impl<D> NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + serde::Serialize + serde::de::DeserializeOwned,
{
pub fn save<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
let file = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(file);
bincode::serialize_into(writer, self)?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
let model = bincode::deserialize_from(reader)?;
Ok(model)
}
}
#[cfg(feature = "serde-extras")]
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct PortableVocabulary {
pub words: Vec<String>,
}
#[cfg(feature = "serde-extras")]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct PortableNgramModel {
pub entries: Vec<(String, crate::ngram::NgramEntrySnapshot)>,
pub max_order: usize,
pub vocab_size: usize,
pub total_count: u64,
pub smoothing: KneserNeySmoothing,
#[serde(default)]
pub vocabulary: Option<PortableVocabulary>,
}
#[cfg(feature = "serde-extras")]
impl<D> NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
pub fn to_portable(&self) -> PortableNgramModel
where
D: crate::ngram::trie::IterableDictionary,
{
self.to_portable_with_vocabulary(None)
}
pub fn to_portable_with_vocabulary(
&self,
vocabulary: Option<&crate::ngram::SharedVocabARTrie>,
) -> PortableNgramModel
where
D: crate::ngram::trie::IterableDictionary,
{
let entries: Vec<(String, crate::ngram::NgramEntrySnapshot)> = self
.trie
.iter_entries()
.map(|(key, entry)| (key, crate::ngram::NgramEntrySnapshot::from(&entry)))
.collect();
let portable_vocab = vocabulary.map(|vocab| {
let guard = vocab.read();
let len = guard.len();
let mut words = Vec::with_capacity(len);
for i in 1..=(len as u64) {
if let Some(term) = guard.get_term(i) {
words.push(term);
}
}
PortableVocabulary { words }
});
PortableNgramModel {
entries,
max_order: self.trie.max_order(),
vocab_size: self.vocab_size,
total_count: self.total_count,
smoothing: self.smoothing.clone(),
vocabulary: portable_vocab,
}
}
pub fn save_portable<P: AsRef<Path>>(&self, path: P) -> crate::Result<()>
where
D: crate::ngram::trie::IterableDictionary,
{
let portable = self.to_portable();
let file = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(file);
bincode::serialize_into(writer, &portable)?;
Ok(())
}
pub fn save_portable_with_vocabulary<P: AsRef<Path>>(
&self,
path: P,
vocabulary: &crate::ngram::SharedVocabARTrie,
) -> crate::Result<()>
where
D: crate::ngram::trie::IterableDictionary,
{
let portable = self.to_portable_with_vocabulary(Some(vocabulary));
let file = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(file);
bincode::serialize_into(writer, &portable)?;
Ok(())
}
pub fn load_portable<P, F>(path: P, dictionary_factory: F) -> crate::Result<Self>
where
P: AsRef<Path>,
F: FnOnce() -> D,
{
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
let portable: PortableNgramModel = bincode::deserialize_from(reader)?;
Self::load_portable_from_portable(portable, dictionary_factory)
}
pub fn load_portable_from_portable<F>(
portable: PortableNgramModel,
dictionary_factory: F,
) -> crate::Result<Self>
where
F: FnOnce() -> D,
{
let dictionary = dictionary_factory();
for (key, snapshot) in portable.entries {
dictionary.insert_with_value(&key, NgramEntry::from(snapshot));
}
let trie = NgramTrie::new(dictionary, portable.max_order);
Ok(Self {
trie,
smoothing: portable.smoothing,
vocab_size: portable.vocab_size,
total_count: portable.total_count,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::PlaintextReader;
use crate::ngram::TrainerBuilder;
use liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar;
use liblevenshtein::dictionary::pathmap::PathMapDictionary;
use std::io::Write;
use tempfile::TempDir;
fn create_test_corpus(dir: &std::path::Path, content: &str) -> std::path::PathBuf {
let path = dir.join("test.txt");
let mut file = std::fs::File::create(&path).expect("Failed to create test file");
write!(file, "{}", content).expect("Failed to write test file");
path
}
fn create_test_ngram_model() -> NgramModel<DynamicDawgChar<NgramEntry>> {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox";
let path = create_test_corpus(dir.path(), content);
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
let dictionary = DynamicDawgChar::<NgramEntry>::new();
TrainerBuilder::new(dictionary)
.order(3)
.train(reader)
.expect("N-gram training failed")
}
#[test]
fn test_model_properties() {
let model = create_test_ngram_model();
assert_eq!(model.order(), 3);
assert!(model.vocab_size() > 0);
assert!(model.total_count() > 0);
}
#[test]
fn test_log_prob() {
let model = create_test_ngram_model();
let log_prob = model.log_prob("fox", &["brown"]);
assert!(log_prob.is_finite());
assert!(log_prob <= 0.0);
let unigram_prob = model.log_prob("the", &[]);
assert!(unigram_prob.is_finite());
}
#[test]
fn test_sentence_log_prob() {
let model = create_test_ngram_model();
let log_prob = model.sentence_log_prob(&["the", "quick", "brown", "fox"]);
assert!(log_prob.is_finite());
assert!(log_prob < 0.0);
}
#[cfg(feature = "serde-extras")]
#[test]
fn test_ngram_save_load_roundtrip() {
let model = create_test_ngram_model();
let temp_file = tempfile::NamedTempFile::new().expect("Failed to create temp file");
model.save(temp_file.path()).expect("Failed to save model");
let metadata = std::fs::metadata(temp_file.path()).expect("Failed to get file metadata");
assert!(metadata.len() > 0, "Saved model file should not be empty");
let loaded: NgramModel<DynamicDawgChar<NgramEntry>> =
NgramModel::load(temp_file.path()).expect("Failed to load model");
assert_eq!(model.order(), loaded.order());
assert_eq!(model.vocab_size(), loaded.vocab_size());
assert_eq!(model.total_count(), loaded.total_count());
let orig_prob = model.log_prob("fox", &["the", "quick"]);
let loaded_prob = loaded.log_prob("fox", &["the", "quick"]);
assert!(
probs_equal(orig_prob, loaded_prob),
"Log probabilities should match after roundtrip: {} vs {}",
orig_prob,
loaded_prob
);
let orig_sentence_prob = model.sentence_log_prob(&["the", "quick", "brown", "fox"]);
let loaded_sentence_prob = loaded.sentence_log_prob(&["the", "quick", "brown", "fox"]);
assert!(
probs_equal(orig_sentence_prob, loaded_sentence_prob),
"Sentence log probabilities should match: {} vs {}",
orig_sentence_prob,
loaded_sentence_prob
);
}
#[cfg(feature = "serde-extras")]
fn probs_equal(a: f64, b: f64) -> bool {
if a.is_infinite() && b.is_infinite() {
a.signum() == b.signum() } else if a.is_nan() || b.is_nan() {
false
} else {
(a - b).abs() < 1e-10
}
}
#[test]
fn test_pathmap_model() {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox";
let path = create_test_corpus(dir.path(), content);
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
let dictionary = PathMapDictionary::<NgramEntry>::new();
let model = TrainerBuilder::new(dictionary)
.order(3)
.train(reader)
.expect("N-gram training failed");
let log_prob = model.log_prob("fox", &["brown"]);
assert!(log_prob.is_finite());
}
}