use super::entry::NgramEntry;
use libdictenstein::persistent_artrie::SharedTrieAccess;
use liblevenshtein::dictionary::{MappedDictionaryNode, MutableMappedDictionary};
use std::marker::PhantomData;
use std::sync::Arc;
pub trait IterableDictionary: MutableMappedDictionary<Value = NgramEntry> {
fn iter_all(&self) -> Box<dyn Iterator<Item = (String, NgramEntry)> + '_>;
}
impl IterableDictionary
for liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar<NgramEntry>
{
fn iter_all(&self) -> Box<dyn Iterator<Item = (String, NgramEntry)> + '_> {
Box::new(self.iter())
}
}
impl IterableDictionary for liblevenshtein::dictionary::pathmap::PathMapDictionary<NgramEntry> {
fn iter_all(&self) -> Box<dyn Iterator<Item = (String, NgramEntry)> + '_> {
Box::new(self.iter())
}
}
impl IterableDictionary for libdictenstein::persistent_artrie_char::SharedCharARTrie<NgramEntry> {
fn iter_all(&self) -> Box<dyn Iterator<Item = (String, NgramEntry)> + '_> {
let entries: Vec<(String, NgramEntry)> = self.read().iter_with_values().collect();
Box::new(entries.into_iter())
}
}
impl<D> IterableDictionary for super::vocabulary_indexed::VocabularyIndexedDictionary<D>
where
D: IterableDictionary,
D::Node: MappedDictionaryNode<Unit = char>,
{
fn iter_all(&self) -> Box<dyn Iterator<Item = (String, NgramEntry)> + '_> {
let delimiter = self.delimiter().to_string();
let guard = self.vocabulary().read();
let decoded: Vec<(String, NgramEntry)> = self
.backend()
.iter_all()
.filter_map(|(key, entry)| {
let indices = super::vocabulary_indexed::decode_key_to_indices(&key);
if indices.is_empty() {
return None;
}
let mut words = Vec::with_capacity(indices.len());
for idx in indices {
words.push(guard.get_term(idx)?);
}
Some((words.join(&delimiter), entry))
})
.collect();
drop(guard);
Box::new(decoded.into_iter())
}
}
#[deprecated(
since = "0.3.0",
note = "Use vocabulary-indexed encoding via crate::ngram::vocabulary instead. \
Pipe-separated keys can corrupt data if tokens contain '|'."
)]
pub const NGRAM_SEPARATOR: char = '|';
pub(crate) const LEGACY_NGRAM_SEPARATOR: char = '|';
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(bound = "D: serde::Serialize + serde::de::DeserializeOwned")]
pub struct NgramTrie<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
dictionary: Arc<D>,
max_order: usize,
#[serde(skip)]
_marker: PhantomData<D>,
}
impl<D> NgramTrie<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
pub fn new(dictionary: D, max_order: usize) -> Self {
Self {
dictionary: Arc::new(dictionary),
max_order,
_marker: PhantomData,
}
}
pub fn from_arc(dictionary: Arc<D>, max_order: usize) -> Self {
Self {
dictionary,
max_order,
_marker: PhantomData,
}
}
#[inline]
pub fn max_order(&self) -> usize {
self.max_order
}
#[inline]
pub fn dictionary(&self) -> &D {
&self.dictionary
}
#[inline]
pub fn dictionary_arc(&self) -> Arc<D> {
Arc::clone(&self.dictionary)
}
#[inline]
#[deprecated(
since = "0.3.0",
note = "Use vocabulary::encode_ngram_key() instead. \
Pipe-separated keys can corrupt data if tokens contain '|'."
)]
pub fn encode_key(tokens: &[&str]) -> String {
Self::encode_key_legacy(tokens)
}
#[inline]
pub(crate) fn encode_key_legacy(tokens: &[&str]) -> String {
tokens.join(&LEGACY_NGRAM_SEPARATOR.to_string())
}
pub fn insert(&self, tokens: &[&str]) -> bool {
let key = Self::encode_key_legacy(tokens);
self.dictionary
.update_or_insert(&key, NgramEntry::new(1), |entry| entry.increment())
}
pub fn insert_with_key(&self, key: &str) -> bool {
self.dictionary
.update_or_insert(key, NgramEntry::new(1), |entry| entry.increment())
}
pub fn insert_with_count(&self, tokens: &[&str], count: u64) -> bool {
let key = Self::encode_key_legacy(tokens);
self.dictionary
.insert_with_value(&key, NgramEntry::new(count))
}
pub fn insert_with_key_and_count(&self, key: &str, count: u64) -> bool {
self.dictionary
.insert_with_value(key, NgramEntry::new(count))
}
pub fn get(&self, tokens: &[&str]) -> Option<NgramEntry> {
let key = Self::encode_key_legacy(tokens);
self.dictionary.get_value(&key)
}
pub fn get_by_key(&self, key: &str) -> Option<NgramEntry> {
self.dictionary.get_value(key)
}
pub fn contains(&self, tokens: &[&str]) -> bool {
let key = Self::encode_key_legacy(tokens);
self.dictionary.contains(&key)
}
pub fn contains_key(&self, key: &str) -> bool {
self.dictionary.contains(key)
}
#[inline]
pub fn count(&self, tokens: &[&str]) -> u64 {
self.get(tokens).map(|e| e.count()).unwrap_or(0)
}
#[inline]
pub fn count_by_key(&self, key: &str) -> u64 {
self.get_by_key(key).map(|e| e.count()).unwrap_or(0)
}
pub fn update_continuation_count(&self, tokens: &[&str], continuation_count: u32) {
let key = Self::encode_key_legacy(tokens);
self.dictionary.update_or_insert(
&key,
NgramEntry::with_stats(0, continuation_count, 0),
|entry| entry.set_continuation_count(continuation_count),
);
}
pub fn update_continuation_count_by_key(&self, key: &str, continuation_count: u32) {
self.dictionary.update_or_insert(
key,
NgramEntry::with_stats(0, continuation_count, 0),
|entry| entry.set_continuation_count(continuation_count),
);
}
pub fn update_unique_continuations(&self, tokens: &[&str], unique_continuations: u32) {
let key = Self::encode_key_legacy(tokens);
self.dictionary.update_or_insert(
&key,
NgramEntry::with_stats(0, 0, unique_continuations),
|entry| entry.set_unique_continuations(unique_continuations),
);
}
pub fn update_unique_continuations_by_key(&self, key: &str, unique_continuations: u32) {
self.dictionary.update_or_insert(
key,
NgramEntry::with_stats(0, 0, unique_continuations),
|entry| entry.set_unique_continuations(unique_continuations),
);
}
pub fn len(&self) -> usize {
self.dictionary.len().unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.dictionary.len().map_or(true, |len| len == 0)
}
pub fn iter_entries(&self) -> impl Iterator<Item = (String, NgramEntry)> + '_
where
D: IterableDictionary,
{
self.dictionary.iter_all()
}
}
impl<D> Clone for NgramTrie<D>
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
fn clone(&self) -> Self {
Self {
dictionary: Arc::clone(&self.dictionary),
max_order: self.max_order,
_marker: PhantomData,
}
}
}
#[inline]
#[allow(dead_code)]
pub fn hash_ngram_key(tokens: &[&str]) -> u64 {
use crate::util::hash::safe_hash_with_seed;
const GOLDEN_RATIO: u64 = 0x9e3779b97f4a7c15;
const NGRAM_SEED: u64 = 0x6e6772616d5f7365;
let mut hash = NGRAM_SEED;
for (i, token) in tokens.iter().enumerate() {
let token_hash = safe_hash_with_seed(token.as_bytes(), i as u64);
hash = hash.wrapping_add(token_hash).wrapping_mul(GOLDEN_RATIO);
}
hash ^ (hash >> 32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_key_legacy() {
type Trie = NgramTrie<liblevenshtein::dictionary::pathmap::PathMapDictionary<NgramEntry>>;
assert_eq!(Trie::encode_key_legacy(&["the"]), "the");
assert_eq!(Trie::encode_key_legacy(&["the", "quick"]), "the|quick");
assert_eq!(
Trie::encode_key_legacy(&["the", "quick", "brown"]),
"the|quick|brown"
);
}
#[test]
fn test_legacy_encoding_pipe_bug() {
type Trie = NgramTrie<liblevenshtein::dictionary::pathmap::PathMapDictionary<NgramEntry>>;
let tokens = ["foo|bar", "baz"];
let encoded = Trie::encode_key_legacy(&tokens);
let decoded: Vec<_> = encoded.split(LEGACY_NGRAM_SEPARATOR).collect();
assert_eq!(decoded.len(), 3, "Bug: pipe in token causes wrong split");
assert_eq!(
decoded,
["foo", "bar", "baz"],
"Bug: original tokens corrupted"
);
}
#[test]
fn test_hash_ngram_key_order_matters() {
let hash1 = hash_ngram_key(&["a", "b"]);
let hash2 = hash_ngram_key(&["b", "a"]);
assert_ne!(
hash1, hash2,
"Different orderings should have different hashes"
);
}
#[test]
fn test_hash_ngram_key_deterministic() {
let hash1 = hash_ngram_key(&["the", "quick", "brown"]);
let hash2 = hash_ngram_key(&["the", "quick", "brown"]);
assert_eq!(hash1, hash2, "Same input should produce same hash");
}
#[test]
fn iter_all_shared_char_artrie_roundtrip() {
use libdictenstein::persistent_artrie_char::{PersistentARTrieChar, SharedCharARTrie};
use std::collections::HashMap;
use std::sync::Arc;
let dir = tempfile::tempdir().expect("tempdir");
let trie = PersistentARTrieChar::<NgramEntry>::create(dir.path().join("c.artrie"))
.expect("create counts trie");
let backend: SharedCharARTrie<NgramEntry> = Arc::new(trie);
backend.insert_with_value("ab", NgramEntry::new(3));
backend.insert_with_value("cd", NgramEntry::with_stats(5, 2, 1));
let got: HashMap<String, u64> = backend.iter_all().map(|(k, v)| (k, v.count())).collect();
assert_eq!(got.get("ab"), Some(&3));
assert_eq!(got.get("cd"), Some(&5));
assert_eq!(got.len(), 2);
}
#[test]
fn iter_all_vocab_indexed_reconstructs_words() {
use crate::ngram::vocabulary::create_vocabulary;
use crate::ngram::vocabulary_indexed::VocabularyIndexedDictionary;
use libdictenstein::persistent_artrie_char::{PersistentARTrieChar, SharedCharARTrie};
use std::collections::HashMap;
use std::sync::Arc;
let dir = tempfile::tempdir().expect("tempdir");
let vocab = create_vocabulary(&dir.path().join("v.artrie")).expect("vocab");
let counts: SharedCharARTrie<NgramEntry> = Arc::new(
PersistentARTrieChar::<NgramEntry>::create(dir.path().join("c.artrie"))
.expect("counts"),
);
let dict = VocabularyIndexedDictionary::with_delimiter(counts, vocab, '|');
dict.insert_with_value("the|quick|brown", NgramEntry::new(2));
dict.insert_with_value("the|lazy", NgramEntry::new(5));
let got: HashMap<String, u64> = dict.iter_all().map(|(k, v)| (k, v.count())).collect();
assert_eq!(
got.get("the|quick|brown"),
Some(&2),
"trigram reconstructed"
);
assert_eq!(got.get("the|lazy"), Some(&5), "bigram reconstructed");
assert_eq!(got.len(), 2);
}
#[test]
fn iter_all_vocab_indexed_skips_missing_index() {
use crate::ngram::vocabulary::{create_vocabulary, encode_varint};
use crate::ngram::vocabulary_indexed::{
decode_key_to_indices, VocabularyIndexedDictionary,
};
use libdictenstein::persistent_artrie_char::{PersistentARTrieChar, SharedCharARTrie};
use std::sync::Arc;
let dir = tempfile::tempdir().expect("tempdir");
let vocab = create_vocabulary(&dir.path().join("v.artrie")).expect("vocab");
let counts: SharedCharARTrie<NgramEntry> = Arc::new(
PersistentARTrieChar::<NgramEntry>::create(dir.path().join("c.artrie"))
.expect("counts"),
);
let dict = VocabularyIndexedDictionary::with_delimiter(counts.clone(), vocab, '|');
dict.insert_with_value("alpha|beta", NgramEntry::new(1));
let mut buf = Vec::new();
encode_varint(9999, &mut buf);
let bogus_key: String = buf.iter().map(|&b| b as char).collect(); assert_eq!(decode_key_to_indices(&bogus_key), vec![9999]);
counts.insert_with_value(&bogus_key, NgramEntry::new(7));
let got: Vec<String> = dict.iter_all().map(|(k, _)| k).collect();
assert_eq!(got, vec!["alpha|beta".to_string()]);
}
}