use std::collections::HashMap;
use std::sync::Arc;
use lling_llang::asr::{NgramBuilder, NgramConfig, NgramTransducer};
use lling_llang::semiring::{LogWeight, ProbabilityWeight, Semiring, TropicalWeight};
use lling_llang::wfst::{MutableWfst, StateId, VectorWfst};
#[allow(deprecated)]
use crate::ngram::{IterableDictionary, NgramEntry, NgramModel, NGRAM_SEPARATOR};
use liblevenshtein::dictionary::MutableMappedDictionary;
use super::vocabulary::{WordId, WordVocabulary, EOS_WORD_ID, UNK_WORD_ID};
pub trait FromLogProb: Semiring {
fn from_log_prob(log_prob: f64) -> Self;
fn from_neg_log_prob(neg_log_prob: f64) -> Self;
}
impl FromLogProb for LogWeight {
#[inline]
fn from_log_prob(log_prob: f64) -> Self {
LogWeight::new(-log_prob)
}
#[inline]
fn from_neg_log_prob(neg_log_prob: f64) -> Self {
LogWeight::new(neg_log_prob)
}
}
impl FromLogProb for TropicalWeight {
#[inline]
fn from_log_prob(log_prob: f64) -> Self {
TropicalWeight::new(-log_prob)
}
#[inline]
fn from_neg_log_prob(neg_log_prob: f64) -> Self {
TropicalWeight::new(neg_log_prob)
}
}
impl FromLogProb for ProbabilityWeight {
#[inline]
fn from_log_prob(log_prob: f64) -> Self {
ProbabilityWeight::new(log_prob.exp())
}
#[inline]
fn from_neg_log_prob(neg_log_prob: f64) -> Self {
ProbabilityWeight::new((-neg_log_prob).exp())
}
}
pub struct NgramWfstBuilder<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
W: Semiring + FromLogProb,
{
model: Arc<NgramModel<D>>,
vocabulary: WordVocabulary,
wfst: VectorWfst<WordId, W>,
history_to_state: HashMap<Vec<WordId>, StateId>,
start_state: StateId,
backoff_state: StateId,
}
impl<D, W> NgramWfstBuilder<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
W: Semiring + FromLogProb,
{
pub fn new(model: Arc<NgramModel<D>>) -> Self {
let vocab_size = model.vocab_size();
let vocabulary = Self::build_vocabulary(&model);
let estimated_states = vocab_size.saturating_mul(2);
let mut wfst = VectorWfst::with_capacity(estimated_states);
let start_state = wfst.add_state();
let backoff_state = wfst.add_state();
wfst.set_start(start_state);
wfst.set_final(start_state, W::one());
wfst.set_final(backoff_state, W::one());
let mut history_to_state = HashMap::with_capacity(estimated_states);
history_to_state.insert(vec![], backoff_state);
Self {
model,
vocabulary,
wfst,
history_to_state,
start_state,
backoff_state,
}
}
#[allow(deprecated)]
fn build_vocabulary(model: &NgramModel<D>) -> WordVocabulary {
let mut vocab = WordVocabulary::with_capacity(model.vocab_size());
for (key, _entry) in model.trie().iter_entries() {
if !key.contains(NGRAM_SEPARATOR) {
vocab.add_word(&key);
}
}
vocab
}
fn get_or_create_state(&mut self, history: &[WordId]) -> StateId {
if let Some(&state) = self.history_to_state.get(history) {
return state;
}
let state = self.wfst.add_state();
self.wfst.set_final(state, W::one());
self.history_to_state.insert(history.to_vec(), state);
state
}
fn get_backoff_history(history: &[WordId]) -> Option<Vec<WordId>> {
if history.is_empty() {
None
} else {
Some(history[1..].to_vec())
}
}
pub fn build(mut self) -> (VectorWfst<WordId, W>, WordVocabulary) {
let order = self.model.order();
self.wfst
.add_epsilon(self.start_state, self.backoff_state, W::one());
self.add_unigrams();
if order > 1 {
self.add_higher_order_ngrams();
}
self.add_backoff_transitions();
(self.wfst, self.vocabulary)
}
fn add_unigrams(&mut self) {
let backoff_state = self.backoff_state;
let unigram_words: Vec<String> = self
.vocabulary
.iter()
.skip(2) .map(|(word, _)| word.to_string())
.collect();
for word in unigram_words {
let word_id = self
.vocabulary
.get_id(&word)
.expect("Word must be in vocabulary");
let log_prob = self.model.log_prob(&word, &[]);
let weight = W::from_log_prob(log_prob);
let history = vec![word_id];
let target_state = self.get_or_create_state(&history);
self.wfst.add_arc(
backoff_state,
Some(word_id),
Some(word_id),
target_state,
weight,
);
}
}
#[allow(deprecated)]
fn add_higher_order_ngrams(&mut self) {
let order = self.model.order();
let ngrams: Vec<(Vec<String>, String)> = self
.model
.trie()
.iter_entries()
.filter_map(|(key, _entry)| {
let tokens: Vec<&str> = key.split(NGRAM_SEPARATOR).collect();
if tokens.len() >= 2 && tokens.len() <= order {
let history: Vec<String> = tokens[..tokens.len() - 1]
.iter()
.map(|s| s.to_string())
.collect();
let word = tokens.last().unwrap().to_string();
Some((history, word))
} else {
None
}
})
.collect();
for (history_words, word) in ngrams {
let history_ids: Vec<WordId> = history_words
.iter()
.filter_map(|w| self.vocabulary.get_id(w))
.collect();
if history_ids.len() != history_words.len() {
continue;
}
let word_id = match self.vocabulary.get_id(&word) {
Some(id) => id,
None => continue, };
let source_state = self.get_or_create_state(&history_ids);
let mut target_history = history_ids.clone();
target_history.push(word_id);
if target_history.len() >= order {
target_history = target_history[target_history.len() - (order - 1)..].to_vec();
}
let target_state = self.get_or_create_state(&target_history);
let history_strs: Vec<&str> = history_words.iter().map(|s| s.as_str()).collect();
let log_prob = self.model.log_prob(&word, &history_strs);
let weight = W::from_log_prob(log_prob);
self.wfst.add_arc(
source_state,
Some(word_id),
Some(word_id),
target_state,
weight,
);
}
}
fn add_backoff_transitions(&mut self) {
let histories: Vec<(Vec<WordId>, StateId)> = self
.history_to_state
.iter()
.map(|(h, &s)| (h.clone(), s))
.collect();
for (history, state) in histories {
if history.is_empty() {
continue;
}
if let Some(backoff_history) = Self::get_backoff_history(&history) {
let backoff_state = self
.history_to_state
.get(&backoff_history)
.copied()
.unwrap_or(self.backoff_state);
let backoff_weight = W::one();
self.wfst.add_epsilon(state, backoff_state, backoff_weight);
}
}
}
}
pub trait NgramWfstExport<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
{
fn to_wfst<W>(&self) -> (VectorWfst<WordId, W>, WordVocabulary)
where
W: Semiring + FromLogProb;
fn into_wfst<W>(self) -> (VectorWfst<WordId, W>, WordVocabulary)
where
W: Semiring + FromLogProb;
fn to_ngram_transducer<W>(&self) -> (NgramTransducer<W>, WordVocabulary)
where
W: Semiring + FromLogProb + Clone;
fn into_ngram_transducer<W>(self) -> (NgramTransducer<W>, WordVocabulary)
where
W: Semiring + FromLogProb + Clone;
}
pub struct NgramTransducerBuilder<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
W: Semiring + FromLogProb + Clone,
{
model: Arc<NgramModel<D>>,
vocabulary: WordVocabulary,
builder: NgramBuilder<W>,
}
impl<D, W> NgramTransducerBuilder<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
W: Semiring + FromLogProb + Clone,
{
pub fn new(model: Arc<NgramModel<D>>) -> Self {
let vocabulary = NgramWfstBuilder::<D, W>::build_vocabulary(&model);
let mut builder = NgramBuilder::new(model.order());
let config = NgramConfig {
order: model.order(),
add_sentence_markers: true,
sos_id: None, eos_id: Some(EOS_WORD_ID),
unk_id: Some(UNK_WORD_ID),
};
builder = builder.config(config);
Self {
model,
vocabulary,
builder,
}
}
pub fn build(mut self) -> (NgramTransducer<W>, WordVocabulary) {
self.builder = self.builder.vocab_size(self.vocabulary.len());
self.add_unigrams();
self.add_higher_order_ngrams();
let transducer = self.builder.build();
(transducer, self.vocabulary)
}
fn add_unigrams(&mut self) {
let words: Vec<(String, WordId)> = self
.vocabulary
.iter()
.skip(2) .map(|(word, id)| (word.to_string(), id))
.collect();
for (word, word_id) in words {
let log_prob = self.model.log_prob(&word, &[]);
let weight = W::from_log_prob(log_prob);
self.builder.add_unigram(word_id, weight);
}
}
#[allow(deprecated)]
fn add_higher_order_ngrams(&mut self) {
let order = self.model.order();
let ngrams: Vec<(Vec<String>, String)> = self
.model
.trie()
.iter_entries()
.filter_map(|(key, _entry)| {
let tokens: Vec<&str> = key.split(NGRAM_SEPARATOR).collect();
if tokens.len() >= 2 && tokens.len() <= order {
let history: Vec<String> = tokens[..tokens.len() - 1]
.iter()
.map(|s| s.to_string())
.collect();
let word = tokens.last().expect("tokens not empty").to_string();
Some((history, word))
} else {
None
}
})
.collect();
let mut histories_seen: std::collections::HashSet<Vec<WordId>> =
std::collections::HashSet::new();
for (history_words, word) in ngrams {
let history_ids: Vec<WordId> = history_words
.iter()
.filter_map(|w| self.vocabulary.get_id(w))
.collect();
if history_ids.len() != history_words.len() {
continue;
}
let word_id = match self.vocabulary.get_id(&word) {
Some(id) => id,
None => continue, };
let history_strs: Vec<&str> = history_words.iter().map(|s| s.as_str()).collect();
let log_prob = self.model.log_prob(&word, &history_strs);
let weight = W::from_log_prob(log_prob);
self.builder.add_ngram(&history_ids, word_id, weight);
histories_seen.insert(history_ids);
}
for history_ids in histories_seen {
let backoff_weight = W::one();
self.builder.set_backoff(&history_ids, backoff_weight);
}
}
}
impl<D> NgramWfstExport<D> for NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
{
fn to_wfst<W>(&self) -> (VectorWfst<WordId, W>, WordVocabulary)
where
W: Semiring + FromLogProb,
{
let model = Arc::new(self.clone());
let builder = NgramWfstBuilder::new(model);
builder.build()
}
fn into_wfst<W>(self) -> (VectorWfst<WordId, W>, WordVocabulary)
where
W: Semiring + FromLogProb,
{
let model = Arc::new(self);
let builder = NgramWfstBuilder::new(model);
builder.build()
}
fn to_ngram_transducer<W>(&self) -> (NgramTransducer<W>, WordVocabulary)
where
W: Semiring + FromLogProb + Clone,
{
let model = Arc::new(self.clone());
let builder = NgramTransducerBuilder::new(model);
builder.build()
}
fn into_ngram_transducer<W>(self) -> (NgramTransducer<W>, WordVocabulary)
where
W: Semiring + FromLogProb + Clone,
{
let model = Arc::new(self);
let builder = NgramTransducerBuilder::new(model);
builder.build()
}
}
impl<D> NgramWfstExport<D> for Arc<NgramModel<D>>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
{
fn to_wfst<W>(&self) -> (VectorWfst<WordId, W>, WordVocabulary)
where
W: Semiring + FromLogProb,
{
let builder = NgramWfstBuilder::new(Arc::clone(self));
builder.build()
}
fn into_wfst<W>(self) -> (VectorWfst<WordId, W>, WordVocabulary)
where
W: Semiring + FromLogProb,
{
let builder = NgramWfstBuilder::new(self);
builder.build()
}
fn to_ngram_transducer<W>(&self) -> (NgramTransducer<W>, WordVocabulary)
where
W: Semiring + FromLogProb + Clone,
{
let builder = NgramTransducerBuilder::new(Arc::clone(self));
builder.build()
}
fn into_ngram_transducer<W>(self) -> (NgramTransducer<W>, WordVocabulary)
where
W: Semiring + FromLogProb + Clone,
{
let builder = NgramTransducerBuilder::new(self);
builder.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::PlaintextReader;
use crate::ngram::TrainerBuilder;
use liblevenshtein::dictionary::pathmap::PathMapDictionary;
use lling_llang::wfst::Wfst;
use std::io::Write;
use tempfile::TempDir;
fn create_test_model() -> NgramModel<PathMapDictionary<NgramEntry>> {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox";
let path = dir.path().join("test.txt");
let mut file = std::fs::File::create(&path).expect("Failed to create test file");
write!(file, "{}", content).expect("Failed to write test file");
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
let dictionary = PathMapDictionary::<NgramEntry>::new();
TrainerBuilder::new(dictionary)
.order(3)
.train(reader)
.expect("Training failed")
}
#[test]
fn test_from_log_prob_log_weight() {
let log_prob = -1.0_f64; let weight = LogWeight::from_log_prob(log_prob);
assert!((weight.value() - 1.0).abs() < 1e-10);
}
#[test]
fn test_from_log_prob_tropical_weight() {
let log_prob = -2.0_f64;
let weight = TropicalWeight::from_log_prob(log_prob);
assert!((weight.value() - 2.0).abs() < 1e-10);
}
#[test]
fn test_from_log_prob_probability_weight() {
let log_prob = 0.0_f64; let weight = ProbabilityWeight::from_log_prob(log_prob);
assert!((weight.value() - 1.0).abs() < 1e-10);
let log_prob = -1.0_f64; let weight = ProbabilityWeight::from_log_prob(log_prob);
assert!((weight.value() - (-1.0_f64).exp()).abs() < 1e-10);
}
#[test]
fn test_build_vocabulary() {
let model = create_test_model();
let model_arc = Arc::new(model);
let vocab = NgramWfstBuilder::<_, LogWeight>::build_vocabulary(&model_arc);
assert!(vocab.len() >= 2);
assert!(vocab.contains("the"));
assert!(vocab.contains("quick"));
assert!(vocab.contains("brown"));
assert!(vocab.contains("fox"));
}
#[test]
fn test_to_wfst_basic() {
let model = create_test_model();
let (wfst, vocab): (VectorWfst<WordId, LogWeight>, _) = model.to_wfst();
assert!(wfst.num_states() > 0);
assert!(wfst.start() != u32::MAX);
assert!(vocab.len() > 2);
}
#[test]
fn test_to_wfst_has_transitions() {
let model = create_test_model();
let (wfst, _vocab): (VectorWfst<WordId, LogWeight>, _) = model.to_wfst();
let start = wfst.start();
let transitions = wfst.transitions(start);
assert!(!transitions.is_empty());
}
#[test]
fn test_to_wfst_final_states() {
let model = create_test_model();
let (wfst, _vocab): (VectorWfst<WordId, LogWeight>, _) = model.to_wfst();
for state_id in 0..wfst.num_states() as StateId {
assert!(
wfst.is_final(state_id),
"State {} should be final",
state_id
);
}
}
#[test]
fn test_to_wfst_weights_finite() {
let model = create_test_model();
let (wfst, _vocab): (VectorWfst<WordId, LogWeight>, _) = model.to_wfst();
for state_id in 0..wfst.num_states() as StateId {
for transition in wfst.transitions(state_id) {
assert!(
transition.weight.value().is_finite(),
"Transition weight should be finite"
);
}
}
}
#[test]
fn test_tropical_weight_wfst() {
let model = create_test_model();
let (wfst, _vocab): (VectorWfst<WordId, TropicalWeight>, _) = model.to_wfst();
assert!(wfst.num_states() > 0);
}
#[test]
fn test_probability_weight_wfst() {
let model = create_test_model();
let (wfst, _vocab): (VectorWfst<WordId, ProbabilityWeight>, _) = model.to_wfst();
assert!(wfst.num_states() > 0);
}
#[test]
fn test_to_ngram_transducer_basic() {
let model = create_test_model();
let (transducer, vocab) = model.to_ngram_transducer::<LogWeight>();
assert!(transducer.fst.num_states() > 0);
assert!(transducer.fst.start() != u32::MAX);
assert!(vocab.len() > 2);
assert_eq!(transducer.order(), 3);
}
#[test]
fn test_to_ngram_transducer_has_unigram_transitions() {
let model = create_test_model();
let (transducer, vocab) = model.to_ngram_transducer::<LogWeight>();
let mut word_transitions = 0;
for state_id in 0..transducer.fst.num_states() as StateId {
for trans in transducer.fst.transitions(state_id) {
if trans.input.is_some() {
word_transitions += 1;
}
}
}
assert!(
word_transitions >= vocab.len() - 2,
"Expected at least {} word transitions, got {}",
vocab.len() - 2,
word_transitions
);
}
#[test]
fn test_to_ngram_transducer_final_states() {
let model = create_test_model();
let (transducer, _vocab) = model.to_ngram_transducer::<LogWeight>();
for state_id in 0..transducer.fst.num_states() as StateId {
assert!(
transducer.fst.is_final(state_id),
"State {} should be final",
state_id
);
}
}
#[test]
fn test_to_ngram_transducer_weights_finite() {
let model = create_test_model();
let (transducer, _vocab) = model.to_ngram_transducer::<LogWeight>();
for state_id in 0..transducer.fst.num_states() as StateId {
for transition in transducer.fst.transitions(state_id) {
assert!(
transition.weight.value().is_finite(),
"Transition weight should be finite at state {}",
state_id
);
}
}
}
#[test]
fn test_to_ngram_transducer_tropical_weight() {
let model = create_test_model();
let (transducer, _vocab) = model.to_ngram_transducer::<TropicalWeight>();
assert!(transducer.fst.num_states() > 0);
assert_eq!(transducer.order(), 3);
}
#[test]
fn test_to_ngram_transducer_vocabulary_size() {
let model = create_test_model();
let (transducer, vocab) = model.to_ngram_transducer::<LogWeight>();
assert_eq!(transducer.vocabulary_size(), vocab.len());
}
#[test]
fn test_to_ngram_transducer_has_backoff_arcs() {
let model = create_test_model();
let (transducer, _vocab) = model.to_ngram_transducer::<LogWeight>();
let mut has_epsilon = false;
for state_id in 0..transducer.fst.num_states() as StateId {
for trans in transducer.fst.transitions(state_id) {
if trans.input.is_none() {
has_epsilon = true;
break;
}
}
if has_epsilon {
break;
}
}
assert!(
has_epsilon,
"Transducer should have backoff epsilon transitions"
);
}
#[test]
fn test_into_ngram_transducer() {
let model = create_test_model();
let (transducer, vocab) = model.into_ngram_transducer::<LogWeight>();
assert!(transducer.fst.num_states() > 0);
assert!(vocab.len() > 2);
}
#[test]
fn test_arc_ngram_transducer() {
let model = Arc::new(create_test_model());
let (transducer, vocab) = model.to_ngram_transducer::<LogWeight>();
assert!(transducer.fst.num_states() > 0);
assert!(vocab.len() > 2);
}
}