#![cfg(all(feature = "replication", feature = "var-collections"))]
use std::net::{SocketAddr, TcpListener};
use std::sync::Arc;
use std::time::{Duration, Instant};
use armdb::Config;
use armdb::ShutdownSignal;
use armdb::VarTree;
use armdb::replication::{ReplicationClientOptions, ReplicationRegistry, ReplicationServerOptions};
fn next_bind_addr() -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
addr
}
fn wait_until(timeout: Duration, mut f: impl FnMut() -> bool) -> bool {
let start = Instant::now();
loop {
if f() {
return true;
}
if start.elapsed() >= timeout {
return f(); }
std::thread::sleep(Duration::from_millis(20));
}
}
fn test_cfg() -> Config {
Config::test()
}
#[test]
fn catch_up_smoke() {
let leader_dir = tempfile::tempdir().unwrap();
let follower_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), test_cfg()).unwrap();
for i in 0u64..512 {
leader
.put(&i.to_be_bytes(), format!("v{i}").as_bytes())
.unwrap();
}
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server(addr, leader_signal.clone())
.unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), test_cfg()).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let follower_signal = ShutdownSignal::new();
let _client = follower
.start_replication_client(addr, registry, follower_signal.clone())
.unwrap();
assert!(
wait_until(Duration::from_secs(15), || {
(0u64..512).all(|i| follower.contains(&i.to_be_bytes()))
}),
"follower did not catch up within 15s (len={})",
follower.len()
);
follower_signal.shutdown();
leader_signal.shutdown();
}
#[test]
fn streaming_after_catchup() {
let leader_dir = tempfile::tempdir().unwrap();
let follower_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), test_cfg()).unwrap();
for i in 0u64..64 {
leader.put(&i.to_be_bytes(), b"initial").unwrap();
}
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server(addr, leader_signal.clone())
.unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), test_cfg()).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let follower_signal = ShutdownSignal::new();
let _client = follower
.start_replication_client(addr, registry, follower_signal.clone())
.unwrap();
assert!(
wait_until(Duration::from_secs(10), || { follower.len() >= 64 }),
"catch-up did not complete within 10s"
);
for i in 512u64..768 {
leader.put(&i.to_be_bytes(), b"streamed").unwrap();
}
assert!(
wait_until(Duration::from_secs(5), || {
(512u64..768).all(|i| follower.contains(&i.to_be_bytes()))
}),
"follower did not receive streamed entries within 5s (len={})",
follower.len()
);
follower_signal.shutdown();
leader_signal.shutdown();
}
#[test]
fn reconnect_no_loss_no_duplicates() {
let leader_dir = tempfile::tempdir().unwrap();
let follower_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), test_cfg()).unwrap();
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server_with_options(
addr,
leader_signal.clone(),
ReplicationServerOptions {
heartbeat_interval_secs: 1,
},
)
.unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), test_cfg()).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let signal1 = ShutdownSignal::new();
let client1 = follower
.start_replication_client(addr, registry.clone(), signal1.clone())
.unwrap();
for i in 0u64..100 {
leader.put(&i.to_be_bytes(), b"batch1").unwrap();
}
assert!(
wait_until(Duration::from_secs(10), || follower.len() >= 100),
"follower did not receive first 100 entries"
);
drop(client1);
signal1.shutdown();
std::thread::sleep(Duration::from_millis(50));
for i in 100u64..200 {
leader.put(&i.to_be_bytes(), b"batch2").unwrap();
}
let signal2 = ShutdownSignal::new();
let _client2 = follower
.start_replication_client_with_options(
addr,
registry,
signal2.clone(),
ReplicationClientOptions {
reconnect_base_ms: 50,
reconnect_max_ms: 200,
},
)
.unwrap();
assert!(
wait_until(Duration::from_secs(15), || follower.len() >= 200),
"follower did not receive all 200 entries after reconnect (len={})",
follower.len()
);
assert_eq!(
follower.len(),
200,
"expected 200 unique entries, got {}",
follower.len()
);
signal2.shutdown();
leader_signal.shutdown();
}
#[test]
fn rotation_during_streaming() {
let leader_dir = tempfile::tempdir().unwrap();
let follower_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let cfg = Config {
shard_count: 2,
max_file_size: 32 * 1024,
write_buffer_size: 16 * 1024,
..Config::default()
};
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), cfg.clone()).unwrap();
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server(addr, leader_signal.clone())
.unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), cfg).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let follower_signal = ShutdownSignal::new();
let _client = follower
.start_replication_client(addr, registry, follower_signal.clone())
.unwrap();
let value = b"rotation_test_value_32bytes_xxxxx";
for i in 0u64..1000 {
leader.put(&i.to_be_bytes(), value).unwrap();
}
assert!(
wait_until(Duration::from_secs(15), || { follower.len() >= 1000 }),
"follower did not catch up after rotations (len={})",
follower.len()
);
for i in [0u64, 499, 999] {
assert!(
follower.get(&i.to_be_bytes()).is_some(),
"key {i} missing from follower after rotation"
);
}
follower_signal.shutdown();
leader_signal.shutdown();
}
#[test]
fn multi_shard() {
let leader_dir = tempfile::tempdir().unwrap();
let follower_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let cfg = Config {
shard_count: 4,
..Config::default()
};
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), cfg.clone()).unwrap();
for i in 0u64..200 {
leader.put(&i.to_be_bytes(), b"multi").unwrap();
}
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server(addr, leader_signal.clone())
.unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), cfg).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let follower_signal = ShutdownSignal::new();
let _client = follower
.start_replication_client(addr, registry, follower_signal.clone())
.unwrap();
assert!(
wait_until(Duration::from_secs(15), || {
(0u64..200).all(|i| follower.contains(&i.to_be_bytes()))
}),
"multi-shard follower did not receive all entries (len={})",
follower.len()
);
follower_signal.shutdown();
leader_signal.shutdown();
}
#[cfg(feature = "encryption")]
#[test]
fn encrypted_catchup() {
let leader_dir = tempfile::tempdir().unwrap();
let follower_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let cfg = Config {
shard_count: 2,
write_buffer_size: 8192,
#[cfg(feature = "encryption")]
encryption_key: Some([0x42u8; 32]),
..Config::default()
};
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), cfg.clone()).unwrap();
for i in 0u64..100 {
leader.put(&i.to_be_bytes(), b"encrypted").unwrap();
}
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server(addr, leader_signal.clone())
.unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), cfg).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let follower_signal = ShutdownSignal::new();
let _client = follower
.start_replication_client(addr, registry, follower_signal.clone())
.unwrap();
assert!(
wait_until(Duration::from_secs(15), || {
(0u64..100).all(|i| follower.contains(&i.to_be_bytes()))
}),
"encrypted follower did not catch up (len={})",
follower.len()
);
follower_signal.shutdown();
leader_signal.shutdown();
}
#[test]
fn reject_second_follower() {
let leader_dir = tempfile::tempdir().unwrap();
let follower1_dir = tempfile::tempdir().unwrap();
let follower2_dir = tempfile::tempdir().unwrap();
let addr = next_bind_addr();
let leader = VarTree::<[u8; 8]>::open(leader_dir.path(), test_cfg()).unwrap();
let leader_signal = ShutdownSignal::new();
let _server = leader
.start_replication_server(addr, leader_signal.clone())
.unwrap();
let follower1 = Arc::new(VarTree::<[u8; 8]>::open(follower1_dir.path(), test_cfg()).unwrap());
let registry1 = Arc::new(ReplicationRegistry::new(follower1.as_replication_target()));
let signal1 = ShutdownSignal::new();
let _client1 = follower1
.start_replication_client(addr, registry1, signal1.clone())
.unwrap();
std::thread::sleep(Duration::from_millis(300));
let follower2 = Arc::new(VarTree::<[u8; 8]>::open(follower2_dir.path(), test_cfg()).unwrap());
let registry2 = Arc::new(ReplicationRegistry::new(follower2.as_replication_target()));
let signal2 = ShutdownSignal::new();
let _client2 = follower2
.start_replication_client(addr, registry2, signal2.clone())
.unwrap();
for i in 0u64..50 {
leader
.put(&i.to_be_bytes(), format!("v{i}").as_bytes())
.unwrap();
}
assert!(
wait_until(Duration::from_secs(10), || { follower1.len() >= 50 }),
"follower1 did not receive entries (len={})",
follower1.len()
);
std::thread::sleep(Duration::from_secs(3));
assert_eq!(
follower2.len(),
0,
"follower2 should be empty but has {} entries",
follower2.len()
);
signal1.shutdown();
signal2.shutdown();
leader_signal.shutdown();
}
#[test]
fn shard_id_validation() {
use armdb::replication::protocol::{
EntryBatch, ShardInfo, SyncRequest, WireEntry, read_frame, write_frame,
};
use std::io::BufReader;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let server_addr: SocketAddr = listener.local_addr().unwrap();
let rogue_handle = std::thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept");
let _ = stream.set_nodelay(true);
let mut reader = BufReader::new(stream.try_clone().unwrap());
let mut writer = stream;
let frame = read_frame(&mut reader).expect("read SyncRequest");
let req = SyncRequest::decode(&frame.payload).expect("decode SyncRequest");
let info = ShardInfo {
shard_count: 2,
max_file_size: 256 * 1024 * 1024,
};
write_frame(&mut writer, &info.encode()).expect("write ShardInfo");
let wrong_shard_id: u8 = if req.shard_id == 0 { 1 } else { 0 };
let bad_batch = EntryBatch {
shard_id: 99, entries: vec![WireEntry {
entry_len: 0,
key_len: 8,
gsn: 1,
data: vec![],
}],
};
let _ = write_frame(&mut writer, &bad_batch.encode());
let _ = wrong_shard_id;
});
let follower_dir = tempfile::tempdir().unwrap();
let follower = Arc::new(VarTree::<[u8; 8]>::open(follower_dir.path(), test_cfg()).unwrap());
let registry = Arc::new(ReplicationRegistry::new(follower.as_replication_target()));
let signal = ShutdownSignal::new();
let _client = follower
.start_replication_client(server_addr, registry, signal.clone())
.unwrap();
rogue_handle.join().expect("rogue server thread panicked");
std::thread::sleep(Duration::from_millis(500));
assert_eq!(
follower.len(),
0,
"follower must not apply entries from a mismatched shard_id"
);
signal.shutdown();
}