armdb 0.1.13

sharded bitcask key-value storage optimized for NVMe
Documentation
use std::io::BufReader;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::thread::{self, JoinHandle};
use std::time::Duration;

use crate::error::{DbError, DbResult};
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;

use super::cursor::ReplicationCursor;
use super::protocol::*;
use super::{ReplicationEntry, ReplicationRegistry};

const ACK_INTERVAL: usize = 1000;
const RECONNECT_BASE_MS: u64 = 1000;
const RECONNECT_MAX_MS: u64 = 30_000;
const CURSOR_SAVE_INTERVAL: usize = 1000;

/// Replication client running on a follower node.
/// Connects to the leader and receives entries per-shard.
pub struct ReplicationClient {
    stop: ShutdownSignal,
    handles: Vec<JoinHandle<()>>,
}

impl ReplicationClient {
    /// Start the replication client.
    ///
    /// Spawns one thread per shard to connect to the leader and receive entries.
    pub fn start(
        leader_addr: SocketAddr,
        shards: Arc<Vec<Shard>>,
        registry: Arc<ReplicationRegistry>,
        key_lens: Vec<u16>,
        signal: ShutdownSignal,
    ) -> DbResult<Self> {
        let mut handles = Vec::with_capacity(shards.len());

        for shard_id in 0..shards.len() {
            let shards = shards.clone();
            let registry = registry.clone();
            let stop = signal.clone();
            let key_lens = key_lens.clone();

            let handle = thread::spawn(move || {
                run_shard_client(leader_addr, &shards, shard_id, &registry, &key_lens, &stop);
            });
            handles.push(handle);
        }

        Ok(Self {
            stop: signal,
            handles,
        })
    }

    pub fn stop(&self) {
        self.stop.shutdown();
    }
}

impl Drop for ReplicationClient {
    fn drop(&mut self) {
        self.stop.shutdown();
        for h in self.handles.drain(..) {
            let _ = h.join();
        }
    }
}

fn run_shard_client(
    leader_addr: SocketAddr,
    shards: &[Shard],
    shard_id: usize,
    registry: &ReplicationRegistry,
    key_lens: &[u16],
    stop: &ShutdownSignal,
) {
    let shard = &shards[shard_id];
    let shard_dir = shard.dir().to_path_buf();
    let mut backoff_ms = RECONNECT_BASE_MS;

    loop {
        if stop.is_shutdown() {
            return;
        }

        // Load cursor
        let cursor = ReplicationCursor::load(shard_id as u8, &shard_dir)
            .ok()
            .flatten()
            .unwrap_or_else(|| ReplicationCursor::new(shard_id as u8));

        match connect_and_stream(
            leader_addr,
            shards,
            shard_id,
            &cursor,
            registry,
            key_lens,
            stop.as_flag(),
        ) {
            Ok(()) => {
                tracing::info!(shard_id, "replication stream ended cleanly");
                return;
            }
            Err(e) => {
                tracing::error!(shard_id, error = %e, backoff_ms, "replication error, reconnecting");
                if stop.wait_timeout(Duration::from_millis(backoff_ms)) {
                    return;
                }
                backoff_ms = (backoff_ms * 2).min(RECONNECT_MAX_MS);
            }
        }
    }
}

fn connect_and_stream(
    leader_addr: SocketAddr,
    shards: &[Shard],
    shard_id: usize,
    cursor: &ReplicationCursor,
    registry: &ReplicationRegistry,
    key_lens: &[u16],
    stop: &std::sync::atomic::AtomicBool,
) -> DbResult<()> {
    let stream = TcpStream::connect_timeout(&leader_addr, Duration::from_secs(5))?;
    let _ = stream.set_nodelay(true);

    let mut writer = stream.try_clone()?;
    let mut reader = BufReader::new(stream);

    // Send SyncRequest
    let req = SyncRequest {
        shard_id: shard_id as u8,
        from_gsn: cursor.last_gsn,
        key_lens: key_lens.to_vec(),
    };
    write_frame(&mut writer, &req.encode())?;

    // Receive ShardInfo
    let info_frame = read_frame(&mut reader)?;
    if info_frame.msg_type != MessageType::ShardInfo {
        return Err(DbError::Replication(format!(
            "expected ShardInfo, got {:?}",
            info_frame.msg_type
        )));
    }
    let info = ShardInfo::decode(&info_frame.payload)?;
    if info.shard_count as usize != shards.len() {
        return Err(DbError::ShardCountMismatch {
            leader: info.shard_count as usize,
            follower: shards.len(),
        });
    }

    tracing::info!(shard_id, from_gsn = cursor.last_gsn, "connected to leader");

    let shard = &shards[shard_id];
    let shard_dir = shard.dir().to_path_buf();
    let mut last_gsn = cursor.last_gsn;
    let mut entries_since_ack = 0usize;
    let mut entries_since_save = 0usize;

    loop {
        if stop.load(Ordering::Relaxed) {
            if entries_since_save > 0 {
                let mut cursor = ReplicationCursor::new(shard_id as u8);
                cursor.advance(last_gsn, 0, 0);
                let _ = cursor.save(&shard_dir);
            }
            return Ok(());
        }

        let frame = read_frame(&mut reader)?;

        match frame.msg_type {
            MessageType::EntryBatch => {
                let batch = EntryBatch::decode(&frame.payload)?;

                for wire_entry in &batch.entries {
                    let repl_entry = ReplicationEntry {
                        data: wire_entry.data.clone(),
                        key_len: wire_entry.key_len,
                    };
                    registry.apply_streaming(shard, &repl_entry)?;

                    if wire_entry.gsn > last_gsn {
                        last_gsn = wire_entry.gsn;
                    }
                    entries_since_ack += 1;
                    entries_since_save += 1;
                }

                // Send Ack periodically
                if entries_since_ack >= ACK_INTERVAL {
                    let ack = AckMessage {
                        shard_id: shard_id as u8,
                        last_gsn,
                    };
                    write_frame(&mut writer, &ack.encode())?;
                    entries_since_ack = 0;
                }

                // Save cursor periodically
                if entries_since_save >= CURSOR_SAVE_INTERVAL {
                    let mut cursor = ReplicationCursor::new(shard_id as u8);
                    cursor.advance(last_gsn, 0, 0);
                    cursor.save(&shard_dir)?;
                    entries_since_save = 0;
                }
            }
            MessageType::CaughtUp => {
                let caught_up = CaughtUp::decode(&frame.payload)?;
                tracing::info!(shard_id, leader_gsn = caught_up.leader_gsn, "caught up");
                // Continue — streaming will follow
            }
            MessageType::Heartbeat => {
                // Respond with heartbeat
                write_frame(&mut writer, &encode_heartbeat())?;
            }
            MessageType::Error => {
                let msg = decode_error(&frame.payload);
                return Err(DbError::Replication(msg));
            }
            other => {
                return Err(DbError::Replication(format!(
                    "unexpected message: {other:?}"
                )));
            }
        }
    }
}