armdb 0.2.0

sharded bitcask key-value storage optimized for NVMe
Documentation
use std::io::BufReader;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;

use crate::error::{DbError, DbResult};
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;

use super::cursor::ReplicationCursor;
use super::protocol::*;
use super::{ReplicationEntry, ReplicationRegistry};

const ACK_INTERVAL_CATCHUP: usize = 1; // Ack per batch during catch-up (C1: unblocks pipelined server)
const ACK_INTERVAL: usize = 1000; // Streaming: batched acks
const RECONNECT_BASE_MS: u64 = 1000;
const RECONNECT_MAX_MS: u64 = 30_000;
const CURSOR_SAVE_INTERVAL: usize = 1000;
const CURSOR_FILENAME: &str = "repl.cursor";

/// Replication client running on a follower node.
/// Connects to the leader and receives entries per-shard.
pub struct ReplicationClient {
    stop: ShutdownSignal,
    handles: Vec<JoinHandle<()>>,
}

/// Options for tuning the replication client (e.g. in tests).
pub struct ReplicationClientOptions {
    pub reconnect_base_ms: u64,
    pub reconnect_max_ms: u64,
}

impl Default for ReplicationClientOptions {
    fn default() -> Self {
        Self {
            reconnect_base_ms: RECONNECT_BASE_MS,
            reconnect_max_ms: RECONNECT_MAX_MS,
        }
    }
}

impl ReplicationClient {
    /// Start the replication client.
    ///
    /// Spawns one thread per shard to connect to the leader and receive entries.
    pub fn start(
        leader_addr: SocketAddr,
        shards: Arc<Vec<Shard>>,
        registry: Arc<ReplicationRegistry>,
        key_len: u16,
        signal: ShutdownSignal,
    ) -> DbResult<Self> {
        Self::start_with_options(
            leader_addr,
            shards,
            registry,
            key_len,
            signal,
            ReplicationClientOptions::default(),
        )
    }

    pub fn start_with_options(
        leader_addr: SocketAddr,
        shards: Arc<Vec<Shard>>,
        registry: Arc<ReplicationRegistry>,
        key_len: u16,
        signal: ShutdownSignal,
        options: ReplicationClientOptions,
    ) -> DbResult<Self> {
        let mut handles = Vec::with_capacity(shards.len());
        let reconnect_base_ms = options.reconnect_base_ms;
        let reconnect_max_ms = options.reconnect_max_ms;

        for shard_id in 0..shards.len() {
            let shards = shards.clone();
            let registry = registry.clone();
            let stop = signal.clone();

            let handle = thread::spawn(move || {
                run_shard_client(
                    leader_addr,
                    &shards,
                    shard_id,
                    &registry,
                    key_len,
                    &stop,
                    reconnect_base_ms,
                    reconnect_max_ms,
                );
            });
            handles.push(handle);
        }

        Ok(Self {
            stop: signal,
            handles,
        })
    }

    pub fn stop(&self) {
        self.stop.shutdown();
    }
}

impl Drop for ReplicationClient {
    fn drop(&mut self) {
        self.stop.shutdown();
        for h in self.handles.drain(..) {
            let _ = h.join();
        }
    }
}

#[allow(clippy::too_many_arguments)]
fn run_shard_client(
    leader_addr: SocketAddr,
    shards: &[Shard],
    shard_id: usize,
    registry: &ReplicationRegistry,
    key_len: u16,
    stop: &ShutdownSignal,
    reconnect_base_ms: u64,
    reconnect_max_ms: u64,
) {
    let shard = &shards[shard_id];
    let shard_dir = shard.dir().to_path_buf();
    let mut backoff_ms = reconnect_base_ms;

    loop {
        if stop.is_shutdown() {
            return;
        }

        // Load cursor
        let cursor_path = shard_dir.join(CURSOR_FILENAME);
        let cursor = ReplicationCursor::load(&cursor_path)
            .ok()
            .flatten()
            .unwrap_or_default();

        match connect_and_stream(
            leader_addr,
            shards,
            shard_id,
            &cursor,
            &cursor_path,
            registry,
            key_len,
            stop.as_flag(),
        ) {
            Ok(()) => {
                tracing::info!(shard_id, "replication stream ended cleanly");
                return;
            }
            Err(e) => {
                tracing::error!(shard_id, error = %e, backoff_ms, "replication error, reconnecting");
                if stop.wait_timeout(Duration::from_millis(backoff_ms)) {
                    return;
                }
                backoff_ms = (backoff_ms * 2).min(reconnect_max_ms);
            }
        }
    }
}

fn send_ack(writer: &mut TcpStream, shard_id: u8, last_gsn: u64) -> DbResult<()> {
    let ack = AckMessage { shard_id, last_gsn };
    write_frame(writer, &ack.encode())?;
    Ok(())
}

