armdb 0.2.0

sharded bitcask key-value storage optimized for NVMe
Documentation
use std::io::{self, Read, Write};

/// Maximum allowed frame payload size. Guards `read_frame` against a
/// malformed length prefix that could trigger multi-GiB allocations.
pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; // 16 MiB cap; protects against malformed length prefix

/// Message types for the replication protocol.
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
    SyncRequest = 1,
    ShardInfo = 2,
    EntryBatch = 3,
    CaughtUp = 4,
    Ack = 5,
    Heartbeat = 6,
    Error = 255,
}

impl MessageType {
    fn from_u8(v: u8) -> Option<Self> {
        match v {
            1 => Some(Self::SyncRequest),
            2 => Some(Self::ShardInfo),
            3 => Some(Self::EntryBatch),
            4 => Some(Self::CaughtUp),
            5 => Some(Self::Ack),
            6 => Some(Self::Heartbeat),
            255 => Some(Self::Error),
            _ => None,
        }
    }
}

/// Wire frame: [type:u8][len:u32 LE][payload:len bytes]
#[derive(Debug)]
pub struct Frame {
    pub msg_type: MessageType,
    pub payload: Vec<u8>,
}

pub fn write_frame(w: &mut impl Write, frame: &Frame) -> io::Result<()> {
    w.write_all(&[frame.msg_type as u8])?;
    w.write_all(&(frame.payload.len() as u32).to_le_bytes())?;
    w.write_all(&frame.payload)?;
    w.flush()
}

pub fn read_frame(r: &mut impl Read) -> io::Result<Frame> {
    let mut type_buf = [0u8; 1];
    r.read_exact(&mut type_buf)?;
    let msg_type = MessageType::from_u8(type_buf[0]).ok_or_else(|| {
        io::Error::new(
            io::ErrorKind::InvalidData,
            format!("unknown message type: {}", type_buf[0]),
        )
    })?;
    let mut len_buf = [0u8; 4];
    r.read_exact(&mut len_buf)?;
    let len = u32::from_le_bytes(len_buf) as usize;
    if len > MAX_FRAME_SIZE {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "frame too large",
        ));
    }
    let mut payload = vec![0u8; len];
    if len > 0 {
        r.read_exact(&mut payload)?;
    }
    Ok(Frame { msg_type, payload })
}

// --- SyncRequest: shard_id:u8, from_gsn:u64, key_len:u16 ---
// Payload layout: shard_id (1) + from_gsn (8 LE) + key_len (2 LE) = 11 bytes.
// One Engine = one Tree, so a single key_len suffices (C7/C20 cleanup).

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SyncRequest {
    pub shard_id: u8,
    pub from_gsn: u64,
    pub key_len: u16,
}

impl SyncRequest {
    pub fn encode(&self) -> Frame {
        let mut payload = Vec::with_capacity(11);
        payload.push(self.shard_id);
        payload.extend_from_slice(&self.from_gsn.to_le_bytes());
        payload.extend_from_slice(&self.key_len.to_le_bytes());
        Frame {
            msg_type: MessageType::SyncRequest,
            payload,
        }
    }

    pub fn decode(payload: &[u8]) -> io::Result<Self> {
        if payload.len() < 11 {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "SyncRequest too short",
            ));
        }
        let shard_id = payload[0];
        let from_gsn = u64::from_le_bytes(payload[1..9].try_into().expect("impossible"));
        let key_len = u16::from_le_bytes(payload[9..11].try_into().expect("impossible"));
        Ok(Self {
            shard_id,
            from_gsn,
            key_len,
        })
    }
}

// --- ShardInfo: shard_count:u8, max_file_size:u64 ---

pub struct ShardInfo {
    pub shard_count: u8,
    pub max_file_size: u64,
}

impl ShardInfo {
    pub fn encode(&self) -> Frame {
        let mut payload = Vec::with_capacity(9);
        payload.push(self.shard_count);
        payload.extend_from_slice(&self.max_file_size.to_le_bytes());
        Frame {
            msg_type: MessageType::ShardInfo,
            payload,
        }
    }

    pub fn decode(payload: &[u8]) -> io::Result<Self> {
        if payload.len() < 9 {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "ShardInfo too short",
            ));
        }
        Ok(Self {
            shard_count: payload[0],
            max_file_size: u64::from_le_bytes(payload[1..9].try_into().expect("impossible")),
        })
    }
}

// --- EntryBatch: shard_id:u8, count:u32, entries:[entry_len:u32 + key_len:u16 + gsn:u64 + data] ---

pub struct WireEntry {
    pub entry_len: u32,
    pub key_len: u16,
    pub gsn: u64,
    pub data: Vec<u8>,
}

pub struct EntryBatch {
    pub shard_id: u8,
    pub entries: Vec<WireEntry>,
}

