use crate::sync_compat::RwLock;
use log::warn;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::Arc;
use super::bucket::StringBucket;
#[allow(unused_imports)]
use super::error::Result;
use crate::value::DictionaryValue;
#[allow(unused_imports)]
use crate::{Dictionary, MappedDictionary, SyncStrategy};
#[allow(unused_imports)]
use super::nodes::ArtNode;
use super::nodes::Node;
use super::serialization;
use super::swizzled_ptr::{NodeType, SwizzledPtr};
use super::transitions::ChildNode;
use super::arena_manager::ArenaManager;
use super::block_storage::BlockStorage;
use super::buffer_manager::BufferManager;
use super::disk_manager::MmapDiskManager;
use super::wal::AsyncWalWriter;
use super::wal_managed::WalManaged;
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.2"))]
#[inline]
fn simd_cmp_bytes(a: &[u8], b: &[u8]) -> std::cmp::Ordering {
use std::arch::x86_64::*;
use std::cmp::Ordering;
let min_len = a.len().min(b.len());
let mut offset = 0;
while offset + 16 <= min_len {
unsafe {
let va = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
let vb = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
let diff = _mm_xor_si128(va, vb);
let mask = _mm_movemask_epi8(_mm_cmpeq_epi8(diff, _mm_setzero_si128()));
if mask != 0xFFFF {
let first_diff = (!mask as u32).trailing_zeros() as usize;
let pos = offset + first_diff;
return a[pos].cmp(&b[pos]);
}
}
offset += 16;
}
for i in offset..min_len {
match a[i].cmp(&b[i]) {
Ordering::Equal => continue,
other => return other,
}
}
a.len().cmp(&b.len())
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse4.2")))]
#[inline]
fn simd_cmp_bytes(a: &[u8], b: &[u8]) -> std::cmp::Ordering {
a.cmp(b)
}
#[cfg(test)]
#[inline]
pub(super) fn bytes_le(a: &[u8], b: &[u8]) -> bool {
matches!(
simd_cmp_bytes(a, b),
std::cmp::Ordering::Less | std::cmp::Ordering::Equal
)
}
#[inline]
pub(super) fn bytes_gt(a: &[u8], b: &[u8]) -> bool {
matches!(simd_cmp_bytes(a, b), std::cmp::Ordering::Greater)
}
pub(super) enum SingleChildData {
Bucket(StringBucket),
ArtNodePartial {
node: Node,
is_final: bool,
child_ptrs: Vec<(u8, SwizzledPtr)>,
value: Option<Vec<u8>>,
},
}
pub(super) fn resolve_child_for_mutation_with_bm<S: BlockStorage>(
child: &mut ChildNode,
buffer_manager: Option<&Arc<RwLock<BufferManager<S>>>>,
) -> bool {
let ChildNode::DiskRef { ptr } = child else {
return true; };
let Some(disk_location) = ptr.disk_location() else {
warn!("DiskRef has no valid disk location");
return false;
};
let Some(bm_arc) = buffer_manager else {
warn!("No buffer manager available for resolving DiskRef");
return false;
};
let resolved: ChildNode = {
let bm = bm_arc.read();
let page_guard = match bm.fetch_page(disk_location.block_id) {
Ok(pg) => pg,
Err(e) => {
warn!(
"Failed to fetch page for DiskRef at block {}: {}",
disk_location.block_id, e
);
return false;
}
};
let page_data = page_guard.data();
let offset = disk_location.offset as usize;
let node_data = &page_data[offset..];
match disk_location.node_type {
NodeType::Bucket => {
ChildNode::Bucket(StringBucket::new())
}
NodeType::Node4 | NodeType::Node16 | NodeType::Node48 | NodeType::Node256 => {
match serialization::from_bytes(node_data) {
Ok(node) => {
let is_final = node.header().is_final();
ChildNode::ArtNode {
node,
is_final,
value: None,
children: Vec::new(),
}
}
Err(e) => {
warn!(
"Failed to deserialize ART node at block {}, offset {}: {}",
disk_location.block_id, disk_location.offset, e
);
return false;
}
}
}
NodeType::CharNode4
| NodeType::CharNode16
| NodeType::CharNode48
| NodeType::CharBucket => {
warn!(
"Char-level node type encountered in byte-level PersistentARTrie at block {}, offset {}",
disk_location.block_id, disk_location.offset
);
return false;
}
}
};
*child = resolved;
true
}
pub struct PersistentARTrie<V: DictionaryValue = (), S: BlockStorage = MmapDiskManager> {
pub(crate) term_count: AtomicUsize,
pub(crate) dirty: AtomicBool,
pub(crate) buffer_manager: Option<Arc<RwLock<BufferManager<S>>>>,
pub(crate) wal_writer: Option<Arc<AsyncWalWriter>>,
pub(crate) next_lsn: std::sync::atomic::AtomicU64,
pub(crate) prefetcher: super::prefetch::Prefetcher,
pub(crate) arena_manager: Option<Arc<RwLock<ArenaManager<S>>>>,
pub(crate) durability_policy:
crate::persistent_artrie_core::shared_access::AtomicEnumCell<DurabilityPolicy>,
pub(crate) epoch_manager: Arc<super::concurrency::EpochManager>,
pub(crate) stats: Arc<super::concurrency::TrieStats>,
pub(crate) eviction_coordinator:
std::sync::Mutex<Option<Arc<super::eviction::EvictionCoordinator>>>,
#[cfg(feature = "persistent-artrie")]
pub(crate) lockfree_root: Option<super::nodes::AtomicNodePtr<V>>,
#[cfg(feature = "persistent-artrie")]
pub(crate) lockfree_cache: Option<dashmap::DashMap<Vec<u8>, bool>>,
#[cfg(feature = "persistent-artrie")]
pub(crate) cas_retries: std::sync::atomic::AtomicU64,
pub(crate) committed_watermark:
crate::persistent_artrie_core::committed_watermark::CommittedWatermark,
pub(crate) checkpoint_lock: std::sync::Arc<parking_lot::Mutex<()>>,
pub(crate) merge_lock: std::sync::Arc<parking_lot::Mutex<()>>,
pub(crate) commit_seq: std::sync::atomic::AtomicU64,
}
pub use super::prefix_term::{PrefixTermWithArena, PrefixTermWithValueAndArena};
pub use super::iterators::{TermIterator, TermValueIterator};
pub use super::transactions::TransactionState;
pub use crate::persistent_artrie_core::durability::DurabilityPolicy;
pub use super::compaction::{CompactionConfig, CompactionProgress, CompactionStats};
pub use super::transactions::DocumentTransaction;
impl<V: DictionaryValue, S: BlockStorage> WalManaged for PersistentARTrie<V, S> {
fn wal_writer(&self) -> Option<&Arc<AsyncWalWriter>> {
self.wal_writer.as_ref()
}
}
impl<V: DictionaryValue> Default for PersistentARTrie<V> {
#[allow(deprecated)]
fn default() -> Self {
Self::new()
}
}
impl<V: DictionaryValue, S: BlockStorage> PersistentARTrie<V, S> {
}
pub(super) const ROOT_TYPE_EMPTY: u8 = 0;
pub(super) const ROOT_TYPE_BUCKET: u8 = 1;
pub(super) const ROOT_TYPE_ART_NODE: u8 = 2;
#[cfg(feature = "parallel-merge")]
pub use super::parallel_merge::SharedARTrieParallelExt;
impl<V: DictionaryValue, S: BlockStorage> PersistentARTrie<V, S> {
pub fn close(&self) {
let coordinator = self
.eviction_coordinator
.lock()
.expect("eviction_coordinator mutex poisoned")
.take();
if let Some(coordinator) = coordinator {
coordinator.shutdown();
}
if let Some(ref wal_writer) = self.wal_writer {
wal_writer.stop_sync();
let _ = wal_writer.sync();
}
if let Some(ref buffer_manager) = self.buffer_manager {
let bm = buffer_manager.read();
let _ = bm.flush_all();
}
}
}
impl<V: DictionaryValue, S: BlockStorage> Drop for PersistentARTrie<V, S> {
fn drop(&mut self) {
self.close();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_dictionary() {
let dict: PersistentARTrie = PersistentARTrie::new();
assert_eq!(dict.len(), Some(0));
assert!(!dict.is_dirty());
}
#[test]
fn test_insert_and_contains() {
let dict: PersistentARTrie = PersistentARTrie::new();
assert!(dict.insert("apple"));
assert!(dict.insert("banana"));
assert!(dict.insert("cherry"));
assert!(dict.contains("apple"));
assert!(dict.contains("banana"));
assert!(dict.contains("cherry"));
assert!(!dict.contains("date"));
assert_eq!(dict.len(), Some(3));
}
#[test]
fn test_duplicate_insert() {
let dict: PersistentARTrie = PersistentARTrie::new();
assert!(dict.insert("test"));
assert!(!dict.insert("test"));
assert_eq!(dict.len(), Some(1));
}
#[test]
fn test_remove() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("apple");
dict.insert("banana");
assert!(dict.remove("apple"));
assert!(!dict.contains("apple"));
assert!(dict.contains("banana"));
assert_eq!(dict.len(), Some(1));
}
#[test]
fn test_remove_not_found() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("apple");
assert!(!dict.remove("banana"));
assert_eq!(dict.len(), Some(1));
}
#[test]
fn test_empty_string() {
let dict: PersistentARTrie = PersistentARTrie::new();
assert!(dict.insert(""));
assert!(dict.contains(""));
dict.insert("test");
assert!(dict.contains(""));
assert!(dict.contains("test"));
}
#[test]
fn test_dictionary_trait() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("hello");
dict.insert("world");
let dict_ref: &dyn Dictionary<Node = _> = &dict;
assert!(dict_ref.contains("hello"));
assert!(!dict_ref.contains("hi"));
}
#[test]
fn test_many_insertions() {
let dict: PersistentARTrie = PersistentARTrie::new();
for i in 0..100 {
dict.insert(&format!("word{:03}", i));
}
assert_eq!(dict.len(), Some(100));
for i in 0..100 {
assert!(dict.contains(&format!("word{:03}", i)));
}
}
#[test]
fn test_sync_strategy() {
let dict: PersistentARTrie = PersistentARTrie::new();
assert_eq!(dict.sync_strategy(), SyncStrategy::InternalSync);
}
#[test]
fn test_iter_empty() {
let dict: PersistentARTrie = PersistentARTrie::new();
let terms: Vec<_> = dict.iter().collect();
assert!(terms.is_empty());
}
#[test]
fn test_iter_single() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("hello");
let terms: Vec<_> = dict.iter().collect();
assert_eq!(terms.len(), 1);
assert_eq!(terms[0], b"hello".to_vec());
}
#[test]
fn test_iter_multiple() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("apple");
dict.insert("banana");
dict.insert("cherry");
let terms: Vec<String> = dict.iter_strings().collect();
assert_eq!(terms.len(), 3);
assert!(terms.contains(&"apple".to_string()));
assert!(terms.contains(&"banana".to_string()));
assert!(terms.contains(&"cherry".to_string()));
}
#[test]
fn test_iter_with_empty_string() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("");
dict.insert("hello");
let terms: Vec<String> = dict.iter_strings().collect();
assert_eq!(terms.len(), 2);
assert!(terms.contains(&"".to_string()));
assert!(terms.contains(&"hello".to_string()));
}
#[test]
fn test_iter_common_prefix() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("test");
dict.insert("testing");
dict.insert("tested");
dict.insert("tester");
let terms: Vec<String> = dict.iter_strings().collect();
assert_eq!(terms.len(), 4);
assert!(terms.contains(&"test".to_string()));
assert!(terms.contains(&"testing".to_string()));
assert!(terms.contains(&"tested".to_string()));
assert!(terms.contains(&"tester".to_string()));
}
#[test]
fn test_iter_preserves_order() {
let dict: PersistentARTrie = PersistentARTrie::new();
dict.insert("cherry");
dict.insert("apple");
dict.insert("banana");
let terms: Vec<String> = dict.iter_strings().collect();
assert_eq!(terms.len(), 3);
}
mod persistent_tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_create_and_open() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
{
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert("hello");
dict.insert("world");
dict.sync().expect("sync");
}
{
let dict: PersistentARTrie<()> =
PersistentARTrie::open(&dict_path).expect("open dict");
assert!(dict.contains("hello"));
assert!(dict.contains("world"));
assert_eq!(dict.len(), Some(2));
}
}
#[test]
fn test_create_fails_if_exists() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
std::fs::write(&dict_path, b"dummy").expect("create file");
let result: Result<PersistentARTrie<()>> = PersistentARTrie::create(&dict_path);
assert!(result.is_err());
}
#[test]
fn test_open_fails_if_not_exists() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("nonexistent.part");
let result: Result<PersistentARTrie<()>> = PersistentARTrie::open(&dict_path);
assert!(result.is_err());
}
#[test]
fn test_wal_recovery() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
{
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert("apple");
dict.insert("banana");
dict.insert("cherry");
dict.sync().expect("sync");
}
{
let dict: PersistentARTrie<()> =
PersistentARTrie::open(&dict_path).expect("open dict");
assert!(dict.contains("apple"));
assert!(dict.contains("banana"));
assert!(dict.contains("cherry"));
}
}
#[test]
fn test_checkpoint() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert("test");
dict.checkpoint().expect("checkpoint");
}
#[test]
fn test_sync() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert("test");
dict.sync().expect("sync");
}
#[test]
fn test_many_insertions_persistent() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
{
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
for i in 0..50 {
dict.insert(&format!("word{:03}", i));
}
dict.sync().expect("sync");
}
{
let dict: PersistentARTrie<()> =
PersistentARTrie::open(&dict_path).expect("open dict");
assert_eq!(dict.len(), Some(50));
for i in 0..50 {
assert!(
dict.contains(&format!("word{:03}", i)),
"missing word{:03}",
i
);
}
}
}
}
mod atomic_ops_tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_increment_new_term() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
let result = dict.increment("counter", 1).expect("increment");
assert_eq!(result, 1, "First increment should return delta value");
assert!(dict.contains("counter"));
}
#[test]
fn test_upsert_new_term() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let dict: PersistentARTrie<String> =
PersistentARTrie::create(&dict_path).expect("create dict");
let is_new = dict
.upsert("greeting", "hello".to_string())
.expect("upsert");
assert!(is_new, "Should return true for new insertion");
let value = dict.get_value("greeting");
assert_eq!(value, Some("hello".to_string()));
}
#[test]
fn test_upsert_existing_term() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let dict: PersistentARTrie<String> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.upsert("greeting", "hello".to_string())
.expect("upsert");
let is_new = dict.upsert("greeting", "hi".to_string()).expect("upsert");
assert!(!is_new, "Should return false for update");
let value = dict.get_value("greeting");
assert_eq!(value, Some("hi".to_string()));
}
#[test]
fn test_compare_and_swap_success() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.upsert("counter", 0i32).expect("upsert");
let success = dict.compare_and_swap("counter", Some(0), 1).expect("cas");
assert!(success, "CAS should succeed when expected matches");
assert_eq!(dict.get_value("counter"), Some(1));
}
#[test]
fn test_compare_and_swap_failure() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.upsert("counter", 5i32).expect("upsert");
let success = dict.compare_and_swap("counter", Some(0), 10).expect("cas");
assert!(!success, "CAS should fail when expected doesn't match");
assert_eq!(dict.get_value("counter"), Some(5));
}
#[test]
fn test_compare_and_swap_none_expected() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
let success = dict.compare_and_swap("new_key", None, 42).expect("cas");
assert!(
success,
"CAS should succeed when expecting None and key doesn't exist"
);
assert_eq!(dict.get_value("new_key"), Some(42));
}
#[test]
fn test_fetch_add() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.upsert("counter", 10i64).expect("upsert");
let old = dict.fetch_add("counter", 5).expect("fetch_add");
assert_eq!(old, 10, "fetch_add should return old value");
let new_val = dict.increment("counter", 0).expect("read");
assert_eq!(new_val, 15);
}
#[test]
fn test_get_or_insert_new() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
let value = dict.get_or_insert("key", 42).expect("get_or_insert");
assert_eq!(value, 42);
assert!(dict.contains("key"));
}
#[test]
fn test_get_or_insert_existing() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("atomic_test.part");
let mut dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.upsert("key", 100i32).expect("upsert");
let value = dict.get_or_insert("key", 42).expect("get_or_insert");
assert_eq!(value, 100, "Should return existing value, not default");
}
#[test]
fn test_document_transaction_commit() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("tx_test.part");
let dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
let mut tx = dict.begin_document("doc1").expect("begin transaction");
assert_eq!(tx.state, TransactionState::Active);
dict.tx_insert(&mut tx, "term1", Some(100));
dict.tx_insert(&mut tx, "term2", Some(200));
dict.tx_insert(&mut tx, "term3", None);
assert!(!dict.contains("term1"));
assert!(!dict.contains("term2"));
assert!(!dict.contains("term3"));
let count = dict.commit_document(tx).expect("commit");
assert_eq!(count, 3);
assert!(dict.contains("term1"));
assert!(dict.contains("term2"));
assert!(dict.contains("term3"));
assert_eq!(dict.get_value("term1"), Some(100));
assert_eq!(dict.get_value("term2"), Some(200));
assert_eq!(dict.get_value("term3"), None);
}
#[test]
fn test_document_transaction_abort() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("tx_test.part");
let dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert_with_value("existing", 42);
assert!(dict.contains("existing"));
let mut tx = dict.begin_document("doc1").expect("begin transaction");
dict.tx_insert(&mut tx, "term1", Some(100));
dict.tx_insert(&mut tx, "term2", Some(200));
dict.abort_document(tx).expect("abort");
assert!(!dict.contains("term1"));
assert!(!dict.contains("term2"));
assert!(dict.contains("existing"));
assert_eq!(dict.get_value("existing"), Some(42));
}
#[test]
fn test_document_transaction_empty_commit() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("tx_test.part");
let dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
let tx = dict.begin_document("empty_doc").expect("begin transaction");
let count = dict.commit_document(tx).expect("commit");
assert_eq!(count, 0);
}
#[test]
fn test_document_transaction_bytes() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("tx_test.part");
let dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
let mut tx = dict.begin_document("doc1").expect("begin transaction");
dict.tx_insert_bytes(&mut tx, b"binary_term", Some(999));
let count = dict.commit_document(tx).expect("commit");
assert_eq!(count, 1);
assert!(dict.contains("binary_term"));
assert_eq!(dict.get_value("binary_term"), Some(999));
}
#[test]
fn test_multiple_document_transactions() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("tx_test.part");
let dict: PersistentARTrie<i64> =
PersistentARTrie::create(&dict_path).expect("create dict");
let mut tx1 = dict.begin_document("doc1").expect("begin tx1");
dict.tx_insert(&mut tx1, "doc1_term1", Some(1));
dict.tx_insert(&mut tx1, "doc1_term2", Some(2));
dict.commit_document(tx1).expect("commit tx1");
let mut tx2 = dict.begin_document("doc2").expect("begin tx2");
dict.tx_insert(&mut tx2, "doc2_term1", Some(100));
dict.abort_document(tx2).expect("abort tx2");
let mut tx3 = dict.begin_document("doc3").expect("begin tx3");
dict.tx_insert(&mut tx3, "doc3_term1", Some(300));
dict.commit_document(tx3).expect("commit tx3");
assert!(dict.contains("doc1_term1"));
assert!(dict.contains("doc1_term2"));
assert!(!dict.contains("doc2_term1")); assert!(dict.contains("doc3_term1"));
assert_eq!(dict.get_value("doc1_term1"), Some(1));
assert_eq!(dict.get_value("doc3_term1"), Some(300));
}
}
mod sequential_siblings_tests {
use super::*;
use crate::persistent_artrie::nodes::{ChildStorage, Node, Node4};
use crate::persistent_artrie::swizzled_ptr::SwizzledPtr;
#[test]
fn test_check_sequential_children_empty() {
let node = Node::N4(Box::new(Node4::new()));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 0);
assert!(result.is_none());
}
#[test]
fn test_check_sequential_children_single_child() {
let mut n4 = Node4::new();
let child_ptr = SwizzledPtr::on_disk(
1,
10,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let _ = n4.add_child(b'a', child_ptr);
let node = Node::N4(Box::new(n4));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 0);
assert!(result.is_none(), "Single child should not use sequential");
}
#[test]
fn test_check_sequential_children_consecutive() {
let mut n4 = Node4::new();
let ptr1 = SwizzledPtr::on_disk(
1,
10,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let ptr2 = SwizzledPtr::on_disk(
1,
11,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let _ = n4.add_child(b'a', ptr1);
let _ = n4.add_child(b'b', ptr2);
let node = Node::N4(Box::new(n4));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 0);
assert!(
result.is_some(),
"Consecutive children should use sequential"
);
let first = result.unwrap();
assert_eq!(first.arena_id, 0);
assert_eq!(first.slot_id, 10);
}
#[test]
fn test_check_sequential_children_not_consecutive() {
let mut n4 = Node4::new();
let ptr1 = SwizzledPtr::on_disk(
1,
10,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let ptr2 = SwizzledPtr::on_disk(
1,
15,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
); let _ = n4.add_child(b'a', ptr1);
let _ = n4.add_child(b'b', ptr2);
let node = Node::N4(Box::new(n4));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 0);
assert!(
result.is_none(),
"Non-consecutive slots should not use sequential"
);
}
#[test]
fn test_check_sequential_children_different_arenas() {
let mut n4 = Node4::new();
let ptr1 = SwizzledPtr::on_disk(
1,
10,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let ptr2 = SwizzledPtr::on_disk(
2,
11,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
); let _ = n4.add_child(b'a', ptr1);
let _ = n4.add_child(b'b', ptr2);
let node = Node::N4(Box::new(n4));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 0);
assert!(
result.is_none(),
"Cross-arena children should not use sequential"
);
}
#[test]
fn test_check_sequential_children_wrong_parent_arena() {
let mut n4 = Node4::new();
let ptr1 = SwizzledPtr::on_disk(
1,
10,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let ptr2 = SwizzledPtr::on_disk(
1,
11,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let _ = n4.add_child(b'a', ptr1);
let _ = n4.add_child(b'b', ptr2);
let node = Node::N4(Box::new(n4));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 1);
assert!(result.is_none(), "Children must be in same arena as parent");
}
#[test]
fn test_check_sequential_children_three_consecutive() {
let mut n4 = Node4::new();
let ptr1 = SwizzledPtr::on_disk(
1,
100,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let ptr2 = SwizzledPtr::on_disk(
1,
101,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let ptr3 = SwizzledPtr::on_disk(
1,
102,
crate::persistent_artrie::swizzled_ptr::NodeType::Node4,
);
let _ = n4.add_child(b'a', ptr1);
let _ = n4.add_child(b'b', ptr2);
let _ = n4.add_child(b'c', ptr3);
let node = Node::N4(Box::new(n4));
let result = PersistentARTrie::<()>::check_sequential_children(&node, 0);
assert!(result.is_some());
let first = result.unwrap();
assert_eq!(first.arena_id, 0);
assert_eq!(first.slot_id, 100);
}
#[test]
fn test_child_storage_enum() {
let direct = ChildStorage::Direct;
assert!(direct.is_direct());
assert!(!direct.is_sequential());
assert!(direct.first_slot().is_none());
let sequential = ChildStorage::sequential(5, 100);
assert!(!sequential.is_direct());
assert!(sequential.is_sequential());
assert_eq!(sequential.arena_id(), Some(5));
assert_eq!(sequential.first_slot(), Some(100));
}
}
mod optimization_tests {
use super::*;
#[test]
fn test_multi_level_prefetch_respects_depth_limit() {
use crate::persistent_artrie::prefetch::{PrefetchStrategy, Prefetcher};
use crate::persistent_artrie::swizzled_ptr::{NodeType, SwizzledPtr};
let prefetcher = Prefetcher::with_config(100, PrefetchStrategy::DepthLimited(2));
let children: Vec<(u8, SwizzledPtr)> = (0..5)
.map(|i| (i, SwizzledPtr::on_disk(i as u32, 0, NodeType::Node4)))
.collect();
prefetcher.prefetch_children_bounded(&children, 0);
assert_eq!(prefetcher.queue_len(), 5);
prefetcher.clear();
prefetcher.prefetch_children_bounded(&children, 1);
assert_eq!(prefetcher.queue_len(), 5);
prefetcher.clear();
prefetcher.prefetch_children_bounded(&children, 2);
assert_eq!(prefetcher.queue_len(), 5);
prefetcher.clear();
prefetcher.prefetch_children_bounded(&children, 3);
assert_eq!(prefetcher.queue_len(), 0);
}
#[test]
fn test_prefetch_children_bounded_with_first_n_strategy() {
use crate::persistent_artrie::prefetch::{PrefetchStrategy, Prefetcher};
use crate::persistent_artrie::swizzled_ptr::{NodeType, SwizzledPtr};
let prefetcher = Prefetcher::with_config(100, PrefetchStrategy::FirstN(3));
let children: Vec<(u8, SwizzledPtr)> = (0..10)
.map(|i| (i, SwizzledPtr::on_disk(i as u32, 0, NodeType::Node4)))
.collect();
prefetcher.prefetch_children_bounded(&children, 0);
assert_eq!(prefetcher.queue_len(), 3);
prefetcher.clear();
prefetcher.prefetch_children_bounded(&children, 5);
assert_eq!(prefetcher.queue_len(), 3);
}
#[test]
fn test_prefetch_disabled_strategy() {
use crate::persistent_artrie::prefetch::{PrefetchStrategy, Prefetcher};
use crate::persistent_artrie::swizzled_ptr::{NodeType, SwizzledPtr};
let prefetcher = Prefetcher::with_config(100, PrefetchStrategy::Disabled);
let children: Vec<(u8, SwizzledPtr)> = (0..5)
.map(|i| (i, SwizzledPtr::on_disk(i as u32, 0, NodeType::Node4)))
.collect();
prefetcher.prefetch_children_bounded(&children, 0);
assert_eq!(prefetcher.queue_len(), 0);
}
#[test]
fn test_merge_arena_sorting_preserves_correctness() {
let source: PersistentARTrie<u32> = PersistentARTrie::new();
source.insert_with_value("apple", 1);
source.insert_with_value("banana", 2);
source.insert_with_value("cherry", 3);
source.insert_with_value("apricot", 4);
source.insert_with_value("blueberry", 5);
let mut target: PersistentARTrie<u32> = PersistentARTrie::new();
target.insert_with_value("apple", 10);
target.insert_with_value("date", 6);
let result1 = target.merge_from_batched(&source, |a, b| a + b, 2);
assert!(result1.is_ok());
let count1 = result1.unwrap();
assert_eq!(count1, 5);
assert_eq!(target.get_value("apple"), Some(11)); assert_eq!(target.get_value("banana"), Some(2));
assert_eq!(target.get_value("cherry"), Some(3));
assert_eq!(target.get_value("apricot"), Some(4));
assert_eq!(target.get_value("blueberry"), Some(5));
assert_eq!(target.get_value("date"), Some(6)); }
#[test]
fn test_merge_arena_grouped_ordering() {
let source: PersistentARTrie<u32> = PersistentARTrie::new();
for i in 0..100 {
let term = format!("term{:03}", i);
source.insert_with_value(&term, i);
}
let mut target: PersistentARTrie<u32> = PersistentARTrie::new();
let result = target.merge_from_batched_grouped(&source, |a, b| a + b, 20);
assert!(result.is_ok());
let count = result.unwrap();
assert_eq!(count, 100);
for i in 0..100 {
let term = format!("term{:03}", i);
assert_eq!(target.get_value(&term), Some(i));
}
}
#[test]
fn test_insert_batch_arena_grouped_ordering() {
let trie: PersistentARTrie<u32> = PersistentARTrie::new();
let entries: Vec<(Vec<u8>, Option<u32>)> = vec![
(b"zebra".to_vec(), Some(1)),
(b"apple".to_vec(), Some(2)),
(b"apricot".to_vec(), Some(3)),
(b"zoo".to_vec(), Some(4)),
(b"banana".to_vec(), Some(5)),
(b"azure".to_vec(), Some(6)),
];
let count = trie.insert_batch_arena_grouped(entries);
assert_eq!(count, 6);
assert_eq!(trie.get_value("zebra"), Some(1));
assert_eq!(trie.get_value("apple"), Some(2));
assert_eq!(trie.get_value("apricot"), Some(3));
assert_eq!(trie.get_value("zoo"), Some(4));
assert_eq!(trie.get_value("banana"), Some(5));
assert_eq!(trie.get_value("azure"), Some(6));
}
#[test]
fn test_insert_batch_grouped_string_variant() {
let trie: PersistentARTrie<u32> = PersistentARTrie::new();
let entries: Vec<(String, Option<u32>)> = vec![
("zebra".to_string(), Some(1)),
("apple".to_string(), Some(2)),
("apricot".to_string(), Some(3)),
("zoo".to_string(), Some(4)),
("banana".to_string(), Some(5)),
("azure".to_string(), Some(6)),
];
let count = trie.insert_batch_grouped(entries);
assert_eq!(count, 6);
assert_eq!(trie.get_value("zebra"), Some(1));
assert_eq!(trie.get_value("apple"), Some(2));
assert_eq!(trie.get_value("apricot"), Some(3));
}
#[test]
fn test_insert_batch_arena_grouped_empty() {
let trie: PersistentARTrie<u32> = PersistentARTrie::new();
let entries: Vec<(Vec<u8>, Option<u32>)> = vec![];
let count = trie.insert_batch_arena_grouped(entries);
assert_eq!(count, 0);
assert_eq!(trie.len(), Some(0));
}
#[test]
fn test_insert_batch_grouped_preserves_values() {
let trie: PersistentARTrie<String> = PersistentARTrie::new();
let entries: Vec<(String, Option<String>)> = vec![
("key1".to_string(), Some("value1".to_string())),
("key2".to_string(), Some("value2".to_string())),
("akey".to_string(), Some("avalue".to_string())),
];
let count = trie.insert_batch_grouped(entries);
assert_eq!(count, 3);
assert_eq!(trie.get_value("key1"), Some("value1".to_string()));
assert_eq!(trie.get_value("key2"), Some("value2".to_string()));
assert_eq!(trie.get_value("akey"), Some("avalue".to_string()));
}
#[test]
fn test_arena_manager_flush_sequential() {
use crate::persistent_artrie::arena_manager::ArenaManager;
use crate::persistent_artrie::disk_manager::MmapDiskManager;
let mut manager: ArenaManager<MmapDiskManager> = ArenaManager::new();
manager.allocate(b"test1").expect("alloc 1");
manager.allocate(b"test2").expect("alloc 2");
let result = manager.flush_sequential();
assert!(result.is_ok());
}
}
mod lsn_api_tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_current_lsn_starts_at_one_for_persistent() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("lsn_test.part");
let dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
assert_eq!(dict.current_lsn(), 1);
}
#[test]
fn test_current_lsn_starts_at_zero_for_in_memory() {
let dict: PersistentARTrie<i32> = PersistentARTrie::new();
assert_eq!(dict.current_lsn(), 0);
}
#[test]
fn test_current_lsn_increases_after_insert() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("lsn_test.part");
let dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
let before = dict.current_lsn();
dict.insert_with_value("key1", 42);
let after = dict.current_lsn();
assert!(
after > before,
"LSN should increase after insert: before={}, after={}",
before,
after
);
}
#[test]
fn test_current_lsn_increases_after_remove() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("lsn_test.part");
let dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert_with_value("key1", 42);
let before = dict.current_lsn();
dict.remove("key1");
let after = dict.current_lsn();
assert!(
after > before,
"LSN should increase after remove: before={}, after={}",
before,
after
);
}
#[test]
fn test_synced_lsn_none_for_in_memory() {
let dict: PersistentARTrie<i32> = PersistentARTrie::new();
assert!(
dict.synced_lsn().is_none(),
"In-memory trie should have no synced LSN"
);
}
#[test]
fn test_synced_lsn_after_sync() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("lsn_test.part");
let dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert_with_value("key1", 42);
dict.insert_with_value("key2", 43);
let synced_before = dict
.synced_lsn()
.expect("persistent trie should have synced_lsn");
assert_eq!(
synced_before,
dict.current_lsn().saturating_sub(1),
"Immediate durability should sync through the last acknowledged write"
);
dict.sync().expect("sync should succeed");
let synced_after = dict
.synced_lsn()
.expect("persistent trie should have synced_lsn");
assert!(
synced_after >= synced_before,
"synced_lsn should not go backwards after sync: before={}, after={}",
synced_before,
synced_after
);
}
#[test]
fn test_synced_lsn_invariant() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("lsn_test.part");
let dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert_with_value("key1", 42);
dict.sync().expect("sync should succeed");
dict.insert_with_value("key2", 43);
let current = dict.current_lsn();
let synced = dict
.synced_lsn()
.expect("persistent trie should have synced_lsn");
assert!(
synced < current,
"synced_lsn ({}) should be less than current_lsn ({})",
synced,
current
);
}
#[test]
fn test_lsn_monotonically_increasing() {
let dir = tempdir().expect("create temp dir");
let dict_path = dir.path().join("lsn_test.part");
let dict: PersistentARTrie<i32> =
PersistentARTrie::create(&dict_path).expect("create dict");
let mut prev_lsn = dict.current_lsn();
for i in 0..10 {
dict.insert_with_value(&format!("key{}", i), i);
let curr_lsn = dict.current_lsn();
assert!(
curr_lsn > prev_lsn,
"LSN should increase monotonically: prev={}, curr={}",
prev_lsn,
curr_lsn
);
prev_lsn = curr_lsn;
}
}
}
#[test]
fn test_simd_cmp_empty_slices() {
assert!(bytes_le(b"", b""));
assert!(!bytes_gt(b"", b""));
}
#[test]
fn test_simd_cmp_different_lengths_prefix() {
assert!(bytes_le(b"abc", b"abcd"));
assert!(bytes_gt(b"abcd", b"abc"));
}
#[test]
fn test_simd_cmp_first_byte_difference() {
assert!(bytes_le(b"a", b"b"));
assert!(bytes_gt(b"b", b"a"));
}
#[test]
fn test_simd_cmp_position_1_difference() {
assert!(bytes_le(b"aa", b"ab"));
assert!(bytes_gt(b"ab", b"aa"));
}
#[test]
fn test_simd_cmp_mid_chunk_difference() {
let a = b"aaaaaaaa_aaaaaaa";
let b = b"aaaaaaaazaaaaaaa";
assert!(bytes_le(a, b));
assert!(bytes_gt(b, a));
}
#[test]
fn test_simd_cmp_position_15_difference() {
let a = b"aaaaaaaaaaaaaaax";
let b = b"aaaaaaaaaaaaaaay";
assert!(bytes_le(a, b));
assert!(bytes_gt(b, a));
}
#[test]
fn test_simd_cmp_across_chunk_boundary() {
let a = b"aaaaaaaaaaaaaaaa_bbbbbbbbbbbbbbb";
let b = b"aaaaaaaaaaaaaaaa~bbbbbbbbbbbbbbb";
assert!(bytes_le(a, b));
assert!(bytes_gt(b, a));
}
#[test]
fn test_simd_cmp_long_equal_prefix() {
let mut a = vec![b'x'; 100];
let mut b = vec![b'x'; 100];
a.push(b'a');
b.push(b'b');
assert!(bytes_le(&a, &b));
assert!(bytes_gt(&b, &a));
}
#[test]
fn test_simd_cmp_scalar_fallback() {
let a = b"hello";
let b = b"helli";
assert!(bytes_gt(a, b)); assert!(bytes_le(b, a));
}
#[test]
fn test_simd_cmp_exactly_16_bytes() {
let a = b"abcdefghijklmnop";
let b = b"abcdefghijklmnop";
assert!(bytes_le(a, b)); assert!(!bytes_gt(a, b));
}
#[test]
fn test_simd_cmp_all_positions_in_chunk() {
for pos in 0..16 {
let mut a = vec![b'a'; 16];
let mut b = vec![b'a'; 16];
a[pos] = b'x';
b[pos] = b'y';
assert!(bytes_le(&a, &b), "bytes_le failed at position {}", pos);
assert!(bytes_gt(&b, &a), "bytes_gt failed at position {}", pos);
}
}
#[test]
fn test_simd_cmp_utf8_multibyte() {
let a = "hello世界";
let b = "hello地球";
assert!(
bytes_le(a.as_bytes(), b.as_bytes()) || bytes_gt(a.as_bytes(), b.as_bytes()),
"One must be true"
);
}
#[test]
fn test_resolve_child_already_in_memory() {
let bucket = StringBucket::new();
let mut child = ChildNode::Bucket(bucket);
let none_bm: Option<&Arc<RwLock<BufferManager>>> = None;
assert!(resolve_child_for_mutation_with_bm(&mut child, none_bm));
}
#[test]
fn test_resolve_child_art_node_already_in_memory() {
let node = Node::new_node4();
let mut child = ChildNode::ArtNode {
node,
is_final: false,
value: None,
children: Vec::new(),
};
let none_bm: Option<&Arc<RwLock<BufferManager>>> = None;
assert!(resolve_child_for_mutation_with_bm(&mut child, none_bm));
}
#[test]
fn test_resolve_child_disk_ref_no_buffer_manager() {
let ptr = SwizzledPtr::on_disk(1, 0, NodeType::Node4);
let mut child = ChildNode::DiskRef { ptr };
let none_bm: Option<&Arc<RwLock<BufferManager>>> = None;
assert!(!resolve_child_for_mutation_with_bm(&mut child, none_bm));
}
#[test]
fn test_bytes_le_equality() {
assert!(bytes_le(b"test", b"test"));
assert!(!bytes_gt(b"test", b"test"));
}
#[test]
fn test_simd_cmp_binary_data() {
let a: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC];
let b: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFB];
assert!(bytes_gt(a, b)); }
#[test]
fn test_simd_cmp_null_bytes() {
let a = b"\x00\x00\x01";
let b = b"\x00\x00\x02";
assert!(bytes_le(a, b));
assert!(bytes_gt(b, a));
}
mod error_path_tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_open_nonexistent_returns_error() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("nonexistent.part");
let result: Result<PersistentARTrie<()>> = PersistentARTrie::open(&dict_path);
assert!(result.is_err());
}
#[test]
fn test_create_with_invalid_parent_path() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("nested/deep/path/test.part");
let result: Result<PersistentARTrie<()>> = PersistentARTrie::create(&dict_path);
assert!(
result.is_ok(),
"Create should handle nested directory creation"
);
}
#[test]
fn test_sync_on_new_dict() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("new.part");
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.sync().expect("sync empty dict");
}
#[test]
fn test_checkpoint_on_new_dict() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("new.part");
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.checkpoint().expect("checkpoint empty dict");
}
#[test]
fn test_open_with_recovery_new_file() {
let temp_dir = TempDir::new().expect("create temp dir");
let dict_path = temp_dir.path().join("test.part");
{
let dict: PersistentARTrie<()> =
PersistentARTrie::create(&dict_path).expect("create dict");
dict.insert("test");
dict.sync().expect("sync");
}
let (dict, report) =
PersistentARTrie::<()>::open_with_recovery(&dict_path).expect("open_with_recovery");
assert!(report.mode.is_normal());
assert!(dict.contains("test"));
}
}
}