hive_btle/
relay.rs

1// Copyright (c) 2025-2026 (r)evolve - Revolve Team LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Multi-hop relay support for HIVE BLE mesh
17//!
18//! This module provides message deduplication and hop tracking for multi-hop
19//! relay scenarios. Without deduplication, messages would bounce infinitely
20//! between mesh nodes.
21//!
22//! ## Wire Format
23//!
24//! Relay envelope wraps documents for multi-hop transmission:
25//!
26//! ```text
27//! [1 byte:  marker (0xB1)]
28//! [1 byte:  flags]
29//!   - bit 0: requires_ack
30//!   - bit 1: is_broadcast
31//!   - bits 2-7: reserved
32//! [16 bytes: message_id (UUID)]
33//! [1 byte:  hop_count (current)]
34//! [1 byte:  max_hops (TTL)]
35//! [4 bytes: origin_node_id]
36//! [4 bytes: payload_len]
37//! [N bytes: payload (encrypted document)]
38//! ```
39//!
40//! ## Deduplication
41//!
42//! The `SeenMessageCache` tracks message IDs with TTL expiration to prevent
43//! infinite relay loops while allowing legitimate re-transmissions after
44//! the TTL expires.
45
46#[cfg(not(feature = "std"))]
47use alloc::{collections::BTreeMap, vec::Vec};
48#[cfg(feature = "std")]
49use std::collections::HashMap;
50
51use crate::NodeId;
52
53/// Marker byte indicating relay envelope
54pub const RELAY_ENVELOPE_MARKER: u8 = 0xB1;
55
56/// Default max hops for relay messages
57pub const DEFAULT_MAX_HOPS: u8 = 7;
58
59/// Default TTL for seen messages (5 minutes in ms)
60pub const DEFAULT_SEEN_TTL_MS: u64 = 300_000;
61
62/// Maximum cache size before cleanup is forced
63pub const MAX_CACHE_SIZE: usize = 1000;
64
65/// A 128-bit message identifier for deduplication
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67#[cfg_attr(not(feature = "std"), derive(Ord, PartialOrd))]
68pub struct MessageId([u8; 16]);
69
70impl MessageId {
71    /// Create a new random message ID
72    #[cfg(feature = "std")]
73    pub fn new() -> Self {
74        use std::time::SystemTime;
75
76        // Generate pseudo-random ID from timestamp and random bits
77        let now = SystemTime::now()
78            .duration_since(SystemTime::UNIX_EPOCH)
79            .map(|d| d.as_nanos())
80            .unwrap_or(0);
81
82        let mut id = [0u8; 16];
83
84        // Use timestamp for first 8 bytes
85        id[0..8].copy_from_slice(&now.to_le_bytes()[0..8]);
86
87        // Use LCG for remaining bytes (pseudo-random but fast)
88        let mut seed = now as u64;
89        for i in 0..8 {
90            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
91            id[8 + i] = (seed >> 32) as u8;
92        }
93
94        Self(id)
95    }
96
97    /// Create from raw bytes
98    pub fn from_bytes(bytes: [u8; 16]) -> Self {
99        Self(bytes)
100    }
101
102    /// Get the raw bytes
103    pub fn as_bytes(&self) -> &[u8; 16] {
104        &self.0
105    }
106
107    /// Create a deterministic ID from content hash
108    ///
109    /// Use this when you need consistent IDs for the same content,
110    /// e.g., for idempotent operations.
111    pub fn from_content(origin: NodeId, timestamp_ms: u64, payload_hash: u32) -> Self {
112        let mut id = [0u8; 16];
113
114        // Origin node (4 bytes)
115        id[0..4].copy_from_slice(&origin.as_u32().to_le_bytes());
116
117        // Timestamp (8 bytes)
118        id[4..12].copy_from_slice(&timestamp_ms.to_le_bytes());
119
120        // Payload hash (4 bytes)
121        id[12..16].copy_from_slice(&payload_hash.to_le_bytes());
122
123        Self(id)
124    }
125}
126
127#[cfg(feature = "std")]
128impl Default for MessageId {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl core::fmt::Display for MessageId {
135    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
136        // Display as hex (first 8 bytes for brevity)
137        write!(
138            f,
139            "{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
140            self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7]
141        )
142    }
143}
144
145/// Relay envelope flags
146#[derive(Debug, Clone, Copy, Default)]
147pub struct RelayFlags {
148    /// Whether this message requires acknowledgment
149    pub requires_ack: bool,
150    /// Whether this is a broadcast (vs targeted)
151    pub is_broadcast: bool,
152}
153
154impl RelayFlags {
155    /// Encode flags to a byte
156    pub fn to_byte(&self) -> u8 {
157        let mut flags = 0u8;
158        if self.requires_ack {
159            flags |= 0x01;
160        }
161        if self.is_broadcast {
162            flags |= 0x02;
163        }
164        flags
165    }
166
167    /// Decode flags from a byte
168    pub fn from_byte(byte: u8) -> Self {
169        Self {
170            requires_ack: byte & 0x01 != 0,
171            is_broadcast: byte & 0x02 != 0,
172        }
173    }
174}
175
176/// A relay envelope wrapping a document for multi-hop transmission
177#[derive(Debug, Clone)]
178pub struct RelayEnvelope {
179    /// Unique message identifier for deduplication
180    pub message_id: MessageId,
181
182    /// Current hop count (increments with each relay)
183    pub hop_count: u8,
184
185    /// Maximum allowed hops (TTL)
186    pub max_hops: u8,
187
188    /// Original sender node ID
189    pub origin_node: NodeId,
190
191    /// Envelope flags
192    pub flags: RelayFlags,
193
194    /// The wrapped payload (typically an encrypted document)
195    pub payload: Vec<u8>,
196}
197
198impl RelayEnvelope {
199    /// Create a new relay envelope for a payload
200    #[cfg(feature = "std")]
201    pub fn new(origin_node: NodeId, payload: Vec<u8>) -> Self {
202        Self {
203            message_id: MessageId::new(),
204            hop_count: 0,
205            max_hops: DEFAULT_MAX_HOPS,
206            origin_node,
207            flags: RelayFlags::default(),
208            payload,
209        }
210    }
211
212    /// Create with broadcast flag
213    #[cfg(feature = "std")]
214    pub fn broadcast(origin_node: NodeId, payload: Vec<u8>) -> Self {
215        Self {
216            message_id: MessageId::new(),
217            hop_count: 0,
218            max_hops: DEFAULT_MAX_HOPS,
219            origin_node,
220            flags: RelayFlags {
221                requires_ack: false,
222                is_broadcast: true,
223            },
224            payload,
225        }
226    }
227
228    /// Create with custom max hops
229    pub fn with_max_hops(mut self, max_hops: u8) -> Self {
230        self.max_hops = max_hops;
231        self
232    }
233
234    /// Check if this envelope can be relayed further
235    pub fn can_relay(&self) -> bool {
236        self.hop_count < self.max_hops
237    }
238
239    /// Get remaining hops
240    pub fn remaining_hops(&self) -> u8 {
241        self.max_hops.saturating_sub(self.hop_count)
242    }
243
244    /// Create a relay copy with incremented hop count
245    ///
246    /// Returns None if TTL expired.
247    pub fn relay(&self) -> Option<Self> {
248        if !self.can_relay() {
249            return None;
250        }
251
252        Some(Self {
253            message_id: self.message_id,
254            hop_count: self.hop_count + 1,
255            max_hops: self.max_hops,
256            origin_node: self.origin_node,
257            flags: self.flags,
258            payload: self.payload.clone(),
259        })
260    }
261
262    /// Encode to bytes for transmission
263    pub fn encode(&self) -> Vec<u8> {
264        let size = 28 + self.payload.len(); // marker(1) + flags(1) + id(16) + hops(2) + origin(4) + len(4) + payload
265        let mut buf = Vec::with_capacity(size);
266
267        buf.push(RELAY_ENVELOPE_MARKER);
268        buf.push(self.flags.to_byte());
269        buf.extend_from_slice(self.message_id.as_bytes());
270        buf.push(self.hop_count);
271        buf.push(self.max_hops);
272        buf.extend_from_slice(&self.origin_node.as_u32().to_le_bytes());
273        buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
274        buf.extend_from_slice(&self.payload);
275
276        buf
277    }
278
279    /// Decode from bytes
280    pub fn decode(data: &[u8]) -> Option<Self> {
281        // Minimum size: marker(1) + flags(1) + id(16) + hops(2) + origin(4) + len(4) = 28
282        if data.len() < 28 {
283            return None;
284        }
285
286        if data[0] != RELAY_ENVELOPE_MARKER {
287            return None;
288        }
289
290        let flags = RelayFlags::from_byte(data[1]);
291
292        let mut id_bytes = [0u8; 16];
293        id_bytes.copy_from_slice(&data[2..18]);
294        let message_id = MessageId::from_bytes(id_bytes);
295
296        let hop_count = data[18];
297        let max_hops = data[19];
298
299        let origin_node = NodeId::new(u32::from_le_bytes([data[20], data[21], data[22], data[23]]));
300
301        let payload_len = u32::from_le_bytes([data[24], data[25], data[26], data[27]]) as usize;
302
303        if data.len() < 28 + payload_len {
304            return None;
305        }
306
307        let payload = data[28..28 + payload_len].to_vec();
308
309        Some(Self {
310            message_id,
311            hop_count,
312            max_hops,
313            origin_node,
314            flags,
315            payload,
316        })
317    }
318
319    /// Check if data starts with relay envelope marker
320    pub fn is_relay_envelope(data: &[u8]) -> bool {
321        !data.is_empty() && data[0] == RELAY_ENVELOPE_MARKER
322    }
323}
324
325/// Cache entry for a seen message
326#[derive(Debug, Clone)]
327struct SeenEntry {
328    /// When this message was first seen (ms)
329    first_seen_ms: u64,
330    /// How many times we've seen this message
331    count: u32,
332    /// Origin node that sent this message
333    origin: NodeId,
334}
335
336/// Cache of seen message IDs for deduplication
337///
338/// Tracks message IDs with TTL expiration to prevent infinite relay loops
339/// while allowing legitimate re-transmissions.
340#[cfg(feature = "std")]
341#[derive(Debug)]
342pub struct SeenMessageCache {
343    /// Map of message ID to entry
344    cache: HashMap<MessageId, SeenEntry>,
345    /// TTL for entries in milliseconds
346    ttl_ms: u64,
347    /// Last cleanup time
348    last_cleanup_ms: u64,
349}
350
351#[cfg(feature = "std")]
352impl SeenMessageCache {
353    /// Create a new cache with default TTL
354    pub fn new() -> Self {
355        Self {
356            cache: HashMap::new(),
357            ttl_ms: DEFAULT_SEEN_TTL_MS,
358            last_cleanup_ms: 0,
359        }
360    }
361
362    /// Create with custom TTL
363    pub fn with_ttl(ttl_ms: u64) -> Self {
364        Self {
365            cache: HashMap::new(),
366            ttl_ms,
367            last_cleanup_ms: 0,
368        }
369    }
370
371    /// Check if a message has been seen before
372    ///
373    /// Returns true if the message was already seen (should not process/relay).
374    /// Returns false if this is a new message (should process).
375    pub fn has_seen(&self, message_id: &MessageId) -> bool {
376        self.cache.contains_key(message_id)
377    }
378
379    /// Mark a message as seen
380    ///
381    /// Returns true if this is a new message (first time seen).
382    /// Returns false if we've seen this message before.
383    pub fn mark_seen(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
384        // Run cleanup periodically
385        if now_ms.saturating_sub(self.last_cleanup_ms) > self.ttl_ms / 2 {
386            self.cleanup(now_ms);
387        }
388
389        if let Some(entry) = self.cache.get_mut(&message_id) {
390            entry.count += 1;
391            false // Already seen
392        } else {
393            self.cache.insert(
394                message_id,
395                SeenEntry {
396                    first_seen_ms: now_ms,
397                    count: 1,
398                    origin,
399                },
400            );
401            true // New message
402        }
403    }
404
405    /// Check and mark in one operation
406    ///
407    /// Returns true if this is a new message that should be processed.
408    /// Returns false if the message was already seen (duplicate).
409    pub fn check_and_mark(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
410        self.mark_seen(message_id, origin, now_ms)
411    }
412
413    /// Remove expired entries
414    pub fn cleanup(&mut self, now_ms: u64) {
415        self.last_cleanup_ms = now_ms;
416
417        self.cache
418            .retain(|_, entry| now_ms.saturating_sub(entry.first_seen_ms) < self.ttl_ms);
419
420        // Force cleanup if still too large
421        if self.cache.len() > MAX_CACHE_SIZE {
422            // Remove oldest entries
423            let mut entries: Vec<_> = self.cache.iter().collect();
424            entries.sort_by_key(|(_, e)| e.first_seen_ms);
425
426            let to_remove: Vec<_> = entries
427                .iter()
428                .take(self.cache.len() - MAX_CACHE_SIZE / 2)
429                .map(|(id, _)| **id)
430                .collect();
431
432            for id in to_remove {
433                self.cache.remove(&id);
434            }
435        }
436    }
437
438    /// Get the number of entries in the cache
439    pub fn len(&self) -> usize {
440        self.cache.len()
441    }
442
443    /// Check if the cache is empty
444    pub fn is_empty(&self) -> bool {
445        self.cache.is_empty()
446    }
447
448    /// Clear all entries
449    pub fn clear(&mut self) {
450        self.cache.clear();
451    }
452
453    /// Get statistics about a message
454    pub fn get_stats(&self, message_id: &MessageId) -> Option<(u64, u32, NodeId)> {
455        self.cache
456            .get(message_id)
457            .map(|e| (e.first_seen_ms, e.count, e.origin))
458    }
459}
460
461#[cfg(feature = "std")]
462impl Default for SeenMessageCache {
463    fn default() -> Self {
464        Self::new()
465    }
466}
467
468/// No_std version using BTreeMap
469#[cfg(not(feature = "std"))]
470#[derive(Debug)]
471pub struct SeenMessageCache {
472    cache: BTreeMap<MessageId, SeenEntry>,
473    ttl_ms: u64,
474    last_cleanup_ms: u64,
475}
476
477#[cfg(not(feature = "std"))]
478impl SeenMessageCache {
479    pub fn new() -> Self {
480        Self {
481            cache: BTreeMap::new(),
482            ttl_ms: DEFAULT_SEEN_TTL_MS,
483            last_cleanup_ms: 0,
484        }
485    }
486
487    pub fn with_ttl(ttl_ms: u64) -> Self {
488        Self {
489            cache: BTreeMap::new(),
490            ttl_ms,
491            last_cleanup_ms: 0,
492        }
493    }
494
495    pub fn has_seen(&self, message_id: &MessageId) -> bool {
496        self.cache.contains_key(message_id)
497    }
498
499    pub fn mark_seen(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
500        if now_ms.saturating_sub(self.last_cleanup_ms) > self.ttl_ms / 2 {
501            self.cleanup(now_ms);
502        }
503
504        if let Some(entry) = self.cache.get_mut(&message_id) {
505            entry.count += 1;
506            false
507        } else {
508            self.cache.insert(
509                message_id,
510                SeenEntry {
511                    first_seen_ms: now_ms,
512                    count: 1,
513                    origin,
514                },
515            );
516            true
517        }
518    }
519
520    pub fn check_and_mark(&mut self, message_id: MessageId, origin: NodeId, now_ms: u64) -> bool {
521        self.mark_seen(message_id, origin, now_ms)
522    }
523
524    pub fn cleanup(&mut self, now_ms: u64) {
525        self.last_cleanup_ms = now_ms;
526
527        let expired: Vec<_> = self
528            .cache
529            .iter()
530            .filter(|(_, e)| now_ms.saturating_sub(e.first_seen_ms) >= self.ttl_ms)
531            .map(|(id, _)| *id)
532            .collect();
533
534        for id in expired {
535            self.cache.remove(&id);
536        }
537    }
538
539    pub fn len(&self) -> usize {
540        self.cache.len()
541    }
542
543    pub fn is_empty(&self) -> bool {
544        self.cache.is_empty()
545    }
546
547    pub fn clear(&mut self) {
548        self.cache.clear();
549    }
550}
551
552#[cfg(not(feature = "std"))]
553impl Default for SeenMessageCache {
554    fn default() -> Self {
555        Self::new()
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_message_id_from_content() {
565        let origin = NodeId::new(0x12345678);
566        let id1 = MessageId::from_content(origin, 1000, 0xDEADBEEF);
567        let id2 = MessageId::from_content(origin, 1000, 0xDEADBEEF);
568        let id3 = MessageId::from_content(origin, 1001, 0xDEADBEEF);
569
570        assert_eq!(id1, id2); // Same content = same ID
571        assert_ne!(id1, id3); // Different timestamp = different ID
572    }
573
574    #[test]
575    fn test_relay_flags() {
576        let flags = RelayFlags {
577            requires_ack: true,
578            is_broadcast: false,
579        };
580        let byte = flags.to_byte();
581        let decoded = RelayFlags::from_byte(byte);
582        assert!(decoded.requires_ack);
583        assert!(!decoded.is_broadcast);
584
585        let flags = RelayFlags {
586            requires_ack: false,
587            is_broadcast: true,
588        };
589        let byte = flags.to_byte();
590        let decoded = RelayFlags::from_byte(byte);
591        assert!(!decoded.requires_ack);
592        assert!(decoded.is_broadcast);
593    }
594
595    #[test]
596    fn test_relay_envelope_encode_decode() {
597        let origin = NodeId::new(0x12345678);
598        let payload = vec![1, 2, 3, 4, 5];
599        let envelope = RelayEnvelope::new(origin, payload.clone());
600
601        let encoded = envelope.encode();
602        let decoded = RelayEnvelope::decode(&encoded).unwrap();
603
604        assert_eq!(decoded.message_id, envelope.message_id);
605        assert_eq!(decoded.hop_count, 0);
606        assert_eq!(decoded.max_hops, DEFAULT_MAX_HOPS);
607        assert_eq!(decoded.origin_node, origin);
608        assert_eq!(decoded.payload, payload);
609    }
610
611    #[test]
612    fn test_relay_envelope_hop_tracking() {
613        let origin = NodeId::new(0x12345678);
614        let envelope = RelayEnvelope::new(origin, vec![1, 2, 3]).with_max_hops(3);
615
616        assert!(envelope.can_relay());
617        assert_eq!(envelope.remaining_hops(), 3);
618
619        let relayed = envelope.relay().unwrap();
620        assert_eq!(relayed.hop_count, 1);
621        assert!(relayed.can_relay());
622
623        let relayed = relayed.relay().unwrap();
624        assert_eq!(relayed.hop_count, 2);
625        assert!(relayed.can_relay());
626
627        let relayed = relayed.relay().unwrap();
628        assert_eq!(relayed.hop_count, 3);
629        assert!(!relayed.can_relay()); // TTL expired
630
631        assert!(relayed.relay().is_none()); // Cannot relay further
632    }
633
634    #[test]
635    fn test_is_relay_envelope() {
636        let data = vec![RELAY_ENVELOPE_MARKER, 0, 0, 0];
637        assert!(RelayEnvelope::is_relay_envelope(&data));
638
639        let data = vec![0x00, 0, 0, 0];
640        assert!(!RelayEnvelope::is_relay_envelope(&data));
641
642        let data: Vec<u8> = vec![];
643        assert!(!RelayEnvelope::is_relay_envelope(&data));
644    }
645
646    #[test]
647    fn test_seen_cache_basic() {
648        let mut cache = SeenMessageCache::new();
649        let origin = NodeId::new(0x12345678);
650
651        let id1 = MessageId::from_content(origin, 1000, 0xAABBCCDD);
652        let id2 = MessageId::from_content(origin, 1001, 0xAABBCCDD);
653
654        // First time seeing id1
655        assert!(cache.check_and_mark(id1, origin, 1000));
656        assert!(!cache.has_seen(&id2));
657
658        // Second time seeing id1 - should be rejected
659        assert!(!cache.check_and_mark(id1, origin, 1001));
660
661        // First time seeing id2
662        assert!(cache.check_and_mark(id2, origin, 1002));
663
664        assert_eq!(cache.len(), 2);
665    }
666
667    #[test]
668    fn test_seen_cache_cleanup() {
669        let mut cache = SeenMessageCache::with_ttl(1000); // 1 second TTL
670        let origin = NodeId::new(0x12345678);
671
672        let id1 = MessageId::from_content(origin, 1000, 0x11111111);
673        let id2 = MessageId::from_content(origin, 2000, 0x22222222);
674
675        // Add id1 at t=0
676        cache.mark_seen(id1, origin, 0);
677        assert_eq!(cache.len(), 1);
678
679        // Add id2 at t=500
680        cache.mark_seen(id2, origin, 500);
681        assert_eq!(cache.len(), 2);
682
683        // Cleanup at t=1001 - id1 should be expired (1001 - 0 = 1001 >= 1000)
684        // id2 should still be valid (1001 - 500 = 501 < 1000)
685        cache.cleanup(1001);
686        assert_eq!(cache.len(), 1);
687        assert!(!cache.has_seen(&id1));
688        assert!(cache.has_seen(&id2));
689
690        // Cleanup at t=1501 - id2 should be expired too (1501 - 500 = 1001 >= 1000)
691        cache.cleanup(1501);
692        assert_eq!(cache.len(), 0);
693    }
694
695    #[test]
696    fn test_seen_cache_stats() {
697        let mut cache = SeenMessageCache::new();
698        let origin = NodeId::new(0x12345678);
699        let id = MessageId::from_content(origin, 1000, 0xDEADBEEF);
700
701        // First mark
702        cache.mark_seen(id, origin, 1000);
703        let (first_seen, count, orig) = cache.get_stats(&id).unwrap();
704        assert_eq!(first_seen, 1000);
705        assert_eq!(count, 1);
706        assert_eq!(orig, origin);
707
708        // Second mark - count should increase
709        cache.mark_seen(id, origin, 2000);
710        let (first_seen, count, _) = cache.get_stats(&id).unwrap();
711        assert_eq!(first_seen, 1000); // Still the first time
712        assert_eq!(count, 2);
713    }
714}