impl EntryBatch {
    pub fn encode(&self) -> Frame {
        let mut payload = Vec::with_capacity(5 + self.entries.len() * 64);
        payload.push(self.shard_id);
        payload.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
        for e in &self.entries {
            payload.extend_from_slice(&e.entry_len.to_le_bytes());
            payload.extend_from_slice(&e.key_len.to_le_bytes());
            payload.extend_from_slice(&e.gsn.to_le_bytes());
            payload.extend_from_slice(&e.data);
        }
        Frame {
            msg_type: MessageType::EntryBatch,
            payload,
        }
    }

    pub fn decode(payload: &[u8]) -> io::Result<Self> {
        if payload.len() < 5 {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "EntryBatch too short",
            ));
        }
        let shard_id = payload[0];
        let count = u32::from_le_bytes(payload[1..5].try_into().expect("impossible")) as usize;
        let mut entries = Vec::with_capacity(count);
        let mut off = 5;
        for _ in 0..count {
            if off + 14 > payload.len() {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    "EntryBatch truncated",
                ));
            }
            let entry_len =
                u32::from_le_bytes(payload[off..off + 4].try_into().expect("impossible"));
            let key_len =
                u16::from_le_bytes(payload[off + 4..off + 6].try_into().expect("impossible"));
            let gsn =
                u64::from_le_bytes(payload[off + 6..off + 14].try_into().expect("impossible"));
            off += 14;
            if off + entry_len as usize > payload.len() {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    "EntryBatch data truncated",
                ));
            }
            let data = payload[off..off + entry_len as usize].to_vec();
            off += entry_len as usize;
            entries.push(WireEntry {
                entry_len,
                key_len,
                gsn,
                data,
            });
        }
        Ok(Self { shard_id, entries })
    }
}

// --- CaughtUp: shard_id:u8, leader_gsn:u64 ---

pub struct CaughtUp {
    pub shard_id: u8,
    pub leader_gsn: u64,
}

impl CaughtUp {
    pub fn encode(&self) -> Frame {
        let mut payload = Vec::with_capacity(9);
        payload.push(self.shard_id);
        payload.extend_from_slice(&self.leader_gsn.to_le_bytes());
        Frame {
            msg_type: MessageType::CaughtUp,
            payload,
        }
    }

    pub fn decode(payload: &[u8]) -> io::Result<Self> {
        if payload.len() < 9 {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                "CaughtUp too short",
            ));
        }
        Ok(Self {
            shard_id: payload[0],
            leader_gsn: u64::from_le_bytes(payload[1..9].try_into().expect("impossible")),
        })
    }
}

// --- Ack: shard_id:u8, last_gsn:u64 ---

pub struct AckMessage {
    pub shard_id: u8,
    pub last_gsn: u64,
}

impl AckMessage {
    pub fn encode(&self) -> Frame {
        let mut payload = Vec::with_capacity(9);
        payload.push(self.shard_id);
        payload.extend_from_slice(&self.last_gsn.to_le_bytes());
        Frame {
            msg_type: MessageType::Ack,
            payload,
        }
    }

    pub fn decode(payload: &[u8]) -> io::Result<Self> {
        if payload.len() < 9 {
            return Err(io::Error::new(io::ErrorKind::InvalidData, "Ack too short"));
        }
        Ok(Self {
            shard_id: payload[0],
            last_gsn: u64::from_le_bytes(payload[1..9].try_into().expect("impossible")),
        })
    }
}

// --- Heartbeat ---

pub fn encode_heartbeat() -> Frame {
    Frame {
        msg_type: MessageType::Heartbeat,
        payload: Vec::new(),
    }
}

// --- Error ---

pub fn encode_error(msg: &str) -> Frame {
    Frame {
        msg_type: MessageType::Error,
        payload: msg.as_bytes().to_vec(),
    }
}

pub fn decode_error(payload: &[u8]) -> String {
    String::from_utf8_lossy(payload).into_owned()
}

#[cfg(test)]
mod max_frame_size_tests {
    use super::*;
    use std::io::Cursor;

    #[test]
    fn read_frame_rejects_oversized_length() {
        // Manually construct a frame header: type byte + 4-byte LE length too big
        let mut bytes = Vec::new();
        bytes.push(MessageType::SyncRequest as u8);
        bytes.extend_from_slice(&((MAX_FRAME_SIZE + 1) as u32).to_le_bytes());
        let mut cursor = Cursor::new(bytes);
        let err = read_frame(&mut cursor).unwrap_err();
        let msg = format!("{err}");
        assert!(msg.contains("frame too large"), "got: {msg}");
    }
}

#[cfg(test)]
mod sync_request_tests {
    use super::*;

    #[test]
    fn encode_decode_round_trip() {
        let req = SyncRequest {
            shard_id: 7,
            from_gsn: 0xDEAD_BEEF_CAFE_BABE,
            key_len: 32,
        };
        let buf = req.encode();
        let parsed = SyncRequest::decode(&buf.payload).expect("decode ok");
        assert_eq!(parsed, req);
    }
}