use super::metadata_filtering_zipper::{MetadataFilteringZipper, METADATA_PREFIX};
use super::vocabulary::{decode_varint, encode_varint, SharedVocabARTrie};
use liblevenshtein::dictionary::{
Dictionary, DictionaryNode, MappedDictionary, MappedDictionaryNode, MutableMappedDictionary,
SyncStrategy,
};
#[inline]
fn bytes_to_latin1(bytes: &[u8]) -> String {
bytes.iter().map(|&b| char::from(b)).collect()
}
#[inline]
fn latin1_to_bytes(s: &str) -> Vec<u8> {
s.chars().map(|c| c as u8).collect()
}
pub fn decode_key_to_indices(key: &str) -> Vec<u64> {
let bytes = latin1_to_bytes(key);
let mut indices = Vec::new();
let mut offset = 0;
while offset < bytes.len() {
if let Some((index, consumed)) = decode_varint(&bytes[offset..]) {
indices.push(index);
offset += consumed;
} else {
break;
}
}
indices
}
#[derive(Clone)]
pub struct VocabularyIndexedDictionary<D> {
backend: D,
vocabulary: SharedVocabARTrie,
delimiter: char,
}
impl<D> std::fmt::Debug for VocabularyIndexedDictionary<D>
where
D: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VocabularyIndexedDictionary")
.field("backend", &self.backend)
.field("delimiter", &self.delimiter)
.finish_non_exhaustive()
}
}
impl<D> VocabularyIndexedDictionary<D> {
pub fn new(backend: D, vocabulary: SharedVocabARTrie) -> Self {
Self {
backend,
vocabulary,
delimiter: ' ',
}
}
pub fn with_delimiter(backend: D, vocabulary: SharedVocabARTrie, delimiter: char) -> Self {
Self {
backend,
vocabulary,
delimiter,
}
}
pub fn backend(&self) -> &D {
&self.backend
}
pub fn vocabulary(&self) -> &SharedVocabARTrie {
&self.vocabulary
}
pub fn delimiter(&self) -> char {
self.delimiter
}
fn split_term<'a>(&self, term: &'a str) -> impl Iterator<Item = &'a str> {
term.split(self.delimiter)
}
fn encode_key_existing(&self, words: &[&str]) -> Option<String> {
let mut buf = Vec::with_capacity(words.len() * 2);
let guard = self.vocabulary.read();
for word in words {
let index = guard.get_index(word)?;
encode_varint(index, &mut buf);
}
Some(bytes_to_latin1(&buf))
}
fn encode_key_inserting(&self, words: &[&str]) -> String {
let mut buf = Vec::with_capacity(words.len() * 2);
let guard = self.vocabulary.write();
for word in words {
let index = guard
.insert(word)
.expect("vocabulary insert: persistent ARTrie I/O failed");
encode_varint(index, &mut buf);
}
bytes_to_latin1(&buf)
}
}
impl<D> VocabularyIndexedDictionary<D>
where
D: MappedDictionary,
{
pub fn contains_ngram(&self, words: &[&str]) -> bool {
self.encode_key_existing(words)
.map(|key| self.backend.contains(&key))
.unwrap_or(false)
}
pub fn get_ngram(&self, words: &[&str]) -> Option<D::Value> {
let key = self.encode_key_existing(words)?;
self.backend.get_value(&key)
}
}
impl<D> VocabularyIndexedDictionary<D>
where
D: MutableMappedDictionary,
{
pub fn insert_ngram(&self, words: &[&str], value: D::Value) -> bool {
let key = self.encode_key_inserting(words);
self.backend.insert_with_value(&key, value)
}
pub fn update_or_insert_ngram<F>(
&self,
words: &[&str],
default_value: D::Value,
update_fn: F,
) -> bool
where
F: FnOnce(&mut D::Value),
{
let key = self.encode_key_inserting(words);
self.backend
.update_or_insert(&key, default_value, update_fn)
}
}
impl<D> Dictionary for VocabularyIndexedDictionary<D>
where
D: Dictionary,
D::Node: DictionaryNode<Unit = char>,
{
type Node = VocabularyIndexedNode<D::Node>;
fn root(&self) -> Self::Node {
VocabularyIndexedNode {
inner: self.backend.root(),
at_root: true,
}
}
fn contains(&self, term: &str) -> bool {
let words: Vec<&str> = self.split_term(term).collect();
self.encode_key_existing(&words)
.map(|key| self.backend.contains(&key))
.unwrap_or(false)
}
fn len(&self) -> Option<usize> {
let backend_len = self.backend.len()?;
match self.backend.root().transition(METADATA_PREFIX) {
Some(meta_subtree) => Some(backend_len.saturating_sub(count_finals(meta_subtree))),
None => Some(backend_len),
}
}
fn is_empty(&self) -> bool {
if self.backend.is_empty() {
return true;
}
if self.backend.root().has_edge(METADATA_PREFIX) {
!has_visible_final(self.root())
} else {
false
}
}
fn sync_strategy(&self) -> SyncStrategy {
self.backend.sync_strategy()
}
}
impl<D> MappedDictionary for VocabularyIndexedDictionary<D>
where
D: MappedDictionary,
D::Node: MappedDictionaryNode<Unit = char>,
{
type Value = D::Value;
fn get_value(&self, term: &str) -> Option<Self::Value> {
let words: Vec<&str> = self.split_term(term).collect();
let key = self.encode_key_existing(&words)?;
self.backend.get_value(&key)
}
fn contains_with_value<F>(&self, term: &str, predicate: F) -> bool
where
F: Fn(&Self::Value) -> bool,
{
let words: Vec<&str> = self.split_term(term).collect();
self.encode_key_existing(&words)
.map(|key| self.backend.contains_with_value(&key, predicate))
.unwrap_or(false)
}
}
impl<D> MutableMappedDictionary for VocabularyIndexedDictionary<D>
where
D: MutableMappedDictionary,
D::Node: MappedDictionaryNode<Unit = char>,
{
fn insert_with_value(&self, term: &str, value: Self::Value) -> bool {
let words: Vec<&str> = self.split_term(term).collect();
let key = self.encode_key_inserting(&words);
self.backend.insert_with_value(&key, value)
}
fn union_with<F>(&self, other: &Self, merge_fn: F) -> usize
where
F: Fn(&Self::Value, &Self::Value) -> Self::Value,
Self::Value: Clone,
{
self.backend.union_with(&other.backend, merge_fn)
}
fn update_or_insert<F>(&self, term: &str, default_value: Self::Value, update_fn: F) -> bool
where
F: FnOnce(&mut Self::Value),
{
let words: Vec<&str> = self.split_term(term).collect();
let key = self.encode_key_inserting(&words);
self.backend
.update_or_insert(&key, default_value, update_fn)
}
}
#[derive(Clone)]
pub struct VocabularyIndexedNode<N> {
inner: N,
at_root: bool,
}
impl<N: std::fmt::Debug> std::fmt::Debug for VocabularyIndexedNode<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VocabularyIndexedNode")
.field("inner", &self.inner)
.field("at_root", &self.at_root)
.finish()
}
}
impl<N: DictionaryNode<Unit = char>> DictionaryNode for VocabularyIndexedNode<N> {
type Unit = char;
fn is_final(&self) -> bool {
self.inner.is_final()
}
fn transition(&self, label: Self::Unit) -> Option<Self> {
if self.at_root && label == METADATA_PREFIX {
return None;
}
self.inner.transition(label).map(|inner| Self {
inner,
at_root: false,
})
}
fn edges(&self) -> Box<dyn Iterator<Item = (Self::Unit, Self)> + '_> {
let at_root = self.at_root;
Box::new(
self.inner
.edges()
.filter(move |(label, _)| !at_root || *label != METADATA_PREFIX)
.map(|(label, inner)| {
(
label,
Self {
inner,
at_root: false,
},
)
}),
)
}
fn has_edge(&self, label: Self::Unit) -> bool {
self.inner.has_edge(label)
}
fn edge_count(&self) -> Option<usize> {
self.inner.edge_count()
}
}
impl<N: MappedDictionaryNode<Unit = char>> MappedDictionaryNode for VocabularyIndexedNode<N> {
type Value = N::Value;
fn value(&self) -> Option<Self::Value> {
self.inner.value()
}
}
fn has_visible_final<N>(root: VocabularyIndexedNode<N>) -> bool
where
N: DictionaryNode<Unit = char>,
{
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if node.is_final() {
return true;
}
for (_, child) in node.edges() {
stack.push(child);
}
}
false
}
fn count_finals<N>(node: N) -> usize
where
N: DictionaryNode<Unit = char>,
{
let mut count = 0;
let mut stack = vec![node];
while let Some(node) = stack.pop() {
if node.is_final() {
count += 1;
}
for (_, child) in node.edges() {
stack.push(child);
}
}
count
}
use liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar;
use liblevenshtein::dictionary::dynamic_dawg_char_zipper::DynamicDawgCharZipper;
use liblevenshtein::dictionary::value::DictionaryValue;
pub type VocabularyIndexedDictionaryZipper<V> = MetadataFilteringZipper<DynamicDawgCharZipper<V>>;
impl<V: DictionaryValue> VocabularyIndexedDictionary<DynamicDawgChar<V>> {
pub fn zipper(&self) -> VocabularyIndexedDictionaryZipper<V> {
let backend_zipper = DynamicDawgCharZipper::new_from_dict(self.backend());
MetadataFilteringZipper::new(backend_zipper)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ngram::vocabulary::create_vocabulary;
use liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar;
use tempfile::TempDir;
fn create_test_dict() -> (TempDir, VocabularyIndexedDictionary<DynamicDawgChar<u64>>) {
let dir = TempDir::new().expect("Failed to create temp dir");
let vocab_path = dir.path().join("vocab.artrie");
let vocab = create_vocabulary(&vocab_path).expect("Failed to create vocabulary");
let backend: DynamicDawgChar<u64> = DynamicDawgChar::new();
let dict = VocabularyIndexedDictionary::new(backend, vocab);
(dir, dict)
}
#[test]
fn test_insert_ngram() {
let (_dir, dict) = create_test_dict();
assert!(dict.insert_ngram(&["the", "quick"], 100));
assert!(!dict.insert_ngram(&["the", "quick"], 200));
assert_eq!(dict.get_ngram(&["the", "quick"]), Some(200));
}
#[test]
fn test_contains_ngram() {
let (_dir, dict) = create_test_dict();
assert!(!dict.contains_ngram(&["the", "quick"]));
dict.insert_ngram(&["the", "quick"], 100);
assert!(dict.contains_ngram(&["the", "quick"]));
assert!(!dict.contains_ngram(&["unknown", "word"]));
}
#[test]
fn test_oov_returns_none() {
let (_dir, dict) = create_test_dict();
assert!(dict.get_ngram(&["unknown", "word"]).is_none());
assert_eq!(dict.vocabulary().read().len(), 0);
}
#[test]
fn test_update_or_insert_ngram() {
let (_dir, dict) = create_test_dict();
let is_new = dict.update_or_insert_ngram(&["the", "quick"], 10, |v| *v += 5);
assert!(is_new);
assert_eq!(dict.get_ngram(&["the", "quick"]), Some(10));
let is_new = dict.update_or_insert_ngram(&["the", "quick"], 10, |v| *v += 5);
assert!(!is_new);
assert_eq!(dict.get_ngram(&["the", "quick"]), Some(15));
}
#[test]
fn test_dictionary_trait() {
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["the", "quick"], 100);
assert!(dict.len().is_some());
assert!(!dict.is_empty());
}
#[test]
fn test_mapped_dictionary_trait() {
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["hello", "world"], 42);
assert_eq!(dict.get_value("hello world"), Some(42));
assert!(dict.get_value("unknown word").is_none());
}
#[test]
fn test_mutable_mapped_dictionary_trait() {
let (_dir, dict) = create_test_dict();
assert!(dict.insert_with_value("hello world", 100));
assert_eq!(dict.get_value("hello world"), Some(100));
let is_new = dict.update_or_insert("hello world", 50, |v| *v *= 2);
assert!(!is_new);
assert_eq!(dict.get_value("hello world"), Some(200));
}
#[test]
fn test_custom_delimiter() {
let dir = TempDir::new().expect("Failed to create temp dir");
let vocab_path = dir.path().join("vocab.artrie");
let vocab = create_vocabulary(&vocab_path).expect("Failed to create vocabulary");
let backend: DynamicDawgChar<u64> = DynamicDawgChar::new();
let dict = VocabularyIndexedDictionary::with_delimiter(backend, vocab, '|');
dict.insert_with_value("the|quick|brown", 123);
assert_eq!(dict.get_value("the|quick|brown"), Some(123));
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
use std::thread;
let dir = TempDir::new().expect("Failed to create temp dir");
let vocab_path = dir.path().join("vocab.artrie");
let vocab = create_vocabulary(&vocab_path).expect("Failed to create vocabulary");
let backend: DynamicDawgChar<u64> = DynamicDawgChar::new();
let dict = Arc::new(VocabularyIndexedDictionary::new(backend, vocab));
let mut handles = vec![];
for _ in 0..10 {
let dict = Arc::clone(&dict);
handles.push(thread::spawn(move || {
dict.insert_ngram(&["shared", "ngram"], 42);
}));
}
for handle in handles {
handle.join().expect("thread should complete");
}
assert_eq!(dict.get_ngram(&["shared", "ngram"]), Some(42));
}
#[test]
fn test_node_traversal() {
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["abc"], 1);
let root = dict.root();
assert!(!root.is_final());
let edges: Vec<_> = root.edges().collect();
assert!(!edges.is_empty());
}
#[test]
fn vocabulary_query_root_traversal_filters_metadata() {
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["hello"], 1);
dict.backend().insert_with_value("\x00__meta__", 999);
let root = dict.root();
assert!(
root.transition(METADATA_PREFIX).is_none(),
"root traversal must not expose metadata keys"
);
let children: Vec<char> = root.edges().map(|(label, _)| label).collect();
assert!(
!children.contains(&METADATA_PREFIX),
"root edges must filter metadata keys"
);
assert!(
!children.is_empty(),
"data edges should remain visible after metadata filtering"
);
assert_eq!(
dict.len(),
Some(1),
"Dictionary::len should count visible query terms only"
);
}
#[test]
fn vocabulary_query_is_empty_filters_metadata() {
let (_dir, dict) = create_test_dict();
assert!(dict.is_empty(), "fresh dictionary should be empty");
dict.backend().insert_with_value("\x00__meta__", 999);
assert!(
dict.is_empty(),
"metadata-only dictionary must report visible-empty"
);
dict.insert_ngram(&["hello"], 1);
assert!(
!dict.is_empty(),
"dictionary with a visible term must not report empty"
);
assert_eq!(dict.len(), Some(1), "only the visible term is counted");
}
#[test]
fn vocabulary_query_value_traversal_never_emits_root_metadata() {
fn collect_values<N>(node: N, values: &mut Vec<u64>)
where
N: MappedDictionaryNode<Unit = char, Value = u64> + Clone,
{
if let Some(value) = node.value() {
values.push(value);
}
for (_, child) in node.edges() {
collect_values(child, values);
}
}
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["hello"], 1);
dict.insert_ngram(&["world"], 2);
dict.backend().insert_with_value("\x00__meta__", 999);
let mut values = Vec::new();
collect_values(dict.root(), &mut values);
values.sort_unstable();
assert_eq!(values, vec![1, 2]);
assert!(
!values.contains(&999),
"value-yielding traversal must not emit metadata values"
);
}
#[test]
fn vocabulary_query_oov_reads_do_not_mutate_vocabulary() {
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["known"], 7);
let len_before = dict.vocabulary().read().len();
assert!(dict.get_ngram(&["missing"]).is_none());
assert!(!dict.contains_ngram(&["known", "missing"]));
assert_eq!(
dict.vocabulary().read().len(),
len_before,
"read-only query paths must not allocate vocabulary indices"
);
}
#[test]
fn test_large_vocabulary_indices() {
let dir = TempDir::new().expect("Failed to create temp dir");
let vocab_path = dir.path().join("vocab.artrie");
let vocab = create_vocabulary(&vocab_path).expect("Failed to create vocabulary");
let backend: DynamicDawgChar<u64> = DynamicDawgChar::new();
let dict = VocabularyIndexedDictionary::new(backend, vocab.clone());
{
let guard = vocab.write();
for i in 0..200 {
guard.insert(&format!("word{}", i)).expect("insert word");
}
}
dict.insert_ngram(&["word0", "word127", "word199"], 999);
assert_eq!(dict.get_ngram(&["word0", "word127", "word199"]), Some(999));
}
#[test]
fn test_decode_key_to_indices() {
let indices = vec![1u64, 127, 128, 16383];
let mut buf = Vec::new();
for &idx in &indices {
encode_varint(idx, &mut buf);
}
let key = bytes_to_latin1(&buf);
let decoded = decode_key_to_indices(&key);
assert_eq!(decoded, indices);
}
#[test]
fn test_empty_ngram() {
let (_dir, dict) = create_test_dict();
assert!(dict.insert_ngram(&[], 42));
assert_eq!(dict.get_ngram(&[]), Some(42));
}
#[test]
fn test_single_word_ngram() {
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["unigram"], 1);
assert_eq!(dict.get_ngram(&["unigram"]), Some(1));
}
#[test]
fn test_zipper_excludes_metadata() {
use liblevenshtein::dictionary::zipper::DictZipper;
let (_dir, dict) = create_test_dict();
dict.insert_ngram(&["hello"], 1);
dict.insert_ngram(&["world"], 2);
dict.backend().insert_with_value("\x00__meta__", 999);
let zipper = dict.zipper();
let children: Vec<char> = zipper.children().map(|(c, _)| c).collect();
assert!(
!children.contains(&'\x00'),
"Metadata prefix should be filtered from children"
);
assert!(
zipper.descend('\x00').is_none(),
"Should not be able to descend to metadata at root"
);
assert!(
!children.is_empty(),
"Should have children for regular n-grams"
);
}
}