armdb 0.1.12

sharded bitcask key-value storage optimized for NVMe
Documentation
use std::io::BufWriter;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;

use crate::error::DbResult;
use crate::shard::Shard;
use crate::shutdown::ShutdownSignal;

use super::ReplicationEntry;
use super::log_reader::ShardLogReader;
use super::protocol::*;

const BATCH_MAX_ENTRIES: usize = 256;
const BATCH_MAX_BYTES: usize = 64 * 1024;
const TAIL_POLL_MS: u64 = 1;

type HandlerHandles = Arc<crate::sync::Mutex<Vec<JoinHandle<()>>>>;

/// Replication server running on the leader node.
/// Accepts follower connections and streams entries per-shard.
pub struct ReplicationServer {
    stop: ShutdownSignal,
    acceptor_handle: Option<JoinHandle<()>>,
    handler_handles: HandlerHandles,
    /// Per-shard minimum replicated GSN across all followers.
    pub min_replicated_gsn: Arc<Vec<AtomicU64>>,
}

impl ReplicationServer {
    /// Start the replication server.
    ///
    /// `consumers`: one SPSC consumer per shard (taken from the ring buffers
    /// installed via `Shard::set_replication_producer`).
    ///
    /// `key_lens`: all known key lengths across all trees.
    pub fn start(
        bind_addr: SocketAddr,
        shards: Arc<Vec<Shard>>,
        consumers: Vec<rtrb::Consumer<ReplicationEntry>>,
        key_lens: Vec<usize>,
        max_file_size: u64,
        signal: ShutdownSignal,
    ) -> DbResult<Self> {
        let shard_count = shards.len();

        let min_replicated_gsn: Arc<Vec<AtomicU64>> =
            Arc::new((0..shard_count).map(|_| AtomicU64::new(0)).collect());

        // Wrap consumers in Arc<Mutex> so the acceptor thread can hand them out
        let consumers: Arc<Vec<crate::sync::Mutex<Option<rtrb::Consumer<ReplicationEntry>>>>> =
            Arc::new(
                consumers
                    .into_iter()
                    .map(|c| crate::sync::Mutex::new(Some(c)))
                    .collect(),
            );

        let listener = TcpListener::bind(bind_addr)?;
        listener.set_nonblocking(true)?;

        let stop2 = signal.clone();
        let shards2 = shards.clone();
        let key_lens2 = key_lens.clone();
        let min_gsn2 = min_replicated_gsn.clone();
        let consumers2 = consumers.clone();
        let handler_handles: HandlerHandles = Arc::new(crate::sync::Mutex::new(Vec::new()));
        let hh2 = handler_handles.clone();

        let acceptor = thread::spawn(move || {
            tracing::info!(%bind_addr, "replication server started");
            while !stop2.is_shutdown() {
                match listener.accept() {
                    Ok((stream, addr)) => {
                        tracing::info!(%addr, "follower connected");
                        let _ = stream.set_nodelay(true);
                        handle_connection(
                            stream,
                            &shards2,
                            &consumers2,
                            &key_lens2,
                            max_file_size,
                            &stop2,
                            &min_gsn2,
                            &hh2,
                        );
                    }
                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                        stop2.wait_timeout(Duration::from_millis(50));
                    }
                    Err(e) => {
                        tracing::error!(error = %e, "accept error");
                        stop2.wait_timeout(Duration::from_millis(100));
                    }
                }
            }
            tracing::info!("replication server stopped");
        });

        Ok(Self {
            stop: signal,
            acceptor_handle: Some(acceptor),
            handler_handles,
            min_replicated_gsn,
        })
    }

    pub fn stop(&self) {
        self.stop.shutdown();
    }
}

impl crate::compaction::CompactionGuard for ReplicationServer {
    fn min_replicated_gsn(&self, shard_id: u8) -> u64 {
        self.min_replicated_gsn
            .get(shard_id as usize)
            .map(|v| v.load(Ordering::Relaxed))
            .unwrap_or(u64::MAX)
    }
}

impl Drop for ReplicationServer {
    fn drop(&mut self) {
        self.stop.shutdown();
        if let Some(h) = self.acceptor_handle.take() {
            let _ = h.join();
        }
        let mut handles = crate::sync::lock(&self.handler_handles);
        for h in handles.drain(..) {
            let _ = h.join();
        }
    }
}

#[allow(clippy::too_many_arguments)]
fn handle_connection(
    stream: TcpStream,
    shards: &Arc<Vec<Shard>>,
    consumers: &Arc<Vec<crate::sync::Mutex<Option<rtrb::Consumer<ReplicationEntry>>>>>,
    key_lens: &[usize],
    max_file_size: u64,
    stop: &ShutdownSignal,
    min_gsn: &Arc<Vec<AtomicU64>>,
    handler_handles: &HandlerHandles,
) {
    // Read initial SyncRequest to determine which shard
    let mut reader = stream.try_clone().expect("clone tcp stream");
    let frame = match read_frame(&mut reader) {
        Ok(f) => f,
        Err(e) => {
            tracing::error!(error = %e, "failed to read SyncRequest");
            return;
        }
    };

    if frame.msg_type != MessageType::SyncRequest {
        tracing::error!("expected SyncRequest, got {:?}", frame.msg_type);
        return;
    }

    let req = match SyncRequest::decode(&frame.payload) {
        Ok(r) => r,
        Err(e) => {
            tracing::error!(error = %e, "invalid SyncRequest");
            return;
        }
    };

    let shard_id = req.shard_id as usize;
    if shard_id >= shards.len() {
        tracing::error!(shard_id, "invalid shard_id");
        return;
    }

    // Send ShardInfo
    let info = ShardInfo {
        shard_count: shards.len() as u8,
        max_file_size,
    };
    let mut writer = BufWriter::new(stream);
    if write_frame(&mut writer, &info.encode()).is_err() {
        return;
    }

    // Try to take SPSC consumer for this shard (if available)
    let consumer = {
        let mut guard = crate::sync::lock(&consumers[shard_id]);
        guard.take()
    };

    let shards = shards.clone();
    let key_lens = key_lens.to_vec();
    let stop = stop.clone();
    let min_gsn = min_gsn.clone();

    // Spawn a dedicated thread for this (follower, shard) stream
    let handle = thread::spawn(move || {
        if let Err(e) = serve_shard(
            &mut reader,
            &mut writer,
            &shards,
            shard_id,
            req.from_gsn,
            consumer,
            &key_lens,
            stop.as_flag(),
            &min_gsn,
        ) {
            tracing::error!(shard_id, error = %e, "shard stream error");
        }
    });
    crate::sync::lock(handler_handles).push(handle);
}

