use std::io::{self, Read, Write};
use crate::fixed::slot::{STATUS_OCCUPIED, status_of};
pub const PROTOCOL_VERSION: u8 = 1;
pub const BATCH_MAX_ENTRIES: usize = 256;
pub const BATCH_MAX_BYTES: usize = 64 * 1024;
pub const TAIL_POLL_MS: u64 = 1;
pub const HEARTBEAT_INTERVAL_SECS: u64 = 5;
pub const ACK_INTERVAL: usize = 1000;
pub const FLAG_EMPTY_STATE: u8 = 1 << 0;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FixedMessageType {
SyncRequest = 1,
ShardInfo = 2,
SlotBatch = 3,
CaughtUp = 4,
Ack = 5,
Heartbeat = 6,
Error = 255,
}
impl FixedMessageType {
pub fn from_u8(v: u8) -> Option<Self> {
Some(match v {
1 => Self::SyncRequest,
2 => Self::ShardInfo,
3 => Self::SlotBatch,
4 => Self::CaughtUp,
5 => Self::Ack,
6 => Self::Heartbeat,
255 => Self::Error,
_ => return None,
})
}
}
#[derive(Debug)]
pub struct Frame {
pub msg_type: FixedMessageType,
pub payload: Vec<u8>,
}
pub fn write_frame(w: &mut impl Write, frame: &Frame) -> io::Result<()> {
w.write_all(&[frame.msg_type as u8])?;
w.write_all(&(frame.payload.len() as u32).to_le_bytes())?;
w.write_all(&frame.payload)?;
w.flush()
}
pub fn read_frame(r: &mut impl Read) -> io::Result<Frame> {
let mut t = [0u8; 1];
r.read_exact(&mut t)?;
let msg_type = FixedMessageType::from_u8(t[0]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("unknown message type: {}", t[0]),
)
})?;
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut payload = vec![0u8; len];
if len > 0 {
r.read_exact(&mut payload)?;
}
Ok(Frame { msg_type, payload })
}
#[derive(Debug)]
pub struct SyncRequest {
pub shard_id: u8,
pub protocol_version: u8,
pub flags: u8, }
impl SyncRequest {
pub fn encode(&self) -> Frame {
let payload = vec![self.shard_id, self.protocol_version, self.flags];
Frame {
msg_type: FixedMessageType::SyncRequest,
payload,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 3 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"SyncRequest too short",
));
}
Ok(Self {
shard_id: payload[0],
protocol_version: payload[1],
flags: payload[2],
})
}
}
pub struct ShardInfo {
pub shard_count: u8,
pub key_len: u16,
pub value_len: u16,
pub slot_size: u16,
pub current_slot_count: u32,
pub shard_prefix_bits: u8,
}
impl ShardInfo {
pub fn encode(&self) -> Frame {
let mut p = Vec::with_capacity(12);
p.push(self.shard_count);
p.extend_from_slice(&self.key_len.to_le_bytes());
p.extend_from_slice(&self.value_len.to_le_bytes());
p.extend_from_slice(&self.slot_size.to_le_bytes());
p.extend_from_slice(&self.current_slot_count.to_le_bytes());
p.push(self.shard_prefix_bits);
Frame {
msg_type: FixedMessageType::ShardInfo,
payload: p,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 12 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ShardInfo too short",
));
}
Ok(Self {
shard_count: payload[0],
key_len: u16::from_le_bytes(payload[1..3].try_into().expect("2 bytes")),
value_len: u16::from_le_bytes(payload[3..5].try_into().expect("2 bytes")),
slot_size: u16::from_le_bytes(payload[5..7].try_into().expect("2 bytes")),
current_slot_count: u32::from_le_bytes(payload[7..11].try_into().expect("4 bytes")),
shard_prefix_bits: payload[11],
})
}
}
pub struct SlotBatchEncoder {
payload: Vec<u8>,
count: u32,
key_len: usize,
value_len: usize,
count_offset: usize,
}
impl SlotBatchEncoder {
pub fn new(shard_id: u8, key_len: usize, value_len: usize) -> Self {
let mut payload = Vec::with_capacity(BATCH_MAX_BYTES);
payload.push(shard_id);
let count_offset = payload.len();
payload.extend_from_slice(&0u32.to_le_bytes()); Self {
payload,
count: 0,
key_len,
value_len,
count_offset,
}
}
pub fn add_occupied(&mut self, slot_id: u32, meta: u32, key: &[u8], value: &[u8]) {
debug_assert_eq!(key.len(), self.key_len);
debug_assert_eq!(value.len(), self.value_len);
debug_assert_eq!(status_of(meta), STATUS_OCCUPIED);
self.payload.extend_from_slice(&slot_id.to_le_bytes());
self.payload.extend_from_slice(&meta.to_le_bytes());
self.payload.extend_from_slice(key);
self.payload.extend_from_slice(value);
self.count += 1;
}
pub fn add_deleted(&mut self, slot_id: u32, meta: u32, key: &[u8]) {
debug_assert_eq!(key.len(), self.key_len);
self.payload.extend_from_slice(&slot_id.to_le_bytes());
self.payload.extend_from_slice(&meta.to_le_bytes());
self.payload.extend_from_slice(key);
self.count += 1;
}
pub fn len(&self) -> u32 {
self.count
}
pub fn bytes(&self) -> usize {
self.payload.len()
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn finish(mut self) -> Frame {
self.payload[self.count_offset..self.count_offset + 4]
.copy_from_slice(&self.count.to_le_bytes());
Frame {
msg_type: FixedMessageType::SlotBatch,
payload: self.payload,
}
}
}
pub struct SlotBatchDecoder<'a> {
pub shard_id: u8,
pub count: u32,
pub events_bytes: &'a [u8],
key_len: usize,
value_len: usize,
}
#[derive(Debug)]
pub struct SlotEventRef<'a> {
pub slot_id: u32,
pub meta: u32,
pub key: &'a [u8],
pub value: &'a [u8], }
impl<'a> SlotBatchDecoder<'a> {
pub fn new(payload: &'a [u8], key_len: usize, value_len: usize) -> io::Result<Self> {
if payload.len() < 5 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"SlotBatch too short",
));
}
let shard_id = payload[0];
let count = u32::from_le_bytes(payload[1..5].try_into().expect("4 bytes"));
Ok(Self {
shard_id,
count,
events_bytes: &payload[5..],
key_len,
value_len,
})
}
pub fn iter(&self) -> SlotEventIter<'a> {
SlotEventIter {
remaining: self.events_bytes,
left: self.count,
key_len: self.key_len,
value_len: self.value_len,
}
}
}
pub struct SlotEventIter<'a> {
remaining: &'a [u8],
left: u32,
key_len: usize,
value_len: usize,
}
impl<'a> Iterator for SlotEventIter<'a> {
type Item = io::Result<SlotEventRef<'a>>;
fn next(&mut self) -> Option<Self::Item> {
if self.left == 0 {
return None;
}
if self.remaining.len() < 8 + self.key_len {
return Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"SlotBatch event truncated at header/key",
)));
}
let slot_id = u32::from_le_bytes(self.remaining[0..4].try_into().expect("4 bytes"));
let meta = u32::from_le_bytes(self.remaining[4..8].try_into().expect("4 bytes"));
let key_start = 8;
let key_end = key_start + self.key_len;
let key = &self.remaining[key_start..key_end];
let value_len_if_occ = if status_of(meta) == STATUS_OCCUPIED {
self.value_len
} else {
0
};
let value_end = key_end + value_len_if_occ;
if self.remaining.len() < value_end {
return Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"SlotBatch event truncated at value",
)));
}
let value = &self.remaining[key_end..value_end];
self.remaining = &self.remaining[value_end..];
self.left -= 1;
Some(Ok(SlotEventRef {
slot_id,
meta,
key,
value,
}))
}
}
pub struct CaughtUp {
pub shard_id: u8,
pub total_scanned: u64,
}
impl CaughtUp {
pub fn encode(&self) -> Frame {
let mut p = Vec::with_capacity(9);
p.push(self.shard_id);
p.extend_from_slice(&self.total_scanned.to_le_bytes());
Frame {
msg_type: FixedMessageType::CaughtUp,
payload: p,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 9 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"CaughtUp too short",
));
}
Ok(Self {
shard_id: payload[0],
total_scanned: u64::from_le_bytes(payload[1..9].try_into().expect("8 bytes")),
})
}
}
pub struct Ack {
pub shard_id: u8,
pub applied_count: u64,
pub max_version_seen: u32,
}
impl Ack {
pub fn encode(&self) -> Frame {
let mut p = Vec::with_capacity(13);
p.push(self.shard_id);
p.extend_from_slice(&self.applied_count.to_le_bytes());
p.extend_from_slice(&self.max_version_seen.to_le_bytes());
Frame {
msg_type: FixedMessageType::Ack,
payload: p,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 13 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Ack too short"));
}
Ok(Self {
shard_id: payload[0],
applied_count: u64::from_le_bytes(payload[1..9].try_into().expect("8 bytes")),
max_version_seen: u32::from_le_bytes(payload[9..13].try_into().expect("4 bytes")),
})
}
}
pub fn encode_heartbeat() -> Frame {
Frame {
msg_type: FixedMessageType::Heartbeat,
payload: Vec::new(),
}
}
pub fn encode_error(msg: &str) -> Frame {
Frame {
msg_type: FixedMessageType::Error,
payload: msg.as_bytes().to_vec(),
}
}
pub fn decode_error(payload: &[u8]) -> String {
String::from_utf8_lossy(payload).into_owned()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fixed::slot::{STATUS_DELETED, STATUS_OCCUPIED, pack_meta};
#[test]
fn test_frame_roundtrip() {
let f = Frame {
msg_type: FixedMessageType::Heartbeat,
payload: vec![1, 2, 3],
};
let mut buf = Vec::new();
write_frame(&mut buf, &f).unwrap();
let parsed = read_frame(&mut &buf[..]).unwrap();
assert_eq!(parsed.msg_type, FixedMessageType::Heartbeat);
assert_eq!(parsed.payload, vec![1, 2, 3]);
}
#[test]
fn test_frame_unknown_type() {
let buf = [42u8, 0, 0, 0, 0];
let err = read_frame(&mut &buf[..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn test_sync_request_roundtrip() {
let r = SyncRequest {
shard_id: 7,
protocol_version: PROTOCOL_VERSION,
flags: FLAG_EMPTY_STATE,
};
let frame = r.encode();
let d = SyncRequest::decode(&frame.payload).unwrap();
assert_eq!(d.shard_id, 7);
assert_eq!(d.protocol_version, PROTOCOL_VERSION);
assert_eq!(d.flags, FLAG_EMPTY_STATE);
}
#[test]
fn test_sync_request_truncated() {
assert_eq!(
SyncRequest::decode(&[1u8]).unwrap_err().kind(),
std::io::ErrorKind::InvalidData
);
}
#[test]
fn test_shard_info_roundtrip() {
let s = ShardInfo {
shard_count: 16,
key_len: 8,
value_len: 32,
slot_size: 48,
current_slot_count: 1_000_000,
shard_prefix_bits: 4,
};
let f = s.encode();
let d = ShardInfo::decode(&f.payload).unwrap();
assert_eq!(d.shard_count, 16);
assert_eq!(d.key_len, 8);
assert_eq!(d.value_len, 32);
assert_eq!(d.slot_size, 48);
assert_eq!(d.current_slot_count, 1_000_000);
assert_eq!(d.shard_prefix_bits, 4);
}
#[test]
fn test_slot_batch_mixed_roundtrip() {
let key_len = 4usize;
let value_len = 8usize;
let mut b = SlotBatchEncoder::new(5, key_len, value_len);
let occ_meta = pack_meta(STATUS_OCCUPIED, 42);
b.add_occupied(100, occ_meta, b"key0", b"12345678");
let del_meta = pack_meta(STATUS_DELETED, 43);
b.add_deleted(101, del_meta, b"key1");
let frame = b.finish();
assert_eq!(frame.msg_type, FixedMessageType::SlotBatch);
let d = SlotBatchDecoder::new(&frame.payload, key_len, value_len).unwrap();
assert_eq!(d.shard_id, 5);
let events: Vec<_> = d.iter().collect::<Result<_, _>>().unwrap();
assert_eq!(events.len(), 2);
assert_eq!(events[0].slot_id, 100);
assert_eq!(events[0].meta, occ_meta);
assert_eq!(events[0].key, b"key0");
assert_eq!(events[0].value, b"12345678");
assert_eq!(events[1].slot_id, 101);
assert_eq!(events[1].meta, del_meta);
assert_eq!(events[1].key, b"key1");
assert!(events[1].value.is_empty());
}
#[test]
fn test_slot_batch_truncated_event_fails() {
let key_len = 4usize;
let value_len = 8usize;
let mut b = SlotBatchEncoder::new(0, key_len, value_len);
b.add_occupied(1, pack_meta(STATUS_OCCUPIED, 1), b"key0", b"12345678");
let mut frame = b.finish();
frame.payload.truncate(frame.payload.len() - 4);
let d = SlotBatchDecoder::new(&frame.payload, key_len, value_len).unwrap();
let mut iter = d.iter();
let err = iter.next().unwrap().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn test_caught_up_roundtrip() {
let c = CaughtUp {
shard_id: 3,
total_scanned: 123_456_789,
};
let f = c.encode();
let d = CaughtUp::decode(&f.payload).unwrap();
assert_eq!(d.shard_id, 3);
assert_eq!(d.total_scanned, 123_456_789);
}
#[test]
fn test_ack_roundtrip() {
let a = Ack {
shard_id: 4,
applied_count: 1000,
max_version_seen: 42,
};
let f = a.encode();
let d = Ack::decode(&f.payload).unwrap();
assert_eq!(d.shard_id, 4);
assert_eq!(d.applied_count, 1000);
assert_eq!(d.max_version_seen, 42);
}
}