use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{MutableWfst, StateId, VectorWfst};
pub type VocabId = u32;
pub const UNK_ID: VocabId = 0;
pub const BOS_ID: VocabId = 1;
pub const EOS_ID: VocabId = 2;
#[derive(Clone, Debug)]
pub struct NgramEntry {
pub context: SmallVec<[VocabId; 4]>,
pub word: VocabId,
pub log_prob: f64,
}
#[derive(Clone, Debug)]
pub struct BackoffWeight {
pub context: SmallVec<[VocabId; 4]>,
pub weight: f64,
}
#[derive(Clone, Debug)]
pub struct NgramLmConfig {
pub order: usize,
pub use_backoff_symbol: bool,
pub vocab_size: usize,
pub prune_threshold: Option<f64>,
}
impl Default for NgramLmConfig {
fn default() -> Self {
Self {
order: 3,
use_backoff_symbol: true,
vocab_size: 0, prune_threshold: None,
}
}
}
pub struct NgramLmBuilder {
config: NgramLmConfig,
context_to_state: FxHashMap<SmallVec<[VocabId; 4]>, StateId>,
backoff_weights: FxHashMap<SmallVec<[VocabId; 4]>, f64>,
ngrams: Vec<NgramEntry>,
vocab: FxHashMap<VocabId, bool>,
}
impl NgramLmBuilder {
pub fn new(config: NgramLmConfig) -> Self {
Self {
config,
context_to_state: FxHashMap::default(),
backoff_weights: FxHashMap::default(),
ngrams: Vec::new(),
vocab: FxHashMap::default(),
}
}
pub fn add_ngram(&mut self, context: &[VocabId], word: VocabId, log_prob: f64) {
if let Some(threshold) = self.config.prune_threshold {
if log_prob > threshold {
return; }
}
self.vocab.insert(word, true);
for &w in context {
self.vocab.insert(w, true);
}
self.ngrams.push(NgramEntry {
context: SmallVec::from_slice(context),
word,
log_prob,
});
}
pub fn add_backoff(&mut self, context: &[VocabId], weight: f64) {
self.backoff_weights
.insert(SmallVec::from_slice(context), weight);
}
fn get_or_create_state<L: Clone + Send + Sync>(
&mut self,
fst: &mut VectorWfst<L, LogWeight>,
context: &[VocabId],
) -> StateId {
let key: SmallVec<[VocabId; 4]> = SmallVec::from_slice(context);
if let Some(&state) = self.context_to_state.get(&key) {
return state;
}
let state = fst.add_state();
self.context_to_state.insert(key, state);
state
}
fn backoff_context(context: &[VocabId]) -> SmallVec<[VocabId; 4]> {
if context.is_empty() {
SmallVec::new()
} else {
SmallVec::from_slice(&context[1..])
}
}
pub fn build(mut self) -> VectorWfst<VocabId, LogWeight> {
let mut fst: VectorWfst<VocabId, LogWeight> = VectorWfst::new();
let initial = fst.add_state();
fst.set_start(initial);
self.context_to_state.insert(SmallVec::new(), initial);
let ngrams = self.ngrams.clone();
let mut all_contexts: Vec<SmallVec<[VocabId; 4]>> = Vec::new();
for ngram in &ngrams {
all_contexts.push(ngram.context.clone());
let mut new_context = ngram.context.clone();
new_context.push(ngram.word);
if new_context.len() > self.config.order - 1 {
new_context.remove(0);
}
all_contexts.push(new_context);
}
for context in &all_contexts {
let _state = self.get_or_create_state(&mut fst, context);
}
let contexts: Vec<_> = self.context_to_state.keys().cloned().collect();
for context in &contexts {
if !context.is_empty() {
let backoff = Self::backoff_context(context);
let _backoff_state = self.get_or_create_state(&mut fst, &backoff);
}
}
for ngram in &ngrams {
let source = *self
.context_to_state
.get(&ngram.context)
.expect("context exists");
let mut new_context = ngram.context.clone();
new_context.push(ngram.word);
if new_context.len() > self.config.order - 1 {
new_context.remove(0);
}
let target = *self
.context_to_state
.get(&new_context)
.expect("target exists");
fst.add_arc(
source,
Some(ngram.word),
Some(ngram.word),
target,
LogWeight::new(ngram.log_prob),
);
}
let context_states: Vec<_> = self
.context_to_state
.iter()
.map(|(k, &v)| (k.clone(), v))
.collect();
for (context, state) in &context_states {
if context.is_empty() {
continue; }
let backoff_context = Self::backoff_context(context);
let backoff_state = *self
.context_to_state
.get(&backoff_context)
.expect("backoff context exists");
let backoff_weight = self.backoff_weights.get(context).copied().unwrap_or(0.0);
if self.config.use_backoff_symbol {
fst.add_arc(
*state,
None, None, backoff_state,
LogWeight::new(backoff_weight),
);
} else {
fst.add_arc(
*state,
None,
None,
backoff_state,
LogWeight::new(backoff_weight),
);
}
}
for (ctx, &state) in self
.context_to_state
.iter()
.filter(|(ctx, _)| ctx.last() == Some(&EOS_ID))
{
let _ = ctx; fst.set_final(state, LogWeight::one());
}
fst.set_final(initial, LogWeight::one());
fst
}
pub fn stats(&self) -> NgramStats {
let mut order_counts = [0usize; 8];
for ngram in &self.ngrams {
let order = ngram.context.len() + 1;
if order < order_counts.len() {
order_counts[order] += 1;
}
}
NgramStats {
num_ngrams: self.ngrams.len(),
num_contexts: self.context_to_state.len(),
num_backoffs: self.backoff_weights.len(),
vocab_size: self.vocab.len(),
order_counts,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct NgramStats {
pub num_ngrams: usize,
pub num_contexts: usize,
pub num_backoffs: usize,
pub vocab_size: usize,
pub order_counts: [usize; 8],
}
pub struct BigramLm {
unigram_probs: Vec<f64>,
bigram_probs: FxHashMap<(VocabId, VocabId), f64>,
backoff_weights: Vec<f64>,
vocab_size: usize,
}
impl BigramLm {
pub fn new(vocab_size: usize) -> Self {
Self {
unigram_probs: vec![f64::INFINITY; vocab_size], bigram_probs: FxHashMap::default(),
backoff_weights: vec![0.0; vocab_size], vocab_size,
}
}
pub fn set_unigram(&mut self, word: VocabId, log_prob: f64) {
if (word as usize) < self.vocab_size {
self.unigram_probs[word as usize] = log_prob;
}
}
pub fn set_bigram(&mut self, w1: VocabId, w2: VocabId, log_prob: f64) {
self.bigram_probs.insert((w1, w2), log_prob);
}
pub fn set_backoff(&mut self, word: VocabId, weight: f64) {
if (word as usize) < self.vocab_size {
self.backoff_weights[word as usize] = weight;
}
}
pub fn prob(&self, w1: VocabId, w2: VocabId) -> f64 {
if let Some(&log_prob) = self.bigram_probs.get(&(w1, w2)) {
return log_prob;
}
let unigram = self
.unigram_probs
.get(w2 as usize)
.copied()
.unwrap_or(f64::INFINITY);
let backoff = self
.backoff_weights
.get(w1 as usize)
.copied()
.unwrap_or(0.0);
unigram + backoff
}
pub fn to_wfst(&self) -> VectorWfst<VocabId, LogWeight> {
let mut fst: VectorWfst<VocabId, LogWeight> = VectorWfst::new();
let backoff_state = fst.add_state();
fst.set_start(backoff_state);
let mut word_states: Vec<StateId> = Vec::with_capacity(self.vocab_size);
for _ in 0..self.vocab_size {
word_states.push(fst.add_state());
}
for (w, &log_prob) in self.unigram_probs.iter().enumerate() {
if log_prob < f64::INFINITY {
fst.add_arc(
backoff_state,
Some(w as VocabId),
Some(w as VocabId),
word_states[w],
LogWeight::new(log_prob),
);
}
}
for (&(w1, w2), &log_prob) in &self.bigram_probs {
if (w1 as usize) < self.vocab_size && (w2 as usize) < self.vocab_size {
fst.add_arc(
word_states[w1 as usize],
Some(w2),
Some(w2),
word_states[w2 as usize],
LogWeight::new(log_prob),
);
}
}
for (w, &backoff_weight) in self.backoff_weights.iter().enumerate() {
fst.add_arc(
word_states[w],
None, None,
backoff_state,
LogWeight::new(backoff_weight),
);
}
fst.set_final(backoff_state, LogWeight::one());
for &state in &word_states {
fst.set_final(state, LogWeight::one());
}
fst
}
pub fn stats(&self) -> BigramStats {
let num_unigrams = self
.unigram_probs
.iter()
.filter(|&&p| p < f64::INFINITY)
.count();
BigramStats {
vocab_size: self.vocab_size,
num_unigrams,
num_bigrams: self.bigram_probs.len(),
sparsity: 1.0
- (self.bigram_probs.len() as f64 / (self.vocab_size * self.vocab_size) as f64),
}
}
}
#[derive(Clone, Debug)]
pub struct BigramStats {
pub vocab_size: usize,
pub num_unigrams: usize,
pub num_bigrams: usize,
pub sparsity: f64,
}
#[derive(Clone, Debug)]
pub enum PruningStrategy {
None,
CountThreshold(usize),
ProbabilityThreshold(f64),
EntropyThreshold(f64),
}
pub fn compute_size_reduction(
vocab_size: usize,
num_observed: usize,
order: usize,
) -> SizeReduction {
let dense_states = vocab_size.pow((order - 1) as u32);
let dense_arcs = vocab_size.pow(order as u32);
let sparse_states = num_observed / vocab_size + 1; let sparse_arcs = num_observed + sparse_states;
SizeReduction {
dense_states,
dense_arcs,
sparse_states,
sparse_arcs,
state_reduction: if dense_states > 0 {
1.0 - (sparse_states as f64 / dense_states as f64)
} else {
0.0
},
arc_reduction: if dense_arcs > 0 {
1.0 - (sparse_arcs as f64 / dense_arcs as f64)
} else {
0.0
},
}
}
#[derive(Clone, Debug)]
pub struct SizeReduction {
pub dense_states: usize,
pub dense_arcs: usize,
pub sparse_states: usize,
pub sparse_arcs: usize,
pub state_reduction: f64,
pub arc_reduction: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::Wfst;
#[test]
fn test_bigram_lm_basic() {
let mut lm = BigramLm::new(5);
lm.set_unigram(0, 2.0); lm.set_unigram(1, 1.5);
lm.set_unigram(2, 1.0);
lm.set_bigram(0, 1, 0.5); lm.set_bigram(1, 2, 0.3);
lm.set_backoff(0, 0.1);
assert!((lm.prob(0, 1) - 0.5).abs() < 1e-10); assert!((lm.prob(0, 2) - (1.0 + 0.1)).abs() < 1e-10); }
#[test]
fn test_bigram_lm_to_wfst() {
let mut lm = BigramLm::new(3);
lm.set_unigram(0, 1.0);
lm.set_unigram(1, 2.0);
lm.set_bigram(0, 1, 0.5);
let fst = lm.to_wfst();
assert_eq!(fst.num_states(), 4);
}
#[test]
fn test_ngram_builder_basic() {
let config = NgramLmConfig {
order: 2,
use_backoff_symbol: true,
vocab_size: 5,
prune_threshold: None,
};
let mut builder = NgramLmBuilder::new(config);
builder.add_ngram(&[], 0, 2.0);
builder.add_ngram(&[], 1, 1.5);
builder.add_ngram(&[], 2, 1.0);
builder.add_ngram(&[0], 1, 0.5);
builder.add_ngram(&[1], 2, 0.3);
builder.add_backoff(&[0], 0.1);
builder.add_backoff(&[1], 0.2);
let stats = builder.stats();
assert_eq!(stats.num_ngrams, 5);
assert_eq!(stats.vocab_size, 3);
let _fst = builder.build();
}
#[test]
fn test_ngram_builder_with_pruning() {
let config = NgramLmConfig {
order: 2,
use_backoff_symbol: true,
vocab_size: 5,
prune_threshold: Some(1.0), };
let mut builder = NgramLmBuilder::new(config);
builder.add_ngram(&[], 0, 0.5); builder.add_ngram(&[], 1, 1.5); builder.add_ngram(&[], 2, 2.0);
let stats = builder.stats();
assert_eq!(stats.num_ngrams, 1); }
#[test]
fn test_size_reduction() {
let reduction = compute_size_reduction(1000, 50000, 2);
assert_eq!(reduction.dense_states, 1000);
assert_eq!(reduction.dense_arcs, 1_000_000);
assert!(reduction.sparse_arcs < reduction.dense_arcs);
assert!(reduction.arc_reduction > 0.9); }
#[test]
fn test_trigram_builder() {
let config = NgramLmConfig {
order: 3,
use_backoff_symbol: true,
vocab_size: 10,
prune_threshold: None,
};
let mut builder = NgramLmBuilder::new(config);
builder.add_ngram(&[0, 1], 2, 0.5);
builder.add_ngram(&[1, 2], 3, 0.3);
builder.add_ngram(&[0], 1, 0.8);
builder.add_ngram(&[1], 2, 0.6);
builder.add_ngram(&[], 0, 2.0);
builder.add_ngram(&[], 1, 1.8);
builder.add_ngram(&[], 2, 1.5);
builder.add_ngram(&[], 3, 1.2);
builder.add_backoff(&[0, 1], 0.1);
builder.add_backoff(&[1, 2], 0.1);
builder.add_backoff(&[0], 0.2);
builder.add_backoff(&[1], 0.2);
let stats = builder.stats();
assert_eq!(stats.order_counts[1], 4); assert_eq!(stats.order_counts[2], 2); assert_eq!(stats.order_counts[3], 2);
let _fst = builder.build();
}
#[test]
fn test_backoff_context() {
let ctx: SmallVec<[VocabId; 4]> = SmallVec::from_slice(&[1, 2, 3]);
let backoff = NgramLmBuilder::backoff_context(&ctx);
assert_eq!(backoff.as_slice(), &[2, 3]);
let empty: SmallVec<[VocabId; 4]> = SmallVec::new();
let backoff_empty = NgramLmBuilder::backoff_context(&empty);
assert!(backoff_empty.is_empty());
}
#[test]
fn test_bigram_stats() {
let mut lm = BigramLm::new(100);
for i in 0..50 {
lm.set_unigram(i, 1.0);
}
for i in 0..100 {
lm.set_bigram(i, (i + 1) % 100, 0.5);
}
let stats = lm.stats();
assert_eq!(stats.vocab_size, 100);
assert_eq!(stats.num_unigrams, 50);
assert_eq!(stats.num_bigrams, 100);
assert!(stats.sparsity >= 0.99); }
}