use std::io;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ClusterMessage {
Gossip {
sender_id: String,
key: String,
value: u64,
version: u64,
},
Ping {
sender_id: String,
seq: u64,
},
Pong {
sender_id: String,
seq: u64,
},
Replicate {
leader_id: String,
index: u64,
term: u64,
checksum: u64,
},
ReplicateAck {
follower_id: String,
index: u64,
success: bool,
},
}
pub struct MessageCodec;
impl MessageCodec {
pub async fn write<W>(writer: &mut W, msg: &ClusterMessage) -> io::Result<()>
where
W: AsyncWriteExt + Unpin,
{
let body =
serde_json::to_vec(msg).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let len = u32::try_from(body.len()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"message body exceeds u32::MAX bytes",
)
})?;
writer.write_all(&len.to_be_bytes()).await?;
writer.write_all(&body).await?;
writer.flush().await?;
Ok(())
}
pub async fn read<R>(reader: &mut R) -> io::Result<ClusterMessage>
where
R: AsyncReadExt + Unpin,
{
let mut len_buf = [0u8; 4];
reader.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf) as usize;
let mut body = vec![0u8; len];
reader.read_exact(&mut body).await?;
serde_json::from_slice(&body).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_roundtrip_gossip() {
let msg = ClusterMessage::Gossip {
sender_id: "node-1".to_owned(),
key: "alpha".to_owned(),
value: 99,
version: 7,
};
let (mut client, mut server) = duplex(1024);
MessageCodec::write(&mut client, &msg).await.expect("write");
let received = MessageCodec::read(&mut server).await.expect("read");
assert_eq!(msg, received);
}
#[tokio::test]
async fn test_roundtrip_ping() {
let msg = ClusterMessage::Ping {
sender_id: "node-2".to_owned(),
seq: 42,
};
let (mut client, mut server) = duplex(1024);
MessageCodec::write(&mut client, &msg).await.expect("write");
let received = MessageCodec::read(&mut server).await.expect("read");
assert_eq!(msg, received);
}
#[tokio::test]
async fn test_roundtrip_replicate() {
let msg = ClusterMessage::Replicate {
leader_id: "leader".to_owned(),
index: 100,
term: 3,
checksum: 0xDEAD_BEEF,
};
let (mut client, mut server) = duplex(1024);
MessageCodec::write(&mut client, &msg).await.expect("write");
let received = MessageCodec::read(&mut server).await.expect("read");
assert_eq!(msg, received);
}
#[tokio::test]
async fn test_multiple_messages_in_sequence() {
let msgs = vec![
ClusterMessage::Ping {
sender_id: "a".to_owned(),
seq: 1,
},
ClusterMessage::Pong {
sender_id: "b".to_owned(),
seq: 1,
},
ClusterMessage::ReplicateAck {
follower_id: "c".to_owned(),
index: 5,
success: true,
},
];
let (mut client, mut server) = duplex(4096);
for msg in &msgs {
MessageCodec::write(&mut client, msg).await.expect("write");
}
for expected in &msgs {
let received = MessageCodec::read(&mut server).await.expect("read");
assert_eq!(expected, &received);
}
}
}