#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeMap, vec::Vec};
#[cfg(feature = "std")]
use std::collections::HashMap;
use crate::NodeId;
pub const RELAY_ENVELOPE_MARKER: u8 = 0xB1;
pub const DEFAULT_MAX_HOPS: u8 = 7;
pub const DEFAULT_SEEN_TTL_MS: u64 = 300_000;
pub const MAX_CACHE_SIZE: usize = 1000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(not(feature = "std"), derive(Ord, PartialOrd))]
pub struct MessageId([u8; 16]);
impl MessageId {
#[cfg(feature = "std")]
pub fn new() -> Self {
use std::time::SystemTime;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let mut id = [0u8; 16];
id[0..8].copy_from_slice(&now.to_le_bytes()[0..8]);
let mut seed = now as u64;
for i in 0..8 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
id[8 + i] = (seed >> 32) as u8;
}
Self(id)
}
pub fn from_bytes(bytes: [u8; 16]) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8; 16] {
&self.0
}
pub fn from_content(origin: NodeId, timestamp_ms: u64, payload_hash: u32) -> Self {
let mut id = [0u8; 16];
id[0..4].copy_from_slice(&origin.as_u32().to_le_bytes());
id[4..12].copy_from_slice(×tamp_ms.to_le_bytes());
id[12..16].copy_from_slice(&payload_hash.to_le_bytes());
Self(id)
}
}
#[cfg(feature = "std")]
impl Default for MessageId {
fn default() -> Self {
Self::new()
}
}
impl core::fmt::Display for MessageId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7]
)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct RelayFlags {
pub requires_ack: bool,
pub is_broadcast: bool,
}
impl RelayFlags {
pub fn to_byte(&self) -> u8 {
let mut flags = 0u8;
if self.requires_ack {
flags |= 0x01;
}
if self.is_broadcast {
flags |= 0x02;
}
flags
}
pub fn from_byte(byte: u8) -> Self {
Self {
requires_ack: byte & 0x01 != 0,
is_broadcast: byte & 0x02 != 0,
}
}
}
#[derive(Debug, Clone)]
pub struct RelayEnvelope {
pub message_id: MessageId,
pub hop_count: u8,
pub max_hops: u8,
pub origin_node: NodeId,
pub flags: RelayFlags,
pub payload: Vec<u8>,
}
impl RelayEnvelope {
#[cfg(feature = "std")]
pub fn new(origin_node: NodeId, payload: Vec<u8>) -> Self {
Self {
message_id: MessageId::new(),
hop_count: 0,
max_hops: DEFAULT_MAX_HOPS,
origin_node,
flags: RelayFlags::default(),
payload,
}
}
#[cfg(feature = "std")]
pub fn broadcast(origin_node: NodeId, payload: Vec<u8>) -> Self {
Self {
message_id: MessageId::new(),
hop_count: 0,
max_hops: DEFAULT_MAX_HOPS,
origin_node,
flags: RelayFlags {
requires_ack: false,
is_broadcast: true,
},
payload,
}
}
pub fn with_max_hops(mut self, max_hops: u8) -> Self {
self.max_hops = max_hops;
self
}
pub fn can_relay(&self) -> bool {
self.hop_count < self.max_hops
}
pub fn remaining_hops(&self) -> u8 {
self.max_hops.saturating_sub(self.hop_count)
}
pub fn relay(&self) -> Option<Self> {
if !self.can_relay() {
return None;
}
Some(Self {
message_id: self.message_id,
hop_count: self.hop_count + 1,
max_hops: self.max_hops,
origin_node: self.origin_node,
flags: self.flags,
payload: self.payload.clone(),
})
}
pub fn encode(&self) -> Vec<u8> {
let size = 28 + self.payload.len(); let mut buf = Vec::with_capacity(size);
buf.push(RELAY_ENVELOPE_MARKER);
buf.push(self.flags.to_byte());
buf.extend_from_slice(self.message_id.as_bytes());
buf.push(self.hop_count);
buf.push(self.max_hops);
buf.extend_from_slice(&self.origin_node.as_u32().to_le_bytes());
buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
buf.extend_from_slice(&self.payload);
buf
}
pub fn decode(data: &[u8]) -> Option<Self> {
if data.len() < 28 {
return None;
}
if data[0] != RELAY_ENVELOPE_MARKER {
return None;
}
let flags = RelayFlags::from_byte(data[1]);
let mut id_bytes = [0u8; 16];
id_bytes.copy_from_slice(&data[2..18]);
let message_id = MessageId::from_bytes(id_bytes);
let hop_count = data[18];
let max_hops = data[19];
let origin_node = NodeId::new(u32::from_le_bytes([data[20], data[21], data[22], data[23]]));
let payload_len = u32::from_le_bytes([data[24], data[25], data[26], data[27]]) as usize;
if data.len() < 28 + payload_len {
return None;
}
let payload = data[28..28 + payload_len].to_vec();
Some(Self {
message_id,
hop_count,
max_hops,
origin_node,
flags,
payload,
})
}
pub fn is_relay_envelope(data: &[u8]) -> bool {
!data.is_empty() && data[0] == RELAY_ENVELOPE_MARKER
}
}
#[derive(Debug, Clone)]
struct SeenEntry {
first_seen_ms: u64,
count: u32,
origin: NodeId,
}
#[cfg(feature = "std")]
#[derive(Debug)]
pub struct SeenMessageCache {
cache: HashMap<MessageId, SeenEntry>,
ttl_ms: u64,
last_cleanup_ms: u64,
}
#[cfg(feature = "std")]
impl SeenMessageCache {
pub fn new() -> Self {
Self {
cache: HashMap::new(),
ttl_ms: DEFAULT_SEEN_TTL_MS,
last_cleanup_ms: 0,
}
}
pub fn with_ttl(ttl_ms: u64) -> Self {
Self {
cache: HashMap::new(),
ttl_ms,
last_cleanup_ms: 0,
}
}
pub fn has_seen(&self, message_id: &MessageId) -> bool {
self.cache.contains_key(message_id)
}
pub fn mark_seen(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
if now_ms.saturating_sub(self.last_cleanup_ms) > self.ttl_ms / 2 {
self.cleanup(now_ms);
}
if let Some(entry) = self.cache.get_mut(&message_id) {
entry.count += 1;
false } else {
self.cache.insert(
message_id,
SeenEntry {
first_seen_ms: now_ms,
count: 1,
origin,
},
);
true }
}
pub fn check_and_mark(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
self.mark_seen(message_id, origin, now_ms)
}
pub fn cleanup(&mut self, now_ms: u64) {
self.last_cleanup_ms = now_ms;
self.cache
.retain(|_, entry| now_ms.saturating_sub(entry.first_seen_ms) < self.ttl_ms);
if self.cache.len() > MAX_CACHE_SIZE {
let mut entries: Vec<_> = self.cache.iter().collect();
entries.sort_by_key(|(_, e)| e.first_seen_ms);
let to_remove: Vec<_> = entries
.iter()
.take(self.cache.len() - MAX_CACHE_SIZE / 2)
.map(|(id, _)| **id)
.collect();
for id in to_remove {
self.cache.remove(&id);
}
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn clear(&mut self) {
self.cache.clear();
}
pub fn get_stats(&self, message_id: &MessageId) -> Option<(u64, u32, NodeId)> {
self.cache
.get(message_id)
.map(|e| (e.first_seen_ms, e.count, e.origin))
}
}
#[cfg(feature = "std")]
impl Default for SeenMessageCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(feature = "std"))]
#[derive(Debug)]
pub struct SeenMessageCache {
cache: BTreeMap<MessageId, SeenEntry>,
ttl_ms: u64,
last_cleanup_ms: u64,
}
#[cfg(not(feature = "std"))]
impl SeenMessageCache {
pub fn new() -> Self {
Self {
cache: BTreeMap::new(),
ttl_ms: DEFAULT_SEEN_TTL_MS,
last_cleanup_ms: 0,
}
}
pub fn with_ttl(ttl_ms: u64) -> Self {
Self {
cache: BTreeMap::new(),
ttl_ms,
last_cleanup_ms: 0,
}
}
pub fn has_seen(&self, message_id: &MessageId) -> bool {
self.cache.contains_key(message_id)
}
pub fn mark_seen(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
if now_ms.saturating_sub(self.last_cleanup_ms) > self.ttl_ms / 2 {
self.cleanup(now_ms);
}
if let Some(entry) = self.cache.get_mut(&message_id) {
entry.count += 1;
false
} else {
self.cache.insert(
message_id,
SeenEntry {
first_seen_ms: now_ms,
count: 1,
origin,
},
);
true
}
}
pub fn check_and_mark(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
self.mark_seen(message_id, origin, now_ms)
}
pub fn cleanup(&mut self, now_ms: u64) {
self.last_cleanup_ms = now_ms;
let expired: Vec<_> = self
.cache
.iter()
.filter(|(_, e)| now_ms.saturating_sub(e.first_seen_ms) >= self.ttl_ms)
.map(|(id, _)| *id)
.collect();
for id in expired {
self.cache.remove(&id);
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn clear(&mut self) {
self.cache.clear();
}
}
#[cfg(not(feature = "std"))]
impl Default for SeenMessageCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_id_from_content() {
let origin = NodeId::new(0x12345678);
let id1 = MessageId::from_content(origin, 1000, 0xDEADBEEF);
let id2 = MessageId::from_content(origin, 1000, 0xDEADBEEF);
let id3 = MessageId::from_content(origin, 1001, 0xDEADBEEF);
assert_eq!(id1, id2); assert_ne!(id1, id3); }
#[test]
fn test_relay_flags() {
let flags = RelayFlags {
requires_ack: true,
is_broadcast: false,
};
let byte = flags.to_byte();
let decoded = RelayFlags::from_byte(byte);
assert!(decoded.requires_ack);
assert!(!decoded.is_broadcast);
let flags = RelayFlags {
requires_ack: false,
is_broadcast: true,
};
let byte = flags.to_byte();
let decoded = RelayFlags::from_byte(byte);
assert!(!decoded.requires_ack);
assert!(decoded.is_broadcast);
}
#[test]
fn test_relay_envelope_encode_decode() {
let origin = NodeId::new(0x12345678);
let payload = vec![1, 2, 3, 4, 5];
let envelope = RelayEnvelope::new(origin, payload.clone());
let encoded = envelope.encode();
let decoded = RelayEnvelope::decode(&encoded).unwrap();
assert_eq!(decoded.message_id, envelope.message_id);
assert_eq!(decoded.hop_count, 0);
assert_eq!(decoded.max_hops, DEFAULT_MAX_HOPS);
assert_eq!(decoded.origin_node, origin);
assert_eq!(decoded.payload, payload);
}
#[test]
fn test_relay_envelope_hop_tracking() {
let origin = NodeId::new(0x12345678);
let envelope = RelayEnvelope::new(origin, vec![1, 2, 3]).with_max_hops(3);
assert!(envelope.can_relay());
assert_eq!(envelope.remaining_hops(), 3);
let relayed = envelope.relay().unwrap();
assert_eq!(relayed.hop_count, 1);
assert!(relayed.can_relay());
let relayed = relayed.relay().unwrap();
assert_eq!(relayed.hop_count, 2);
assert!(relayed.can_relay());
let relayed = relayed.relay().unwrap();
assert_eq!(relayed.hop_count, 3);
assert!(!relayed.can_relay());
assert!(relayed.relay().is_none()); }
#[test]
fn test_is_relay_envelope() {
let data = vec![RELAY_ENVELOPE_MARKER, 0, 0, 0];
assert!(RelayEnvelope::is_relay_envelope(&data));
let data = vec![0x00, 0, 0, 0];
assert!(!RelayEnvelope::is_relay_envelope(&data));
let data: Vec<u8> = vec![];
assert!(!RelayEnvelope::is_relay_envelope(&data));
}
#[test]
fn test_seen_cache_basic() {
let mut cache = SeenMessageCache::new();
let origin = NodeId::new(0x12345678);
let id1 = MessageId::from_content(origin, 1000, 0xAABBCCDD);
let id2 = MessageId::from_content(origin, 1001, 0xAABBCCDD);
assert!(cache.check_and_mark(id1, origin, 1000));
assert!(!cache.has_seen(&id2));
assert!(!cache.check_and_mark(id1, origin, 1001));
assert!(cache.check_and_mark(id2, origin, 1002));
assert_eq!(cache.len(), 2);
}
#[test]
fn test_seen_cache_cleanup() {
let mut cache = SeenMessageCache::with_ttl(1000); let origin = NodeId::new(0x12345678);
let id1 = MessageId::from_content(origin, 1000, 0x11111111);
let id2 = MessageId::from_content(origin, 2000, 0x22222222);
cache.mark_seen(id1, origin, 0);
assert_eq!(cache.len(), 1);
cache.mark_seen(id2, origin, 500);
assert_eq!(cache.len(), 2);
cache.cleanup(1001);
assert_eq!(cache.len(), 1);
assert!(!cache.has_seen(&id1));
assert!(cache.has_seen(&id2));
cache.cleanup(1501);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_seen_cache_stats() {
let mut cache = SeenMessageCache::new();
let origin = NodeId::new(0x12345678);
let id = MessageId::from_content(origin, 1000, 0xDEADBEEF);
cache.mark_seen(id, origin, 1000);
let (first_seen, count, orig) = cache.get_stats(&id).unwrap();
assert_eq!(first_seen, 1000);
assert_eq!(count, 1);
assert_eq!(orig, origin);
cache.mark_seen(id, origin, 2000);
let (first_seen, count, _) = cache.get_stats(&id).unwrap();
assert_eq!(first_seen, 1000); assert_eq!(count, 2);
}
}