use std::io::{self, Read, Write};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use moloch_core::block::{Block, BlockHash, BlockHeader};
use moloch_core::crypto::{Hash, PublicKey, Sig};
use moloch_core::event::{AuditEvent, EventId};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProtocolVersion {
pub major: u16,
pub minor: u16,
pub patch: u16,
}
impl ProtocolVersion {
pub const CURRENT: Self = Self {
major: 1,
minor: 0,
patch: 0,
};
pub fn is_compatible_with(&self, other: &Self) -> bool {
self.major == other.major
}
pub fn new(major: u16, minor: u16, patch: u16) -> Self {
Self {
major,
minor,
patch,
}
}
}
impl std::fmt::Display for ProtocolVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
}
}
impl Default for ProtocolVersion {
fn default() -> Self {
Self::CURRENT
}
}
pub type MessageId = u64;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PeerId {
pub key: PublicKey,
}
impl PeerId {
pub fn new(key: PublicKey) -> Self {
Self { key }
}
pub fn id(&self) -> Hash {
self.key.id()
}
}
impl std::fmt::Display for PeerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", hex::encode(&self.key.as_bytes()[..8]))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Message {
Hello(HelloMessage),
HelloAck(HelloAckMessage),
Status(StatusMessage),
Goodbye(GoodbyeMessage),
NewEvent(NewEventMessage),
NewBlock(NewBlockMessage),
Announce(AnnounceMessage),
GetBlocks(GetBlocksMessage),
Blocks(BlocksMessage),
GetHeaders(GetHeadersMessage),
Headers(HeadersMessage),
GetEvents(GetEventsMessage),
Events(EventsMessage),
GetSnapshot(GetSnapshotMessage),
Snapshot(SnapshotMessage),
Proposal(ProposalMessage),
Vote(VoteMessage),
GetVotes(GetVotesMessage),
Votes(VotesMessage),
Ping(PingMessage),
Pong(PongMessage),
}
impl Message {
pub fn type_name(&self) -> &'static str {
match self {
Message::Hello(_) => "Hello",
Message::HelloAck(_) => "HelloAck",
Message::Status(_) => "Status",
Message::Goodbye(_) => "Goodbye",
Message::NewEvent(_) => "NewEvent",
Message::NewBlock(_) => "NewBlock",
Message::Announce(_) => "Announce",
Message::GetBlocks(_) => "GetBlocks",
Message::Blocks(_) => "Blocks",
Message::GetHeaders(_) => "GetHeaders",
Message::Headers(_) => "Headers",
Message::GetEvents(_) => "GetEvents",
Message::Events(_) => "Events",
Message::GetSnapshot(_) => "GetSnapshot",
Message::Snapshot(_) => "Snapshot",
Message::Proposal(_) => "Proposal",
Message::Vote(_) => "Vote",
Message::GetVotes(_) => "GetVotes",
Message::Votes(_) => "Votes",
Message::Ping(_) => "Ping",
Message::Pong(_) => "Pong",
}
}
pub fn is_request(&self) -> bool {
matches!(
self,
Message::Hello(_)
| Message::GetBlocks(_)
| Message::GetHeaders(_)
| Message::GetEvents(_)
| Message::GetSnapshot(_)
| Message::GetVotes(_)
| Message::Ping(_)
)
}
pub fn message_id(&self) -> Option<MessageId> {
match self {
Message::Hello(m) => Some(m.id),
Message::HelloAck(m) => Some(m.request_id),
Message::GetBlocks(m) => Some(m.id),
Message::Blocks(m) => Some(m.request_id),
Message::GetHeaders(m) => Some(m.id),
Message::Headers(m) => Some(m.request_id),
Message::GetEvents(m) => Some(m.id),
Message::Events(m) => Some(m.request_id),
Message::GetSnapshot(m) => Some(m.id),
Message::Snapshot(m) => Some(m.request_id),
Message::GetVotes(m) => Some(m.id),
Message::Votes(m) => Some(m.request_id),
Message::Ping(m) => Some(m.id),
Message::Pong(m) => Some(m.request_id),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HelloMessage {
pub id: MessageId,
pub version: ProtocolVersion,
pub chain_id: String,
pub node_key: PublicKey,
pub height: Option<u64>,
pub head_hash: Option<BlockHash>,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub timestamp: DateTime<Utc>,
pub signature: Sig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HelloAckMessage {
pub request_id: MessageId,
pub version: ProtocolVersion,
pub chain_id: String,
pub node_key: PublicKey,
pub height: Option<u64>,
pub head_hash: Option<BlockHash>,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub timestamp: DateTime<Utc>,
pub signature: Sig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatusMessage {
pub height: Option<u64>,
pub head_hash: Option<BlockHash>,
pub peer_count: usize,
pub syncing: bool,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoodbyeMessage {
pub reason: DisconnectReason,
pub message: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DisconnectReason {
Shutdown,
ProtocolMismatch,
ChainMismatch,
TooManyConnections,
Misbehavior,
Timeout,
DuplicateConnection,
Requested,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NewEventMessage {
pub event: AuditEvent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NewBlockMessage {
pub block: Block,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnnounceMessage {
pub announcement: Announcement,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Announcement {
Block { height: u64, hash: BlockHash },
Events { ids: Vec<EventId> },
ChainTip { height: u64, hash: BlockHash },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetBlocksMessage {
pub id: MessageId,
pub start_height: u64,
pub count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlocksMessage {
pub request_id: MessageId,
pub blocks: Vec<Block>,
pub has_more: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetHeadersMessage {
pub id: MessageId,
pub start_height: u64,
pub count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeadersMessage {
pub request_id: MessageId,
pub headers: Vec<BlockHeader>,
pub has_more: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetEventsMessage {
pub id: MessageId,
pub event_ids: Vec<EventId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventsMessage {
pub request_id: MessageId,
pub events: Vec<AuditEvent>,
pub not_found: Vec<EventId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetSnapshotMessage {
pub id: MessageId,
pub height: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotMessage {
pub request_id: MessageId,
pub height: u64,
pub head_hash: BlockHash,
pub mmr_root: Hash,
pub block_count: u64,
pub event_count: u64,
pub validators: Vec<PublicKey>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProposalMessage {
pub block: Block,
pub signature: Sig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoteMessage {
pub block_hash: BlockHash,
pub height: u64,
pub voter: PublicKey,
pub signature: Sig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GetVotesMessage {
pub id: MessageId,
pub block_hash: BlockHash,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VotesMessage {
pub request_id: MessageId,
pub block_hash: BlockHash,
pub votes: Vec<VoteMessage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PingMessage {
pub id: MessageId,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PongMessage {
pub request_id: MessageId,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub ping_timestamp: DateTime<Utc>,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub pong_timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct MessageCodec {
max_size: usize,
}
impl MessageCodec {
pub const DEFAULT_MAX_SIZE: usize = 16 * 1024 * 1024;
pub fn new() -> Self {
Self {
max_size: Self::DEFAULT_MAX_SIZE,
}
}
pub fn with_max_size(max_size: usize) -> Self {
Self { max_size }
}
pub fn encode(&self, message: &Message) -> Result<Vec<u8>, CodecError> {
let payload = bincode::serialize(message)?;
if payload.len() > self.max_size {
return Err(CodecError::MessageTooLarge {
size: payload.len(),
max: self.max_size,
});
}
let mut frame = Vec::with_capacity(4 + payload.len());
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(&payload);
Ok(frame)
}
pub fn decode(&self, data: &[u8]) -> Result<Message, CodecError> {
if data.len() < 4 {
return Err(CodecError::IncompletFrame);
}
let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
if length > self.max_size {
return Err(CodecError::MessageTooLarge {
size: length,
max: self.max_size,
});
}
if data.len() < 4 + length {
return Err(CodecError::IncompletFrame);
}
let message = bincode::deserialize(&data[4..4 + length])?;
Ok(message)
}
pub fn read_message<R: Read>(&self, reader: &mut R) -> Result<Message, CodecError> {
let mut len_buf = [0u8; 4];
reader.read_exact(&mut len_buf)?;
let length = u32::from_be_bytes(len_buf) as usize;
if length > self.max_size {
return Err(CodecError::MessageTooLarge {
size: length,
max: self.max_size,
});
}
let mut payload = vec![0u8; length];
reader.read_exact(&mut payload)?;
let message = bincode::deserialize(&payload)?;
Ok(message)
}
pub fn write_message<W: Write>(
&self,
writer: &mut W,
message: &Message,
) -> Result<(), CodecError> {
let frame = self.encode(message)?;
writer.write_all(&frame)?;
Ok(())
}
}
impl Default for MessageCodec {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum CodecError {
#[error("message too large: {size} bytes exceeds limit of {max} bytes")]
MessageTooLarge { size: usize, max: usize },
#[error("incomplete frame")]
IncompletFrame,
#[error("serialization error: {0}")]
Serialization(#[from] bincode::Error),
#[error("I/O error: {0}")]
Io(#[from] io::Error),
}
pub fn generate_message_id() -> MessageId {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[cfg(test)]
mod tests {
use super::*;
use moloch_core::crypto::SecretKey;
#[test]
fn test_protocol_version_compatibility() {
let v1 = ProtocolVersion::new(1, 0, 0);
let v1_1 = ProtocolVersion::new(1, 1, 0);
let v2 = ProtocolVersion::new(2, 0, 0);
assert!(v1.is_compatible_with(&v1_1));
assert!(v1_1.is_compatible_with(&v1));
assert!(!v1.is_compatible_with(&v2));
}
#[test]
fn test_protocol_version_display() {
let v = ProtocolVersion::new(1, 2, 3);
assert_eq!(format!("{}", v), "1.2.3");
}
#[test]
fn test_peer_id() {
let key = SecretKey::generate();
let peer_id = PeerId::new(key.public_key());
let id1 = peer_id.id();
let id2 = peer_id.id();
assert_eq!(id1, id2);
let display = format!("{}", peer_id);
assert_eq!(display.len(), 16); }
#[test]
fn test_message_type_names() {
let key = SecretKey::generate();
let hello = Message::Hello(HelloMessage {
id: 1,
version: ProtocolVersion::CURRENT,
chain_id: "test".into(),
node_key: key.public_key(),
height: Some(100),
head_hash: None,
timestamp: Utc::now(),
signature: key.sign(b"hello"),
});
assert_eq!(hello.type_name(), "Hello");
assert!(hello.is_request());
assert_eq!(hello.message_id(), Some(1));
}
#[test]
fn test_message_codec_roundtrip() {
let codec = MessageCodec::new();
let _key = SecretKey::generate();
let original = Message::Status(StatusMessage {
height: Some(50),
head_hash: None,
peer_count: 5,
syncing: false,
timestamp: Utc::now(),
});
let encoded = codec.encode(&original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
match (&original, &decoded) {
(Message::Status(orig), Message::Status(dec)) => {
assert_eq!(orig.height, dec.height);
assert_eq!(orig.peer_count, dec.peer_count);
assert_eq!(orig.syncing, dec.syncing);
}
_ => panic!("message type mismatch"),
}
}
#[test]
fn test_message_codec_size_limit() {
let codec = MessageCodec::with_max_size(100);
let large_message = Message::Goodbye(GoodbyeMessage {
reason: DisconnectReason::Other,
message: Some("x".repeat(200)),
});
let result = codec.encode(&large_message);
assert!(matches!(result, Err(CodecError::MessageTooLarge { .. })));
}
#[test]
fn test_message_codec_incomplete_frame() {
let codec = MessageCodec::new();
let result = codec.decode(&[0, 0, 0]); assert!(matches!(result, Err(CodecError::IncompletFrame)));
}
#[test]
fn test_ping_pong_messages() {
let ping = PingMessage {
id: 42,
timestamp: Utc::now(),
};
let pong = PongMessage {
request_id: 42,
ping_timestamp: ping.timestamp,
pong_timestamp: Utc::now(),
};
assert_eq!(pong.request_id, ping.id);
}
#[test]
fn test_disconnect_reasons() {
let reasons = vec![
DisconnectReason::Shutdown,
DisconnectReason::ProtocolMismatch,
DisconnectReason::ChainMismatch,
DisconnectReason::TooManyConnections,
DisconnectReason::Misbehavior,
DisconnectReason::Timeout,
DisconnectReason::DuplicateConnection,
DisconnectReason::Requested,
DisconnectReason::Other,
];
let codec = MessageCodec::new();
for reason in reasons {
let msg = Message::Goodbye(GoodbyeMessage {
reason,
message: None,
});
let encoded = codec.encode(&msg).unwrap();
let decoded = codec.decode(&encoded).unwrap();
match decoded {
Message::Goodbye(g) => assert_eq!(g.reason, reason),
_ => panic!("wrong message type"),
}
}
}
#[test]
fn test_get_blocks_message() {
let msg = GetBlocksMessage {
id: generate_message_id(),
start_height: 100,
count: 50,
};
let codec = MessageCodec::new();
let encoded = codec.encode(&Message::GetBlocks(msg.clone())).unwrap();
let decoded = codec.decode(&encoded).unwrap();
match decoded {
Message::GetBlocks(m) => {
assert_eq!(m.start_height, 100);
assert_eq!(m.count, 50);
}
_ => panic!("wrong message type"),
}
}
#[test]
fn test_announcement_variants() {
use moloch_core::crypto::hash;
let announcements = vec![
Announcement::Block {
height: 100,
hash: moloch_core::block::BlockHash(hash(b"block")),
},
Announcement::Events {
ids: vec![moloch_core::event::EventId(hash(b"event1"))],
},
Announcement::ChainTip {
height: 200,
hash: moloch_core::block::BlockHash(hash(b"tip")),
},
];
let codec = MessageCodec::new();
for ann in announcements {
let msg = Message::Announce(AnnounceMessage { announcement: ann });
let encoded = codec.encode(&msg).unwrap();
let decoded = codec.decode(&encoded).unwrap();
assert!(matches!(decoded, Message::Announce(_)));
}
}
#[test]
fn test_generate_message_id_unique() {
let id1 = generate_message_id();
let id2 = generate_message_id();
let id3 = generate_message_id();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
}
}