#[allow(clippy::too_many_arguments)]
fn serve_shard(
    reader: &mut TcpStream,
    writer: &mut BufWriter<TcpStream>,
    shards: &[Shard],
    shard_id: usize,
    from_gsn: u64,
    consumer: Option<rtrb::Consumer<ReplicationEntry>>,
    key_lens: &[usize],
    stop: &AtomicBool,
    min_gsn: &Arc<Vec<AtomicU64>>,
) -> DbResult<()> {
    let shard = &shards[shard_id];
    let shard_dir = shard.dir().to_path_buf();
    let current_gsn = shard.gsn().load(Ordering::Relaxed);

    // Phase 1: Catch-up via ShardLogReader
    if from_gsn < current_gsn {
        tracing::info!(shard_id, from_gsn, current_gsn, "starting catch-up");

        // Flush write buffer so ShardLogReader can see all entries
        shard.flush_buf()?;

        let mut log_reader = ShardLogReader::new(shard_dir, from_gsn, key_lens.to_vec())?;
        let mut last_gsn = from_gsn;

        loop {
            if stop.load(Ordering::Relaxed) {
                return Ok(());
            }

            let mut batch = Vec::new();
            let mut batch_bytes = 0;

            while batch.len() < BATCH_MAX_ENTRIES && batch_bytes < BATCH_MAX_BYTES {
                match log_reader.next_entry()? {
                    Some(entry) => {
                        last_gsn = entry.gsn;
                        batch_bytes += entry.data.len();
                        batch.push(WireEntry {
                            entry_len: entry.data.len() as u32,
                            key_len: entry.key_len,
                            gsn: entry.gsn,
                            data: entry.data,
                        });
                    }
                    None => break,
                }
            }

            if batch.is_empty() {
                break; // Caught up
            }

            let msg = EntryBatch {
                shard_id: shard_id as u8,
                entries: batch,
            };
            write_frame(writer, &msg.encode())?;

            // Read Ack
            let ack_frame = read_frame(reader)?;
            if ack_frame.msg_type == MessageType::Ack {
                let ack = AckMessage::decode(&ack_frame.payload)?;
                min_gsn[shard_id].fetch_max(ack.last_gsn, Ordering::Relaxed);
            }
        }

        // Send CaughtUp
        let caught_up = CaughtUp {
            shard_id: shard_id as u8,
            leader_gsn: last_gsn,
        };
        write_frame(writer, &caught_up.encode())?;

        tracing::info!(shard_id, last_gsn, "catch-up complete");
    }

    // Phase 2: Streaming via SPSC (if consumer available)
    if let Some(mut consumer) = consumer {
        tracing::info!(shard_id, "entering streaming mode");

        loop {
            if stop.load(Ordering::Relaxed) {
                return Ok(());
            }

            let mut batch = Vec::new();
            let mut batch_bytes = 0;

            while batch.len() < BATCH_MAX_ENTRIES && batch_bytes < BATCH_MAX_BYTES {
                match consumer.pop() {
                    Ok(entry) => {
                        batch_bytes += entry.data.len();
                        let gsn = extract_gsn(&entry.data);
                        batch.push(WireEntry {
                            entry_len: entry.data.len() as u32,
                            key_len: entry.key_len,
                            gsn,
                            data: entry.data,
                        });
                    }
                    Err(_) => break, // Empty
                }
            }

            if batch.is_empty() {
                thread::sleep(Duration::from_millis(TAIL_POLL_MS));
                continue;
            }

            let msg = EntryBatch {
                shard_id: shard_id as u8,
                entries: batch,
            };
            write_frame(writer, &msg.encode())?;

            // Non-blocking ack check
            let _ = reader.set_nonblocking(true);
            if let Ok(ack_frame) = read_frame(reader)
                && ack_frame.msg_type == MessageType::Ack
                && let Ok(ack) = AckMessage::decode(&ack_frame.payload)
            {
                min_gsn[shard_id].fetch_max(ack.last_gsn, Ordering::Relaxed);
            }
            let _ = reader.set_nonblocking(false);
        }
    }

    Ok(())
}

/// Extract GSN (sequence only, no tombstone bit) from raw entry bytes.
fn extract_gsn(data: &[u8]) -> u64 {
    if data.len() < 8 {
        return 0;
    }
    let gsn = u64::from_ne_bytes(data[..8].try_into().expect("impossible"));
    gsn & !crate::entry::TOMBSTONE_BIT
}