use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst};
pub type WordId = u32;
pub type NgramOrder = usize;
pub type NgramWeight<W> = W;
pub const NGRAM_EPSILON: Option<WordId> = None;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct BackoffState {
pub history: Vec<WordId>,
pub order: usize,
}
impl BackoffState {
pub fn new(history: Vec<WordId>) -> Self {
let order = history.len();
Self { history, order }
}
pub fn initial() -> Self {
Self {
history: Vec::new(),
order: 0,
}
}
}
#[derive(Clone, Debug)]
pub struct NgramConfig {
pub order: NgramOrder,
pub add_sentence_markers: bool,
pub sos_id: Option<WordId>,
pub eos_id: Option<WordId>,
pub unk_id: Option<WordId>,
}
impl Default for NgramConfig {
fn default() -> Self {
Self {
order: 3, add_sentence_markers: false,
sos_id: None,
eos_id: None,
unk_id: None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct NgramState {
pub history: Vec<WordId>,
pub is_backoff: bool,
}
impl NgramState {
pub fn with_history(history: Vec<WordId>) -> Self {
Self {
history,
is_backoff: false,
}
}
pub fn backoff(history: Vec<WordId>) -> Self {
Self {
history,
is_backoff: true,
}
}
pub fn initial() -> Self {
Self {
history: Vec::new(),
is_backoff: false,
}
}
pub fn backed_off(&self) -> Self {
let mut new_history = self.history.clone();
if !new_history.is_empty() {
new_history.remove(0);
}
Self {
history: new_history,
is_backoff: true,
}
}
pub fn extend(&self, word: WordId, max_history: usize) -> Self {
let mut new_history = self.history.clone();
new_history.push(word);
while new_history.len() > max_history {
new_history.remove(0);
}
Self {
history: new_history,
is_backoff: false,
}
}
}
pub struct NgramTransducer<W: Semiring> {
pub fst: VectorWfst<WordId, W>,
pub config: NgramConfig,
pub vocab_size: usize,
}
pub struct NgramBuilder<W: Semiring> {
config: NgramConfig,
vocab_size: usize,
unigrams: HashMap<WordId, W>,
ngrams: HashMap<Vec<WordId>, HashMap<WordId, W>>,
backoffs: HashMap<Vec<WordId>, W>,
}
impl<W: Semiring + Clone> NgramBuilder<W> {
pub fn new(order: NgramOrder) -> Self {
Self {
config: NgramConfig {
order,
..Default::default()
},
vocab_size: 0,
unigrams: HashMap::new(),
ngrams: HashMap::new(),
backoffs: HashMap::new(),
}
}
pub fn vocab_size(mut self, size: usize) -> Self {
self.vocab_size = size;
self
}
pub fn config(mut self, config: NgramConfig) -> Self {
self.config = config;
self
}
pub fn add_unigram(&mut self, word: WordId, weight: W) {
self.unigrams.insert(word, weight);
self.vocab_size = self.vocab_size.max(word as usize + 1);
}
pub fn add_bigram(&mut self, history: &[WordId], word: WordId, weight: W) {
self.add_ngram(history, word, weight);
}
pub fn add_ngram(&mut self, history: &[WordId], word: WordId, weight: W) {
let history_vec = history.to_vec();
self.ngrams
.entry(history_vec)
.or_insert_with(HashMap::new)
.insert(word, weight);
self.vocab_size = self.vocab_size.max(word as usize + 1);
for &h in history {
self.vocab_size = self.vocab_size.max(h as usize + 1);
}
}
pub fn set_backoff(&mut self, history: &[WordId], weight: W) {
self.backoffs.insert(history.to_vec(), weight);
}
pub fn build(self) -> NgramTransducer<W> {
let mut fst: VectorWfst<WordId, W> = VectorWfst::new();
let mut state_map: HashMap<NgramState, StateId> = HashMap::new();
let initial = NgramState::initial();
let start_id = fst.add_state();
fst.set_start(start_id);
fst.set_final(start_id, W::one());
state_map.insert(initial.clone(), start_id);
let unigram_backoff = NgramState::backoff(Vec::new());
let backoff_id = fst.add_state();
fst.set_final(backoff_id, W::one());
state_map.insert(unigram_backoff.clone(), backoff_id);
for (&word, weight) in &self.unigrams {
let next_state = NgramState::with_history(vec![word]);
let next_id = self.get_or_create_state(&mut fst, &mut state_map, &next_state);
fst.add_arc(backoff_id, Some(word), Some(word), next_id, weight.clone());
}
for (history, word_weights) in &self.ngrams {
let from_state = NgramState::with_history(history.clone());
let from_id = self.get_or_create_state(&mut fst, &mut state_map, &from_state);
for (&word, weight) in word_weights {
let next_state = from_state.extend(word, self.config.order - 1);
let next_id = self.get_or_create_state(&mut fst, &mut state_map, &next_state);
fst.add_arc(from_id, Some(word), Some(word), next_id, weight.clone());
}
if let Some(backoff_weight) = self.backoffs.get(history) {
let backoff_state = from_state.backed_off();
let backoff_id = self.get_or_create_state(&mut fst, &mut state_map, &backoff_state);
fst.add_arc(
from_id,
None, None, backoff_id,
backoff_weight.clone(),
);
}
}
let unigram_backoff_id = *state_map
.get(&NgramState::backoff(Vec::new()))
.expect("unigram backoff should exist");
if let Some(backoff_weight) = self.backoffs.get(&Vec::new()) {
fst.add_arc(
start_id,
None,
None,
unigram_backoff_id,
backoff_weight.clone(),
);
} else {
fst.add_arc(start_id, None, None, unigram_backoff_id, W::one());
}
NgramTransducer {
fst,
config: self.config,
vocab_size: self.vocab_size,
}
}
fn get_or_create_state(
&self,
fst: &mut VectorWfst<WordId, W>,
state_map: &mut HashMap<NgramState, StateId>,
state: &NgramState,
) -> StateId {
if let Some(&id) = state_map.get(state) {
id
} else {
let id = fst.add_state();
fst.set_final(id, W::one());
state_map.insert(state.clone(), id);
id
}
}
}
impl<W: Semiring> NgramTransducer<W> {
pub fn as_fst(&self) -> &VectorWfst<WordId, W> {
&self.fst
}
pub fn order(&self) -> NgramOrder {
self.config.order
}
pub fn vocabulary_size(&self) -> usize {
self.vocab_size
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::{Wfst, NO_STATE};
#[test]
fn test_ngram_state_initial() {
let state = NgramState::initial();
assert!(state.history.is_empty());
assert!(!state.is_backoff);
}
#[test]
fn test_ngram_state_extend() {
let state = NgramState::initial();
let state1 = state.extend(1, 2);
assert_eq!(state1.history, vec![1]);
let state2 = state1.extend(2, 2);
assert_eq!(state2.history, vec![1, 2]);
let state3 = state2.extend(3, 2);
assert_eq!(state3.history, vec![2, 3]);
}
#[test]
fn test_ngram_state_backoff() {
let state = NgramState::with_history(vec![1, 2, 3]);
let backoff = state.backed_off();
assert_eq!(backoff.history, vec![2, 3]);
assert!(backoff.is_backoff);
}
#[test]
fn test_bigram_builder() {
let mut builder = NgramBuilder::<LogWeight>::new(2);
builder.add_unigram(1, LogWeight::new(5.0));
builder.add_unigram(2, LogWeight::new(4.0));
builder.add_unigram(3, LogWeight::new(6.0));
builder.add_bigram(&[1], 2, LogWeight::new(2.0));
builder.add_bigram(&[1], 3, LogWeight::new(3.0));
builder.add_bigram(&[2], 1, LogWeight::new(2.5));
builder.set_backoff(&[1], LogWeight::new(0.5));
builder.set_backoff(&[2], LogWeight::new(0.6));
let lm = builder.build();
assert!(lm.fst.start() != NO_STATE);
assert!(lm.fst.num_states() > 0);
}
#[test]
fn test_trigram_builder() {
let mut builder = NgramBuilder::<LogWeight>::new(3);
builder.add_unigram(1, LogWeight::new(5.0));
builder.add_unigram(2, LogWeight::new(4.0));
builder.add_ngram(&[1], 2, LogWeight::new(2.0));
builder.add_ngram(&[1, 2], 1, LogWeight::new(1.5));
let lm = builder.build();
assert_eq!(lm.order(), 3);
assert!(lm.fst.num_states() >= 3);
}
#[test]
fn test_vocabulary_tracking() {
let mut builder = NgramBuilder::<LogWeight>::new(2);
builder.add_unigram(5, LogWeight::new(3.0));
builder.add_unigram(10, LogWeight::new(4.0));
let lm = builder.build();
assert!(lm.vocabulary_size() >= 11);
}
#[test]
fn test_backoff_transitions() {
let mut builder = NgramBuilder::<LogWeight>::new(2);
builder.add_unigram(1, LogWeight::new(5.0));
builder.add_unigram(2, LogWeight::new(4.0));
builder.add_bigram(&[1], 2, LogWeight::new(2.0));
builder.set_backoff(&[1], LogWeight::new(0.5));
let lm = builder.build();
let mut has_epsilon = false;
for state in 0..lm.fst.num_states() as StateId {
for trans in lm.fst.transitions(state) {
if trans.input.is_none() {
has_epsilon = true;
break;
}
}
if has_epsilon {
break;
}
}
assert!(has_epsilon, "Should have backoff epsilon transitions");
}
#[test]
fn test_all_states_final() {
let mut builder = NgramBuilder::<LogWeight>::new(2);
builder.add_unigram(1, LogWeight::new(5.0));
builder.add_unigram(2, LogWeight::new(4.0));
let lm = builder.build();
for state in 0..lm.fst.num_states() as StateId {
assert!(lm.fst.is_final(state));
}
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::{Wfst, NO_STATE};
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn initial_backoff_empty(_seed in any::<u64>()) {
let state = BackoffState::initial();
prop_assert!(state.history.is_empty());
prop_assert_eq!(state.order, 0);
}
#[test]
fn backoff_order_matches_history(history in prop::collection::vec(0u32..100, 0..5)) {
let expected_order = history.len();
let state = BackoffState::new(history);
prop_assert_eq!(state.order, expected_order);
}
#[test]
fn backoff_preserves_history(history in prop::collection::vec(0u32..100, 0..5)) {
let state = BackoffState::new(history.clone());
prop_assert_eq!(state.history, history);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_config_trigram(_seed in any::<u64>()) {
let config = NgramConfig::default();
prop_assert_eq!(config.order, 3);
}
#[test]
fn default_config_no_markers(_seed in any::<u64>()) {
let config = NgramConfig::default();
prop_assert!(!config.add_sentence_markers);
prop_assert!(config.sos_id.is_none());
prop_assert!(config.eos_id.is_none());
}
#[test]
fn default_config_no_unk(_seed in any::<u64>()) {
let config = NgramConfig::default();
prop_assert!(config.unk_id.is_none());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn initial_not_backoff(_seed in any::<u64>()) {
let state = NgramState::initial();
prop_assert!(!state.is_backoff);
prop_assert!(state.history.is_empty());
}
#[test]
fn with_history_not_backoff(history in prop::collection::vec(0u32..100, 0..5)) {
let state = NgramState::with_history(history);
prop_assert!(!state.is_backoff);
}
#[test]
fn backoff_creates_backoff(history in prop::collection::vec(0u32..100, 0..5)) {
let state = NgramState::backoff(history);
prop_assert!(state.is_backoff);
}
#[test]
fn backed_off_shortens(history in prop::collection::vec(0u32..100, 1..5)) {
let state = NgramState::with_history(history.clone());
let backed = state.backed_off();
prop_assert_eq!(backed.history.len(), history.len() - 1);
prop_assert!(backed.is_backoff);
}
#[test]
fn backed_off_empty_stays_empty(_seed in any::<u64>()) {
let state = NgramState::initial();
let backed = state.backed_off();
prop_assert!(backed.history.is_empty());
prop_assert!(backed.is_backoff);
}
#[test]
fn backed_off_removes_first(history in prop::collection::vec(0u32..100, 2..5)) {
let state = NgramState::with_history(history.clone());
let backed = state.backed_off();
prop_assert_eq!(backed.history, history[1..].to_vec());
}
#[test]
fn extend_adds_word(word in 0u32..100, max_history in 1usize..5) {
let state = NgramState::initial();
let extended = state.extend(word, max_history);
prop_assert!(extended.history.contains(&word));
prop_assert!(!extended.is_backoff);
}
#[test]
fn extend_respects_max(
words in prop::collection::vec(0u32..100, 1..10),
max_history in 1usize..5
) {
let mut state = NgramState::initial();
for &word in &words {
state = state.extend(word, max_history);
prop_assert!(state.history.len() <= max_history);
}
}
#[test]
fn extend_clears_backoff(word in 0u32..100, max_history in 1usize..5) {
let state = NgramState::backoff(vec![1, 2]);
let extended = state.extend(word, max_history);
prop_assert!(!extended.is_backoff);
}
#[test]
fn ngram_state_equality(history in prop::collection::vec(0u32..50, 0..4)) {
let state1 = NgramState::with_history(history.clone());
let state2 = NgramState::with_history(history);
prop_assert_eq!(state1, state2);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn builder_preserves_order(order in 1usize..5) {
let builder = NgramBuilder::<LogWeight>::new(order);
let lm = builder.build();
prop_assert_eq!(lm.order(), order);
}
#[test]
fn unigrams_update_vocab(word_id in 1u32..100) {
let mut builder = NgramBuilder::<LogWeight>::new(2);
builder.add_unigram(word_id, LogWeight::new(1.0));
let lm = builder.build();
prop_assert!(lm.vocabulary_size() >= word_id as usize + 1);
}
#[test]
fn ngrams_update_vocab(
history in prop::collection::vec(0u32..50, 1..3),
word in 50u32..100
) {
let mut builder = NgramBuilder::<LogWeight>::new(3);
builder.add_ngram(&history, word, LogWeight::new(1.0));
let lm = builder.build();
let max_word = history.iter().cloned().max().unwrap_or(0).max(word);
prop_assert!(lm.vocabulary_size() >= max_word as usize + 1);
}
#[test]
fn builder_config_updates(order in 1usize..5) {
let config = NgramConfig {
order,
add_sentence_markers: true,
..Default::default()
};
let builder = NgramBuilder::<LogWeight>::new(2).config(config);
let lm = builder.build();
prop_assert_eq!(lm.config.order, order);
prop_assert!(lm.config.add_sentence_markers);
}
#[test]
fn builder_vocab_size(size in 10usize..100) {
let builder = NgramBuilder::<LogWeight>::new(2).vocab_size(size);
let lm = builder.build();
prop_assert!(lm.vocabulary_size() <= size);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(25))]
#[test]
fn transducer_has_start(order in 1usize..4) {
let builder = NgramBuilder::<LogWeight>::new(order);
let lm = builder.build();
prop_assert!(lm.fst.start() != NO_STATE);
}
#[test]
fn transducer_min_states(order in 1usize..4) {
let builder = NgramBuilder::<LogWeight>::new(order);
let lm = builder.build();
prop_assert!(lm.fst.num_states() >= 2);
}
#[test]
fn transducer_all_final(order in 1usize..4) {
let mut builder = NgramBuilder::<LogWeight>::new(order);
builder.add_unigram(1, LogWeight::new(1.0));
let lm = builder.build();
for state in 0..lm.fst.num_states() as StateId {
prop_assert!(lm.fst.is_final(state));
}
}
#[test]
fn transducer_has_backoff_arcs(order in 2usize..4) {
let mut builder = NgramBuilder::<LogWeight>::new(order);
builder.add_unigram(1, LogWeight::new(1.0));
builder.add_unigram(2, LogWeight::new(1.0));
let lm = builder.build();
let mut has_epsilon = false;
for state in 0..lm.fst.num_states() as StateId {
for trans in lm.fst.transitions(state) {
if trans.input.is_none() {
has_epsilon = true;
break;
}
}
}
prop_assert!(has_epsilon);
}
#[test]
fn unigram_transitions_exist(words in prop::collection::vec(1u32..10, 1..5)) {
let mut builder = NgramBuilder::<LogWeight>::new(2);
for &word in &words {
builder.add_unigram(word, LogWeight::new(1.0));
}
let lm = builder.build();
let mut word_transitions = 0;
for state in 0..lm.fst.num_states() as StateId {
for trans in lm.fst.transitions(state) {
if trans.input.is_some() {
word_transitions += 1;
}
}
}
let unique_words: std::collections::HashSet<_> = words.iter().collect();
prop_assert!(word_transitions >= unique_words.len());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn bigram_backoff_structure(
word1 in 1u32..10,
word2 in 10u32..20
) {
let mut builder = NgramBuilder::<LogWeight>::new(2);
builder.add_unigram(word1, LogWeight::new(1.0));
builder.add_unigram(word2, LogWeight::new(1.0));
builder.add_bigram(&[word1], word2, LogWeight::new(0.5));
builder.set_backoff(&[word1], LogWeight::new(0.3));
let lm = builder.build();
prop_assert!(lm.fst.num_states() >= 3);
}
}
}