use libdictenstein::persistent_artrie_char::PersistentARTrieChar;
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AccumulatorError {
#[error("Storage I/O error: {0}")]
Storage(#[from] std::io::Error),
#[error("Dictionary error: {0}")]
Dictionary(String),
#[error("Checkpoint error: {0}")]
Checkpoint(String),
}
pub type AccumulatorResult<T> = std::result::Result<T, AccumulatorError>;
pub struct NgramAccumulator {
trie: PersistentARTrieChar<i64>,
path: PathBuf,
unique_count: std::sync::atomic::AtomicUsize,
}
impl NgramAccumulator {
pub fn create(path: &Path) -> AccumulatorResult<Self> {
let trie = PersistentARTrieChar::create(path).map_err(|e| {
AccumulatorError::Dictionary(format!("Failed to create persistent trie: {}", e))
})?;
Ok(Self {
trie,
path: path.to_path_buf(),
unique_count: std::sync::atomic::AtomicUsize::new(0),
})
}
pub fn open(path: &Path) -> AccumulatorResult<Self> {
let (trie, report) = PersistentARTrieChar::open_with_recovery(path).map_err(|e| {
AccumulatorError::Dictionary(format!("Failed to open/recover persistent trie: {}", e))
})?;
if report.mode.recovered() {
log::info!(
"Recovered accumulator from crash: mode={:?}, {} records replayed, {} terms recovered",
report.mode,
report.records_replayed,
report.terms_recovered
);
}
let count = trie.len();
Ok(Self {
trie,
path: path.to_path_buf(),
unique_count: std::sync::atomic::AtomicUsize::new(count),
})
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn increment(&mut self, ngram: &str) -> AccumulatorResult<i64> {
let result = self.trie.increment(ngram, 1).map_err(|e| {
AccumulatorError::Dictionary(format!("Failed to increment n-gram: {}", e))
})?;
if result == 1 {
self.unique_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Ok(result)
}
pub fn increment_by(&mut self, ngram: &str, delta: i64) -> AccumulatorResult<i64> {
self.trie
.increment(ngram, delta)
.map_err(|e| AccumulatorError::Dictionary(format!("Failed to increment n-gram: {}", e)))
}
pub fn get(&self, ngram: &str) -> Option<i64> {
self.trie.get(ngram)
}
pub fn contains(&self, ngram: &str) -> bool {
self.trie.contains(ngram)
}
pub fn sync(&mut self) -> AccumulatorResult<()> {
self.trie
.checkpoint()
.map_err(|e| AccumulatorError::Dictionary(format!("Failed to sync WAL: {}", e)))
}
pub fn checkpoint(&mut self) -> AccumulatorResult<()> {
self.trie
.checkpoint()
.map_err(|e| AccumulatorError::Dictionary(format!("Failed to checkpoint: {}", e)))
}
pub fn iter_with_counts(&self) -> impl Iterator<Item = (String, i64)> {
self.trie
.iter_prefix_with_values("")
.ok()
.flatten()
.unwrap_or_default()
.into_iter()
}
pub fn iter(&self) -> impl Iterator<Item = String> {
self.trie
.iter_prefix("")
.ok()
.flatten()
.unwrap_or_default()
.into_iter()
}
pub fn len(&self) -> usize {
self.unique_count.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn exact_len(&self) -> usize {
self.trie.len()
}
}
pub mod key_format {
pub const SEPARATOR: char = '|';
pub fn build_key(tokens: &[&str]) -> String {
tokens.join(&SEPARATOR.to_string())
}
pub fn build_key_owned(tokens: &[String]) -> String {
tokens.join(&SEPARATOR.to_string())
}
pub fn parse_key(key: &str) -> Vec<&str> {
key.split(SEPARATOR).collect()
}
pub fn order(key: &str) -> usize {
key.split(SEPARATOR).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_create_and_increment() {
let dir = TempDir::new().expect("Failed to create temp dir");
let path = dir.path().join("test.artrie");
let mut acc = NgramAccumulator::create(&path).expect("Failed to create accumulator");
let count = acc.increment("the|quick").expect("Failed to increment");
assert_eq!(count, 1);
let count = acc.increment("the|quick").expect("Failed to increment");
assert_eq!(count, 2);
let count = acc.increment("quick|brown").expect("Failed to increment");
assert_eq!(count, 1);
assert_eq!(acc.len(), 2);
}
#[test]
fn test_persistence_and_recovery() {
let dir = TempDir::new().expect("Failed to create temp dir");
let path = dir.path().join("test.artrie");
{
let mut acc = NgramAccumulator::create(&path).expect("Failed to create accumulator");
acc.increment("the|quick").expect("Failed to increment");
acc.increment("the|quick").expect("Failed to increment");
acc.increment("quick|brown").expect("Failed to increment");
acc.sync().expect("Failed to sync");
}
{
let acc = NgramAccumulator::open(&path).expect("Failed to open accumulator");
assert_eq!(acc.get("the|quick"), Some(2));
assert_eq!(acc.get("quick|brown"), Some(1));
assert_eq!(acc.get("nonexistent"), None);
}
}
#[test]
fn test_iteration() {
let dir = TempDir::new().expect("Failed to create temp dir");
let path = dir.path().join("test.artrie");
let mut acc = NgramAccumulator::create(&path).expect("Failed to create accumulator");
acc.increment("a|b").expect("Failed to increment");
acc.increment("a|b").expect("Failed to increment");
acc.increment("c|d").expect("Failed to increment");
let mut entries: Vec<_> = acc.iter_with_counts().collect();
entries.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(entries.len(), 2);
assert_eq!(entries[0], ("a|b".to_string(), 2));
assert_eq!(entries[1], ("c|d".to_string(), 1));
}
#[test]
fn test_key_format() {
use key_format::*;
assert_eq!(build_key(&["the", "quick", "brown"]), "the|quick|brown");
assert_eq!(parse_key("the|quick|brown"), vec!["the", "quick", "brown"]);
assert_eq!(order("the|quick|brown"), 3);
assert_eq!(order("unigram"), 1);
}
}