use std::collections::{HashMap, HashSet};
use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst};
pub type TokenId = u32;
pub type NgramOrder = usize;
#[derive(Clone, Debug)]
pub struct PrunedNgramConfig {
pub order: NgramOrder,
pub min_count: usize,
pub use_backoff: bool,
pub backoff_weight: f64,
pub smoothing: bool,
pub discount: f64,
}
impl Default for PrunedNgramConfig {
fn default() -> Self {
Self {
order: 2,
min_count: 1,
use_backoff: true,
backoff_weight: 0.0,
smoothing: false,
discount: 0.5,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct NgramCounts {
pub unigrams: HashMap<TokenId, usize>,
pub bigrams: HashMap<(TokenId, TokenId), usize>,
pub trigrams: HashMap<(TokenId, TokenId, TokenId), usize>,
pub total: usize,
}
impl NgramCounts {
pub fn new() -> Self {
Self::default()
}
pub fn add_sequence(&mut self, tokens: &[TokenId]) {
for &token in tokens {
*self.unigrams.entry(token).or_insert(0) += 1;
self.total += 1;
}
for window in tokens.windows(2) {
let bigram = (window[0], window[1]);
*self.bigrams.entry(bigram).or_insert(0) += 1;
}
for window in tokens.windows(3) {
let trigram = (window[0], window[1], window[2]);
*self.trigrams.entry(trigram).or_insert(0) += 1;
}
}
pub fn unigram_count(&self, token: TokenId) -> usize {
self.unigrams.get(&token).copied().unwrap_or(0)
}
pub fn bigram_count(&self, prev: TokenId, curr: TokenId) -> usize {
self.bigrams.get(&(prev, curr)).copied().unwrap_or(0)
}
pub fn trigram_count(&self, prev2: TokenId, prev1: TokenId, curr: TokenId) -> usize {
self.trigrams
.get(&(prev2, prev1, curr))
.copied()
.unwrap_or(0)
}
pub fn unigram_prob(&self, token: TokenId) -> f64 {
if self.total == 0 {
return 0.0;
}
self.unigram_count(token) as f64 / self.total as f64
}
pub fn bigram_prob(&self, prev: TokenId, curr: TokenId) -> f64 {
let prev_count = self.unigram_count(prev);
if prev_count == 0 {
return 0.0;
}
self.bigram_count(prev, curr) as f64 / prev_count as f64
}
}
pub fn build_pruned_bigram_graph(
vocab_size: usize,
counts: &NgramCounts,
config: &PrunedNgramConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let start = fst.add_state();
fst.set_start(start);
fst.set_final(start, LogWeight::one());
let mut token_states: HashMap<TokenId, StateId> = HashMap::new();
for token in 0..vocab_size as TokenId {
let state = fst.add_state();
token_states.insert(token, state);
fst.set_final(state, LogWeight::one());
}
let backoff_state = if config.use_backoff {
let state = fst.add_state();
fst.set_final(state, LogWeight::one());
Some(state)
} else {
None
};
for token in 0..vocab_size as TokenId {
let log_prob = if counts.total > 0 {
let prob = counts.unigram_prob(token).max(1e-10);
-prob.ln()
} else {
0.0
};
let to_state = token_states[&token];
fst.add_arc(
start,
Some(token),
Some(token),
to_state,
LogWeight::new(log_prob),
);
}
for prev in 0..vocab_size as TokenId {
let from_state = token_states[&prev];
let mut seen_tokens = HashSet::new();
for curr in 0..vocab_size as TokenId {
let count = counts.bigram_count(prev, curr);
if count >= config.min_count {
let log_prob = if config.smoothing {
compute_smoothed_prob(counts, prev, curr, config)
} else {
let prob = counts.bigram_prob(prev, curr).max(1e-10);
-prob.ln()
};
let to_state = token_states[&curr];
fst.add_arc(
from_state,
Some(curr),
Some(curr),
to_state,
LogWeight::new(log_prob),
);
seen_tokens.insert(curr);
}
}
if config.use_backoff {
if let Some(backoff) = backoff_state {
fst.add_arc(
from_state,
None,
None,
backoff,
LogWeight::new(config.backoff_weight),
);
}
}
}
if let Some(backoff) = backoff_state {
for token in 0..vocab_size as TokenId {
let log_prob = if counts.total > 0 {
let prob = counts.unigram_prob(token).max(1e-10);
-prob.ln()
} else {
0.0
};
let to_state = token_states[&token];
fst.add_arc(
backoff,
Some(token),
Some(token),
to_state,
LogWeight::new(log_prob),
);
}
}
fst
}
fn compute_smoothed_prob(
counts: &NgramCounts,
prev: TokenId,
curr: TokenId,
config: &PrunedNgramConfig,
) -> f64 {
let bigram_count = counts.bigram_count(prev, curr) as f64;
let prev_count = counts.unigram_count(prev) as f64;
if prev_count == 0.0 {
return 0.0;
}
let discounted = (bigram_count - config.discount).max(0.0) / prev_count;
let unigram_prob = counts.unigram_prob(curr);
let lambda = config.discount / prev_count;
let prob = (discounted + lambda * unigram_prob).max(1e-10);
-prob.ln()
}
pub fn build_pruned_trigram_graph(
vocab_size: usize,
counts: &NgramCounts,
config: &PrunedNgramConfig,
) -> VectorWfst<TokenId, LogWeight> {
let mut fst = VectorWfst::new();
let start = fst.add_state();
fst.set_start(start);
fst.set_final(start, LogWeight::one());
let mut unigram_states: HashMap<TokenId, StateId> = HashMap::new();
for token in 0..vocab_size as TokenId {
let state = fst.add_state();
unigram_states.insert(token, state);
fst.set_final(state, LogWeight::one());
}
let mut bigram_states: HashMap<(TokenId, TokenId), StateId> = HashMap::new();
for &(prev, curr) in counts.bigrams.keys() {
if counts.bigram_count(prev, curr) >= config.min_count {
let state = fst.add_state();
bigram_states.insert((prev, curr), state);
fst.set_final(state, LogWeight::one());
}
}
let bigram_backoff = if config.use_backoff {
let state = fst.add_state();
fst.set_final(state, LogWeight::one());
Some(state)
} else {
None
};
for token in 0..vocab_size as TokenId {
let log_prob = if counts.total > 0 {
let prob = counts.unigram_prob(token).max(1e-10);
-prob.ln()
} else {
0.0
};
let to_state = unigram_states[&token];
fst.add_arc(
start,
Some(token),
Some(token),
to_state,
LogWeight::new(log_prob),
);
}
for prev in 0..vocab_size as TokenId {
let from_state = unigram_states[&prev];
for curr in 0..vocab_size as TokenId {
let count = counts.bigram_count(prev, curr);
if count >= config.min_count {
if let Some(&to_state) = bigram_states.get(&(prev, curr)) {
let prob = counts.bigram_prob(prev, curr).max(1e-10);
fst.add_arc(
from_state,
Some(curr),
Some(curr),
to_state,
LogWeight::new(-prob.ln()),
);
}
}
}
if let Some(backoff) = bigram_backoff {
fst.add_arc(
from_state,
None,
None,
backoff,
LogWeight::new(config.backoff_weight),
);
}
}
if let Some(backoff) = bigram_backoff {
for token in 0..vocab_size as TokenId {
let prob = counts.unigram_prob(token).max(1e-10);
let to_state = unigram_states[&token];
fst.add_arc(
backoff,
Some(token),
Some(token),
to_state,
LogWeight::new(-prob.ln()),
);
}
}
for (&(prev1, prev2), &from_state) in &bigram_states {
for curr in 0..vocab_size as TokenId {
let count = counts.trigram_count(prev1, prev2, curr);
if count >= config.min_count {
if let Some(&to_state) = bigram_states.get(&(prev2, curr)) {
let denom = counts.bigram_count(prev1, prev2) as f64;
let prob = if denom > 0.0 {
(count as f64 / denom).max(1e-10)
} else {
1e-10
};
fst.add_arc(
from_state,
Some(curr),
Some(curr),
to_state,
LogWeight::new(-prob.ln()),
);
}
}
}
if let Some(backoff) = bigram_backoff {
fst.add_arc(
from_state,
None,
None,
backoff,
LogWeight::new(config.backoff_weight),
);
}
}
fst
}
#[derive(Clone, Debug, Default)]
pub struct PrunedNgramStats {
pub num_states: usize,
pub num_arcs: usize,
pub ngrams_kept: usize,
pub ngrams_pruned: usize,
pub pruning_ratio: f64,
pub dense_arcs: usize,
pub compression_ratio: f64,
}
impl PrunedNgramStats {
pub fn from_bigram_graph<L: Clone + Send + Sync>(
fst: &VectorWfst<L, LogWeight>,
vocab_size: usize,
) -> Self {
let num_states = fst.num_states();
let num_arcs: usize = (0..num_states as StateId)
.map(|s| fst.transitions(s).len())
.sum();
let dense_arcs = vocab_size * vocab_size;
Self {
num_states,
num_arcs,
ngrams_kept: 0, ngrams_pruned: 0,
pruning_ratio: 0.0,
dense_arcs,
compression_ratio: if num_arcs > 0 {
dense_arcs as f64 / num_arcs as f64
} else {
0.0
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::NO_STATE;
#[test]
fn test_pruned_ngram_config_default() {
let config = PrunedNgramConfig::default();
assert_eq!(config.order, 2);
assert_eq!(config.min_count, 1);
assert!(config.use_backoff);
}
#[test]
fn test_ngram_counts_empty() {
let counts = NgramCounts::new();
assert_eq!(counts.total, 0);
assert_eq!(counts.unigram_count(0), 0);
}
#[test]
fn test_ngram_counts_add_sequence() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[1, 2, 3, 1, 2]);
assert_eq!(counts.unigram_count(1), 2);
assert_eq!(counts.unigram_count(2), 2);
assert_eq!(counts.unigram_count(3), 1);
assert_eq!(counts.total, 5);
assert_eq!(counts.bigram_count(1, 2), 2);
assert_eq!(counts.bigram_count(2, 3), 1);
assert_eq!(counts.bigram_count(3, 1), 1);
}
#[test]
fn test_ngram_counts_probabilities() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 0, 1]);
assert!((counts.unigram_prob(0) - 0.5).abs() < 1e-6);
assert!((counts.unigram_prob(1) - 0.5).abs() < 1e-6);
assert!((counts.bigram_prob(0, 1) - 1.0).abs() < 1e-6);
}
#[test]
fn test_build_pruned_bigram_graph() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 2, 0, 1]);
let config = PrunedNgramConfig::default();
let fst = build_pruned_bigram_graph(3, &counts, &config);
assert!(fst.start() != NO_STATE);
assert!(fst.num_states() > 0);
}
#[test]
fn test_pruned_bigram_with_threshold() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 0, 1, 0, 1]);
let config = PrunedNgramConfig {
min_count: 2,
..Default::default()
};
let fst = build_pruned_bigram_graph(3, &counts, &config);
assert!(fst.num_states() > 0);
}
#[test]
fn test_build_pruned_trigram_graph() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 2, 0, 1, 2]);
let config = PrunedNgramConfig {
order: 3,
..Default::default()
};
let fst = build_pruned_trigram_graph(3, &counts, &config);
assert!(fst.start() != NO_STATE);
assert!(fst.num_states() > 0);
}
#[test]
fn test_pruned_ngram_stats() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 2]);
let config = PrunedNgramConfig::default();
let fst = build_pruned_bigram_graph(10, &counts, &config);
let stats = PrunedNgramStats::from_bigram_graph(&fst, 10);
assert!(stats.num_states > 0);
assert!(stats.num_arcs > 0);
assert_eq!(stats.dense_arcs, 100); }
#[test]
fn test_backoff_disabled() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1]);
let config = PrunedNgramConfig {
use_backoff: false,
..Default::default()
};
let fst = build_pruned_bigram_graph(3, &counts, &config);
assert!(fst.start() != NO_STATE);
}
#[test]
fn test_smoothing_enabled() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 2]);
let config = PrunedNgramConfig {
smoothing: true,
discount: 0.5,
..Default::default()
};
let fst = build_pruned_bigram_graph(3, &counts, &config);
assert!(fst.num_states() > 0);
}
#[test]
fn test_compression_ratio() {
let mut counts = NgramCounts::new();
counts.add_sequence(&[0, 1, 0, 1]);
let config = PrunedNgramConfig {
min_count: 2, ..Default::default()
};
let fst = build_pruned_bigram_graph(100, &counts, &config);
let stats = PrunedNgramStats::from_bigram_graph(&fst, 100);
assert!(stats.num_arcs < stats.dense_arcs);
assert!(stats.compression_ratio > 1.0);
}
}