#[allow(clippy::too_many_arguments)]
fn connect_and_stream(
    leader_addr: SocketAddr,
    shards: &[Shard],
    shard_id: usize,
    cursor: &ReplicationCursor,
    cursor_path: &std::path::Path,
    registry: &ReplicationRegistry,
    key_len: u16,
    stop: &std::sync::atomic::AtomicBool,
) -> DbResult<()> {
    let stream = TcpStream::connect_timeout(&leader_addr, Duration::from_secs(5))?;
    let _ = stream.set_nodelay(true);

    let mut writer = stream.try_clone()?;
    let mut reader = BufReader::new(stream);

    // Send SyncRequest — ask for the entry *after* last_gsn so the leader
    // does not resend the already-applied entry (C12).
    let req = SyncRequest {
        shard_id: shard_id as u8,
        from_gsn: cursor.last_gsn.saturating_add(1),
        key_len,
    };
    write_frame(&mut writer, &req.encode())?;

    // Receive ShardInfo
    let info_frame = read_frame(&mut reader)?;
    if info_frame.msg_type != MessageType::ShardInfo {
        return Err(DbError::Replication(format!(
            "expected ShardInfo, got {:?}",
            info_frame.msg_type
        )));
    }
    let info = ShardInfo::decode(&info_frame.payload)?;
    if info.shard_count as usize != shards.len() {
        return Err(DbError::ShardCountMismatch {
            leader: info.shard_count as usize,
            follower: shards.len(),
        });
    }

    tracing::info!(shard_id, from_gsn = cursor.last_gsn, "connected to leader");

    let shard = &shards[shard_id];
    // Atomic tracks the highest GSN successfully applied. Initialized from the
    // saved cursor and advanced by apply_streaming / apply_catchup (C12).
    let last_applied = AtomicU64::new(cursor.last_gsn);
    let mut entries_since_ack = 0usize;
    let mut entries_since_save = 0usize;
    // Track catch-up phase: true until the server sends CaughtUp (C1).
    let mut in_catchup = true;

    loop {
        if stop.load(Ordering::Relaxed) {
            if entries_since_save > 0 {
                let gsn = last_applied.load(Ordering::Relaxed);
                let _ = ReplicationCursor { last_gsn: gsn }.save(cursor_path);
            }
            return Ok(());
        }

        let frame = read_frame(&mut reader)?;

        match frame.msg_type {
            MessageType::EntryBatch => {
                let batch = EntryBatch::decode(&frame.payload)?;
                // C15: reject batches whose shard_id doesn't match this stream.
                if batch.shard_id != shard_id as u8 {
                    return Err(DbError::Replication("EntryBatch shard_id mismatch".into()));
                }

                for wire_entry in &batch.entries {
                    let repl_entry = ReplicationEntry {
                        data: wire_entry.data.clone(),
                        key_len: wire_entry.key_len,
                    };
                    registry.apply_streaming(shard, &repl_entry, &last_applied)?;

                    entries_since_ack += 1;
                    entries_since_save += 1;
                }

                // Send Ack at phase-appropriate interval.
                // During catch-up: ack every batch (ACK_INTERVAL_CATCHUP=1) so the
                // server's catch-up loop is never stalled waiting for an ack (C1).
                // During streaming: ack every ACK_INTERVAL entries (batched).
                let ack_threshold = if in_catchup {
                    ACK_INTERVAL_CATCHUP
                } else {
                    ACK_INTERVAL
                };
                if entries_since_ack >= ack_threshold {
                    send_ack(
                        &mut writer,
                        shard_id as u8,
                        last_applied.load(Ordering::Relaxed),
                    )?;
                    entries_since_ack = 0;
                }

                // Save cursor periodically
                if entries_since_save >= CURSOR_SAVE_INTERVAL {
                    let last_gsn = last_applied.load(Ordering::Relaxed);
                    ReplicationCursor { last_gsn }.save(cursor_path)?;
                    entries_since_save = 0;
                }
            }
            MessageType::CaughtUp => {
                let caught_up = CaughtUp::decode(&frame.payload)?;
                tracing::info!(shard_id, leader_gsn = caught_up.leader_gsn, "caught up");
                in_catchup = false;
                // Continue — streaming will follow
            }
            MessageType::Heartbeat => {
                // Respond with heartbeat
                write_frame(&mut writer, &encode_heartbeat())?;
            }
            MessageType::Error => {
                let msg = decode_error(&frame.payload);
                return Err(DbError::Replication(msg));
            }
            other => {
                return Err(DbError::Replication(format!(
                    "unexpected message: {other:?}"
                )));
            }
        }
    }
}