use crate::ngram::vocabulary::{
encode_ngram_key_bytes, encode_ngram_key_existing_bytes, SharedVocabARTrie,
};
use crate::sources::google_books::sharding::coordinator::ShardCoordinator;
#[allow(deprecated)]
use liblevenshtein::dictionary::{
Dictionary, DictionaryNode, MappedDictionary, MappedDictionaryNode, SyncStrategy,
};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct AggregationConfig {
pub delimiter: char,
pub max_open_shards: usize,
}
impl Default for AggregationConfig {
fn default() -> Self {
Self {
delimiter: ' ',
max_open_shards: 100,
}
}
}
pub struct AggregatedLanguageModelDictionary {
vocabulary: SharedVocabARTrie,
coordinator: Arc<ShardCoordinator>,
config: AggregationConfig,
}
impl std::fmt::Debug for AggregatedLanguageModelDictionary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AggregatedLanguageModelDictionary")
.field("vocabulary_size", &self.vocabulary.len())
.field("open_shards", &self.coordinator.open_shard_count())
.field("config", &self.config)
.finish()
}
}
impl AggregatedLanguageModelDictionary {
pub fn new(coordinator: Arc<ShardCoordinator>, vocabulary: SharedVocabARTrie) -> Self {
Self {
vocabulary,
coordinator,
config: AggregationConfig::default(),
}
}
pub fn with_config(
coordinator: Arc<ShardCoordinator>,
vocabulary: SharedVocabARTrie,
config: AggregationConfig,
) -> Self {
Self {
vocabulary,
coordinator,
config,
}
}
pub fn vocabulary(&self) -> &SharedVocabARTrie {
&self.vocabulary
}
pub fn coordinator(&self) -> &Arc<ShardCoordinator> {
&self.coordinator
}
pub fn config(&self) -> &AggregationConfig {
&self.config
}
pub fn delimiter(&self) -> char {
self.config.delimiter
}
fn split_term<'a>(&self, term: &'a str) -> Vec<&'a str> {
term.split(self.config.delimiter).collect()
}
pub fn contains_ngram(&self, words: &[&str]) -> bool {
if words.is_empty() {
return false;
}
let encoded_key = match encode_ngram_key_existing_bytes(words, &self.vocabulary) {
Some(k) => k,
None => return false,
};
let shard_key = self.coordinator.route_tokens(words);
self.coordinator
.get_in_shard(&shard_key, &encoded_key)
.is_some()
}
pub fn get_ngram(&self, words: &[&str]) -> Option<u64> {
if words.is_empty() {
return None;
}
let encoded_key = encode_ngram_key_existing_bytes(words, &self.vocabulary)?;
let shard_key = self.coordinator.route_tokens(words);
self.coordinator.get_in_shard(&shard_key, &encoded_key)
}
pub fn insert_ngram(
&self,
words: &[&str],
count: u64,
) -> Result<bool, crate::sources::google_books::sharding::coordinator::CoordinatorError> {
if words.is_empty() {
return Ok(false);
}
let encoded_key = encode_ngram_key_bytes(words, &self.vocabulary);
let shard_key = self.coordinator.route_tokens(words);
self.coordinator
.store_in_shard(&shard_key, &encoded_key, count)
}
pub fn total_ngram_count(&self) -> u64 {
self.coordinator.total_entry_count()
}
pub fn decode_key(&self, key: &[u8]) -> Vec<u64> {
crate::ngram::vocabulary::decode_ngram_key_bytes(key)
}
pub fn build_reverse_vocabulary(&self) -> std::collections::HashMap<u64, String> {
let guard = self.vocabulary.read();
let len = guard.len();
let mut map = std::collections::HashMap::with_capacity(len);
for i in 1..=(len as u64) {
if let Some(term) = guard.get_term(i) {
map.insert(i, term);
}
}
map
}
pub fn checkpoint(&self) -> Result<(), String> {
self.vocabulary
.write()
.checkpoint()
.map_err(|e| format!("Vocabulary checkpoint failed: {}", e))?;
self.coordinator
.checkpoint_all()
.map_err(|e| format!("Shard checkpoint failed: {:?}", e))?;
Ok(())
}
pub fn iter_ngrams(&self) -> Result<impl Iterator<Item = (Vec<String>, u64)> + '_, String> {
let reverse_vocab = self.build_reverse_vocabulary();
let mut all_entries: Vec<(Vec<u8>, u64)> = Vec::new();
for shard_key in self.coordinator.open_shard_keys() {
if let Ok(shard) = self.coordinator.get_or_create_shard(&shard_key) {
let guard = shard.read();
if let Ok(entries) = guard.iter_with_counts() {
all_entries.extend(entries);
}
}
}
Ok(all_entries
.into_iter()
.filter(|(key, _)| !key.starts_with(&[0x00]))
.filter_map(move |(key, count)| {
let indices = crate::ngram::vocabulary::decode_ngram_key_bytes(&key);
let words: Option<Vec<String>> = indices
.into_iter()
.map(|idx| reverse_vocab.get(&idx).cloned())
.collect();
words.map(|w| (w, count))
}))
}
pub fn iter_ngrams_raw(&self) -> Result<impl Iterator<Item = (Vec<u8>, u64)> + '_, String> {
let mut all_entries: Vec<(Vec<u8>, u64)> = Vec::new();
for shard_key in self.coordinator.open_shard_keys() {
if let Ok(shard) = self.coordinator.get_or_create_shard(&shard_key) {
let guard = shard.read();
if let Ok(entries) = guard.iter_with_counts() {
all_entries.extend(entries);
}
}
}
Ok(all_entries
.into_iter()
.filter(|(key, _)| !key.starts_with(&[0x00])))
}
pub fn ngram_count(&self) -> u64 {
self.coordinator.total_entry_count()
}
}
impl Dictionary for AggregatedLanguageModelDictionary {
type Node = AggregatedDictionaryNode;
fn root(&self) -> Self::Node {
AggregatedDictionaryNode {
_phantom: std::marker::PhantomData,
}
}
fn contains(&self, term: &str) -> bool {
let words = self.split_term(term);
let refs: Vec<&str> = words.iter().map(|s| *s).collect();
self.contains_ngram(&refs)
}
fn len(&self) -> Option<usize> {
Some(self.coordinator.total_entry_count() as usize)
}
fn is_empty(&self) -> bool {
self.coordinator.total_entry_count() == 0
}
fn sync_strategy(&self) -> SyncStrategy {
SyncStrategy::InternalSync
}
}
impl MappedDictionary for AggregatedLanguageModelDictionary {
type Value = u64;
fn get_value(&self, term: &str) -> Option<Self::Value> {
let words = self.split_term(term);
let refs: Vec<&str> = words.iter().map(|s| *s).collect();
self.get_ngram(&refs)
}
fn contains_with_value<F>(&self, term: &str, predicate: F) -> bool
where
F: Fn(&Self::Value) -> bool,
{
self.get_value(term).is_some_and(|v| predicate(&v))
}
}
#[derive(Clone)]
pub struct AggregatedDictionaryNode {
_phantom: std::marker::PhantomData<()>,
}
impl std::fmt::Debug for AggregatedDictionaryNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AggregatedDictionaryNode").finish()
}
}
impl DictionaryNode for AggregatedDictionaryNode {
type Unit = char;
fn is_final(&self) -> bool {
false
}
fn transition(&self, _label: Self::Unit) -> Option<Self> {
None
}
fn edges(&self) -> Box<dyn Iterator<Item = (Self::Unit, Self)> + '_> {
Box::new(std::iter::empty())
}
}
impl MappedDictionaryNode for AggregatedDictionaryNode {
type Value = u64;
fn value(&self) -> Option<Self::Value> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ngram::vocabulary::create_vocabulary;
use crate::sources::google_books::sharding::{ShardConfig, ShardGranularity};
use liblevenshtein::dictionary::{Dictionary, MappedDictionary};
use tempfile::TempDir;
#[test]
fn test_aggregation_config_default() {
let config = AggregationConfig::default();
assert_eq!(config.delimiter, ' ');
assert_eq!(config.max_open_shards, 100);
}
#[test]
fn test_shared_vocabulary_methods() {
fn _check_api(v: &SharedVocabARTrie) {
let _: libdictenstein::persistent_artrie::error::Result<u64> = v.write().insert("word");
let _: Option<u64> = v.read().get_index("word");
let _: bool = v.read().contains("word");
let _: usize = v.read().len();
let _: bool = v.read().is_empty();
}
}
#[test]
fn test_aggregated_node_is_empty_traversal_adapter() {
let node = AggregatedDictionaryNode {
_phantom: std::marker::PhantomData,
};
assert!(!node.is_final());
assert!(node.transition('a').is_none());
assert_eq!(node.edges().count(), 0);
assert!(node.value().is_none());
}
#[test]
fn vocabulary_query_aggregated_dictionary_routes_and_reads_exact_values() {
let dir = TempDir::new().expect("Failed to create temp dir");
let vocab_path = dir.path().join("vocab.artrie");
let shard_dir = dir.path().join("shards");
let vocabulary = create_vocabulary(&vocab_path).expect("Failed to create vocabulary");
let config = ShardConfig::new(shard_dir)
.with_granularity(ShardGranularity::TwoChar)
.with_max_open_shards(8);
let coordinator =
Arc::new(ShardCoordinator::new(config).expect("Failed to create coordinator"));
let dict = AggregatedLanguageModelDictionary::new(coordinator.clone(), vocabulary.clone());
assert_eq!(dict.get_ngram(&["missing"]), None);
assert_eq!(
vocabulary.read().len(),
0,
"read-only aggregated queries must not allocate vocabulary indices"
);
assert!(dict
.insert_ngram(&["the", "quick"], 7)
.expect("insert the quick"));
assert!(dict.insert_ngram(&["apple"], 3).expect("insert apple"));
assert!(dict.contains_ngram(&["the", "quick"]));
assert_eq!(dict.get_ngram(&["the", "quick"]), Some(7));
assert_eq!(dict.get_value("the quick"), Some(7));
assert!(dict.contains("apple"));
assert_eq!(dict.get_value("apple"), Some(3));
assert_eq!(dict.get_ngram(&["the", "slow"]), None);
let the_key = coordinator.route_tokens(&["the", "quick"]);
let apple_key = coordinator.route_tokens(&["apple"]);
assert_eq!(the_key.prefix, "th");
assert_eq!(apple_key.prefix, "ap");
}
}