use std::fmt;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
static MESSAGE_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MessageId {
pub origin: SocketAddr,
pub sequence: u64,
pub timestamp: u64,
}
impl MessageId {
pub fn new(origin: SocketAddr) -> Self {
let sequence = MESSAGE_COUNTER.fetch_add(1, Ordering::Relaxed);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as u64;
Self {
origin,
sequence,
timestamp,
}
}
}
impl fmt::Display for MessageId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}:{}", self.origin, self.sequence, self.timestamp)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: MessageId,
pub ttl: u8,
pub payload: Payload,
#[cfg(feature = "crypto")]
pub signature: Option<Vec<u8>>,
}
impl Message {
pub fn new(origin: SocketAddr, payload: Payload) -> Self {
Self {
id: MessageId::new(origin),
ttl: 10, payload,
#[cfg(feature = "crypto")]
signature: None,
}
}
pub fn with_ttl(origin: SocketAddr, payload: Payload, ttl: u8) -> Self {
Self {
id: MessageId::new(origin),
ttl,
payload,
#[cfg(feature = "crypto")]
signature: None,
}
}
pub fn decrement_ttl(&mut self) -> bool {
if self.ttl > 0 {
self.ttl -= 1;
true
} else {
false
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Payload {
Application(Bytes),
PeerDiscovery,
PeerAnnouncement {
peers: Vec<SocketAddr>,
},
Heartbeat {
from: SocketAddr,
},
PeerListRequest,
PeerListResponse {
peers: Vec<SocketAddr>,
},
AntiEntropyDigest {
message_ids: Vec<MessageId>,
},
MessageRequest {
ids: Vec<MessageId>,
},
MessageResponse {
messages: Vec<Message>,
},
Goodbye {
reason: String,
},
DirectMessage {
recipient: SocketAddr,
data: Bytes,
},
}
impl Payload {
pub fn is_protocol_message(&self) -> bool {
!matches!(self, Self::Application(_) | Self::DirectMessage { .. })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_message_id() {
let addr = "127.0.0.1:8000".parse().unwrap();
let id1 = MessageId::new(addr);
let id2 = MessageId::new(addr);
assert_ne!(id1.sequence, id2.sequence);
assert_eq!(id1.origin, addr);
assert_eq!(id2.origin, addr);
}
#[test]
fn decrement_ttl() {
let addr = "127.0.0.1:8000".parse().unwrap();
let mut msg = Message::with_ttl(addr, Payload::PeerDiscovery, 2);
assert!(msg.decrement_ttl());
assert_eq!(msg.ttl, 1);
assert!(msg.decrement_ttl());
assert_eq!(msg.ttl, 0);
assert!(!msg.decrement_ttl());
assert_eq!(msg.ttl, 0);
}
#[test]
fn payload_types() {
let data = Bytes::from("test data");
let payload = Payload::Application(data.clone());
assert!(!payload.is_protocol_message());
match payload {
Payload::Application(d) => assert_eq!(d, data),
_ => panic!("Expected Application payload"),
}
let payload = Payload::PeerDiscovery;
assert!(payload.is_protocol_message());
let peers = vec![
"127.0.0.1:8001".parse().unwrap(),
"127.0.0.1:8002".parse().unwrap(),
];
let payload = Payload::PeerAnnouncement {
peers: peers.clone(),
};
assert!(payload.is_protocol_message());
match payload {
Payload::PeerAnnouncement { peers: p } => assert_eq!(p, peers),
_ => panic!("Expected PeerAnnouncement payload"),
}
let addr = "127.0.0.1:8000".parse().unwrap();
let payload = Payload::Heartbeat { from: addr };
assert!(payload.is_protocol_message());
match payload {
Payload::Heartbeat { from } => assert_eq!(from, addr),
_ => panic!("Expected Heartbeat payload"),
}
let payload = Payload::PeerListRequest;
assert!(payload.is_protocol_message());
let peers = vec!["127.0.0.1:8001".parse().unwrap()];
let payload = Payload::PeerListResponse {
peers: peers.clone(),
};
assert!(payload.is_protocol_message());
match payload {
Payload::PeerListResponse { peers: p } => assert_eq!(p, peers),
_ => panic!("Expected PeerListResponse payload"),
}
let recipient = "127.0.0.1:8001".parse().unwrap();
let data = Bytes::from("private message");
let payload = Payload::DirectMessage {
recipient,
data: data.clone(),
};
assert!(!payload.is_protocol_message());
match payload {
Payload::DirectMessage {
recipient: r,
data: d,
} => {
assert_eq!(r, recipient);
assert_eq!(d, data);
}
_ => panic!("Expected DirectMessage payload"),
}
let reason = "Normal shutdown".to_string();
let payload = Payload::Goodbye {
reason: reason.clone(),
};
assert!(payload.is_protocol_message());
match payload {
Payload::Goodbye { reason: r } => assert_eq!(r, reason),
_ => panic!("Expected Goodbye payload"),
}
}
#[test]
fn multiple_messages_different_sequences() {
let addr = "127.0.0.1:8000".parse().unwrap();
let messages = (0..10)
.map(|_| Message::new(addr, Payload::PeerDiscovery))
.collect::<Vec<_>>();
for i in 0..messages.len() - 1 {
for j in i + 1..messages.len() {
assert_ne!(messages[i].id.sequence, messages[j].id.sequence);
}
}
}
#[test]
fn direct_message_serialization() {
let sender = "127.0.0.1:8000".parse().unwrap();
let recipient = "127.0.0.1:8001".parse().unwrap();
let data = Bytes::from("test direct message");
let message = Message::new(
sender,
Payload::DirectMessage {
recipient,
data: data.clone(),
},
);
let serialized =
bincode::serde::encode_to_vec(&message, bincode::config::standard()).unwrap();
let (deserialized, _): (Message, _) =
bincode::serde::decode_from_slice(&serialized, bincode::config::standard()).unwrap();
assert_eq!(message.id, deserialized.id);
assert_eq!(message.ttl, deserialized.ttl);
match deserialized.payload {
Payload::DirectMessage {
recipient: r,
data: d,
} => {
assert_eq!(r, recipient);
assert_eq!(d, data);
}
_ => panic!("Expected DirectMessage payload after deserialization"),
}
}
}