use std::path::PathBuf;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
use std::sync::Arc;
use parking_lot::RwLock;
use crate::persistent_artrie::block_storage::BlockStorage;
use crate::persistent_artrie::buffer_manager::BufferManager;
use crate::persistent_artrie::dict_impl::DurabilityPolicy;
use crate::persistent_artrie::disk_manager::MmapDiskManager;
use crate::persistent_artrie::wal::AsyncWalWriter;
use crate::persistent_artrie::wal::WalConfig;
use crate::persistent_artrie::wal_managed::WalManaged;
use crate::persistent_artrie_char::arena_manager::ArenaManager;
use crate::persistent_artrie_char::nodes::AtomicNodePtr;
use dashmap::DashMap;
const DEFAULT_VOCAB_BUFFER_POOL_SIZE: usize = 64;
pub use super::sync_handle::VocabSyncHandle;
pub struct PersistentVocabARTrie<S: BlockStorage = MmapDiskManager> {
pub(super) path: PathBuf,
pub(super) entry_count: AtomicUsize,
pub(super) start_index: u64,
pub(super) next_index: AtomicU64,
pub(super) dirty: AtomicBool,
pub(super) wal_writer: Option<Arc<AsyncWalWriter>>,
pub(super) wal_config: WalConfig,
pub(super) next_lsn: AtomicU64,
pub(super) synced_lsn: AtomicU64,
pub(super) durability_policy: DurabilityPolicy,
pub(super) arena_manager: Option<Arc<RwLock<ArenaManager<S>>>>,
pub(super) buffer_manager: Option<Arc<RwLock<BufferManager<S>>>>,
pub(crate) eviction_coordinator:
Option<Arc<crate::persistent_artrie::eviction::EvictionCoordinator>>,
pub(super) lockfree_root: Option<AtomicNodePtr<u64>>,
pub(super) lockfree_cache: Option<DashMap<String, u64>>,
pub(super) cas_retries: AtomicU64,
pub(crate) commit_seq: AtomicU64,
pub(crate) committed_watermark:
crate::persistent_artrie_core::committed_watermark::CommittedWatermark,
pub(crate) epoch_manager: Arc<crate::persistent_artrie_core::concurrency::EpochManager>,
pub(super) reverse_term_map: Option<DashMap<u64, String>>,
}
impl<S: BlockStorage> WalManaged for PersistentVocabARTrie<S> {
fn wal_writer(&self) -> Option<&Arc<AsyncWalWriter>> {
self.wal_writer.as_ref()
}
}
unsafe impl<S: BlockStorage> Send for PersistentVocabARTrie<S> {}
unsafe impl<S: BlockStorage> Sync for PersistentVocabARTrie<S> {}
pub type SharedVocabARTrie<S = MmapDiskManager> = Arc<RwLock<PersistentVocabARTrie<S>>>;
impl<S: BlockStorage> Drop for PersistentVocabARTrie<S> {
fn drop(&mut self) {
let _ = self.checkpoint();
}
}
impl<S: BlockStorage> Clone for PersistentVocabARTrie<S> {
fn clone(&self) -> Self {
Self {
path: self.path.clone(),
entry_count: AtomicUsize::new(self.entry_count.load(Ordering::Acquire)),
start_index: self.start_index,
next_index: AtomicU64::new(self.next_index.load(Ordering::Acquire)),
dirty: AtomicBool::new(self.dirty.load(Ordering::Acquire)),
wal_writer: self.wal_writer.clone(),
wal_config: self.wal_config.clone(),
next_lsn: AtomicU64::new(self.next_lsn.load(Ordering::Acquire)),
synced_lsn: AtomicU64::new(self.synced_lsn.load(Ordering::Acquire)),
durability_policy: self.durability_policy,
arena_manager: None, buffer_manager: None, eviction_coordinator: None, lockfree_root: None, lockfree_cache: None, cas_retries: AtomicU64::new(0),
commit_seq: AtomicU64::new(self.commit_seq.load(Ordering::Acquire)),
committed_watermark:
crate::persistent_artrie_core::committed_watermark::CommittedWatermark::new(0),
epoch_manager: Arc::new(
crate::persistent_artrie_core::concurrency::EpochManager::new(),
),
reverse_term_map: None, }
}
}
impl<S: BlockStorage> std::fmt::Debug for PersistentVocabARTrie<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PersistentVocabARTrie")
.field("path", &self.path)
.field("len", &self.entry_count)
.field("start_index", &self.start_index)
.field("next_index", &self.next_index)
.field("is_dirty", &self.dirty)
.field("next_lsn", &self.next_lsn)
.field("synced_lsn", &self.synced_lsn)
.field("durability_policy", &self.durability_policy)
.field("has_arena_manager", &self.arena_manager.is_some())
.field("has_buffer_manager", &self.buffer_manager.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tempfile::tempdir;
#[test]
fn test_create_and_insert() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
let idx1 = vocab.insert("hello").expect("insert hello");
let idx2 = vocab.insert("world").expect("insert world");
let idx3 = vocab.insert("hello").expect("insert duplicate hello");
assert_eq!(idx1, 0);
assert_eq!(idx2, 1);
assert_eq!(idx3, 0);
assert_eq!(vocab.len(), 2);
}
#[test]
fn test_forward_lookup() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("apple");
vocab.insert("banana");
vocab.insert("cherry");
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("banana"), Some(1));
assert_eq!(vocab.get_index("cherry"), Some(2));
assert_eq!(vocab.get_index("durian"), None);
}
#[test]
fn test_reverse_lookup() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("apple");
vocab.insert("banana");
vocab.insert("cherry");
assert_eq!(vocab.get_term(0), Some("apple".to_string()));
assert_eq!(vocab.get_term(1), Some("banana".to_string()));
assert_eq!(vocab.get_term(2), Some("cherry".to_string()));
assert_eq!(vocab.get_term(999), None);
}
#[test]
fn test_unicode_terms() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
let idx1 = vocab.insert("日本語").expect("insert Japanese term");
let idx2 = vocab.insert("中文").expect("insert Chinese term");
let idx3 = vocab.insert("한글").expect("insert Korean term");
assert_eq!(vocab.get_index("日本語"), Some(idx1));
assert_eq!(vocab.get_index("中文"), Some(idx2));
assert_eq!(vocab.get_index("한글"), Some(idx3));
assert_eq!(vocab.get_term(idx1), Some("日本語".to_string()));
assert_eq!(vocab.get_term(idx2), Some("中文".to_string()));
assert_eq!(vocab.get_term(idx3), Some("한글".to_string()));
}
#[test]
fn test_custom_start_index() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create_with_start_index(&path, 100).unwrap();
let idx1 = vocab.insert("first").expect("insert first");
let idx2 = vocab.insert("second").expect("insert second");
assert_eq!(idx1, 100);
assert_eq!(idx2, 101);
assert_eq!(vocab.start_index(), 100);
}
#[test]
fn test_checkpoint_and_reopen() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("hello");
vocab.insert("world");
vocab.insert("test");
vocab.checkpoint().unwrap();
}
{
let (vocab, report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert!(report.mode.is_normal());
assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_index("hello"), Some(0));
assert_eq!(vocab.get_index("world"), Some(1));
assert_eq!(vocab.get_index("test"), Some(2));
}
}
#[test]
fn test_checkpoint_reopen_modify_checkpoint() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("apple");
vocab.insert("banana");
vocab.checkpoint().unwrap();
}
{
let (mut vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 2);
vocab.insert("cherry");
vocab.insert("durian");
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 4);
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("banana"), Some(1));
assert_eq!(vocab.get_index("cherry"), Some(2));
assert_eq!(vocab.get_index("durian"), Some(3));
}
}
#[test]
fn test_contains() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("present");
assert!(vocab.contains("present"));
assert!(!vocab.contains("absent"));
assert!(vocab.contains_index(0));
assert!(!vocab.contains_index(1));
}
#[test]
fn test_lsn_tracking() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
let initial_lsn = vocab.current_lsn();
assert!(initial_lsn > 0);
assert!(vocab.synced_lsn().is_none());
vocab.insert("test");
assert!(vocab.current_lsn() > initial_lsn);
vocab.sync().unwrap();
assert!(vocab.synced_lsn().is_some());
}
#[test]
fn test_durability_policy() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
assert_eq!(vocab.durability_policy(), DurabilityPolicy::Immediate);
vocab.set_durability_policy(DurabilityPolicy::Periodic);
assert_eq!(vocab.durability_policy(), DurabilityPolicy::Periodic);
}
#[test]
fn test_wal_recovery() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("term1");
vocab.insert("term2");
vocab.insert("term3");
std::mem::forget(vocab); }
let (vocab, report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert!(report.records_replayed > 0);
assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_index("term1"), Some(0));
assert_eq!(vocab.get_index("term2"), Some(1));
assert_eq!(vocab.get_index("term3"), Some(2));
}
#[test]
fn test_partial_wal_recovery() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("apple");
vocab.insert("banana");
vocab.checkpoint().unwrap();
}
{
let (mut vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
vocab.insert("cherry");
vocab.insert("durian");
std::mem::forget(vocab);
}
let (vocab, report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert!(report.records_replayed >= 2);
assert_eq!(vocab.len(), 4);
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("banana"), Some(1));
assert_eq!(vocab.get_index("cherry"), Some(2));
assert_eq!(vocab.get_index("durian"), Some(3));
}
#[test]
fn test_regression_node_growth_during_load() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let terms = ["a", "b", "c", "d", "e", "f", "g", "h"];
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
for (i, term) in terms.iter().enumerate() {
let idx = vocab.insert(term).expect("insert term");
assert_eq!(idx, i as u64, "Term '{}' should have index {}", term, i);
}
vocab.checkpoint().unwrap();
}
{
let (vocab, report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert!(
report.mode.is_normal(),
"Should load from disk without WAL replay"
);
assert_eq!(vocab.len(), terms.len());
for (i, term) in terms.iter().enumerate() {
assert_eq!(
vocab.get_index(term),
Some(i as u64),
"Forward lookup failed for term '{}'",
term
);
}
for (i, term) in terms.iter().enumerate() {
assert_eq!(
vocab.get_term(i as u64),
Some(term.to_string()),
"Reverse lookup failed for index {} (expected '{}')",
i,
term
);
}
}
}
#[test]
fn test_regression_node_growth_during_serialization() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let prefixes = ["aa", "ab", "ac", "ad", "ae", "af", "ag", "ah", "ai", "aj"];
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
for (i, term) in prefixes.iter().enumerate() {
let idx = vocab.insert(term).expect("insert term");
assert_eq!(idx, i as u64);
}
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), prefixes.len());
for (i, term) in prefixes.iter().enumerate() {
assert_eq!(
vocab.get_index(term),
Some(i as u64),
"Term '{}' not found after serialization",
term
);
}
}
}
#[test]
fn test_regression_large_trie_checkpoint_reopen() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let terms: Vec<String> = (0..50).map(|i| format!("term_{:03}", i)).collect();
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
for (i, term) in terms.iter().enumerate() {
let idx = vocab.insert(term).expect("insert term");
assert_eq!(idx, i as u64);
}
vocab.checkpoint().unwrap();
}
{
let (vocab, report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert!(report.mode.is_normal());
assert_eq!(vocab.len(), terms.len());
for (i, term) in terms.iter().enumerate() {
assert_eq!(
vocab.get_index(term),
Some(i as u64),
"Forward lookup failed for '{}'",
term
);
assert_eq!(
vocab.get_term(i as u64),
Some(term.clone()),
"Reverse lookup failed for index {}",
i
);
}
}
{
let (mut vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
let more_terms: Vec<String> = (50..75).map(|i| format!("term_{:03}", i)).collect();
for (i, term) in more_terms.iter().enumerate() {
let idx = vocab.insert(term).expect("insert term");
assert_eq!(idx, (50 + i) as u64);
}
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 75);
for i in 0..75 {
let expected_term = format!("term_{:03}", i);
assert_eq!(vocab.get_index(&expected_term), Some(i as u64));
assert_eq!(vocab.get_term(i as u64), Some(expected_term));
}
}
}
#[test]
fn test_regression_value_preservation() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("with_value_1");
vocab.insert("with_value_2");
vocab.insert("with_value_3");
assert_eq!(vocab.get_index("with_value_1"), Some(0));
assert_eq!(vocab.get_index("with_value_2"), Some(1));
assert_eq!(vocab.get_index("with_value_3"), Some(2));
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(
vocab.get_index("with_value_1"),
Some(0),
"Value for 'with_value_1' was lost"
);
assert_eq!(
vocab.get_index("with_value_2"),
Some(1),
"Value for 'with_value_2' was lost"
);
assert_eq!(
vocab.get_index("with_value_3"),
Some(2),
"Value for 'with_value_3' was lost"
);
assert_eq!(vocab.get_term(0), Some("with_value_1".to_string()));
assert_eq!(vocab.get_term(1), Some("with_value_2".to_string()));
assert_eq!(vocab.get_term(2), Some("with_value_3".to_string()));
}
}
#[test]
fn test_sync_to_disk_async_non_blocking() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let vocab = Arc::new(RwLock::new(
PersistentVocabARTrie::create(&path).expect("Failed to create vocab"),
));
vocab.write().insert("hello");
let handle = vocab
.read()
.sync_to_disk_async()
.expect("Failed to start async sync");
assert!(vocab.read().contains("hello"));
vocab.write().insert("world");
handle.wait().expect("Sync failed");
assert!(vocab.read().contains("hello"));
assert!(vocab.read().contains("world"));
}
#[test]
fn test_sync_to_disk_async_multiple_calls() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.insert("hello");
let handle1 = vocab
.sync_to_disk_async()
.expect("Failed to start first async sync");
vocab.insert("world");
let handle2 = vocab
.sync_to_disk_async()
.expect("Failed to start second async sync");
handle1.wait().expect("First sync failed");
handle2.wait().expect("Second sync failed");
assert!(handle1.is_synced());
assert!(handle2.is_synced());
}
#[test]
fn test_sync_to_disk_no_fragmentation() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
for i in 0..100 {
vocab.insert(&format!("word{}", i));
}
vocab.sync_to_disk().expect("First sync failed");
let size_after_first = std::fs::metadata(&path)
.expect("Failed to get metadata")
.len();
vocab.sync_to_disk().expect("Second sync failed"); let size_after_second = std::fs::metadata(&path)
.expect("Failed to get metadata")
.len();
assert_eq!(
size_after_first, size_after_second,
"File grew without new data (fragmentation detected)"
);
}
}
#[test]
fn test_sync_to_disk_then_checkpoint() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.insert("hello");
vocab.sync_to_disk().expect("First sync failed");
vocab.insert("world");
vocab.sync_to_disk().expect("Second sync failed");
assert!(vocab.contains("hello"), "Missing 'hello' after sync");
assert!(vocab.contains("world"), "Missing 'world' after sync");
assert_eq!(vocab.len(), 2);
vocab.checkpoint().expect("Checkpoint failed");
drop(vocab);
let (vocab, report) =
PersistentVocabARTrie::open_with_recovery(&path).expect("Failed to open vocab");
assert!(
report.mode.is_normal(),
"Should not need WAL replay after checkpoint"
);
assert!(vocab.contains("hello"), "Missing 'hello' after reopen");
assert!(vocab.contains("world"), "Missing 'world' after reopen");
}
#[test]
fn test_sync_to_disk_crash_recovery_via_wal() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.insert("hello");
vocab.insert("world");
vocab.sync().expect("WAL sync failed");
std::mem::forget(vocab);
}
{
let (vocab, report) =
PersistentVocabARTrie::open_with_recovery(&path).expect("Failed to open vocab");
assert!(report.records_replayed > 0, "Expected WAL replay");
assert!(
vocab.contains("hello"),
"Missing 'hello' after WAL recovery"
);
assert!(
vocab.contains("world"),
"Missing 'world' after WAL recovery"
);
}
}
#[test]
fn test_sync_to_disk_concurrent_reads_writes() {
use std::thread;
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let vocab = Arc::new(RwLock::new(
PersistentVocabARTrie::create(&path).expect("Failed to create vocab"),
));
for i in 0..50 {
vocab.write().insert(&format!("initial_{}", i));
}
let handle = vocab
.read()
.sync_to_disk_async()
.expect("Failed to start async sync");
let vocab_clone = Arc::clone(&vocab);
let reader_handle = thread::spawn(move || {
for i in 0..50 {
let _found = vocab_clone.read().contains(&format!("initial_{}", i));
}
});
let vocab_clone2 = Arc::clone(&vocab);
let writer_handle = thread::spawn(move || {
for i in 50..100 {
vocab_clone2.write().insert(&format!("concurrent_{}", i));
}
});
reader_handle.join().expect("Reader thread panicked");
writer_handle.join().expect("Writer thread panicked");
handle.wait().expect("Sync failed");
let vocab_guard = vocab.read();
for i in 0..50 {
assert!(
vocab_guard.contains(&format!("initial_{}", i)),
"Missing initial_{}",
i
);
}
for i in 50..100 {
assert!(
vocab_guard.contains(&format!("concurrent_{}", i)),
"Missing concurrent_{}",
i
);
}
}
#[test]
fn test_sync_to_disk_wait_timeout() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.insert("test");
let handle = vocab
.sync_to_disk_async()
.expect("Failed to start async sync");
let completed = handle
.wait_timeout(Duration::from_secs(10))
.expect("Sync failed");
assert!(completed, "Sync should complete within timeout");
assert!(
handle.is_synced(),
"Handle should report synced after wait_timeout"
);
}
#[test]
fn test_empty_string_insert() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
let idx = vocab.insert("").expect("insert empty term");
assert_eq!(idx, 0);
assert!(vocab.contains(""));
assert_eq!(vocab.get_index(""), Some(0));
assert_eq!(vocab.get_term(0), Some("".to_string()));
}
#[test]
fn test_empty_string_survives_checkpoint_reopen() {
std::fs::create_dir_all("target/test-tmp").ok();
let dir = tempfile::Builder::new()
.prefix("vocab-es-reopen")
.tempdir_in("target/test-tmp")
.expect("scratch under target/test-tmp");
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("").expect("insert empty term");
vocab.insert("hello").expect("insert hello");
vocab.checkpoint().unwrap();
}
let (vocab, _report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(
vocab.get_index(""),
Some(0),
"\"\" -> index 0 lost after reopen"
);
assert_eq!(
vocab.get_term(0),
Some("".to_string()),
"index 0 -> \"\" lost after reopen (reverse-index root branch)"
);
assert!(vocab.contains(""), "\"\" membership lost after reopen");
assert_eq!(vocab.get_index("hello"), Some(1));
}
#[test]
fn test_long_string_insert() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
let long_term: String = "a".repeat(1000);
let idx = vocab.insert(&long_term).expect("insert long term");
assert_eq!(idx, 0);
assert!(vocab.contains(&long_term));
assert_eq!(vocab.get_index(&long_term), Some(0));
assert_eq!(vocab.get_term(0), Some(long_term.clone()));
}
#[test]
fn test_special_characters() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
let special_chars = vec![
"\0", "\t\n\r", "a\0b", "🎉🎊🎁", "αβγδε", "מְזָלֵל", "\u{FEFF}BOM", ];
for (i, term) in special_chars.iter().enumerate() {
let idx = vocab.insert(term).expect("insert special term");
assert_eq!(idx, i as u64, "Failed for term: {:?}", term);
assert!(vocab.contains(term), "Not found: {:?}", term);
assert_eq!(
vocab.get_index(term),
Some(i as u64),
"Index mismatch: {:?}",
term
);
assert_eq!(
vocab.get_term(i as u64),
Some(term.to_string()),
"Reverse lookup failed: {:?}",
term
);
}
}
#[test]
fn test_open_nonexistent_file() {
let dir = tempdir().unwrap();
let path = dir.path().join("nonexistent.vocab");
let result = PersistentVocabARTrie::open(&path);
assert!(result.is_err());
}
#[test]
fn test_create_nested_path() {
let dir = tempdir().unwrap();
let path = dir.path().join("deeply/nested/path/test.vocab");
let vocab = PersistentVocabARTrie::create(&path);
assert!(vocab.is_ok(), "Should create nested directories");
}
#[test]
fn test_serialization_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("simple");
vocab.insert("日本語");
vocab.insert("");
vocab.insert("with spaces and punctuation!");
vocab.insert(&"x".repeat(100));
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 5);
assert_eq!(vocab.get_index("simple"), Some(0));
assert_eq!(vocab.get_index("日本語"), Some(1));
assert_eq!(vocab.get_index(""), Some(2));
assert_eq!(vocab.get_index("with spaces and punctuation!"), Some(3));
assert_eq!(vocab.get_index(&"x".repeat(100)), Some(4));
assert_eq!(vocab.get_term(0), Some("simple".to_string()));
assert_eq!(vocab.get_term(1), Some("日本語".to_string()));
assert_eq!(vocab.get_term(2), Some("".to_string()));
assert_eq!(
vocab.get_term(3),
Some("with spaces and punctuation!".to_string())
);
assert_eq!(vocab.get_term(4), Some("x".repeat(100)));
}
}
#[test]
fn test_large_vocabulary_serialization() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
for i in 0..1000 {
vocab.insert(&format!("term_{:05}", i));
}
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 1000);
for i in [0, 100, 500, 999] {
let term = format!("term_{:05}", i);
assert_eq!(vocab.get_index(&term), Some(i as u64));
assert_eq!(vocab.get_term(i as u64), Some(term));
}
}
}
#[test]
fn test_get_value_trait() {
use crate::MappedDictionary;
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("test");
assert_eq!(MappedDictionary::get_value(&vocab, "test"), Some(0));
assert_eq!(MappedDictionary::get_value(&vocab, "missing"), None);
}
#[test]
fn test_checkpoint_idempotent() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("test");
vocab.checkpoint().unwrap();
vocab.checkpoint().unwrap();
vocab.checkpoint().unwrap();
assert_eq!(vocab.len(), 1);
assert!(vocab.contains("test"));
}
#[test]
fn test_sync_idempotent() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
vocab.insert("test");
vocab.sync().unwrap();
vocab.sync().unwrap();
vocab.sync().unwrap();
assert_eq!(vocab.len(), 1);
assert!(vocab.contains("test"));
}
#[test]
fn test_next_index_tracking() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).unwrap();
assert_eq!(vocab.next_index(), 0);
vocab.insert("first");
assert_eq!(vocab.next_index(), 1);
vocab.insert("second");
assert_eq!(vocab.next_index(), 2);
vocab.insert("first");
assert_eq!(vocab.next_index(), 2);
}
#[test]
fn test_custom_start_index_serialization() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.vocab");
{
let mut vocab = PersistentVocabARTrie::create_with_start_index(&path, 1000).unwrap();
vocab.insert("test");
assert_eq!(vocab.get_index("test"), Some(1000));
vocab.checkpoint().unwrap();
}
{
let (vocab, _) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.start_index(), 1000);
assert_eq!(vocab.get_index("test"), Some(1000));
}
}
#[test]
fn test_insert_batch_basic() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
let indices = vocab
.insert_batch(&["apple", "banana", "cherry"])
.expect("insert batch");
assert_eq!(indices, vec![0, 1, 2]);
assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("banana"), Some(1));
assert_eq!(vocab.get_index("cherry"), Some(2));
assert_eq!(vocab.get_term(0), Some("apple".to_string()));
assert_eq!(vocab.get_term(1), Some("banana".to_string()));
assert_eq!(vocab.get_term(2), Some("cherry".to_string()));
}
#[test]
fn test_insert_batch_with_duplicates() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.insert("apple");
vocab.insert("banana");
let indices = vocab
.insert_batch(&["apple", "cherry", "banana", "date"])
.expect("insert batch with duplicates");
assert_eq!(indices, vec![0, 2, 1, 3]);
assert_eq!(vocab.len(), 4);
}
#[test]
fn test_insert_batch_empty() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
let indices = vocab.insert_batch(&[]).expect("insert empty batch");
assert!(indices.is_empty());
assert_eq!(vocab.len(), 0);
}
#[test]
fn test_insert_batch_wal_recovery() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
let indices = vocab
.insert_batch(&["apple", "banana", "cherry"])
.expect("insert batch");
assert_eq!(indices, vec![0, 1, 2]);
vocab.sync().expect("Sync failed");
}
{
let (vocab, _report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 3, "WAL recovery should restore all 3 terms");
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("banana"), Some(1));
assert_eq!(vocab.get_index("cherry"), Some(2));
assert_eq!(vocab.get_term(0), Some("apple".to_string()));
assert_eq!(vocab.get_term(1), Some("banana".to_string()));
assert_eq!(vocab.get_term(2), Some("cherry".to_string()));
}
}
#[test]
fn test_rotate_wal_recovery() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
{
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.enable_slot_tracking();
vocab.insert("apple");
vocab.insert("banana");
vocab.rotate_wal().expect("rotate_wal failed");
}
{
let (vocab, _report) = PersistentVocabARTrie::open_with_recovery(&path).unwrap();
assert_eq!(vocab.len(), 2, "WAL recovery should restore 2 terms");
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("banana"), Some(1));
}
}
#[test]
fn test_rotate_wal_multiple_batches() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("vocab.vocab");
let mut vocab = PersistentVocabARTrie::create(&path).expect("Failed to create vocab");
vocab.enable_slot_tracking();
vocab.insert_batch(&["apple", "banana"]);
vocab.rotate_wal().expect("First rotate_wal failed");
vocab.insert_batch(&["cherry", "date"]);
vocab.rotate_wal().expect("Second rotate_wal failed");
vocab.insert_batch(&["elderberry"]);
vocab.rotate_wal().expect("Third rotate_wal failed");
assert_eq!(vocab.len(), 5);
assert_eq!(vocab.get_index("apple"), Some(0));
assert_eq!(vocab.get_index("elderberry"), Some(4));
}
#[test]
fn concurrent_inserts_via_shared_arc_are_lock_free() {
let dir = tempdir().expect("temp dir");
let path = dir.path().join("concurrent.vocab");
let vocab = Arc::new(PersistentVocabARTrie::create(&path).expect("create vocab"));
let num_threads = 4usize;
let per_thread = 100usize;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let v = Arc::clone(&vocab);
std::thread::spawn(move || {
for i in 0..per_thread {
v.insert(&format!("t{t}_{i}")).expect("concurrent insert");
}
})
})
.collect();
for h in handles {
h.join().expect("thread join");
}
assert_eq!(vocab.len(), num_threads * per_thread);
for t in 0..num_threads {
for i in 0..per_thread {
let term = format!("t{t}_{i}");
let id = vocab.get_index(&term).expect("forward lookup");
assert_eq!(vocab.get_term(id).as_deref(), Some(term.as_str()));
}
}
}
}