use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, RwLock};
use lling_llang::semiring::Semiring;
use lling_llang::wfst::{LazyState, StateId, StateSource, WeightedTransition};
use smallvec::SmallVec;
#[allow(deprecated)]
use crate::ngram::{IterableDictionary, NgramEntry, NgramModel, NGRAM_SEPARATOR};
use liblevenshtein::dictionary::MutableMappedDictionary;
use super::vocabulary::{WordId, WordVocabulary};
use super::wfst_export::FromLogProb;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum NgramHistoryKey {
Start,
Backoff {
history: Vec<WordId>,
},
History {
words: Vec<WordId>,
},
}
impl NgramHistoryKey {
pub fn start() -> Self {
Self::Start
}
pub fn backoff(history: Vec<WordId>) -> Self {
Self::Backoff { history }
}
pub fn history(words: Vec<WordId>) -> Self {
Self::History { words }
}
}
#[derive(Debug)]
pub struct NgramStateRegistry {
id_to_history: Vec<NgramHistoryKey>,
history_to_id: HashMap<NgramHistoryKey, StateId>,
next_id: StateId,
}
impl NgramStateRegistry {
pub fn new() -> Self {
let mut registry = Self {
id_to_history: Vec::with_capacity(64),
history_to_id: HashMap::with_capacity(64),
next_id: 0,
};
registry.register(NgramHistoryKey::Start);
registry.register(NgramHistoryKey::Backoff { history: vec![] });
registry
}
pub fn register(&mut self, key: NgramHistoryKey) -> StateId {
if let Some(&id) = self.history_to_id.get(&key) {
return id;
}
let id = self.next_id;
self.next_id += 1;
self.id_to_history.push(key.clone());
self.history_to_id.insert(key, id);
id
}
pub fn get_history(&self, state: StateId) -> Option<&NgramHistoryKey> {
self.id_to_history.get(state as usize)
}
pub fn get_state(&self, key: &NgramHistoryKey) -> Option<StateId> {
self.history_to_id.get(key).copied()
}
pub fn len(&self) -> usize {
self.id_to_history.len()
}
pub fn is_empty(&self) -> bool {
self.id_to_history.is_empty()
}
}
impl Default for NgramStateRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct NgramStateSource<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry>,
W: Semiring + FromLogProb,
{
model: Arc<NgramModel<D>>,
vocabulary: Arc<WordVocabulary>,
state_registry: RwLock<NgramStateRegistry>,
_weight: PhantomData<W>,
}
impl<D, W> NgramStateSource<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
W: Semiring + FromLogProb,
{
pub fn new(model: Arc<NgramModel<D>>) -> Self {
let vocabulary = Arc::new(Self::build_vocabulary(&model));
Self {
model,
vocabulary,
state_registry: RwLock::new(NgramStateRegistry::new()),
_weight: PhantomData,
}
}
pub fn with_vocabulary(model: Arc<NgramModel<D>>, vocabulary: Arc<WordVocabulary>) -> Self {
Self {
model,
vocabulary,
state_registry: RwLock::new(NgramStateRegistry::new()),
_weight: PhantomData,
}
}
pub fn vocabulary(&self) -> &WordVocabulary {
&self.vocabulary
}
pub fn model(&self) -> &NgramModel<D> {
&self.model
}
#[allow(deprecated)]
fn build_vocabulary(model: &NgramModel<D>) -> WordVocabulary {
let mut vocab = WordVocabulary::with_capacity(model.vocab_size());
for (key, _entry) in model.trie().iter_entries() {
if !key.contains(NGRAM_SEPARATOR) {
vocab.add_word(&key);
}
}
vocab
}
fn get_or_register_state(&self, key: NgramHistoryKey) -> StateId {
{
let registry = self.state_registry.read().expect("Lock poisoned");
if let Some(id) = registry.get_state(&key) {
return id;
}
}
let mut registry = self.state_registry.write().expect("Lock poisoned");
registry.register(key)
}
fn compute_start_state(&self) -> LazyState<WordId, W> {
let mut transitions: SmallVec<[WeightedTransition<WordId, W>; 4]> = SmallVec::new();
transitions.push(WeightedTransition::new(0, None, None, 1, W::one()));
LazyState::final_state(W::one(), transitions)
}
fn compute_backoff_state(&self, history: &[WordId]) -> LazyState<WordId, W> {
let mut transitions: SmallVec<[WeightedTransition<WordId, W>; 4]> = SmallVec::new();
let order = self.model.order();
let history_strs: Vec<String> = history
.iter()
.filter_map(|&id| self.vocabulary.get_word(id).map(|s| s.to_string()))
.collect();
let history_refs: Vec<&str> = history_strs.iter().map(|s| s.as_str()).collect();
for (word, word_id) in self.vocabulary.iter().skip(2) {
let log_prob = self.model.log_prob(word, &history_refs);
if log_prob.is_finite() {
let weight = W::from_log_prob(log_prob);
let mut target_history = history.to_vec();
target_history.push(word_id);
if target_history.len() >= order {
target_history = target_history[target_history.len() - (order - 1)..].to_vec();
}
let target_key = NgramHistoryKey::History {
words: target_history,
};
let target_state = self.get_or_register_state(target_key);
let source_state = if history.is_empty() {
1 } else {
self.get_or_register_state(NgramHistoryKey::Backoff {
history: history.to_vec(),
})
};
transitions.push(WeightedTransition::new(
source_state,
Some(word_id),
Some(word_id),
target_state,
weight,
));
}
}
LazyState::final_state(W::one(), transitions)
}
fn compute_history_state(&self, words: &[WordId]) -> LazyState<WordId, W> {
let mut transitions: SmallVec<[WeightedTransition<WordId, W>; 4]> = SmallVec::new();
let order = self.model.order();
let history_strs: Vec<String> = words
.iter()
.filter_map(|&id| self.vocabulary.get_word(id).map(|s| s.to_string()))
.collect();
let history_refs: Vec<&str> = history_strs.iter().map(|s| s.as_str()).collect();
let source_state = self.get_or_register_state(NgramHistoryKey::History {
words: words.to_vec(),
});
for (word, word_id) in self.vocabulary.iter().skip(2) {
let log_prob = self.model.log_prob(word, &history_refs);
if log_prob.is_finite() {
let weight = W::from_log_prob(log_prob);
let mut target_history = words.to_vec();
target_history.push(word_id);
if target_history.len() >= order {
target_history = target_history[target_history.len() - (order - 1)..].to_vec();
}
let target_key = NgramHistoryKey::History {
words: target_history,
};
let target_state = self.get_or_register_state(target_key);
transitions.push(WeightedTransition::new(
source_state,
Some(word_id),
Some(word_id),
target_state,
weight,
));
}
}
if !words.is_empty() {
let backoff_history = words[1..].to_vec();
let backoff_key = if backoff_history.is_empty() {
NgramHistoryKey::Backoff { history: vec![] }
} else {
NgramHistoryKey::History {
words: backoff_history,
}
};
let backoff_state = self.get_or_register_state(backoff_key);
transitions.push(WeightedTransition::new(
source_state,
None, None,
backoff_state,
W::one(),
));
}
LazyState::final_state(W::one(), transitions)
}
}
impl<D, W> Clone for NgramStateSource<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
W: Semiring + FromLogProb,
{
fn clone(&self) -> Self {
Self {
model: Arc::clone(&self.model),
vocabulary: Arc::clone(&self.vocabulary),
state_registry: RwLock::new(NgramStateRegistry::new()),
_weight: PhantomData,
}
}
}
impl<D, W> StateSource<WordId, W> for NgramStateSource<D, W>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary + Send + Sync,
W: Semiring + FromLogProb,
{
fn compute_state(&self, state: StateId) -> LazyState<WordId, W> {
let history_key = {
let registry = self.state_registry.read().expect("Lock poisoned");
registry.get_history(state).cloned()
};
match history_key {
Some(NgramHistoryKey::Start) => self.compute_start_state(),
Some(NgramHistoryKey::Backoff { history }) => self.compute_backoff_state(&history),
Some(NgramHistoryKey::History { words }) => self.compute_history_state(&words),
None => {
LazyState::non_final(SmallVec::new())
}
}
}
fn start(&self) -> StateId {
0 }
fn num_states_hint(&self) -> Option<usize> {
None
}
}
pub trait NgramLazyWfst<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
{
fn to_lazy_wfst_source<W>(&self) -> NgramStateSource<D, W>
where
W: Semiring + FromLogProb;
}
impl<D> NgramLazyWfst<D> for NgramModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
{
fn to_lazy_wfst_source<W>(&self) -> NgramStateSource<D, W>
where
W: Semiring + FromLogProb,
{
NgramStateSource::new(Arc::new(self.clone()))
}
}
impl<D> NgramLazyWfst<D> for Arc<NgramModel<D>>
where
D: MutableMappedDictionary<Value = NgramEntry> + IterableDictionary,
{
fn to_lazy_wfst_source<W>(&self) -> NgramStateSource<D, W>
where
W: Semiring + FromLogProb,
{
NgramStateSource::new(Arc::clone(self))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::PlaintextReader;
use crate::ngram::TrainerBuilder;
use liblevenshtein::dictionary::pathmap::PathMapDictionary;
use lling_llang::semiring::LogWeight;
use lling_llang::wfst::LazyWfst;
use lling_llang::wfst::LazyWfstWrapper;
use std::io::Write;
use tempfile::TempDir;
fn create_test_model() -> NgramModel<PathMapDictionary<NgramEntry>> {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox the quick brown dog";
let path = dir.path().join("test.txt");
let mut file = std::fs::File::create(&path).expect("Failed to create test file");
write!(file, "{}", content).expect("Failed to write test file");
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
let dictionary = PathMapDictionary::<NgramEntry>::new();
TrainerBuilder::new(dictionary)
.order(3)
.train(reader)
.expect("Training failed")
}
#[test]
fn test_state_registry() {
let mut registry = NgramStateRegistry::new();
assert_eq!(registry.get_state(&NgramHistoryKey::Start), Some(0));
assert_eq!(
registry.get_state(&NgramHistoryKey::Backoff { history: vec![] }),
Some(1)
);
let id1 = registry.register(NgramHistoryKey::History { words: vec![5] });
let id2 = registry.register(NgramHistoryKey::History { words: vec![5, 6] });
assert_eq!(id1, 2);
assert_eq!(id2, 3);
let id1_dup = registry.register(NgramHistoryKey::History { words: vec![5] });
assert_eq!(id1_dup, id1);
}
#[test]
fn test_ngram_state_source_creation() {
let model = create_test_model();
let source: NgramStateSource<_, LogWeight> = model.to_lazy_wfst_source();
assert_eq!(source.start(), 0);
assert!(source.vocabulary().len() > 2); }
#[test]
fn test_lazy_wfst_start_state() {
let model = create_test_model();
let source: NgramStateSource<_, LogWeight> = model.to_lazy_wfst_source();
let mut lazy = LazyWfstWrapper::new(source);
lazy.expand(0);
assert!(lazy.is_expanded(0));
let transitions = lazy.transitions_lazy(0);
assert!(!transitions.is_empty());
}
#[test]
fn test_lazy_wfst_expansion() {
let model = create_test_model();
let source: NgramStateSource<_, LogWeight> = model.to_lazy_wfst_source();
let mut lazy = LazyWfstWrapper::new(source);
lazy.expand(0);
let initial_count = lazy.computed_states();
lazy.expand(1);
assert!(lazy.computed_states() > initial_count);
}
}