use bytes::{Buf, BufMut, Bytes, BytesMut};
use ipfrs_core::Cid;
use std::io::{self, Cursor};
use thiserror::Error;
pub const PROTOCOL_VERSION: u8 = 1;
pub const MAGIC: [u8; 4] = *b"IPFS";
pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum MessageType {
GetBlock = 0x01,
PutBlock = 0x02,
HasBlock = 0x03,
DeleteBlock = 0x04,
BatchGet = 0x05,
BatchPut = 0x06,
BatchHas = 0x07,
Success = 0x80,
Error = 0x81,
}
impl MessageType {
pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
match value {
0x01 => Ok(MessageType::GetBlock),
0x02 => Ok(MessageType::PutBlock),
0x03 => Ok(MessageType::HasBlock),
0x04 => Ok(MessageType::DeleteBlock),
0x05 => Ok(MessageType::BatchGet),
0x06 => Ok(MessageType::BatchPut),
0x07 => Ok(MessageType::BatchHas),
0x80 => Ok(MessageType::Success),
0x81 => Ok(MessageType::Error),
_ => Err(ProtocolError::InvalidMessageType(value)),
}
}
pub fn to_u8(self) -> u8 {
self as u8
}
}
#[derive(Debug, Clone)]
pub struct BinaryMessage {
pub version: u8,
pub msg_type: MessageType,
pub message_id: u32,
pub payload: Bytes,
}
impl BinaryMessage {
pub fn new(msg_type: MessageType, message_id: u32, payload: Bytes) -> Self {
Self {
version: PROTOCOL_VERSION,
msg_type,
message_id,
payload,
}
}
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
let total_size = 4 + 1 + 1 + 4 + self.payload.len();
if total_size > MAX_MESSAGE_SIZE {
return Err(ProtocolError::MessageTooLarge(total_size));
}
let mut buf = BytesMut::with_capacity(total_size);
buf.put_slice(&MAGIC);
buf.put_u8(self.version);
buf.put_u8(self.msg_type.to_u8());
buf.put_u32(self.message_id);
buf.put_slice(&self.payload);
Ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
if data.len() < 10 {
return Err(ProtocolError::InvalidMessageSize(data.len()));
}
if data.len() > MAX_MESSAGE_SIZE {
return Err(ProtocolError::MessageTooLarge(data.len()));
}
let mut cursor = Cursor::new(data);
let mut magic = [0u8; 4];
cursor.copy_to_slice(&mut magic);
if magic != MAGIC {
return Err(ProtocolError::InvalidMagic(magic));
}
let version = cursor.get_u8();
if version > PROTOCOL_VERSION {
return Err(ProtocolError::UnsupportedVersion(version));
}
let msg_type = MessageType::from_u8(cursor.get_u8())?;
let message_id = cursor.get_u32();
let position = cursor.position() as usize;
let payload = Bytes::copy_from_slice(&data[position..]);
Ok(Self {
version,
msg_type,
message_id,
payload,
})
}
}
#[derive(Debug, Clone)]
pub struct GetBlockRequest {
pub cid: Cid,
}
impl GetBlockRequest {
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
let cid_bytes = self.cid.to_bytes();
let mut buf = BytesMut::with_capacity(cid_bytes.len());
buf.put_slice(&cid_bytes);
Ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
let cid = Cid::try_from(data).map_err(|e| ProtocolError::InvalidCid(e.to_string()))?;
Ok(Self { cid })
}
}
#[derive(Debug, Clone)]
pub struct PutBlockRequest {
pub data: Bytes,
}
impl PutBlockRequest {
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
Ok(self.data.clone())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
Ok(Self {
data: Bytes::copy_from_slice(data),
})
}
}
#[derive(Debug, Clone)]
pub struct HasBlockRequest {
pub cid: Cid,
}
impl HasBlockRequest {
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
let cid_bytes = self.cid.to_bytes();
let mut buf = BytesMut::with_capacity(cid_bytes.len());
buf.put_slice(&cid_bytes);
Ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
let cid = Cid::try_from(data).map_err(|e| ProtocolError::InvalidCid(e.to_string()))?;
Ok(Self { cid })
}
}
#[derive(Debug, Clone)]
pub struct BatchGetRequest {
pub cids: Vec<Cid>,
}
impl BatchGetRequest {
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
let mut buf = BytesMut::new();
buf.put_u32(self.cids.len() as u32);
for cid in &self.cids {
let cid_bytes = cid.to_bytes();
buf.put_u16(cid_bytes.len() as u16);
buf.put_slice(&cid_bytes);
}
Ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
let mut cursor = Cursor::new(data);
if cursor.remaining() < 4 {
return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
}
let count = cursor.get_u32() as usize;
let mut cids = Vec::with_capacity(count);
for _ in 0..count {
if cursor.remaining() < 2 {
return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
}
let len = cursor.get_u16() as usize;
if cursor.remaining() < len {
return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
}
let position = cursor.position() as usize;
let cid_data = &data[position..position + len];
let cid =
Cid::try_from(cid_data).map_err(|e| ProtocolError::InvalidCid(e.to_string()))?;
cids.push(cid);
cursor.set_position((position + len) as u64);
}
Ok(Self { cids })
}
}
#[derive(Debug, Clone)]
pub struct SuccessResponse {
pub data: Bytes,
}
impl SuccessResponse {
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
Ok(self.data.clone())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
Ok(Self {
data: Bytes::copy_from_slice(data),
})
}
}
#[derive(Debug, Clone)]
pub struct ErrorResponse {
pub error_code: u16,
pub message: String,
}
impl ErrorResponse {
pub fn encode(&self) -> Result<Bytes, ProtocolError> {
let message_bytes = self.message.as_bytes();
let mut buf = BytesMut::with_capacity(2 + 2 + message_bytes.len());
buf.put_u16(self.error_code);
buf.put_u16(message_bytes.len() as u16);
buf.put_slice(message_bytes);
Ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
let mut cursor = Cursor::new(data);
if cursor.remaining() < 4 {
return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
}
let error_code = cursor.get_u16();
let message_len = cursor.get_u16() as usize;
if cursor.remaining() < message_len {
return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
}
let position = cursor.position() as usize;
let message_bytes = &data[position..position + message_len];
let message = String::from_utf8(message_bytes.to_vec())
.map_err(|e| ProtocolError::InvalidUtf8(e.to_string()))?;
Ok(Self {
error_code,
message,
})
}
}
#[derive(Debug, Error)]
pub enum ProtocolError {
#[error("Invalid magic bytes: {0:?}")]
InvalidMagic([u8; 4]),
#[error("Unsupported protocol version: {0}")]
UnsupportedVersion(u8),
#[error("Invalid message type: {0}")]
InvalidMessageType(u8),
#[error("Invalid message size: {0}")]
InvalidMessageSize(usize),
#[error("Message too large: {0} bytes")]
MessageTooLarge(usize),
#[error("Invalid CID: {0}")]
InvalidCid(String),
#[error("Invalid UTF-8: {0}")]
InvalidUtf8(String),
#[error("IO error: {0}")]
Io(#[from] io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type_conversion() {
assert_eq!(MessageType::from_u8(0x01).unwrap(), MessageType::GetBlock);
assert_eq!(MessageType::GetBlock.to_u8(), 0x01);
assert!(MessageType::from_u8(0xFF).is_err());
}
#[test]
fn test_binary_message_encode_decode() {
let payload = Bytes::from("test payload");
let msg = BinaryMessage::new(MessageType::GetBlock, 42, payload.clone());
let encoded = msg.encode().unwrap();
let decoded = BinaryMessage::decode(&encoded).unwrap();
assert_eq!(decoded.version, PROTOCOL_VERSION);
assert_eq!(decoded.msg_type, MessageType::GetBlock);
assert_eq!(decoded.message_id, 42);
assert_eq!(decoded.payload, payload);
}
#[test]
fn test_message_too_large() {
let large_payload = Bytes::from(vec![0u8; MAX_MESSAGE_SIZE]);
let msg = BinaryMessage::new(MessageType::GetBlock, 1, large_payload);
assert!(msg.encode().is_err());
}
#[test]
fn test_invalid_magic() {
let data = vec![0xFF, 0xFF, 0xFF, 0xFF, 1, 1, 0, 0, 0, 42];
let result = BinaryMessage::decode(&data);
assert!(result.is_err());
}
#[test]
fn test_batch_get_request_encode_decode() {
use ipfrs_core::Block;
let block1 = Block::new(Bytes::from("test data 1")).unwrap();
let block2 = Block::new(Bytes::from("test data 2")).unwrap();
let cid1 = *block1.cid();
let cid2 = *block2.cid();
let request = BatchGetRequest {
cids: vec![cid1, cid2],
};
let encoded = request.encode().unwrap();
let decoded = BatchGetRequest::decode(&encoded).unwrap();
assert_eq!(decoded.cids.len(), 2);
assert_eq!(decoded.cids[0], cid1);
assert_eq!(decoded.cids[1], cid2);
}
#[test]
fn test_error_response_encode_decode() {
let response = ErrorResponse {
error_code: 404,
message: "Block not found".to_string(),
};
let encoded = response.encode().unwrap();
let decoded = ErrorResponse::decode(&encoded).unwrap();
assert_eq!(decoded.error_code, 404);
assert_eq!(decoded.message, "Block not found");
}
#[test]
fn test_protocol_versioning() {
let payload = Bytes::from("test");
let mut msg = BinaryMessage::new(MessageType::GetBlock, 1, payload);
msg.version = PROTOCOL_VERSION;
let encoded = msg.encode().unwrap();
assert!(BinaryMessage::decode(&encoded).is_ok());
msg.version = PROTOCOL_VERSION + 1;
let encoded = msg.encode().unwrap();
assert!(BinaryMessage::decode(&encoded).is_err());
}
}