1use bytes::{Buf, BufMut, Bytes, BytesMut};
10use ipfrs_core::Cid;
11use std::io::{self, Cursor};
12use thiserror::Error;
13
14pub const PROTOCOL_VERSION: u8 = 1;
20
21pub const MAGIC: [u8; 4] = *b"IPFS";
23
24pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33#[repr(u8)]
34pub enum MessageType {
35 GetBlock = 0x01,
37 PutBlock = 0x02,
39 HasBlock = 0x03,
41 DeleteBlock = 0x04,
43 BatchGet = 0x05,
45 BatchPut = 0x06,
47 BatchHas = 0x07,
49 Success = 0x80,
51 Error = 0x81,
53}
54
55impl MessageType {
56 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
58 match value {
59 0x01 => Ok(MessageType::GetBlock),
60 0x02 => Ok(MessageType::PutBlock),
61 0x03 => Ok(MessageType::HasBlock),
62 0x04 => Ok(MessageType::DeleteBlock),
63 0x05 => Ok(MessageType::BatchGet),
64 0x06 => Ok(MessageType::BatchPut),
65 0x07 => Ok(MessageType::BatchHas),
66 0x80 => Ok(MessageType::Success),
67 0x81 => Ok(MessageType::Error),
68 _ => Err(ProtocolError::InvalidMessageType(value)),
69 }
70 }
71
72 pub fn to_u8(self) -> u8 {
74 self as u8
75 }
76}
77
78#[derive(Debug, Clone)]
91pub struct BinaryMessage {
92 pub version: u8,
94 pub msg_type: MessageType,
96 pub message_id: u32,
98 pub payload: Bytes,
100}
101
102impl BinaryMessage {
103 pub fn new(msg_type: MessageType, message_id: u32, payload: Bytes) -> Self {
105 Self {
106 version: PROTOCOL_VERSION,
107 msg_type,
108 message_id,
109 payload,
110 }
111 }
112
113 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
115 let total_size = 4 + 1 + 1 + 4 + self.payload.len();
116 if total_size > MAX_MESSAGE_SIZE {
117 return Err(ProtocolError::MessageTooLarge(total_size));
118 }
119
120 let mut buf = BytesMut::with_capacity(total_size);
121
122 buf.put_slice(&MAGIC);
124 buf.put_u8(self.version);
126 buf.put_u8(self.msg_type.to_u8());
128 buf.put_u32(self.message_id);
130 buf.put_slice(&self.payload);
132
133 Ok(buf.freeze())
134 }
135
136 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
138 if data.len() < 10 {
139 return Err(ProtocolError::InvalidMessageSize(data.len()));
140 }
141
142 if data.len() > MAX_MESSAGE_SIZE {
143 return Err(ProtocolError::MessageTooLarge(data.len()));
144 }
145
146 let mut cursor = Cursor::new(data);
147
148 let mut magic = [0u8; 4];
150 cursor.copy_to_slice(&mut magic);
151 if magic != MAGIC {
152 return Err(ProtocolError::InvalidMagic(magic));
153 }
154
155 let version = cursor.get_u8();
157 if version > PROTOCOL_VERSION {
158 return Err(ProtocolError::UnsupportedVersion(version));
159 }
160
161 let msg_type = MessageType::from_u8(cursor.get_u8())?;
163
164 let message_id = cursor.get_u32();
166
167 let position = cursor.position() as usize;
169 let payload = Bytes::copy_from_slice(&data[position..]);
170
171 Ok(Self {
172 version,
173 msg_type,
174 message_id,
175 payload,
176 })
177 }
178}
179
180#[derive(Debug, Clone)]
186pub struct GetBlockRequest {
187 pub cid: Cid,
188}
189
190impl GetBlockRequest {
191 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
193 let cid_bytes = self.cid.to_bytes();
194 let mut buf = BytesMut::with_capacity(cid_bytes.len());
195 buf.put_slice(&cid_bytes);
196 Ok(buf.freeze())
197 }
198
199 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
201 let cid = Cid::try_from(data).map_err(|e| ProtocolError::InvalidCid(e.to_string()))?;
202 Ok(Self { cid })
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct PutBlockRequest {
209 pub data: Bytes,
210}
211
212impl PutBlockRequest {
213 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
215 Ok(self.data.clone())
216 }
217
218 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
220 Ok(Self {
221 data: Bytes::copy_from_slice(data),
222 })
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct HasBlockRequest {
229 pub cid: Cid,
230}
231
232impl HasBlockRequest {
233 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
235 let cid_bytes = self.cid.to_bytes();
236 let mut buf = BytesMut::with_capacity(cid_bytes.len());
237 buf.put_slice(&cid_bytes);
238 Ok(buf.freeze())
239 }
240
241 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
243 let cid = Cid::try_from(data).map_err(|e| ProtocolError::InvalidCid(e.to_string()))?;
244 Ok(Self { cid })
245 }
246}
247
248#[derive(Debug, Clone)]
250pub struct BatchGetRequest {
251 pub cids: Vec<Cid>,
252}
253
254impl BatchGetRequest {
255 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
257 let mut buf = BytesMut::new();
258
259 buf.put_u32(self.cids.len() as u32);
261
262 for cid in &self.cids {
264 let cid_bytes = cid.to_bytes();
265 buf.put_u16(cid_bytes.len() as u16);
266 buf.put_slice(&cid_bytes);
267 }
268
269 Ok(buf.freeze())
270 }
271
272 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
274 let mut cursor = Cursor::new(data);
275
276 if cursor.remaining() < 4 {
278 return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
279 }
280 let count = cursor.get_u32() as usize;
281
282 let mut cids = Vec::with_capacity(count);
283
284 for _ in 0..count {
286 if cursor.remaining() < 2 {
287 return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
288 }
289 let len = cursor.get_u16() as usize;
290
291 if cursor.remaining() < len {
292 return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
293 }
294
295 let position = cursor.position() as usize;
296 let cid_data = &data[position..position + len];
297 let cid =
298 Cid::try_from(cid_data).map_err(|e| ProtocolError::InvalidCid(e.to_string()))?;
299 cids.push(cid);
300 cursor.set_position((position + len) as u64);
301 }
302
303 Ok(Self { cids })
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct SuccessResponse {
310 pub data: Bytes,
311}
312
313impl SuccessResponse {
314 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
316 Ok(self.data.clone())
317 }
318
319 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
321 Ok(Self {
322 data: Bytes::copy_from_slice(data),
323 })
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct ErrorResponse {
330 pub error_code: u16,
331 pub message: String,
332}
333
334impl ErrorResponse {
335 pub fn encode(&self) -> Result<Bytes, ProtocolError> {
337 let message_bytes = self.message.as_bytes();
338 let mut buf = BytesMut::with_capacity(2 + 2 + message_bytes.len());
339
340 buf.put_u16(self.error_code);
342 buf.put_u16(message_bytes.len() as u16);
344 buf.put_slice(message_bytes);
346
347 Ok(buf.freeze())
348 }
349
350 pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
352 let mut cursor = Cursor::new(data);
353
354 if cursor.remaining() < 4 {
355 return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
356 }
357
358 let error_code = cursor.get_u16();
359 let message_len = cursor.get_u16() as usize;
360
361 if cursor.remaining() < message_len {
362 return Err(ProtocolError::InvalidMessageSize(cursor.remaining()));
363 }
364
365 let position = cursor.position() as usize;
366 let message_bytes = &data[position..position + message_len];
367 let message = String::from_utf8(message_bytes.to_vec())
368 .map_err(|e| ProtocolError::InvalidUtf8(e.to_string()))?;
369
370 Ok(Self {
371 error_code,
372 message,
373 })
374 }
375}
376
377#[derive(Debug, Error)]
383pub enum ProtocolError {
384 #[error("Invalid magic bytes: {0:?}")]
385 InvalidMagic([u8; 4]),
386
387 #[error("Unsupported protocol version: {0}")]
388 UnsupportedVersion(u8),
389
390 #[error("Invalid message type: {0}")]
391 InvalidMessageType(u8),
392
393 #[error("Invalid message size: {0}")]
394 InvalidMessageSize(usize),
395
396 #[error("Message too large: {0} bytes")]
397 MessageTooLarge(usize),
398
399 #[error("Invalid CID: {0}")]
400 InvalidCid(String),
401
402 #[error("Invalid UTF-8: {0}")]
403 InvalidUtf8(String),
404
405 #[error("IO error: {0}")]
406 Io(#[from] io::Error),
407}
408
409#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_message_type_conversion() {
419 assert_eq!(MessageType::from_u8(0x01).unwrap(), MessageType::GetBlock);
420 assert_eq!(MessageType::GetBlock.to_u8(), 0x01);
421 assert!(MessageType::from_u8(0xFF).is_err());
422 }
423
424 #[test]
425 fn test_binary_message_encode_decode() {
426 let payload = Bytes::from("test payload");
427 let msg = BinaryMessage::new(MessageType::GetBlock, 42, payload.clone());
428
429 let encoded = msg.encode().unwrap();
430 let decoded = BinaryMessage::decode(&encoded).unwrap();
431
432 assert_eq!(decoded.version, PROTOCOL_VERSION);
433 assert_eq!(decoded.msg_type, MessageType::GetBlock);
434 assert_eq!(decoded.message_id, 42);
435 assert_eq!(decoded.payload, payload);
436 }
437
438 #[test]
439 fn test_message_too_large() {
440 let large_payload = Bytes::from(vec![0u8; MAX_MESSAGE_SIZE]);
441 let msg = BinaryMessage::new(MessageType::GetBlock, 1, large_payload);
442 assert!(msg.encode().is_err());
443 }
444
445 #[test]
446 fn test_invalid_magic() {
447 let data = vec![0xFF, 0xFF, 0xFF, 0xFF, 1, 1, 0, 0, 0, 42];
448 let result = BinaryMessage::decode(&data);
449 assert!(result.is_err());
450 }
451
452 #[test]
453 fn test_batch_get_request_encode_decode() {
454 use ipfrs_core::Block;
456 let block1 = Block::new(Bytes::from("test data 1")).unwrap();
457 let block2 = Block::new(Bytes::from("test data 2")).unwrap();
458 let cid1 = *block1.cid();
459 let cid2 = *block2.cid();
460
461 let request = BatchGetRequest {
462 cids: vec![cid1, cid2],
463 };
464
465 let encoded = request.encode().unwrap();
466 let decoded = BatchGetRequest::decode(&encoded).unwrap();
467
468 assert_eq!(decoded.cids.len(), 2);
469 assert_eq!(decoded.cids[0], cid1);
470 assert_eq!(decoded.cids[1], cid2);
471 }
472
473 #[test]
474 fn test_error_response_encode_decode() {
475 let response = ErrorResponse {
476 error_code: 404,
477 message: "Block not found".to_string(),
478 };
479
480 let encoded = response.encode().unwrap();
481 let decoded = ErrorResponse::decode(&encoded).unwrap();
482
483 assert_eq!(decoded.error_code, 404);
484 assert_eq!(decoded.message, "Block not found");
485 }
486
487 #[test]
488 fn test_protocol_versioning() {
489 let payload = Bytes::from("test");
490 let mut msg = BinaryMessage::new(MessageType::GetBlock, 1, payload);
491
492 msg.version = PROTOCOL_VERSION;
494 let encoded = msg.encode().unwrap();
495 assert!(BinaryMessage::decode(&encoded).is_ok());
496
497 msg.version = PROTOCOL_VERSION + 1;
499 let encoded = msg.encode().unwrap();
500 assert!(BinaryMessage::decode(&encoded).is_err());
501 }
502}