use ferntree::{OptimisticRead, Tree};
use rand::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn concurrent_insert_disjoint_ranges() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 4;
let entries_per_thread = 100;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
for i in 0..entries_per_thread {
let key = t * entries_per_thread + i;
tree.insert(key, key * 10);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert_eq!(tree.len(), (num_threads * entries_per_thread) as usize);
for t in 0..num_threads {
for i in 0..entries_per_thread {
let key = t * entries_per_thread + i;
assert_eq!(tree.lookup(&key, |v| *v), Some(key * 10), "Missing key {}", key);
}
}
}
#[test]
fn concurrent_insert_same_keys() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 4;
let iterations = 100;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
for i in 0..iterations {
let key = i % 10; tree.insert(key, t); }
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert_eq!(tree.len(), 10);
for key in 0..10 {
let value = tree.lookup(&key, |v| *v).expect("Key should exist");
assert!(value < num_threads, "Invalid value {} for key {}", value, key);
}
}
#[test]
fn many_concurrent_readers() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_readers = 4;
let entries = 100;
for i in 0..entries {
tree.insert(i, i * 10);
}
tree.assert_invariants();
let handles: Vec<_> = (0..num_readers)
.map(|_| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
let mut iter = tree.raw_iter();
iter.seek_to_first();
let mut count = 0;
while let Some((k, v)) = iter.next() {
assert_eq!(*v, *k * 10);
count += 1;
}
count
})
})
.collect();
for h in handles {
let count = h.join().unwrap();
assert_eq!(count, entries);
}
tree.assert_invariants();
}
#[test]
fn concurrent_insert_and_lookup() {
let tree = Arc::new(Tree::<i32, i32>::new());
let entries = 100;
for i in 0..entries {
tree.insert(i, i * 10);
}
tree.assert_invariants();
let tree_writer = Arc::clone(&tree);
let tree_reader = Arc::clone(&tree);
let writer = thread::spawn(move || {
for i in entries..(entries + 50) {
tree_writer.insert(i, i * 10);
}
});
let reader = thread::spawn(move || {
let mut found = 0;
for i in 0..entries {
if tree_reader.lookup(&i, |v| *v).is_some() {
found += 1;
}
}
found
});
writer.join().unwrap();
let found = reader.join().unwrap();
tree.assert_invariants();
assert!(found > 0);
assert!(tree.len() >= entries as usize);
}
#[test]
fn iterate_while_inserting() {
let tree = Arc::new(Tree::<i32, i32>::new());
for i in 0..50 {
tree.insert(i, i);
}
tree.assert_invariants();
let tree_writer = Arc::clone(&tree);
let tree_reader = Arc::clone(&tree);
let writer = thread::spawn(move || {
for i in 50..75 {
tree_writer.insert(i, i);
}
});
let reader = thread::spawn(move || {
let mut iter = tree_reader.raw_iter();
iter.seek_to_first();
let mut prev = -1i32;
let mut count = 0;
while let Some((k, _)) = iter.next() {
assert!(*k > prev, "Order violation: {} not > {}", *k, prev);
prev = *k;
count += 1;
}
count
});
writer.join().unwrap();
let count = reader.join().unwrap();
tree.assert_invariants();
assert!(count > 0);
}
#[test]
fn concurrent_inserts_cause_splits() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 2;
let entries_per_thread = 50;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
for i in 0..entries_per_thread {
tree.insert(t * entries_per_thread + i, i);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert!(tree.height() > 1, "Tree should have split");
assert_eq!(tree.len(), (num_threads * entries_per_thread) as usize);
for t in 0..num_threads {
for i in 0..entries_per_thread {
let key = t * entries_per_thread + i;
assert_eq!(tree.lookup(&key, |v| *v), Some(i), "Missing key {}", key);
}
}
}
#[test]
fn concurrent_removes() {
let tree = Arc::new(Tree::<i32, i32>::new());
let entries = 100;
for i in 0..entries {
tree.insert(i, i);
}
tree.assert_invariants();
let num_threads = 2;
let entries_per_thread = entries / num_threads;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
for i in 0..entries_per_thread {
let key = t * entries_per_thread + i;
tree.remove(&key);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert!(tree.is_empty());
}
#[test]
fn stress_concurrent_insert_overlapping_ranges() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 8;
let entries_per_thread = 1000;
let key_range = 500;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
let mut rng = rand::rng();
for _ in 0..entries_per_thread {
let key: i32 = rng.random_range(0..key_range);
tree.insert(key, t);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert!(tree.len() <= key_range as usize);
}
#[test]
fn stress_concurrent_mixed_operations() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 8;
let operations_per_thread = 1000;
let key_range = 500;
for i in 0..key_range {
tree.insert(i, i);
}
tree.assert_invariants();
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
let mut rng = rand::rng();
for _ in 0..operations_per_thread {
let key: i32 = rng.random_range(0..key_range);
let op: u8 = rng.random_range(0..3);
match op {
0 => {
tree.insert(key, t);
}
1 => {
tree.remove(&key);
}
2 => {
tree.lookup(&key, |v| *v);
}
_ => unreachable!(),
}
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
let len = tree.len();
assert!(len <= key_range as usize);
}
#[test]
fn stress_high_contention_single_key() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 8;
let iterations = 1000;
tree.insert(42, 0);
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
for i in 0..iterations {
if i % 2 == 0 {
tree.insert(42, t);
} else {
tree.lookup(&42, |v| *v);
}
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert!(tree.lookup(&42, |v| *v).is_some());
assert_eq!(tree.len(), 1);
}
#[test]
fn stress_sustained_mixed_operations() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 4;
let duration_ms = 500;
let running = Arc::new(AtomicUsize::new(1));
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
let running = Arc::clone(&running);
thread::spawn(move || {
let mut rng = rand::rng();
let mut ops = 0u64;
while running.load(Ordering::Relaxed) == 1 {
let key: i32 = rng.random_range(0..1000);
let op: u8 = rng.random_range(0..10);
match op {
0..=3 => {
tree.insert(key, t);
}
4..=5 => {
tree.remove(&key);
}
6..=9 => {
tree.lookup(&key, |v| *v);
}
_ => unreachable!(),
}
ops += 1;
}
ops
})
})
.collect();
thread::sleep(Duration::from_millis(duration_ms));
running.store(0, Ordering::Relaxed);
let total_ops: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
tree.assert_invariants();
assert!(total_ops > 100, "Only {} operations performed", total_ops);
}
#[test]
fn stress_large_scale_concurrent_inserts() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_threads = 8;
let entries_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let tree = Arc::clone(&tree);
thread::spawn(move || {
for i in 0..entries_per_thread {
let key = t * entries_per_thread + i;
tree.insert(key, key * 10);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
tree.assert_invariants();
assert_eq!(tree.len(), (num_threads * entries_per_thread) as usize);
}
#[test]
fn stress_producer_consumer() {
let tree = Arc::new(Tree::<i32, i32>::new());
let num_producers = 4;
let num_consumers = 4;
let entries_per_producer = 500;
let produced = Arc::new(AtomicUsize::new(0));
let producer_handles: Vec<_> = (0..num_producers)
.map(|p| {
let tree = Arc::clone(&tree);
let produced = Arc::clone(&produced);
thread::spawn(move || {
for i in 0..entries_per_producer {
let key = p * entries_per_producer + i;
tree.insert(key, key * 10);
produced.fetch_add(1, Ordering::Relaxed);
}
})
})
.collect();
let consumer_handles: Vec<_> = (0..num_consumers)
.map(|_| {
let tree = Arc::clone(&tree);
let produced = Arc::clone(&produced);
thread::spawn(move || {
let mut rng = rand::rng();
let mut found = 0u64;
let total_entries = num_producers * entries_per_producer;
while produced.load(Ordering::Relaxed) < total_entries as usize {
let key: i32 = rng.random_range(0..total_entries);
if tree.lookup(&key, |v| *v).is_some() {
found += 1;
}
}
found
})
})
.collect();
for h in producer_handles {
h.join().unwrap();
}
let total_found: u64 = consumer_handles.into_iter().map(|h| h.join().unwrap()).sum();
tree.assert_invariants();
assert!(total_found > 0);
assert_eq!(tree.len(), (num_producers * entries_per_producer) as usize);
}
#[test]
fn concurrent_lookup_with_interior_pointer_values() {
const NUM_KEYS: u32 = 16;
const WRITERS: usize = 12;
const READERS: usize = 12;
const OPS_PER_THREAD: usize = 2_000;
type Versions = Vec<u64>;
let tree: Arc<Tree<u32, Versions>> = Arc::new(Tree::new());
for k in 0..NUM_KEYS {
let sv: Versions = vec![0];
tree.insert(k, sv);
}
let mut handles = Vec::new();
for w in 0..WRITERS {
let tree = Arc::clone(&tree);
handles.push(thread::spawn(move || {
let mut rng = StdRng::seed_from_u64(0xA110C + w as u64);
for i in 0..OPS_PER_THREAD {
let key = rng.random_range(0..NUM_KEYS);
let mut iter = tree.raw_iter_mut();
if iter.seek_exact(&key) {
let (_, versions) = iter.next().expect("seek_exact returned true");
if versions.len() > 4 && (i & 1) == 0 {
versions.truncate(2);
} else {
versions.push(i as u64);
}
}
}
}));
}
for r in 0..READERS {
let tree = Arc::clone(&tree);
handles.push(thread::spawn(move || {
let mut rng = StdRng::seed_from_u64(0xBEE5 + r as u64);
for _ in 0..OPS_PER_THREAD {
let key = rng.random_range(0..NUM_KEYS);
let _ = tree.lookup(&key, |v| v.iter().copied().sum::<u64>());
}
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
tree.assert_invariants();
}
#[derive(Clone)]
#[allow(dead_code)] struct RefcountedBlob(Arc<Vec<u8>>);
unsafe impl OptimisticRead for RefcountedBlob {
const EPOCH_DEFERRED_DROP: bool = true;
type Slot = ferntree::atomic_slot::BoxedSlot<Self>;
}
#[cfg(any())]
#[test]
fn epoch_deferred_drop_optimistic_reader_vs_defer_writer() {
let tree: Arc<Tree<i32, RefcountedBlob>> = Arc::new(Tree::new());
let stop = Arc::new(AtomicBool::new(false));
for i in 0..200 {
tree.insert_defer(i, RefcountedBlob(Arc::new(vec![i as u8; 32])));
}
let mut handles = Vec::new();
for _ in 0..4 {
let tree = Arc::clone(&tree);
let stop = Arc::clone(&stop);
handles.push(thread::spawn(move || {
while !stop.load(Ordering::Relaxed) {
for k in 0..200 {
if let Some(blob) = tree.lookup_optimistic(&k, |v| v.clone()) {
let s: usize = blob.0.iter().map(|&b| b as usize).sum();
std::hint::black_box(s);
}
}
}
}));
}
let writer = {
let tree = Arc::clone(&tree);
let stop = Arc::clone(&stop);
thread::spawn(move || {
for round in 0..50 {
for k in 0..200 {
let v = RefcountedBlob(Arc::new(vec![(k + round) as u8; 32]));
tree.insert_defer(k, v);
}
}
stop.store(true, Ordering::Relaxed);
})
};
writer.join().unwrap();
for h in handles {
h.join().unwrap();
}
}
#[cfg(any())]
#[test]
fn k_deferred_drop_optimistic_reader_vs_defer_writer() {
const KEYS: u8 = 64;
let mk_key = |i: u8| RcKey(Arc::new(vec![i; 8]));
let tree: Arc<Tree<RcKey, u64>> = Arc::new(Tree::new());
let stop = Arc::new(AtomicBool::new(false));
for i in 0..KEYS {
tree.insert_defer(mk_key(i), i as u64);
}
let mut handles = Vec::new();
for _ in 0..4 {
let tree = Arc::clone(&tree);
let stop = Arc::clone(&stop);
handles.push(thread::spawn(move || {
while !stop.load(Ordering::Relaxed) {
for i in 0..KEYS {
let k = mk_key(i);
let _ = tree.lookup_optimistic(&k, |v| *v);
}
}
}));
}
let writer = {
let tree = Arc::clone(&tree);
let stop = Arc::clone(&stop);
thread::spawn(move || {
for round in 0..50u64 {
for i in 0..KEYS {
let k = mk_key(i);
if round % 2 == 0 {
tree.remove_defer(&k);
} else {
tree.insert_defer(k, round * 100 + i as u64);
}
}
}
stop.store(true, Ordering::Relaxed);
})
};
writer.join().unwrap();
for h in handles {
h.join().unwrap();
}
}