1use ipfrs_core::Cid;
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39
40fn serialize_cid<S>(cid: &Cid, serializer: S) -> Result<S::Ok, S::Error>
42where
43 S: Serializer,
44{
45 serializer.serialize_str(&cid.to_string())
46}
47
48fn deserialize_cid<'de, D>(deserializer: D) -> Result<Cid, D::Error>
50where
51 D: Deserializer<'de>,
52{
53 let s = String::deserialize(deserializer)?;
54 s.parse().map_err(serde::de::Error::custom)
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum Message {
60 WantList(WantList),
62 Block(BlockMessage),
64 Have(HaveMessage),
66 DontHave(DontHaveMessage),
68 Cancel(CancelMessage),
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct WantList {
75 pub entries: Vec<WantEntry>,
77 pub full: bool,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
83pub struct WantEntry {
84 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
86 pub cid: Cid,
87 pub priority: i32,
89 pub send_dont_have: bool,
91 pub cancel: bool,
93}
94
95impl WantEntry {
96 pub fn new(cid: Cid) -> Self {
98 Self {
99 cid,
100 priority: 0,
101 send_dont_have: false,
102 cancel: false,
103 }
104 }
105
106 pub fn with_priority(cid: Cid, priority: i32) -> Self {
108 Self {
109 cid,
110 priority,
111 send_dont_have: false,
112 cancel: false,
113 }
114 }
115
116 pub fn cancel(cid: Cid) -> Self {
118 Self {
119 cid,
120 priority: 0,
121 send_dont_have: false,
122 cancel: true,
123 }
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct BlockMessage {
130 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
132 pub cid: Cid,
133 pub data: Vec<u8>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct HaveMessage {
140 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
142 pub cid: Cid,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DontHaveMessage {
148 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
150 pub cid: Cid,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct CancelMessage {
156 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
158 pub cid: Cid,
159}
160
161impl Message {
162 pub fn want_list(entries: Vec<WantEntry>, full: bool) -> Self {
164 Message::WantList(WantList { entries, full })
165 }
166
167 pub fn block(cid: Cid, data: Vec<u8>) -> Self {
169 Message::Block(BlockMessage { cid, data })
170 }
171
172 pub fn have(cid: Cid) -> Self {
174 Message::Have(HaveMessage { cid })
175 }
176
177 pub fn dont_have(cid: Cid) -> Self {
179 Message::DontHave(DontHaveMessage { cid })
180 }
181
182 pub fn cancel(cid: Cid) -> Self {
184 Message::Cancel(CancelMessage { cid })
185 }
186
187 pub fn to_bytes(&self) -> Result<Vec<u8>, oxicode::Error> {
189 oxicode::serde::encode_to_vec(self, oxicode::config::standard())
190 }
191
192 pub fn from_bytes(data: &[u8]) -> Result<Self, oxicode::Error> {
194 oxicode::serde::decode_owned_from_slice(data, oxicode::config::standard()).map(|(v, _)| v)
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 fn test_cid() -> Cid {
203 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
204 .parse::<Cid>()
205 .unwrap()
206 }
207
208 fn test_cid2() -> Cid {
209 "bafybeihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
210 .parse::<Cid>()
211 .unwrap()
212 }
213
214 #[test]
216 fn test_want_entry_creation() {
217 let cid = test_cid();
218
219 let entry = WantEntry::new(cid);
220 assert_eq!(entry.priority, 0);
221 assert!(!entry.cancel);
222 assert!(!entry.send_dont_have);
223
224 let priority_entry = WantEntry::with_priority(cid, 10);
225 assert_eq!(priority_entry.priority, 10);
226 assert!(!priority_entry.cancel);
227
228 let cancel_entry = WantEntry::cancel(cid);
229 assert!(cancel_entry.cancel);
230 assert_eq!(cancel_entry.priority, 0);
231 }
232
233 #[test]
234 fn test_want_entry_edge_cases() {
235 let cid = test_cid();
236
237 let max_entry = WantEntry::with_priority(cid, i32::MAX);
239 assert_eq!(max_entry.priority, i32::MAX);
240
241 let min_entry = WantEntry::with_priority(cid, i32::MIN);
243 assert_eq!(min_entry.priority, i32::MIN);
244
245 let zero_entry = WantEntry::with_priority(cid, 0);
247 assert_eq!(zero_entry.priority, 0);
248
249 let neg_entry = WantEntry::with_priority(cid, -100);
251 assert_eq!(neg_entry.priority, -100);
252 }
253
254 #[test]
256 fn test_want_list_serialization_roundtrip() {
257 let cid1 = test_cid();
258 let cid2 = test_cid2();
259
260 let entries = vec![
261 WantEntry::with_priority(cid1, 10),
262 WantEntry::with_priority(cid2, 5),
263 ];
264
265 let msg = Message::want_list(entries.clone(), true);
266 let bytes = msg.to_bytes().unwrap();
267 let decoded = Message::from_bytes(&bytes).unwrap();
268
269 match decoded {
270 Message::WantList(want_list) => {
271 assert!(want_list.full);
272 assert_eq!(want_list.entries.len(), 2);
273 assert_eq!(want_list.entries[0].cid, cid1);
274 assert_eq!(want_list.entries[0].priority, 10);
275 assert_eq!(want_list.entries[1].cid, cid2);
276 assert_eq!(want_list.entries[1].priority, 5);
277 }
278 _ => panic!("Wrong message type"),
279 }
280 }
281
282 #[test]
283 fn test_block_message_serialization_roundtrip() {
284 let cid = test_cid();
285 let data = vec![1, 2, 3, 4, 5];
286
287 let msg = Message::block(cid, data.clone());
288 let bytes = msg.to_bytes().unwrap();
289 let decoded = Message::from_bytes(&bytes).unwrap();
290
291 match decoded {
292 Message::Block(block) => {
293 assert_eq!(block.cid, cid);
294 assert_eq!(block.data, data);
295 }
296 _ => panic!("Wrong message type"),
297 }
298 }
299
300 #[test]
301 fn test_have_message_serialization_roundtrip() {
302 let cid = test_cid();
303
304 let msg = Message::have(cid);
305 let bytes = msg.to_bytes().unwrap();
306 let decoded = Message::from_bytes(&bytes).unwrap();
307
308 match decoded {
309 Message::Have(have) => assert_eq!(have.cid, cid),
310 _ => panic!("Wrong message type"),
311 }
312 }
313
314 #[test]
315 fn test_dont_have_message_serialization_roundtrip() {
316 let cid = test_cid();
317
318 let msg = Message::dont_have(cid);
319 let bytes = msg.to_bytes().unwrap();
320 let decoded = Message::from_bytes(&bytes).unwrap();
321
322 match decoded {
323 Message::DontHave(dont_have) => assert_eq!(dont_have.cid, cid),
324 _ => panic!("Wrong message type"),
325 }
326 }
327
328 #[test]
329 fn test_cancel_message_serialization_roundtrip() {
330 let cid = test_cid();
331
332 let msg = Message::cancel(cid);
333 let bytes = msg.to_bytes().unwrap();
334 let decoded = Message::from_bytes(&bytes).unwrap();
335
336 match decoded {
337 Message::Cancel(cancel) => assert_eq!(cancel.cid, cid),
338 _ => panic!("Wrong message type"),
339 }
340 }
341
342 #[test]
344 fn test_empty_want_list() {
345 let msg = Message::want_list(vec![], false);
346 let bytes = msg.to_bytes().unwrap();
347 let decoded = Message::from_bytes(&bytes).unwrap();
348
349 match decoded {
350 Message::WantList(want_list) => {
351 assert!(!want_list.full);
352 assert_eq!(want_list.entries.len(), 0);
353 }
354 _ => panic!("Wrong message type"),
355 }
356 }
357
358 #[test]
359 fn test_block_with_empty_data() {
360 let cid = test_cid();
361 let msg = Message::block(cid, vec![]);
362 let bytes = msg.to_bytes().unwrap();
363 let decoded = Message::from_bytes(&bytes).unwrap();
364
365 match decoded {
366 Message::Block(block) => {
367 assert_eq!(block.cid, cid);
368 assert_eq!(block.data.len(), 0);
369 }
370 _ => panic!("Wrong message type"),
371 }
372 }
373
374 #[test]
375 fn test_block_with_large_data() {
376 let cid = test_cid();
377 let large_data = vec![42u8; 1_000_000]; let msg = Message::block(cid, large_data.clone());
379 let bytes = msg.to_bytes().unwrap();
380 let decoded = Message::from_bytes(&bytes).unwrap();
381
382 match decoded {
383 Message::Block(block) => {
384 assert_eq!(block.cid, cid);
385 assert_eq!(block.data.len(), 1_000_000);
386 assert_eq!(block.data, large_data);
387 }
388 _ => panic!("Wrong message type"),
389 }
390 }
391
392 #[test]
393 fn test_want_list_with_many_entries() {
394 let cid = test_cid();
395 let entries: Vec<WantEntry> = (0..1000)
396 .map(|i| WantEntry::with_priority(cid, i))
397 .collect();
398
399 let msg = Message::want_list(entries, true);
400 let bytes = msg.to_bytes().unwrap();
401 let decoded = Message::from_bytes(&bytes).unwrap();
402
403 match decoded {
404 Message::WantList(want_list) => {
405 assert_eq!(want_list.entries.len(), 1000);
406 assert_eq!(want_list.entries[500].priority, 500);
407 }
408 _ => panic!("Wrong message type"),
409 }
410 }
411
412 #[test]
413 fn test_want_entry_with_all_flags() {
414 let cid = test_cid();
415 let mut entry = WantEntry::with_priority(cid, 100);
416 entry.send_dont_have = true;
417 entry.cancel = true;
418
419 let msg = Message::want_list(vec![entry], false);
420 let bytes = msg.to_bytes().unwrap();
421 let decoded = Message::from_bytes(&bytes).unwrap();
422
423 match decoded {
424 Message::WantList(want_list) => {
425 assert_eq!(want_list.entries[0].priority, 100);
426 assert!(want_list.entries[0].send_dont_have);
427 assert!(want_list.entries[0].cancel);
428 }
429 _ => panic!("Wrong message type"),
430 }
431 }
432
433 #[test]
435 fn test_invalid_message_bytes() {
436 let invalid_bytes = vec![0xFF, 0xFF, 0xFF, 0xFF];
437 let result = Message::from_bytes(&invalid_bytes);
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn test_empty_bytes() {
443 let empty_bytes: Vec<u8> = vec![];
444 let result = Message::from_bytes(&empty_bytes);
445 assert!(result.is_err());
446 }
447
448 #[test]
449 fn test_truncated_message() {
450 let cid = test_cid();
451 let msg = Message::have(cid);
452 let bytes = msg.to_bytes().unwrap();
453
454 let truncated = &bytes[..bytes.len() / 2];
456 let result = Message::from_bytes(truncated);
457 assert!(result.is_err());
458 }
459
460 #[test]
461 fn test_corrupted_message() {
462 let cid = test_cid();
463 let msg = Message::have(cid);
464 let mut bytes = msg.to_bytes().unwrap();
465
466 if bytes.len() > 10 {
468 bytes[5] = !bytes[5];
469 bytes[10] = !bytes[10];
470 }
471
472 let _ = Message::from_bytes(&bytes);
474 }
475
476 #[test]
478 fn test_json_serialization_want_list() {
479 let cid = test_cid();
480 let entries = vec![WantEntry::with_priority(cid, 10)];
481 let msg = Message::want_list(entries, true);
482
483 let json = serde_json::to_string(&msg).unwrap();
484 let decoded: Message = serde_json::from_str(&json).unwrap();
485
486 match decoded {
487 Message::WantList(want_list) => {
488 assert!(want_list.full);
489 assert_eq!(want_list.entries.len(), 1);
490 assert_eq!(want_list.entries[0].priority, 10);
491 }
492 _ => panic!("Wrong message type"),
493 }
494 }
495
496 #[test]
497 fn test_json_serialization_block() {
498 let cid = test_cid();
499 let data = vec![1, 2, 3];
500 let msg = Message::block(cid, data.clone());
501
502 let json = serde_json::to_string(&msg).unwrap();
503 let decoded: Message = serde_json::from_str(&json).unwrap();
504
505 match decoded {
506 Message::Block(block) => {
507 assert_eq!(block.cid, cid);
508 assert_eq!(block.data, data);
509 }
510 _ => panic!("Wrong message type"),
511 }
512 }
513
514 #[test]
515 fn test_json_serialization_have() {
516 let cid = test_cid();
517 let msg = Message::have(cid);
518
519 let json = serde_json::to_string(&msg).unwrap();
520 let decoded: Message = serde_json::from_str(&json).unwrap();
521
522 match decoded {
523 Message::Have(have) => assert_eq!(have.cid, cid),
524 _ => panic!("Wrong message type"),
525 }
526 }
527
528 #[test]
529 fn test_json_serialization_dont_have() {
530 let cid = test_cid();
531 let msg = Message::dont_have(cid);
532
533 let json = serde_json::to_string(&msg).unwrap();
534 let decoded: Message = serde_json::from_str(&json).unwrap();
535
536 match decoded {
537 Message::DontHave(dont_have) => assert_eq!(dont_have.cid, cid),
538 _ => panic!("Wrong message type"),
539 }
540 }
541
542 #[test]
543 fn test_json_serialization_cancel() {
544 let cid = test_cid();
545 let msg = Message::cancel(cid);
546
547 let json = serde_json::to_string(&msg).unwrap();
548 let decoded: Message = serde_json::from_str(&json).unwrap();
549
550 match decoded {
551 Message::Cancel(cancel) => assert_eq!(cancel.cid, cid),
552 _ => panic!("Wrong message type"),
553 }
554 }
555
556 #[test]
557 fn test_invalid_json() {
558 let invalid_json = r#"{"invalid": "structure"}"#;
559 let result: Result<Message, _> = serde_json::from_str(invalid_json);
560 assert!(result.is_err());
561 }
562
563 #[test]
564 fn test_invalid_cid_in_json() {
565 let invalid_json = r#"{"Have":{"cid":"not-a-valid-cid"}}"#;
566 let result: Result<Message, _> = serde_json::from_str(invalid_json);
567 assert!(result.is_err());
568 }
569
570 #[test]
572 fn test_want_entry_equality() {
573 let cid = test_cid();
574 let entry1 = WantEntry::with_priority(cid, 10);
575 let entry2 = WantEntry::with_priority(cid, 10);
576 assert_eq!(entry1, entry2);
577
578 let entry3 = WantEntry::with_priority(cid, 20);
579 assert_ne!(entry1, entry3);
580
581 let cid2 = test_cid2();
582 let entry4 = WantEntry::with_priority(cid2, 10);
583 assert_ne!(entry1, entry4);
584 }
585}