Skip to main content

hashtree_network/
runtime_peer.rs

1use std::collections::{BTreeSet, HashMap};
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::time::Instant;
5
6use tokio::sync::RwLock;
7
8use crate::mesh_session::MeshSession;
9use crate::types::{PeerId, PeerPool, PoolSettings};
10
11pub type PeerClassifier = Arc<dyn Fn(&str) -> PeerPool + Send + Sync>;
12
13/// Active data transport used for a peer session.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub enum PeerTransport {
16    WebRtc,
17    Bluetooth,
18}
19
20impl PeerTransport {
21    pub const fn as_str(self) -> &'static str {
22        match self {
23            PeerTransport::WebRtc => "webrtc",
24            PeerTransport::Bluetooth => "bluetooth",
25        }
26    }
27}
28
29impl std::fmt::Display for PeerTransport {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.write_str((*self).as_str())
32    }
33}
34
35/// Signaling/discovery path through which a peer was seen.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
37pub enum PeerSignalPath {
38    Relay,
39    Multicast,
40    WifiAware,
41    Bluetooth,
42}
43
44impl PeerSignalPath {
45    pub const fn as_str(self) -> &'static str {
46        match self {
47            PeerSignalPath::Relay => "relay",
48            PeerSignalPath::Multicast => "multicast",
49            PeerSignalPath::WifiAware => "wifi-aware",
50            PeerSignalPath::Bluetooth => "bluetooth",
51        }
52    }
53
54    pub fn from_source_name(source: &str) -> Self {
55        match source {
56            "multicast" => PeerSignalPath::Multicast,
57            "wifi-aware" => PeerSignalPath::WifiAware,
58            "bluetooth" => PeerSignalPath::Bluetooth,
59            _ => PeerSignalPath::Relay,
60        }
61    }
62}
63
64impl std::fmt::Display for PeerSignalPath {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.write_str((*self).as_str())
67    }
68}
69
70/// Direction of peer connection.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PeerDirection {
73    Inbound,
74    Outbound,
75}
76
77impl std::fmt::Display for PeerDirection {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            PeerDirection::Inbound => write!(f, "inbound"),
81            PeerDirection::Outbound => write!(f, "outbound"),
82        }
83    }
84}
85
86/// Connection state for a peer.
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum ConnectionState {
89    Discovered,
90    Connecting,
91    Connected,
92    Failed,
93}
94
95impl std::fmt::Display for ConnectionState {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self {
98            ConnectionState::Discovered => write!(f, "discovered"),
99            ConnectionState::Connecting => write!(f, "connecting"),
100            ConnectionState::Connected => write!(f, "connected"),
101            ConnectionState::Failed => write!(f, "failed"),
102        }
103    }
104}
105
106/// Shared peer entry for transport-backed runtime peers.
107pub struct MeshPeerEntry<P> {
108    pub peer_id: PeerId,
109    pub direction: PeerDirection,
110    pub state: ConnectionState,
111    pub last_seen: Instant,
112    pub peer: Option<P>,
113    pub pool: PeerPool,
114    pub transport: PeerTransport,
115    pub signal_paths: BTreeSet<PeerSignalPath>,
116    pub bytes_sent: u64,
117    pub bytes_received: u64,
118}
119
120pub async fn remember_peer_signal_path<P>(
121    peers: &RwLock<HashMap<String, MeshPeerEntry<P>>>,
122    peer_id: &str,
123    source: &str,
124) {
125    if let Some(entry) = peers.write().await.get_mut(peer_id) {
126        entry
127            .signal_paths
128            .insert(PeerSignalPath::from_source_name(source));
129    }
130}
131
132pub fn can_track_signal_path_peer<P>(
133    signal_path: PeerSignalPath,
134    max_peers: usize,
135    peer_key: &str,
136    peers: &HashMap<String, MeshPeerEntry<P>>,
137) -> bool {
138    if peers.contains_key(peer_key) {
139        return true;
140    }
141    if max_peers == 0 {
142        return false;
143    }
144    peers
145        .values()
146        .filter(|entry| {
147            entry.signal_paths.contains(&signal_path) && entry.state != ConnectionState::Failed
148        })
149        .count()
150        < max_peers
151}
152
153/// Shared registrar for transport-native direct peer sessions.
154#[derive(Clone)]
155pub struct TransportPeerRegistrar<P> {
156    peers: Arc<RwLock<HashMap<String, MeshPeerEntry<P>>>>,
157    connected_count: Arc<AtomicUsize>,
158    peer_classifier: PeerClassifier,
159    pools: PoolSettings,
160    transport: PeerTransport,
161    signal_path: PeerSignalPath,
162    max_transport_peers: usize,
163}
164
165impl<P> TransportPeerRegistrar<P>
166where
167    P: MeshSession + Send + Sync + 'static,
168{
169    pub fn new(
170        peers: Arc<RwLock<HashMap<String, MeshPeerEntry<P>>>>,
171        connected_count: Arc<AtomicUsize>,
172        peer_classifier: PeerClassifier,
173        pools: PoolSettings,
174        transport: PeerTransport,
175        signal_path: PeerSignalPath,
176        max_transport_peers: usize,
177    ) -> Self {
178        Self {
179            peers,
180            connected_count,
181            peer_classifier,
182            pools,
183            transport,
184            signal_path,
185            max_transport_peers,
186        }
187    }
188
189    async fn pool_counts(&self) -> (usize, usize) {
190        let peers = self.peers.read().await;
191        let mut follows = 0usize;
192        let mut other = 0usize;
193        for entry in peers.values() {
194            if entry.state != ConnectionState::Connected {
195                continue;
196            }
197            match entry.pool {
198                PeerPool::Follows => follows += 1,
199                PeerPool::Other => other += 1,
200            }
201        }
202        (follows, other)
203    }
204
205    async fn transport_peer_count(&self, peer_key: &str) -> usize {
206        let peers = self.peers.read().await;
207        peers
208            .values()
209            .filter(|entry| entry.transport == self.transport)
210            .filter(|entry| entry.state == ConnectionState::Connected)
211            .filter(|entry| entry.peer_id.to_string() != peer_key)
212            .count()
213    }
214
215    pub async fn register_connected_peer(
216        &self,
217        peer_id: PeerId,
218        direction: PeerDirection,
219        peer: P,
220    ) -> bool {
221        let peer_key = peer_id.to_string();
222        let pool = (self.peer_classifier)(&peer_id.pubkey);
223        let (follows, other) = self.pool_counts().await;
224        let can_accept_pool = match pool {
225            PeerPool::Follows => follows < self.pools.follows.max_connections,
226            PeerPool::Other => other < self.pools.other.max_connections,
227        };
228        if !can_accept_pool {
229            return false;
230        }
231
232        if self.max_transport_peers == 0
233            || self.transport_peer_count(&peer_key).await >= self.max_transport_peers
234        {
235            return false;
236        }
237
238        let mut peers = self.peers.write().await;
239        let duplicate_keys = peers
240            .iter()
241            .filter(|(key, entry)| {
242                key.as_str() != peer_key
243                    && entry.transport == self.transport
244                    && entry.peer_id.pubkey == peer_id.pubkey
245            })
246            .map(|(key, _)| key.clone())
247            .collect::<Vec<_>>();
248        let was_connected = peers
249            .get(&peer_key)
250            .map(|entry| entry.state == ConnectionState::Connected)
251            .unwrap_or(false);
252        let replaced = peers.insert(
253            peer_key,
254            MeshPeerEntry {
255                peer_id,
256                direction,
257                state: ConnectionState::Connected,
258                last_seen: Instant::now(),
259                peer: Some(peer),
260                pool,
261                transport: self.transport,
262                signal_paths: BTreeSet::from([self.signal_path]),
263                bytes_sent: 0,
264                bytes_received: 0,
265            },
266        );
267        let removed_duplicates = duplicate_keys
268            .into_iter()
269            .filter_map(|key| peers.remove(&key))
270            .collect::<Vec<_>>();
271        drop(peers);
272
273        if let Some(previous) = replaced.and_then(|entry| entry.peer) {
274            let _ = previous.close().await;
275        }
276        for duplicate in &removed_duplicates {
277            if let Some(peer) = duplicate.peer.as_ref() {
278                let _ = peer.close().await;
279            }
280        }
281
282        let removed_connected_duplicates = removed_duplicates
283            .iter()
284            .filter(|entry| entry.state == ConnectionState::Connected)
285            .count() as isize;
286        let connected_delta =
287            1isize - if was_connected { 1 } else { 0 } - removed_connected_duplicates;
288        if connected_delta > 0 {
289            self.connected_count
290                .fetch_add(connected_delta as usize, Ordering::Relaxed);
291        } else if connected_delta < 0 {
292            self.connected_count
293                .fetch_sub((-connected_delta) as usize, Ordering::Relaxed);
294        }
295        true
296    }
297
298    pub async fn unregister_peer(&self, peer_id: &PeerId) {
299        let peer_key = peer_id.to_string();
300        let removed = self.peers.write().await.remove(&peer_key);
301        self.finish_unregister(removed).await;
302    }
303
304    pub async fn unregister_peer_if<F>(&self, peer_id: &PeerId, predicate: F)
305    where
306        F: FnOnce(&P) -> bool + Send,
307    {
308        let peer_key = peer_id.to_string();
309        let removed = {
310            let mut peers = self.peers.write().await;
311            let matches_current = peers
312                .get(&peer_key)
313                .and_then(|entry| entry.peer.as_ref())
314                .map(predicate)
315                .unwrap_or(false);
316            if matches_current {
317                peers.remove(&peer_key)
318            } else {
319                None
320            }
321        };
322        self.finish_unregister(removed).await;
323    }
324
325    async fn finish_unregister(&self, removed: Option<MeshPeerEntry<P>>) {
326        if let Some(entry) = removed {
327            if entry.state == ConnectionState::Connected {
328                self.connected_count.fetch_sub(1, Ordering::Relaxed);
329            }
330            if let Some(peer) = entry.peer {
331                let _ = peer.close().await;
332            }
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use anyhow::Result;
341    use async_trait::async_trait;
342    use nostr_sdk::nostr::{Event, Filter};
343    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
344    use std::time::Duration;
345
346    use crate::types::{MeshNostrFrame, PeerHTLConfig, PoolConfig};
347
348    struct TestSession {
349        closed: AtomicBool,
350    }
351
352    impl TestSession {
353        fn new() -> Self {
354            Self {
355                closed: AtomicBool::new(false),
356            }
357        }
358
359        fn is_closed(&self) -> bool {
360            self.closed.load(AtomicOrdering::Relaxed)
361        }
362    }
363
364    #[async_trait]
365    impl MeshSession for Arc<TestSession> {
366        fn is_ready(&self) -> bool {
367            true
368        }
369
370        fn is_connected(&self) -> bool {
371            true
372        }
373
374        fn htl_config(&self) -> PeerHTLConfig {
375            PeerHTLConfig::from_flags(false, false)
376        }
377
378        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> Result<Option<Vec<u8>>> {
379            Ok(None)
380        }
381
382        async fn query_nostr_events(
383            &self,
384            _filters: Vec<Filter>,
385            _timeout: Duration,
386        ) -> Result<Vec<Event>> {
387            Ok(Vec::new())
388        }
389
390        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> Result<()> {
391            Ok(())
392        }
393
394        async fn close(&self) -> Result<()> {
395            self.closed.store(true, AtomicOrdering::Relaxed);
396            Ok(())
397        }
398    }
399
400    fn test_pools() -> PoolSettings {
401        PoolSettings {
402            follows: PoolConfig {
403                max_connections: 4,
404                satisfied_connections: 0,
405            },
406            other: PoolConfig {
407                max_connections: 4,
408                satisfied_connections: 0,
409            },
410        }
411    }
412
413    fn test_registrar() -> (
414        TransportPeerRegistrar<Arc<TestSession>>,
415        Arc<RwLock<HashMap<String, MeshPeerEntry<Arc<TestSession>>>>>,
416        Arc<AtomicUsize>,
417    ) {
418        let peers = Arc::new(RwLock::new(HashMap::new()));
419        let connected_count = Arc::new(AtomicUsize::new(0));
420        let registrar = TransportPeerRegistrar::new(
421            peers.clone(),
422            connected_count.clone(),
423            Arc::new(|_| PeerPool::Other),
424            test_pools(),
425            PeerTransport::Bluetooth,
426            PeerSignalPath::Bluetooth,
427            2,
428        );
429        (registrar, peers, connected_count)
430    }
431
432    #[tokio::test]
433    async fn register_connected_peer_closes_replaced_session() {
434        let (registrar, _peers, _connected_count) = test_registrar();
435        let peer_id = PeerId::new("peer-pub".to_string());
436        let first = Arc::new(TestSession::new());
437        let second = Arc::new(TestSession::new());
438
439        assert!(
440            registrar
441                .register_connected_peer(peer_id.clone(), PeerDirection::Outbound, first.clone())
442                .await
443        );
444        assert!(
445            registrar
446                .register_connected_peer(peer_id, PeerDirection::Outbound, second)
447                .await
448        );
449
450        assert!(first.is_closed());
451    }
452
453    #[tokio::test]
454    async fn register_connected_peer_replaces_existing_transport_session_for_same_pubkey() {
455        let (registrar, peers, connected_count) = test_registrar();
456        let first_peer_id = PeerId::new("peer-pub".to_string());
457        let second_peer_id = PeerId::new("peer-pub".to_string());
458        let first = Arc::new(TestSession::new());
459        let second = Arc::new(TestSession::new());
460
461        assert!(
462            registrar
463                .register_connected_peer(
464                    first_peer_id.clone(),
465                    PeerDirection::Outbound,
466                    first.clone(),
467                )
468                .await
469        );
470        assert!(
471            registrar
472                .register_connected_peer(second_peer_id.clone(), PeerDirection::Outbound, second,)
473                .await
474        );
475
476        assert!(first.is_closed());
477        let peers = peers.read().await;
478        assert!(peers.contains_key(&second_peer_id.to_string()));
479        assert_eq!(peers.len(), 1);
480        assert_eq!(connected_count.load(Ordering::Relaxed), 1);
481    }
482
483    #[tokio::test]
484    async fn unregister_peer_if_respects_current_predicate() {
485        let (registrar, peers, connected_count) = test_registrar();
486        let peer_id = PeerId::new("peer-pub".to_string());
487        let session = Arc::new(TestSession::new());
488
489        assert!(
490            registrar
491                .register_connected_peer(peer_id.clone(), PeerDirection::Outbound, session.clone(),)
492                .await
493        );
494        registrar
495            .unregister_peer_if(&peer_id, |current| Arc::ptr_eq(current, &session))
496            .await;
497
498        assert!(session.is_closed());
499        assert!(!peers.read().await.contains_key(&peer_id.to_string()));
500        assert_eq!(connected_count.load(Ordering::Relaxed), 0);
501    }
502
503    #[test]
504    fn can_track_signal_path_peer_enforces_limit() {
505        let existing_peer = PeerId::new("peer-a".to_string());
506        let existing_key = existing_peer.to_string();
507        let mut peers = HashMap::new();
508        peers.insert(
509            existing_key.clone(),
510            MeshPeerEntry::<Arc<TestSession>> {
511                peer_id: existing_peer,
512                direction: PeerDirection::Outbound,
513                state: ConnectionState::Discovered,
514                last_seen: Instant::now(),
515                peer: None,
516                pool: PeerPool::Other,
517                transport: PeerTransport::WebRtc,
518                signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
519                bytes_sent: 0,
520                bytes_received: 0,
521            },
522        );
523
524        assert!(can_track_signal_path_peer(
525            PeerSignalPath::WifiAware,
526            1,
527            &existing_key,
528            &peers
529        ));
530        assert!(!can_track_signal_path_peer(
531            PeerSignalPath::WifiAware,
532            1,
533            "peer-b",
534            &peers
535        ));
536        assert!(can_track_signal_path_peer(
537            PeerSignalPath::Relay,
538            1,
539            "peer-c",
540            &peers
541        ));
542    }
543}