1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use guts_storage::{GitObject, ObjectId, ObjectType};
5
6use crate::{P2PError, Result};
7
8const MAX_OBJECTS_PER_MESSAGE: usize = 100_000;
11
12const MAX_REFS_PER_MESSAGE: usize = 10_000;
14
15const MAX_OBJECT_DATA_SIZE: usize = 100 * 1024 * 1024;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(u8)]
21pub enum MessageType {
22 RepoAnnounce = 1,
24 SyncRequest = 2,
26 ObjectData = 3,
28 RefUpdate = 4,
30}
31
32impl MessageType {
33 pub fn from_byte(b: u8) -> Result<Self> {
35 match b {
36 1 => Ok(MessageType::RepoAnnounce),
37 2 => Ok(MessageType::SyncRequest),
38 3 => Ok(MessageType::ObjectData),
39 4 => Ok(MessageType::RefUpdate),
40 _ => Err(P2PError::InvalidMessage(format!(
41 "unknown message type: {}",
42 b
43 ))),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
52pub struct RepoAnnounce {
53 pub repo_key: String,
55 pub object_ids: Vec<ObjectId>,
57 pub refs: Vec<(String, ObjectId)>,
59}
60
61impl RepoAnnounce {
62 pub fn encode(&self) -> Bytes {
64 let mut buf = BytesMut::new();
65 buf.put_u8(MessageType::RepoAnnounce as u8);
66
67 let repo_bytes = self.repo_key.as_bytes();
69 buf.put_u16(repo_bytes.len() as u16);
70 buf.put_slice(repo_bytes);
71
72 buf.put_u32(self.object_ids.len() as u32);
74 for oid in &self.object_ids {
75 buf.put_slice(oid.as_bytes());
76 }
77
78 buf.put_u32(self.refs.len() as u32);
80 for (name, oid) in &self.refs {
81 let name_bytes = name.as_bytes();
82 buf.put_u16(name_bytes.len() as u16);
83 buf.put_slice(name_bytes);
84 buf.put_slice(oid.as_bytes());
85 }
86
87 buf.freeze()
88 }
89
90 pub fn decode(mut buf: &[u8]) -> Result<Self> {
92 if buf.remaining() < 2 {
94 return Err(P2PError::InvalidMessage("truncated repo key length".into()));
95 }
96 let repo_len = buf.get_u16() as usize;
97 if buf.remaining() < repo_len {
98 return Err(P2PError::InvalidMessage("truncated repo key".into()));
99 }
100 let repo_key = String::from_utf8(buf[..repo_len].to_vec())
101 .map_err(|e| P2PError::InvalidMessage(format!("invalid repo key: {}", e)))?;
102 buf.advance(repo_len);
103
104 if buf.remaining() < 4 {
106 return Err(P2PError::InvalidMessage("truncated object count".into()));
107 }
108 let obj_count = buf.get_u32() as usize;
109 if obj_count > MAX_OBJECTS_PER_MESSAGE {
110 return Err(P2PError::InvalidMessage(format!(
111 "object count {} exceeds maximum {}",
112 obj_count, MAX_OBJECTS_PER_MESSAGE
113 )));
114 }
115 let mut object_ids = Vec::with_capacity(obj_count);
116 for _ in 0..obj_count {
117 if buf.remaining() < 20 {
118 return Err(P2PError::InvalidMessage("truncated object id".into()));
119 }
120 let mut oid_bytes = [0u8; 20];
121 buf.copy_to_slice(&mut oid_bytes);
122 object_ids.push(ObjectId::from_bytes(oid_bytes));
123 }
124
125 if buf.remaining() < 4 {
127 return Err(P2PError::InvalidMessage("truncated ref count".into()));
128 }
129 let ref_count = buf.get_u32() as usize;
130 if ref_count > MAX_REFS_PER_MESSAGE {
131 return Err(P2PError::InvalidMessage(format!(
132 "ref count {} exceeds maximum {}",
133 ref_count, MAX_REFS_PER_MESSAGE
134 )));
135 }
136 let mut refs = Vec::with_capacity(ref_count);
137 for _ in 0..ref_count {
138 if buf.remaining() < 2 {
139 return Err(P2PError::InvalidMessage("truncated ref name length".into()));
140 }
141 let name_len = buf.get_u16() as usize;
142 if buf.remaining() < name_len + 20 {
143 return Err(P2PError::InvalidMessage("truncated ref data".into()));
144 }
145 let name = String::from_utf8(buf[..name_len].to_vec())
146 .map_err(|e| P2PError::InvalidMessage(format!("invalid ref name: {}", e)))?;
147 buf.advance(name_len);
148
149 let mut oid_bytes = [0u8; 20];
150 buf.copy_to_slice(&mut oid_bytes);
151 refs.push((name, ObjectId::from_bytes(oid_bytes)));
152 }
153
154 Ok(RepoAnnounce {
155 repo_key,
156 object_ids,
157 refs,
158 })
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct SyncRequest {
165 pub repo_key: String,
167 pub want: Vec<ObjectId>,
169}
170
171impl SyncRequest {
172 pub fn encode(&self) -> Bytes {
174 let mut buf = BytesMut::new();
175 buf.put_u8(MessageType::SyncRequest as u8);
176
177 let repo_bytes = self.repo_key.as_bytes();
179 buf.put_u16(repo_bytes.len() as u16);
180 buf.put_slice(repo_bytes);
181
182 buf.put_u32(self.want.len() as u32);
184 for oid in &self.want {
185 buf.put_slice(oid.as_bytes());
186 }
187
188 buf.freeze()
189 }
190
191 pub fn decode(mut buf: &[u8]) -> Result<Self> {
193 if buf.remaining() < 2 {
195 return Err(P2PError::InvalidMessage("truncated repo key length".into()));
196 }
197 let repo_len = buf.get_u16() as usize;
198 if buf.remaining() < repo_len {
199 return Err(P2PError::InvalidMessage("truncated repo key".into()));
200 }
201 let repo_key = String::from_utf8(buf[..repo_len].to_vec())
202 .map_err(|e| P2PError::InvalidMessage(format!("invalid repo key: {}", e)))?;
203 buf.advance(repo_len);
204
205 if buf.remaining() < 4 {
207 return Err(P2PError::InvalidMessage("truncated want count".into()));
208 }
209 let want_count = buf.get_u32() as usize;
210 if want_count > MAX_OBJECTS_PER_MESSAGE {
211 return Err(P2PError::InvalidMessage(format!(
212 "want count {} exceeds maximum {}",
213 want_count, MAX_OBJECTS_PER_MESSAGE
214 )));
215 }
216 let mut want = Vec::with_capacity(want_count);
217 for _ in 0..want_count {
218 if buf.remaining() < 20 {
219 return Err(P2PError::InvalidMessage("truncated object id".into()));
220 }
221 let mut oid_bytes = [0u8; 20];
222 buf.copy_to_slice(&mut oid_bytes);
223 want.push(ObjectId::from_bytes(oid_bytes));
224 }
225
226 Ok(SyncRequest { repo_key, want })
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct ObjectData {
233 pub repo_key: String,
235 pub objects: Vec<GitObject>,
237}
238
239impl ObjectData {
240 pub fn encode(&self) -> Bytes {
242 let mut buf = BytesMut::new();
243 buf.put_u8(MessageType::ObjectData as u8);
244
245 let repo_bytes = self.repo_key.as_bytes();
247 buf.put_u16(repo_bytes.len() as u16);
248 buf.put_slice(repo_bytes);
249
250 buf.put_u32(self.objects.len() as u32);
252 for obj in &self.objects {
253 buf.put_u8(match obj.object_type {
255 ObjectType::Blob => 1,
256 ObjectType::Tree => 2,
257 ObjectType::Commit => 3,
258 ObjectType::Tag => 4,
259 });
260 buf.put_u32(obj.data.len() as u32);
262 buf.put_slice(&obj.data);
263 }
264
265 buf.freeze()
266 }
267
268 pub fn decode(mut buf: &[u8]) -> Result<Self> {
270 if buf.remaining() < 2 {
272 return Err(P2PError::InvalidMessage("truncated repo key length".into()));
273 }
274 let repo_len = buf.get_u16() as usize;
275 if buf.remaining() < repo_len {
276 return Err(P2PError::InvalidMessage("truncated repo key".into()));
277 }
278 let repo_key = String::from_utf8(buf[..repo_len].to_vec())
279 .map_err(|e| P2PError::InvalidMessage(format!("invalid repo key: {}", e)))?;
280 buf.advance(repo_len);
281
282 if buf.remaining() < 4 {
284 return Err(P2PError::InvalidMessage("truncated object count".into()));
285 }
286 let obj_count = buf.get_u32() as usize;
287 if obj_count > MAX_OBJECTS_PER_MESSAGE {
288 return Err(P2PError::InvalidMessage(format!(
289 "object count {} exceeds maximum {}",
290 obj_count, MAX_OBJECTS_PER_MESSAGE
291 )));
292 }
293 let mut objects = Vec::with_capacity(obj_count);
294 for _ in 0..obj_count {
295 if buf.remaining() < 5 {
296 return Err(P2PError::InvalidMessage("truncated object header".into()));
297 }
298 let obj_type = match buf.get_u8() {
299 1 => ObjectType::Blob,
300 2 => ObjectType::Tree,
301 3 => ObjectType::Commit,
302 4 => ObjectType::Tag,
303 t => {
304 return Err(P2PError::InvalidMessage(format!(
305 "invalid object type: {}",
306 t
307 )))
308 }
309 };
310 let data_len = buf.get_u32() as usize;
311 if data_len > MAX_OBJECT_DATA_SIZE {
312 return Err(P2PError::InvalidMessage(format!(
313 "object data size {} exceeds maximum {}",
314 data_len, MAX_OBJECT_DATA_SIZE
315 )));
316 }
317 if buf.remaining() < data_len {
318 return Err(P2PError::InvalidMessage("truncated object data".into()));
319 }
320 let data = Bytes::copy_from_slice(&buf[..data_len]);
321 buf.advance(data_len);
322 objects.push(GitObject::new(obj_type, data));
323 }
324
325 Ok(ObjectData { repo_key, objects })
326 }
327}
328
329#[derive(Debug, Clone)]
331pub struct RefUpdate {
332 pub repo_key: String,
334 pub ref_name: String,
336 pub old_id: ObjectId,
338 pub new_id: ObjectId,
340}
341
342impl RefUpdate {
343 pub fn encode(&self) -> Bytes {
345 let mut buf = BytesMut::new();
346 buf.put_u8(MessageType::RefUpdate as u8);
347
348 let repo_bytes = self.repo_key.as_bytes();
350 buf.put_u16(repo_bytes.len() as u16);
351 buf.put_slice(repo_bytes);
352
353 let ref_bytes = self.ref_name.as_bytes();
355 buf.put_u16(ref_bytes.len() as u16);
356 buf.put_slice(ref_bytes);
357
358 buf.put_slice(self.old_id.as_bytes());
360 buf.put_slice(self.new_id.as_bytes());
361
362 buf.freeze()
363 }
364
365 pub fn decode(mut buf: &[u8]) -> Result<Self> {
367 if buf.remaining() < 2 {
369 return Err(P2PError::InvalidMessage("truncated repo key length".into()));
370 }
371 let repo_len = buf.get_u16() as usize;
372 if buf.remaining() < repo_len {
373 return Err(P2PError::InvalidMessage("truncated repo key".into()));
374 }
375 let repo_key = String::from_utf8(buf[..repo_len].to_vec())
376 .map_err(|e| P2PError::InvalidMessage(format!("invalid repo key: {}", e)))?;
377 buf.advance(repo_len);
378
379 if buf.remaining() < 2 {
381 return Err(P2PError::InvalidMessage("truncated ref name length".into()));
382 }
383 let ref_len = buf.get_u16() as usize;
384 if buf.remaining() < ref_len + 40 {
385 return Err(P2PError::InvalidMessage("truncated ref data".into()));
386 }
387 let ref_name = String::from_utf8(buf[..ref_len].to_vec())
388 .map_err(|e| P2PError::InvalidMessage(format!("invalid ref name: {}", e)))?;
389 buf.advance(ref_len);
390
391 let mut old_bytes = [0u8; 20];
393 let mut new_bytes = [0u8; 20];
394 buf.copy_to_slice(&mut old_bytes);
395 buf.copy_to_slice(&mut new_bytes);
396
397 Ok(RefUpdate {
398 repo_key,
399 ref_name,
400 old_id: ObjectId::from_bytes(old_bytes),
401 new_id: ObjectId::from_bytes(new_bytes),
402 })
403 }
404}
405
406#[derive(Debug, Clone)]
408pub enum Message {
409 RepoAnnounce(RepoAnnounce),
410 SyncRequest(SyncRequest),
411 ObjectData(ObjectData),
412 RefUpdate(RefUpdate),
413}
414
415impl Message {
416 pub fn encode(&self) -> Bytes {
418 match self {
419 Message::RepoAnnounce(m) => m.encode(),
420 Message::SyncRequest(m) => m.encode(),
421 Message::ObjectData(m) => m.encode(),
422 Message::RefUpdate(m) => m.encode(),
423 }
424 }
425
426 pub fn decode(data: &[u8]) -> Result<Self> {
428 if data.is_empty() {
429 return Err(P2PError::InvalidMessage("empty message".into()));
430 }
431
432 let msg_type = MessageType::from_byte(data[0])?;
433 let payload = &data[1..];
434
435 match msg_type {
436 MessageType::RepoAnnounce => Ok(Message::RepoAnnounce(RepoAnnounce::decode(payload)?)),
437 MessageType::SyncRequest => Ok(Message::SyncRequest(SyncRequest::decode(payload)?)),
438 MessageType::ObjectData => Ok(Message::ObjectData(ObjectData::decode(payload)?)),
439 MessageType::RefUpdate => Ok(Message::RefUpdate(RefUpdate::decode(payload)?)),
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_repo_announce_roundtrip() {
450 let msg = RepoAnnounce {
451 repo_key: "alice/test-repo".to_string(),
452 object_ids: vec![
453 ObjectId::from_bytes([1u8; 20]),
454 ObjectId::from_bytes([2u8; 20]),
455 ],
456 refs: vec![(
457 "refs/heads/main".to_string(),
458 ObjectId::from_bytes([3u8; 20]),
459 )],
460 };
461
462 let encoded = msg.encode();
463 let decoded = Message::decode(&encoded).unwrap();
464
465 match decoded {
466 Message::RepoAnnounce(d) => {
467 assert_eq!(d.repo_key, msg.repo_key);
468 assert_eq!(d.object_ids.len(), 2);
469 assert_eq!(d.refs.len(), 1);
470 }
471 _ => panic!("wrong message type"),
472 }
473 }
474
475 #[test]
476 fn test_sync_request_roundtrip() {
477 let msg = SyncRequest {
478 repo_key: "bob/my-repo".to_string(),
479 want: vec![ObjectId::from_bytes([5u8; 20])],
480 };
481
482 let encoded = msg.encode();
483 let decoded = Message::decode(&encoded).unwrap();
484
485 match decoded {
486 Message::SyncRequest(d) => {
487 assert_eq!(d.repo_key, msg.repo_key);
488 assert_eq!(d.want.len(), 1);
489 }
490 _ => panic!("wrong message type"),
491 }
492 }
493
494 #[test]
495 fn test_object_data_roundtrip() {
496 let obj = GitObject::blob(b"hello world".to_vec());
497 let msg = ObjectData {
498 repo_key: "carol/repo".to_string(),
499 objects: vec![obj.clone()],
500 };
501
502 let encoded = msg.encode();
503 let decoded = Message::decode(&encoded).unwrap();
504
505 match decoded {
506 Message::ObjectData(d) => {
507 assert_eq!(d.repo_key, msg.repo_key);
508 assert_eq!(d.objects.len(), 1);
509 assert_eq!(d.objects[0].id, obj.id);
510 assert_eq!(d.objects[0].data, obj.data);
511 }
512 _ => panic!("wrong message type"),
513 }
514 }
515
516 #[test]
517 fn test_ref_update_roundtrip() {
518 let msg = RefUpdate {
519 repo_key: "dave/code".to_string(),
520 ref_name: "refs/heads/feature".to_string(),
521 old_id: ObjectId::from_bytes([0u8; 20]),
522 new_id: ObjectId::from_bytes([7u8; 20]),
523 };
524
525 let encoded = msg.encode();
526 let decoded = Message::decode(&encoded).unwrap();
527
528 match decoded {
529 Message::RefUpdate(d) => {
530 assert_eq!(d.repo_key, msg.repo_key);
531 assert_eq!(d.ref_name, msg.ref_name);
532 assert_eq!(d.old_id, msg.old_id);
533 assert_eq!(d.new_id, msg.new_id);
534 }
535 _ => panic!("wrong message type"),
536 }
537 }
538
539 #[test]
541 fn test_message_decode_empty() {
542 let result = Message::decode(&[]);
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_message_decode_invalid_type() {
548 let result = Message::decode(&[0xFF]);
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_repo_announce_truncated() {
554 let result = RepoAnnounce::decode(&[]);
556 assert!(result.is_err());
557
558 let result = RepoAnnounce::decode(&[0x00]);
560 assert!(result.is_err());
561
562 let result = RepoAnnounce::decode(&[0x00, 0x05]);
564 assert!(result.is_err());
565 }
566
567 #[test]
568 fn test_sync_request_truncated() {
569 let result = SyncRequest::decode(&[]);
570 assert!(result.is_err());
571 }
572
573 #[test]
574 fn test_object_data_truncated() {
575 let result = ObjectData::decode(&[]);
576 assert!(result.is_err());
577 }
578
579 #[test]
580 fn test_ref_update_truncated() {
581 let result = RefUpdate::decode(&[]);
582 assert!(result.is_err());
583 }
584
585 #[test]
586 fn test_object_data_invalid_type() {
587 let mut buf = bytes::BytesMut::new();
589 buf.put_u16(4); buf.put_slice(b"test");
591 buf.put_u32(1); buf.put_u8(99); buf.put_u32(5); buf.put_slice(b"hello");
595
596 let result = ObjectData::decode(&buf);
597 assert!(result.is_err());
598 }
599
600 #[test]
601 fn test_repo_announce_empty_lists() {
602 let msg = RepoAnnounce {
603 repo_key: "test/repo".to_string(),
604 object_ids: vec![],
605 refs: vec![],
606 };
607
608 let encoded = msg.encode();
609 let decoded = Message::decode(&encoded).unwrap();
610
611 match decoded {
612 Message::RepoAnnounce(d) => {
613 assert_eq!(d.repo_key, "test/repo");
614 assert!(d.object_ids.is_empty());
615 assert!(d.refs.is_empty());
616 }
617 _ => panic!("wrong message type"),
618 }
619 }
620
621 #[test]
622 fn test_message_type_from_byte() {
623 assert_eq!(
624 MessageType::from_byte(1).unwrap(),
625 MessageType::RepoAnnounce
626 );
627 assert_eq!(MessageType::from_byte(2).unwrap(), MessageType::SyncRequest);
628 assert_eq!(MessageType::from_byte(3).unwrap(), MessageType::ObjectData);
629 assert_eq!(MessageType::from_byte(4).unwrap(), MessageType::RefUpdate);
630 assert!(MessageType::from_byte(0).is_err());
631 assert!(MessageType::from_byte(5).is_err());
632 assert!(MessageType::from_byte(255).is_err());
633 }
634
635 #[test]
636 fn test_object_data_all_types() {
637 let objects = vec![
639 GitObject::blob(b"blob data".to_vec()),
640 GitObject::new(ObjectType::Tree, Bytes::from_static(b"tree data")),
641 GitObject::new(ObjectType::Commit, Bytes::from_static(b"commit data")),
642 GitObject::new(ObjectType::Tag, Bytes::from_static(b"tag data")),
643 ];
644
645 let msg = ObjectData {
646 repo_key: "test/repo".to_string(),
647 objects: objects.clone(),
648 };
649
650 let encoded = msg.encode();
651 let decoded = Message::decode(&encoded).unwrap();
652
653 match decoded {
654 Message::ObjectData(d) => {
655 assert_eq!(d.objects.len(), 4);
656 assert_eq!(d.objects[0].object_type, ObjectType::Blob);
657 assert_eq!(d.objects[1].object_type, ObjectType::Tree);
658 assert_eq!(d.objects[2].object_type, ObjectType::Commit);
659 assert_eq!(d.objects[3].object_type, ObjectType::Tag);
660 }
661 _ => panic!("wrong message type"),
662 }
663 }
664}
665
666#[cfg(test)]
667mod proptests {
668 use super::*;
669 use proptest::prelude::*;
670
671 fn object_id_strategy() -> impl Strategy<Value = ObjectId> {
673 prop::array::uniform20(any::<u8>()).prop_map(ObjectId::from_bytes)
674 }
675
676 fn repo_key_strategy() -> impl Strategy<Value = String> {
678 "[a-z][a-z0-9-]{0,30}/[a-z][a-z0-9-]{0,30}"
679 }
680
681 fn ref_name_strategy() -> impl Strategy<Value = String> {
683 prop_oneof![
684 Just("refs/heads/main".to_string()),
685 Just("refs/heads/develop".to_string()),
686 Just("refs/tags/v1.0.0".to_string()),
687 "[a-z/]{1,50}".prop_map(|s| format!("refs/{}", s)),
688 ]
689 }
690
691 proptest! {
692 #[test]
694 fn prop_repo_announce_roundtrip(
695 repo_key in repo_key_strategy(),
696 object_ids in prop::collection::vec(object_id_strategy(), 0..10),
697 refs in prop::collection::vec(
698 (ref_name_strategy(), object_id_strategy()),
699 0..10
700 )
701 ) {
702 let msg = RepoAnnounce {
703 repo_key: repo_key.clone(),
704 object_ids: object_ids.clone(),
705 refs: refs.clone(),
706 };
707
708 let encoded = msg.encode();
709 let decoded = Message::decode(&encoded).unwrap();
710
711 match decoded {
712 Message::RepoAnnounce(d) => {
713 prop_assert_eq!(d.repo_key, repo_key);
714 prop_assert_eq!(d.object_ids.len(), object_ids.len());
715 prop_assert_eq!(d.refs.len(), refs.len());
716 for (orig, dec) in object_ids.iter().zip(d.object_ids.iter()) {
717 prop_assert_eq!(orig, dec);
718 }
719 for ((orig_name, orig_id), (dec_name, dec_id)) in refs.iter().zip(d.refs.iter()) {
720 prop_assert_eq!(orig_name, dec_name);
721 prop_assert_eq!(orig_id, dec_id);
722 }
723 }
724 _ => prop_assert!(false, "wrong message type"),
725 }
726 }
727
728 #[test]
730 fn prop_sync_request_roundtrip(
731 repo_key in repo_key_strategy(),
732 want in prop::collection::vec(object_id_strategy(), 0..20)
733 ) {
734 let msg = SyncRequest {
735 repo_key: repo_key.clone(),
736 want: want.clone(),
737 };
738
739 let encoded = msg.encode();
740 let decoded = Message::decode(&encoded).unwrap();
741
742 match decoded {
743 Message::SyncRequest(d) => {
744 prop_assert_eq!(d.repo_key, repo_key);
745 prop_assert_eq!(d.want.len(), want.len());
746 for (orig, dec) in want.iter().zip(d.want.iter()) {
747 prop_assert_eq!(orig, dec);
748 }
749 }
750 _ => prop_assert!(false, "wrong message type"),
751 }
752 }
753
754 #[test]
756 fn prop_ref_update_roundtrip(
757 repo_key in repo_key_strategy(),
758 ref_name in ref_name_strategy(),
759 old_id in object_id_strategy(),
760 new_id in object_id_strategy()
761 ) {
762 let msg = RefUpdate {
763 repo_key: repo_key.clone(),
764 ref_name: ref_name.clone(),
765 old_id,
766 new_id,
767 };
768
769 let encoded = msg.encode();
770 let decoded = Message::decode(&encoded).unwrap();
771
772 match decoded {
773 Message::RefUpdate(d) => {
774 prop_assert_eq!(d.repo_key, repo_key);
775 prop_assert_eq!(d.ref_name, ref_name);
776 prop_assert_eq!(d.old_id, old_id);
777 prop_assert_eq!(d.new_id, new_id);
778 }
779 _ => prop_assert!(false, "wrong message type"),
780 }
781 }
782
783 #[test]
785 fn prop_object_data_roundtrip(
786 repo_key in repo_key_strategy(),
787 blobs in prop::collection::vec(prop::collection::vec(any::<u8>(), 0..1000), 0..5)
788 ) {
789 let objects: Vec<GitObject> = blobs.iter()
790 .map(|data| GitObject::blob(data.clone()))
791 .collect();
792
793 let msg = ObjectData {
794 repo_key: repo_key.clone(),
795 objects: objects.clone(),
796 };
797
798 let encoded = msg.encode();
799 let decoded = Message::decode(&encoded).unwrap();
800
801 match decoded {
802 Message::ObjectData(d) => {
803 prop_assert_eq!(d.repo_key, repo_key);
804 prop_assert_eq!(d.objects.len(), objects.len());
805 for (orig, dec) in objects.iter().zip(d.objects.iter()) {
806 prop_assert_eq!(orig.id, dec.id);
807 prop_assert_eq!(orig.object_type, dec.object_type);
808 prop_assert_eq!(orig.data.as_ref(), dec.data.as_ref());
809 }
810 }
811 _ => prop_assert!(false, "wrong message type"),
812 }
813 }
814
815 #[test]
817 fn prop_truncated_repo_announce_fails(
818 repo_key in repo_key_strategy(),
819 truncate_at in 0usize..50
820 ) {
821 let msg = RepoAnnounce {
822 repo_key,
823 object_ids: vec![ObjectId::from_bytes([1u8; 20])],
824 refs: vec![],
825 };
826
827 let encoded = msg.encode();
828 if truncate_at < encoded.len() {
829 let truncated = &encoded[1..truncate_at.max(1)];
831 let result = RepoAnnounce::decode(truncated);
832 let _ = result;
835 }
836 }
837
838 #[test]
840 fn prop_invalid_message_type_fails(msg_type in 5u8..=255) {
841 let result = MessageType::from_byte(msg_type);
842 prop_assert!(result.is_err());
843 }
844
845 #[test]
847 fn prop_message_encode_decode_identity(
848 msg_type in 1u8..=4,
849 repo_key in repo_key_strategy()
850 ) {
851 let msg = match msg_type {
852 1 => Message::RepoAnnounce(RepoAnnounce {
853 repo_key: repo_key.clone(),
854 object_ids: vec![],
855 refs: vec![],
856 }),
857 2 => Message::SyncRequest(SyncRequest {
858 repo_key: repo_key.clone(),
859 want: vec![],
860 }),
861 3 => Message::ObjectData(ObjectData {
862 repo_key: repo_key.clone(),
863 objects: vec![],
864 }),
865 4 => Message::RefUpdate(RefUpdate {
866 repo_key: repo_key.clone(),
867 ref_name: "refs/heads/main".to_string(),
868 old_id: ObjectId::from_bytes([0u8; 20]),
869 new_id: ObjectId::from_bytes([1u8; 20]),
870 }),
871 _ => unreachable!(),
872 };
873
874 let encoded = msg.encode();
875 let decoded = Message::decode(&encoded);
876 prop_assert!(decoded.is_ok());
877 }
878 }
879}