use std::io::{self, Read, Write};
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
SyncRequest = 1,
ShardInfo = 2,
EntryBatch = 3,
CaughtUp = 4,
Ack = 5,
Heartbeat = 6,
Error = 255,
}
impl MessageType {
fn from_u8(v: u8) -> Option<Self> {
match v {
1 => Some(Self::SyncRequest),
2 => Some(Self::ShardInfo),
3 => Some(Self::EntryBatch),
4 => Some(Self::CaughtUp),
5 => Some(Self::Ack),
6 => Some(Self::Heartbeat),
255 => Some(Self::Error),
_ => None,
}
}
}
pub struct Frame {
pub msg_type: MessageType,
pub payload: Vec<u8>,
}
pub fn write_frame(w: &mut impl Write, frame: &Frame) -> io::Result<()> {
w.write_all(&[frame.msg_type as u8])?;
w.write_all(&(frame.payload.len() as u32).to_le_bytes())?;
w.write_all(&frame.payload)?;
w.flush()
}
pub fn read_frame(r: &mut impl Read) -> io::Result<Frame> {
let mut type_buf = [0u8; 1];
r.read_exact(&mut type_buf)?;
let msg_type = MessageType::from_u8(type_buf[0]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("unknown message type: {}", type_buf[0]),
)
})?;
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut payload = vec![0u8; len];
if len > 0 {
r.read_exact(&mut payload)?;
}
Ok(Frame { msg_type, payload })
}
pub struct SyncRequest {
pub shard_id: u8,
pub from_gsn: u64,
pub key_lens: Vec<u16>,
}
impl SyncRequest {
pub fn encode(&self) -> Frame {
let mut payload = Vec::with_capacity(10 + self.key_lens.len() * 2);
payload.push(self.shard_id);
payload.extend_from_slice(&self.from_gsn.to_le_bytes());
payload.push(self.key_lens.len() as u8);
for &kl in &self.key_lens {
payload.extend_from_slice(&kl.to_le_bytes());
}
Frame {
msg_type: MessageType::SyncRequest,
payload,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 10 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"SyncRequest too short",
));
}
let shard_id = payload[0];
let from_gsn = u64::from_le_bytes(payload[1..9].try_into().expect("impossible"));
let count = payload[9] as usize;
let mut key_lens = Vec::with_capacity(count);
let mut off = 10;
for _ in 0..count {
if off + 2 > payload.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"SyncRequest truncated",
));
}
key_lens.push(u16::from_le_bytes(
payload[off..off + 2].try_into().expect("impossible"),
));
off += 2;
}
Ok(Self {
shard_id,
from_gsn,
key_lens,
})
}
}
pub struct ShardInfo {
pub shard_count: u8,
pub max_file_size: u64,
}
impl ShardInfo {
pub fn encode(&self) -> Frame {
let mut payload = Vec::with_capacity(9);
payload.push(self.shard_count);
payload.extend_from_slice(&self.max_file_size.to_le_bytes());
Frame {
msg_type: MessageType::ShardInfo,
payload,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 9 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ShardInfo too short",
));
}
Ok(Self {
shard_count: payload[0],
max_file_size: u64::from_le_bytes(payload[1..9].try_into().expect("impossible")),
})
}
}
pub struct WireEntry {
pub entry_len: u32,
pub key_len: u16,
pub gsn: u64,
pub data: Vec<u8>,
}
pub struct EntryBatch {
pub shard_id: u8,
pub entries: Vec<WireEntry>,
}
impl EntryBatch {
pub fn encode(&self) -> Frame {
let mut payload = Vec::with_capacity(5 + self.entries.len() * 64);
payload.push(self.shard_id);
payload.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
for e in &self.entries {
payload.extend_from_slice(&e.entry_len.to_le_bytes());
payload.extend_from_slice(&e.key_len.to_le_bytes());
payload.extend_from_slice(&e.gsn.to_le_bytes());
payload.extend_from_slice(&e.data);
}
Frame {
msg_type: MessageType::EntryBatch,
payload,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 5 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"EntryBatch too short",
));
}
let shard_id = payload[0];
let count = u32::from_le_bytes(payload[1..5].try_into().expect("impossible")) as usize;
let mut entries = Vec::with_capacity(count);
let mut off = 5;
for _ in 0..count {
if off + 14 > payload.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"EntryBatch truncated",
));
}
let entry_len =
u32::from_le_bytes(payload[off..off + 4].try_into().expect("impossible"));
let key_len =
u16::from_le_bytes(payload[off + 4..off + 6].try_into().expect("impossible"));
let gsn =
u64::from_le_bytes(payload[off + 6..off + 14].try_into().expect("impossible"));
off += 14;
if off + entry_len as usize > payload.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"EntryBatch data truncated",
));
}
let data = payload[off..off + entry_len as usize].to_vec();
off += entry_len as usize;
entries.push(WireEntry {
entry_len,
key_len,
gsn,
data,
});
}
Ok(Self { shard_id, entries })
}
}
pub struct CaughtUp {
pub shard_id: u8,
pub leader_gsn: u64,
}
impl CaughtUp {
pub fn encode(&self) -> Frame {
let mut payload = Vec::with_capacity(9);
payload.push(self.shard_id);
payload.extend_from_slice(&self.leader_gsn.to_le_bytes());
Frame {
msg_type: MessageType::CaughtUp,
payload,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 9 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"CaughtUp too short",
));
}
Ok(Self {
shard_id: payload[0],
leader_gsn: u64::from_le_bytes(payload[1..9].try_into().expect("impossible")),
})
}
}
pub struct AckMessage {
pub shard_id: u8,
pub last_gsn: u64,
}
impl AckMessage {
pub fn encode(&self) -> Frame {
let mut payload = Vec::with_capacity(9);
payload.push(self.shard_id);
payload.extend_from_slice(&self.last_gsn.to_le_bytes());
Frame {
msg_type: MessageType::Ack,
payload,
}
}
pub fn decode(payload: &[u8]) -> io::Result<Self> {
if payload.len() < 9 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Ack too short"));
}
Ok(Self {
shard_id: payload[0],
last_gsn: u64::from_le_bytes(payload[1..9].try_into().expect("impossible")),
})
}
}
pub fn encode_heartbeat() -> Frame {
Frame {
msg_type: MessageType::Heartbeat,
payload: Vec::new(),
}
}
pub fn encode_error(msg: &str) -> Frame {
Frame {
msg_type: MessageType::Error,
payload: msg.as_bytes().to_vec(),
}
}
pub fn decode_error(payload: &[u8]) -> String {
String::from_utf8_lossy(payload).into_owned()
}