1use std::io::{self, Read, Write};
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13
14use moloch_core::block::{Block, BlockHash, BlockHeader};
15use moloch_core::crypto::{Hash, PublicKey, Sig};
16use moloch_core::event::{AuditEvent, EventId};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub struct ProtocolVersion {
21 pub major: u16,
22 pub minor: u16,
23 pub patch: u16,
24}
25
26impl ProtocolVersion {
27 pub const CURRENT: Self = Self {
29 major: 1,
30 minor: 0,
31 patch: 0,
32 };
33
34 pub fn is_compatible_with(&self, other: &Self) -> bool {
36 self.major == other.major
38 }
39
40 pub fn new(major: u16, minor: u16, patch: u16) -> Self {
42 Self {
43 major,
44 minor,
45 patch,
46 }
47 }
48}
49
50impl std::fmt::Display for ProtocolVersion {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
53 }
54}
55
56impl Default for ProtocolVersion {
57 fn default() -> Self {
58 Self::CURRENT
59 }
60}
61
62pub type MessageId = u64;
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub struct PeerId {
68 pub key: PublicKey,
70}
71
72impl PeerId {
73 pub fn new(key: PublicKey) -> Self {
75 Self { key }
76 }
77
78 pub fn id(&self) -> Hash {
80 self.key.id()
81 }
82}
83
84impl std::fmt::Display for PeerId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", hex::encode(&self.key.as_bytes()[..8]))
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub enum Message {
93 Hello(HelloMessage),
96
97 HelloAck(HelloAckMessage),
99
100 Status(StatusMessage),
102
103 Goodbye(GoodbyeMessage),
105
106 NewEvent(NewEventMessage),
109
110 NewBlock(NewBlockMessage),
112
113 Announce(AnnounceMessage),
115
116 GetBlocks(GetBlocksMessage),
119
120 Blocks(BlocksMessage),
122
123 GetHeaders(GetHeadersMessage),
125
126 Headers(HeadersMessage),
128
129 GetEvents(GetEventsMessage),
131
132 Events(EventsMessage),
134
135 GetSnapshot(GetSnapshotMessage),
137
138 Snapshot(SnapshotMessage),
140
141 Proposal(ProposalMessage),
144
145 Vote(VoteMessage),
147
148 GetVotes(GetVotesMessage),
150
151 Votes(VotesMessage),
153
154 Ping(PingMessage),
157
158 Pong(PongMessage),
160}
161
162impl Message {
163 pub fn type_name(&self) -> &'static str {
165 match self {
166 Message::Hello(_) => "Hello",
167 Message::HelloAck(_) => "HelloAck",
168 Message::Status(_) => "Status",
169 Message::Goodbye(_) => "Goodbye",
170 Message::NewEvent(_) => "NewEvent",
171 Message::NewBlock(_) => "NewBlock",
172 Message::Announce(_) => "Announce",
173 Message::GetBlocks(_) => "GetBlocks",
174 Message::Blocks(_) => "Blocks",
175 Message::GetHeaders(_) => "GetHeaders",
176 Message::Headers(_) => "Headers",
177 Message::GetEvents(_) => "GetEvents",
178 Message::Events(_) => "Events",
179 Message::GetSnapshot(_) => "GetSnapshot",
180 Message::Snapshot(_) => "Snapshot",
181 Message::Proposal(_) => "Proposal",
182 Message::Vote(_) => "Vote",
183 Message::GetVotes(_) => "GetVotes",
184 Message::Votes(_) => "Votes",
185 Message::Ping(_) => "Ping",
186 Message::Pong(_) => "Pong",
187 }
188 }
189
190 pub fn is_request(&self) -> bool {
192 matches!(
193 self,
194 Message::Hello(_)
195 | Message::GetBlocks(_)
196 | Message::GetHeaders(_)
197 | Message::GetEvents(_)
198 | Message::GetSnapshot(_)
199 | Message::GetVotes(_)
200 | Message::Ping(_)
201 )
202 }
203
204 pub fn message_id(&self) -> Option<MessageId> {
206 match self {
207 Message::Hello(m) => Some(m.id),
208 Message::HelloAck(m) => Some(m.request_id),
209 Message::GetBlocks(m) => Some(m.id),
210 Message::Blocks(m) => Some(m.request_id),
211 Message::GetHeaders(m) => Some(m.id),
212 Message::Headers(m) => Some(m.request_id),
213 Message::GetEvents(m) => Some(m.id),
214 Message::Events(m) => Some(m.request_id),
215 Message::GetSnapshot(m) => Some(m.id),
216 Message::Snapshot(m) => Some(m.request_id),
217 Message::GetVotes(m) => Some(m.id),
218 Message::Votes(m) => Some(m.request_id),
219 Message::Ping(m) => Some(m.id),
220 Message::Pong(m) => Some(m.request_id),
221 _ => None,
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct HelloMessage {
231 pub id: MessageId,
233 pub version: ProtocolVersion,
235 pub chain_id: String,
237 pub node_key: PublicKey,
239 pub height: Option<u64>,
241 pub head_hash: Option<BlockHash>,
243 #[serde(with = "chrono::serde::ts_milliseconds")]
245 pub timestamp: DateTime<Utc>,
246 pub signature: Sig,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct HelloAckMessage {
253 pub request_id: MessageId,
255 pub version: ProtocolVersion,
257 pub chain_id: String,
259 pub node_key: PublicKey,
261 pub height: Option<u64>,
263 pub head_hash: Option<BlockHash>,
265 #[serde(with = "chrono::serde::ts_milliseconds")]
267 pub timestamp: DateTime<Utc>,
268 pub signature: Sig,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct StatusMessage {
275 pub height: Option<u64>,
277 pub head_hash: Option<BlockHash>,
279 pub peer_count: usize,
281 pub syncing: bool,
283 #[serde(with = "chrono::serde::ts_milliseconds")]
285 pub timestamp: DateTime<Utc>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct GoodbyeMessage {
291 pub reason: DisconnectReason,
293 pub message: Option<String>,
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
299#[serde(rename_all = "snake_case")]
300pub enum DisconnectReason {
301 Shutdown,
303 ProtocolMismatch,
305 ChainMismatch,
307 TooManyConnections,
309 Misbehavior,
311 Timeout,
313 DuplicateConnection,
315 Requested,
317 Other,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct NewEventMessage {
326 pub event: AuditEvent,
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct NewBlockMessage {
333 pub block: Block,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct AnnounceMessage {
340 pub announcement: Announcement,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
346pub enum Announcement {
347 Block { height: u64, hash: BlockHash },
349 Events { ids: Vec<EventId> },
351 ChainTip { height: u64, hash: BlockHash },
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct GetBlocksMessage {
360 pub id: MessageId,
362 pub start_height: u64,
364 pub count: u32,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct BlocksMessage {
371 pub request_id: MessageId,
373 pub blocks: Vec<Block>,
375 pub has_more: bool,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct GetHeadersMessage {
382 pub id: MessageId,
384 pub start_height: u64,
386 pub count: u32,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct HeadersMessage {
393 pub request_id: MessageId,
395 pub headers: Vec<BlockHeader>,
397 pub has_more: bool,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct GetEventsMessage {
404 pub id: MessageId,
406 pub event_ids: Vec<EventId>,
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct EventsMessage {
413 pub request_id: MessageId,
415 pub events: Vec<AuditEvent>,
417 pub not_found: Vec<EventId>,
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize)]
423pub struct GetSnapshotMessage {
424 pub id: MessageId,
426 pub height: Option<u64>,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
432pub struct SnapshotMessage {
433 pub request_id: MessageId,
435 pub height: u64,
437 pub head_hash: BlockHash,
439 pub mmr_root: Hash,
441 pub block_count: u64,
443 pub event_count: u64,
445 pub validators: Vec<PublicKey>,
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
453pub struct ProposalMessage {
454 pub block: Block,
456 pub signature: Sig,
458}
459
460#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct VoteMessage {
463 pub block_hash: BlockHash,
465 pub height: u64,
467 pub voter: PublicKey,
469 pub signature: Sig,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
475pub struct GetVotesMessage {
476 pub id: MessageId,
478 pub block_hash: BlockHash,
480}
481
482#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct VotesMessage {
485 pub request_id: MessageId,
487 pub block_hash: BlockHash,
489 pub votes: Vec<VoteMessage>,
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct PingMessage {
498 pub id: MessageId,
500 #[serde(with = "chrono::serde::ts_milliseconds")]
502 pub timestamp: DateTime<Utc>,
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct PongMessage {
508 pub request_id: MessageId,
510 #[serde(with = "chrono::serde::ts_milliseconds")]
512 pub ping_timestamp: DateTime<Utc>,
513 #[serde(with = "chrono::serde::ts_milliseconds")]
515 pub pong_timestamp: DateTime<Utc>,
516}
517
518#[derive(Debug, Clone)]
520pub struct MessageCodec {
521 max_size: usize,
523}
524
525impl MessageCodec {
526 pub const DEFAULT_MAX_SIZE: usize = 16 * 1024 * 1024;
528
529 pub fn new() -> Self {
531 Self {
532 max_size: Self::DEFAULT_MAX_SIZE,
533 }
534 }
535
536 pub fn with_max_size(max_size: usize) -> Self {
538 Self { max_size }
539 }
540
541 pub fn encode(&self, message: &Message) -> Result<Vec<u8>, CodecError> {
543 let payload = bincode::serialize(message)?;
544
545 if payload.len() > self.max_size {
546 return Err(CodecError::MessageTooLarge {
547 size: payload.len(),
548 max: self.max_size,
549 });
550 }
551
552 let mut frame = Vec::with_capacity(4 + payload.len());
554 frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
555 frame.extend_from_slice(&payload);
556
557 Ok(frame)
558 }
559
560 pub fn decode(&self, data: &[u8]) -> Result<Message, CodecError> {
562 if data.len() < 4 {
563 return Err(CodecError::IncompletFrame);
564 }
565
566 let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
567
568 if length > self.max_size {
569 return Err(CodecError::MessageTooLarge {
570 size: length,
571 max: self.max_size,
572 });
573 }
574
575 if data.len() < 4 + length {
576 return Err(CodecError::IncompletFrame);
577 }
578
579 let message = bincode::deserialize(&data[4..4 + length])?;
580 Ok(message)
581 }
582
583 pub fn read_message<R: Read>(&self, reader: &mut R) -> Result<Message, CodecError> {
585 let mut len_buf = [0u8; 4];
587 reader.read_exact(&mut len_buf)?;
588 let length = u32::from_be_bytes(len_buf) as usize;
589
590 if length > self.max_size {
591 return Err(CodecError::MessageTooLarge {
592 size: length,
593 max: self.max_size,
594 });
595 }
596
597 let mut payload = vec![0u8; length];
599 reader.read_exact(&mut payload)?;
600
601 let message = bincode::deserialize(&payload)?;
602 Ok(message)
603 }
604
605 pub fn write_message<W: Write>(
607 &self,
608 writer: &mut W,
609 message: &Message,
610 ) -> Result<(), CodecError> {
611 let frame = self.encode(message)?;
612 writer.write_all(&frame)?;
613 Ok(())
614 }
615}
616
617impl Default for MessageCodec {
618 fn default() -> Self {
619 Self::new()
620 }
621}
622
623#[derive(Debug, thiserror::Error)]
625pub enum CodecError {
626 #[error("message too large: {size} bytes exceeds limit of {max} bytes")]
627 MessageTooLarge { size: usize, max: usize },
628
629 #[error("incomplete frame")]
630 IncompletFrame,
631
632 #[error("serialization error: {0}")]
633 Serialization(#[from] bincode::Error),
634
635 #[error("I/O error: {0}")]
636 Io(#[from] io::Error),
637}
638
639pub fn generate_message_id() -> MessageId {
641 use std::sync::atomic::{AtomicU64, Ordering};
642 static COUNTER: AtomicU64 = AtomicU64::new(0);
643 COUNTER.fetch_add(1, Ordering::SeqCst)
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use moloch_core::crypto::SecretKey;
650
651 #[test]
652 fn test_protocol_version_compatibility() {
653 let v1 = ProtocolVersion::new(1, 0, 0);
654 let v1_1 = ProtocolVersion::new(1, 1, 0);
655 let v2 = ProtocolVersion::new(2, 0, 0);
656
657 assert!(v1.is_compatible_with(&v1_1));
658 assert!(v1_1.is_compatible_with(&v1));
659 assert!(!v1.is_compatible_with(&v2));
660 }
661
662 #[test]
663 fn test_protocol_version_display() {
664 let v = ProtocolVersion::new(1, 2, 3);
665 assert_eq!(format!("{}", v), "1.2.3");
666 }
667
668 #[test]
669 fn test_peer_id() {
670 let key = SecretKey::generate();
671 let peer_id = PeerId::new(key.public_key());
672
673 let id1 = peer_id.id();
675 let id2 = peer_id.id();
676 assert_eq!(id1, id2);
677
678 let display = format!("{}", peer_id);
680 assert_eq!(display.len(), 16); }
682
683 #[test]
684 fn test_message_type_names() {
685 let key = SecretKey::generate();
686 let hello = Message::Hello(HelloMessage {
687 id: 1,
688 version: ProtocolVersion::CURRENT,
689 chain_id: "test".into(),
690 node_key: key.public_key(),
691 height: Some(100),
692 head_hash: None,
693 timestamp: Utc::now(),
694 signature: key.sign(b"hello"),
695 });
696
697 assert_eq!(hello.type_name(), "Hello");
698 assert!(hello.is_request());
699 assert_eq!(hello.message_id(), Some(1));
700 }
701
702 #[test]
703 fn test_message_codec_roundtrip() {
704 let codec = MessageCodec::new();
705 let _key = SecretKey::generate();
706
707 let original = Message::Status(StatusMessage {
708 height: Some(50),
709 head_hash: None,
710 peer_count: 5,
711 syncing: false,
712 timestamp: Utc::now(),
713 });
714
715 let encoded = codec.encode(&original).unwrap();
716 let decoded = codec.decode(&encoded).unwrap();
717
718 match (&original, &decoded) {
719 (Message::Status(orig), Message::Status(dec)) => {
720 assert_eq!(orig.height, dec.height);
721 assert_eq!(orig.peer_count, dec.peer_count);
722 assert_eq!(orig.syncing, dec.syncing);
723 }
724 _ => panic!("message type mismatch"),
725 }
726 }
727
728 #[test]
729 fn test_message_codec_size_limit() {
730 let codec = MessageCodec::with_max_size(100);
731
732 let large_message = Message::Goodbye(GoodbyeMessage {
734 reason: DisconnectReason::Other,
735 message: Some("x".repeat(200)),
736 });
737
738 let result = codec.encode(&large_message);
739 assert!(matches!(result, Err(CodecError::MessageTooLarge { .. })));
740 }
741
742 #[test]
743 fn test_message_codec_incomplete_frame() {
744 let codec = MessageCodec::new();
745 let result = codec.decode(&[0, 0, 0]); assert!(matches!(result, Err(CodecError::IncompletFrame)));
747 }
748
749 #[test]
750 fn test_ping_pong_messages() {
751 let ping = PingMessage {
752 id: 42,
753 timestamp: Utc::now(),
754 };
755
756 let pong = PongMessage {
757 request_id: 42,
758 ping_timestamp: ping.timestamp,
759 pong_timestamp: Utc::now(),
760 };
761
762 assert_eq!(pong.request_id, ping.id);
763 }
764
765 #[test]
766 fn test_disconnect_reasons() {
767 let reasons = vec![
768 DisconnectReason::Shutdown,
769 DisconnectReason::ProtocolMismatch,
770 DisconnectReason::ChainMismatch,
771 DisconnectReason::TooManyConnections,
772 DisconnectReason::Misbehavior,
773 DisconnectReason::Timeout,
774 DisconnectReason::DuplicateConnection,
775 DisconnectReason::Requested,
776 DisconnectReason::Other,
777 ];
778
779 let codec = MessageCodec::new();
780
781 for reason in reasons {
782 let msg = Message::Goodbye(GoodbyeMessage {
783 reason,
784 message: None,
785 });
786
787 let encoded = codec.encode(&msg).unwrap();
788 let decoded = codec.decode(&encoded).unwrap();
789
790 match decoded {
791 Message::Goodbye(g) => assert_eq!(g.reason, reason),
792 _ => panic!("wrong message type"),
793 }
794 }
795 }
796
797 #[test]
798 fn test_get_blocks_message() {
799 let msg = GetBlocksMessage {
800 id: generate_message_id(),
801 start_height: 100,
802 count: 50,
803 };
804
805 let codec = MessageCodec::new();
806 let encoded = codec.encode(&Message::GetBlocks(msg.clone())).unwrap();
807 let decoded = codec.decode(&encoded).unwrap();
808
809 match decoded {
810 Message::GetBlocks(m) => {
811 assert_eq!(m.start_height, 100);
812 assert_eq!(m.count, 50);
813 }
814 _ => panic!("wrong message type"),
815 }
816 }
817
818 #[test]
819 fn test_announcement_variants() {
820 use moloch_core::crypto::hash;
821
822 let announcements = vec![
823 Announcement::Block {
824 height: 100,
825 hash: moloch_core::block::BlockHash(hash(b"block")),
826 },
827 Announcement::Events {
828 ids: vec![moloch_core::event::EventId(hash(b"event1"))],
829 },
830 Announcement::ChainTip {
831 height: 200,
832 hash: moloch_core::block::BlockHash(hash(b"tip")),
833 },
834 ];
835
836 let codec = MessageCodec::new();
837
838 for ann in announcements {
839 let msg = Message::Announce(AnnounceMessage { announcement: ann });
840 let encoded = codec.encode(&msg).unwrap();
841 let decoded = codec.decode(&encoded).unwrap();
842 assert!(matches!(decoded, Message::Announce(_)));
843 }
844 }
845
846 #[test]
847 fn test_generate_message_id_unique() {
848 let id1 = generate_message_id();
849 let id2 = generate_message_id();
850 let id3 = generate_message_id();
851
852 assert_ne!(id1, id2);
853 assert_ne!(id2, id3);
854 assert_ne!(id1, id3);
855 }
856}