ipfrs_transport/
messages.rs

1//! Protocol message definitions for TensorSwap/Bitswap
2//!
3//! Defines the wire format for block exchange messages compatible
4//! with IPFS Bitswap while adding TensorSwap extensions.
5//!
6//! # Example
7//!
8//! ```
9//! use ipfrs_transport::messages::{Message, WantEntry};
10//! use multihash::Multihash;
11//! use cid::Cid;
12//!
13//! // Create a test CID
14//! let hash = Multihash::wrap(0x12, &[0u8; 32]).unwrap();
15//! let cid = Cid::new_v1(0x55, hash);
16//!
17//! // Create a want list message
18//! let want_entry = WantEntry::with_priority(cid, 10);
19//! let message = Message::want_list(vec![want_entry], false);
20//!
21//! // Serialize to bytes
22//! let bytes = message.to_bytes().unwrap();
23//!
24//! // Deserialize back
25//! let decoded = Message::from_bytes(&bytes).unwrap();
26//!
27//! // Verify roundtrip
28//! match decoded {
29//!     Message::WantList(wl) => {
30//!         assert_eq!(wl.entries.len(), 1);
31//!         assert_eq!(wl.entries[0].priority, 10);
32//!     }
33//!     _ => panic!("Expected WantList message"),
34//! }
35//! ```
36
37use ipfrs_core::Cid;
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39
40/// Serialize CID as string
41fn serialize_cid<S>(cid: &Cid, serializer: S) -> Result<S::Ok, S::Error>
42where
43    S: Serializer,
44{
45    serializer.serialize_str(&cid.to_string())
46}
47
48/// Deserialize CID from string
49fn deserialize_cid<'de, D>(deserializer: D) -> Result<Cid, D::Error>
50where
51    D: Deserializer<'de>,
52{
53    let s = String::deserialize(deserializer)?;
54    s.parse().map_err(serde::de::Error::custom)
55}
56
57/// Message type for block exchange protocol
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum Message {
60    /// Want list - request blocks from peers
61    WantList(WantList),
62    /// Block data response
63    Block(BlockMessage),
64    /// Notification that peer has a block
65    Have(HaveMessage),
66    /// Notification that peer doesn't have a block
67    DontHave(DontHaveMessage),
68    /// Cancel a previous want
69    Cancel(CancelMessage),
70}
71
72/// Want list containing block requests
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct WantList {
75    /// List of wanted blocks
76    pub entries: Vec<WantEntry>,
77    /// Whether this is a full want list or incremental update
78    pub full: bool,
79}
80
81/// Entry in a want list
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
83pub struct WantEntry {
84    /// CID of wanted block
85    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
86    pub cid: Cid,
87    /// Priority (higher = more important)
88    pub priority: i32,
89    /// Whether to send the block or just confirmation
90    pub send_dont_have: bool,
91    /// Cancel this want
92    pub cancel: bool,
93}
94
95impl WantEntry {
96    /// Create a new want entry with default priority
97    pub fn new(cid: Cid) -> Self {
98        Self {
99            cid,
100            priority: 0,
101            send_dont_have: false,
102            cancel: false,
103        }
104    }
105
106    /// Create a want entry with specific priority
107    pub fn with_priority(cid: Cid, priority: i32) -> Self {
108        Self {
109            cid,
110            priority,
111            send_dont_have: false,
112            cancel: false,
113        }
114    }
115
116    /// Create a cancel entry
117    pub fn cancel(cid: Cid) -> Self {
118        Self {
119            cid,
120            priority: 0,
121            send_dont_have: false,
122            cancel: true,
123        }
124    }
125}
126
127/// Block message containing block data
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct BlockMessage {
130    /// CID of the block
131    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
132    pub cid: Cid,
133    /// Block data
134    pub data: Vec<u8>,
135}
136
137/// Have message - notify peer we have a block
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct HaveMessage {
140    /// CID of the block we have
141    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
142    pub cid: Cid,
143}
144
145/// Don't have message - notify peer we don't have a block
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DontHaveMessage {
148    /// CID of the block we don't have
149    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
150    pub cid: Cid,
151}
152
153/// Cancel message - cancel a previous want
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct CancelMessage {
156    /// CID to cancel
157    #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
158    pub cid: Cid,
159}
160
161impl Message {
162    /// Create a want list message
163    pub fn want_list(entries: Vec<WantEntry>, full: bool) -> Self {
164        Message::WantList(WantList { entries, full })
165    }
166
167    /// Create a block message
168    pub fn block(cid: Cid, data: Vec<u8>) -> Self {
169        Message::Block(BlockMessage { cid, data })
170    }
171
172    /// Create a have message
173    pub fn have(cid: Cid) -> Self {
174        Message::Have(HaveMessage { cid })
175    }
176
177    /// Create a don't have message
178    pub fn dont_have(cid: Cid) -> Self {
179        Message::DontHave(DontHaveMessage { cid })
180    }
181
182    /// Create a cancel message
183    pub fn cancel(cid: Cid) -> Self {
184        Message::Cancel(CancelMessage { cid })
185    }
186
187    /// Serialize message to bytes
188    pub fn to_bytes(&self) -> Result<Vec<u8>, oxicode::Error> {
189        oxicode::serde::encode_to_vec(self, oxicode::config::standard())
190    }
191
192    /// Deserialize message from bytes
193    pub fn from_bytes(data: &[u8]) -> Result<Self, oxicode::Error> {
194        oxicode::serde::decode_owned_from_slice(data, oxicode::config::standard()).map(|(v, _)| v)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    fn test_cid() -> Cid {
203        "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
204            .parse::<Cid>()
205            .unwrap()
206    }
207
208    fn test_cid2() -> Cid {
209        "bafybeihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
210            .parse::<Cid>()
211            .unwrap()
212    }
213
214    // Basic WantEntry Tests
215    #[test]
216    fn test_want_entry_creation() {
217        let cid = test_cid();
218
219        let entry = WantEntry::new(cid);
220        assert_eq!(entry.priority, 0);
221        assert!(!entry.cancel);
222        assert!(!entry.send_dont_have);
223
224        let priority_entry = WantEntry::with_priority(cid, 10);
225        assert_eq!(priority_entry.priority, 10);
226        assert!(!priority_entry.cancel);
227
228        let cancel_entry = WantEntry::cancel(cid);
229        assert!(cancel_entry.cancel);
230        assert_eq!(cancel_entry.priority, 0);
231    }
232
233    #[test]
234    fn test_want_entry_edge_cases() {
235        let cid = test_cid();
236
237        // Max priority
238        let max_entry = WantEntry::with_priority(cid, i32::MAX);
239        assert_eq!(max_entry.priority, i32::MAX);
240
241        // Min priority
242        let min_entry = WantEntry::with_priority(cid, i32::MIN);
243        assert_eq!(min_entry.priority, i32::MIN);
244
245        // Zero priority
246        let zero_entry = WantEntry::with_priority(cid, 0);
247        assert_eq!(zero_entry.priority, 0);
248
249        // Negative priority
250        let neg_entry = WantEntry::with_priority(cid, -100);
251        assert_eq!(neg_entry.priority, -100);
252    }
253
254    // Message Serialization Tests
255    #[test]
256    fn test_want_list_serialization_roundtrip() {
257        let cid1 = test_cid();
258        let cid2 = test_cid2();
259
260        let entries = vec![
261            WantEntry::with_priority(cid1, 10),
262            WantEntry::with_priority(cid2, 5),
263        ];
264
265        let msg = Message::want_list(entries.clone(), true);
266        let bytes = msg.to_bytes().unwrap();
267        let decoded = Message::from_bytes(&bytes).unwrap();
268
269        match decoded {
270            Message::WantList(want_list) => {
271                assert!(want_list.full);
272                assert_eq!(want_list.entries.len(), 2);
273                assert_eq!(want_list.entries[0].cid, cid1);
274                assert_eq!(want_list.entries[0].priority, 10);
275                assert_eq!(want_list.entries[1].cid, cid2);
276                assert_eq!(want_list.entries[1].priority, 5);
277            }
278            _ => panic!("Wrong message type"),
279        }
280    }
281
282    #[test]
283    fn test_block_message_serialization_roundtrip() {
284        let cid = test_cid();
285        let data = vec![1, 2, 3, 4, 5];
286
287        let msg = Message::block(cid, data.clone());
288        let bytes = msg.to_bytes().unwrap();
289        let decoded = Message::from_bytes(&bytes).unwrap();
290
291        match decoded {
292            Message::Block(block) => {
293                assert_eq!(block.cid, cid);
294                assert_eq!(block.data, data);
295            }
296            _ => panic!("Wrong message type"),
297        }
298    }
299
300    #[test]
301    fn test_have_message_serialization_roundtrip() {
302        let cid = test_cid();
303
304        let msg = Message::have(cid);
305        let bytes = msg.to_bytes().unwrap();
306        let decoded = Message::from_bytes(&bytes).unwrap();
307
308        match decoded {
309            Message::Have(have) => assert_eq!(have.cid, cid),
310            _ => panic!("Wrong message type"),
311        }
312    }
313
314    #[test]
315    fn test_dont_have_message_serialization_roundtrip() {
316        let cid = test_cid();
317
318        let msg = Message::dont_have(cid);
319        let bytes = msg.to_bytes().unwrap();
320        let decoded = Message::from_bytes(&bytes).unwrap();
321
322        match decoded {
323            Message::DontHave(dont_have) => assert_eq!(dont_have.cid, cid),
324            _ => panic!("Wrong message type"),
325        }
326    }
327
328    #[test]
329    fn test_cancel_message_serialization_roundtrip() {
330        let cid = test_cid();
331
332        let msg = Message::cancel(cid);
333        let bytes = msg.to_bytes().unwrap();
334        let decoded = Message::from_bytes(&bytes).unwrap();
335
336        match decoded {
337            Message::Cancel(cancel) => assert_eq!(cancel.cid, cid),
338            _ => panic!("Wrong message type"),
339        }
340    }
341
342    // Edge Case Tests
343    #[test]
344    fn test_empty_want_list() {
345        let msg = Message::want_list(vec![], false);
346        let bytes = msg.to_bytes().unwrap();
347        let decoded = Message::from_bytes(&bytes).unwrap();
348
349        match decoded {
350            Message::WantList(want_list) => {
351                assert!(!want_list.full);
352                assert_eq!(want_list.entries.len(), 0);
353            }
354            _ => panic!("Wrong message type"),
355        }
356    }
357
358    #[test]
359    fn test_block_with_empty_data() {
360        let cid = test_cid();
361        let msg = Message::block(cid, vec![]);
362        let bytes = msg.to_bytes().unwrap();
363        let decoded = Message::from_bytes(&bytes).unwrap();
364
365        match decoded {
366            Message::Block(block) => {
367                assert_eq!(block.cid, cid);
368                assert_eq!(block.data.len(), 0);
369            }
370            _ => panic!("Wrong message type"),
371        }
372    }
373
374    #[test]
375    fn test_block_with_large_data() {
376        let cid = test_cid();
377        let large_data = vec![42u8; 1_000_000]; // 1 MB
378        let msg = Message::block(cid, large_data.clone());
379        let bytes = msg.to_bytes().unwrap();
380        let decoded = Message::from_bytes(&bytes).unwrap();
381
382        match decoded {
383            Message::Block(block) => {
384                assert_eq!(block.cid, cid);
385                assert_eq!(block.data.len(), 1_000_000);
386                assert_eq!(block.data, large_data);
387            }
388            _ => panic!("Wrong message type"),
389        }
390    }
391
392    #[test]
393    fn test_want_list_with_many_entries() {
394        let cid = test_cid();
395        let entries: Vec<WantEntry> = (0..1000)
396            .map(|i| WantEntry::with_priority(cid, i))
397            .collect();
398
399        let msg = Message::want_list(entries, true);
400        let bytes = msg.to_bytes().unwrap();
401        let decoded = Message::from_bytes(&bytes).unwrap();
402
403        match decoded {
404            Message::WantList(want_list) => {
405                assert_eq!(want_list.entries.len(), 1000);
406                assert_eq!(want_list.entries[500].priority, 500);
407            }
408            _ => panic!("Wrong message type"),
409        }
410    }
411
412    #[test]
413    fn test_want_entry_with_all_flags() {
414        let cid = test_cid();
415        let mut entry = WantEntry::with_priority(cid, 100);
416        entry.send_dont_have = true;
417        entry.cancel = true;
418
419        let msg = Message::want_list(vec![entry], false);
420        let bytes = msg.to_bytes().unwrap();
421        let decoded = Message::from_bytes(&bytes).unwrap();
422
423        match decoded {
424            Message::WantList(want_list) => {
425                assert_eq!(want_list.entries[0].priority, 100);
426                assert!(want_list.entries[0].send_dont_have);
427                assert!(want_list.entries[0].cancel);
428            }
429            _ => panic!("Wrong message type"),
430        }
431    }
432
433    // Malformed Input Tests
434    #[test]
435    fn test_invalid_message_bytes() {
436        let invalid_bytes = vec![0xFF, 0xFF, 0xFF, 0xFF];
437        let result = Message::from_bytes(&invalid_bytes);
438        assert!(result.is_err());
439    }
440
441    #[test]
442    fn test_empty_bytes() {
443        let empty_bytes: Vec<u8> = vec![];
444        let result = Message::from_bytes(&empty_bytes);
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn test_truncated_message() {
450        let cid = test_cid();
451        let msg = Message::have(cid);
452        let bytes = msg.to_bytes().unwrap();
453
454        // Take only first half of bytes
455        let truncated = &bytes[..bytes.len() / 2];
456        let result = Message::from_bytes(truncated);
457        assert!(result.is_err());
458    }
459
460    #[test]
461    fn test_corrupted_message() {
462        let cid = test_cid();
463        let msg = Message::have(cid);
464        let mut bytes = msg.to_bytes().unwrap();
465
466        // Corrupt some bytes
467        if bytes.len() > 10 {
468            bytes[5] = !bytes[5];
469            bytes[10] = !bytes[10];
470        }
471
472        // May or may not deserialize, but shouldn't panic
473        let _ = Message::from_bytes(&bytes);
474    }
475
476    // JSON Serialization Tests
477    #[test]
478    fn test_json_serialization_want_list() {
479        let cid = test_cid();
480        let entries = vec![WantEntry::with_priority(cid, 10)];
481        let msg = Message::want_list(entries, true);
482
483        let json = serde_json::to_string(&msg).unwrap();
484        let decoded: Message = serde_json::from_str(&json).unwrap();
485
486        match decoded {
487            Message::WantList(want_list) => {
488                assert!(want_list.full);
489                assert_eq!(want_list.entries.len(), 1);
490                assert_eq!(want_list.entries[0].priority, 10);
491            }
492            _ => panic!("Wrong message type"),
493        }
494    }
495
496    #[test]
497    fn test_json_serialization_block() {
498        let cid = test_cid();
499        let data = vec![1, 2, 3];
500        let msg = Message::block(cid, data.clone());
501
502        let json = serde_json::to_string(&msg).unwrap();
503        let decoded: Message = serde_json::from_str(&json).unwrap();
504
505        match decoded {
506            Message::Block(block) => {
507                assert_eq!(block.cid, cid);
508                assert_eq!(block.data, data);
509            }
510            _ => panic!("Wrong message type"),
511        }
512    }
513
514    #[test]
515    fn test_json_serialization_have() {
516        let cid = test_cid();
517        let msg = Message::have(cid);
518
519        let json = serde_json::to_string(&msg).unwrap();
520        let decoded: Message = serde_json::from_str(&json).unwrap();
521
522        match decoded {
523            Message::Have(have) => assert_eq!(have.cid, cid),
524            _ => panic!("Wrong message type"),
525        }
526    }
527
528    #[test]
529    fn test_json_serialization_dont_have() {
530        let cid = test_cid();
531        let msg = Message::dont_have(cid);
532
533        let json = serde_json::to_string(&msg).unwrap();
534        let decoded: Message = serde_json::from_str(&json).unwrap();
535
536        match decoded {
537            Message::DontHave(dont_have) => assert_eq!(dont_have.cid, cid),
538            _ => panic!("Wrong message type"),
539        }
540    }
541
542    #[test]
543    fn test_json_serialization_cancel() {
544        let cid = test_cid();
545        let msg = Message::cancel(cid);
546
547        let json = serde_json::to_string(&msg).unwrap();
548        let decoded: Message = serde_json::from_str(&json).unwrap();
549
550        match decoded {
551            Message::Cancel(cancel) => assert_eq!(cancel.cid, cid),
552            _ => panic!("Wrong message type"),
553        }
554    }
555
556    #[test]
557    fn test_invalid_json() {
558        let invalid_json = r#"{"invalid": "structure"}"#;
559        let result: Result<Message, _> = serde_json::from_str(invalid_json);
560        assert!(result.is_err());
561    }
562
563    #[test]
564    fn test_invalid_cid_in_json() {
565        let invalid_json = r#"{"Have":{"cid":"not-a-valid-cid"}}"#;
566        let result: Result<Message, _> = serde_json::from_str(invalid_json);
567        assert!(result.is_err());
568    }
569
570    // WantEntry Equality Tests
571    #[test]
572    fn test_want_entry_equality() {
573        let cid = test_cid();
574        let entry1 = WantEntry::with_priority(cid, 10);
575        let entry2 = WantEntry::with_priority(cid, 10);
576        assert_eq!(entry1, entry2);
577
578        let entry3 = WantEntry::with_priority(cid, 20);
579        assert_ne!(entry1, entry3);
580
581        let cid2 = test_cid2();
582        let entry4 = WantEntry::with_priority(cid2, 10);
583        assert_ne!(entry1, entry4);
584    }
585}