Skip to main content

hashtree_network/
runtime_control.rs

1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13    can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14    PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23    Connected(PeerId),
24    SignalHints(PeerId, Vec<String>),
25    Failed(PeerId),
26    Disconnected(PeerId),
27}
28
29pub fn can_track_source_peer<P>(
30    source: &str,
31    peer_key: &str,
32    peers: &HashMap<String, MeshPeerEntry<P>>,
33    max_peers: Option<usize>,
34) -> bool {
35    match max_peers {
36        Some(max_peers) => can_track_signal_path_peer(
37            PeerSignalPath::from_source_name(source),
38            max_peers,
39            peer_key,
40            peers,
41        ),
42        None => true,
43    }
44}
45
46pub async fn forward_mesh_frame_from_runtime<P>(
47    runtime: &MeshRuntimeState<P>,
48    frame: &MeshNostrFrame,
49    exclude_peer_id: Option<&str>,
50) -> usize
51where
52    P: MeshSession + Clone + Send + Sync + 'static,
53{
54    let peers = runtime.peers.read().await;
55    let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
56        .values()
57        .filter(|entry| entry.state == ConnectionState::Connected)
58        .filter_map(|entry| {
59            entry.peer.as_ref().map(|peer| {
60                (
61                    entry.peer_id.to_string(),
62                    Arc::new(peer.clone()) as Arc<dyn MeshSession>,
63                )
64            })
65        })
66        .collect();
67    drop(peers);
68
69    forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
70}
71
72pub async fn create_signaling_event(
73    keys: &Keys,
74    msg: &SignalingMessage,
75    signaling_kind: u64,
76) -> Result<Event> {
77    encode_signaling_event(
78        keys,
79        msg.peer_id(),
80        msg,
81        Kind::from_u16(signaling_kind as u16),
82    )
83    .map_err(|e| anyhow::anyhow!(e.to_string()))
84}
85
86#[allow(clippy::too_many_arguments)]
87pub async fn handle_signaling_event<P, R, F>(
88    signaling_enabled: bool,
89    my_peer_id: &PeerId,
90    keys: &Keys,
91    runtime: &MeshRuntimeState<P>,
92    source: &str,
93    source_max_peers: Option<usize>,
94    event: &Event,
95    shared_router: Option<&Arc<MeshRouter<R, F>>>,
96) -> Result<()>
97where
98    P: MeshSession + Send + Sync + 'static,
99    R: SignalingTransport + 'static,
100    F: PeerLinkFactory + 'static,
101{
102    if !signaling_enabled {
103        return Ok(());
104    }
105
106    let Some(msg) = decode_signaling_event(
107        event,
108        &my_peer_id.to_string(),
109        &keys.public_key().to_hex(),
110        keys,
111    ) else {
112        return Ok(());
113    };
114
115    handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
116}
117
118pub async fn handle_signaling_message<P, R, F>(
119    runtime: &MeshRuntimeState<P>,
120    source: &str,
121    source_max_peers: Option<usize>,
122    msg: SignalingMessage,
123    shared_router: Option<&Arc<MeshRouter<R, F>>>,
124) -> Result<()>
125where
126    P: MeshSession + Send + Sync + 'static,
127    R: SignalingTransport + 'static,
128    F: PeerLinkFactory + 'static,
129{
130    let Some(shared_router) = shared_router else {
131        return Ok(());
132    };
133
134    if matches!(
135        msg,
136        SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
137    ) {
138        let peers = runtime.peers.read().await;
139        if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
140            return Ok(());
141        }
142    }
143
144    debug!(
145        "Received {} from {} via {}",
146        msg.msg_type(),
147        msg.peer_id(),
148        source
149    );
150    let peer_id = msg.peer_id().to_string();
151    let peer_hash_get = match &msg {
152        SignalingMessage::Hello { hash_get, .. } => Some(*hash_get),
153        _ => None,
154    };
155    shared_router
156        .handle_message(msg)
157        .await
158        .map_err(|e| anyhow::anyhow!(e.to_string()))?;
159    if let Some(hash_get) = peer_hash_get {
160        runtime.set_peer_hash_get(&peer_id, hash_get).await;
161    }
162    remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
163
164    Ok(())
165}
166
167#[allow(clippy::too_many_arguments)]
168pub async fn dispatch_signaling_message<P, S>(
169    signaling_enabled: bool,
170    keys: &Keys,
171    my_peer_id: &PeerId,
172    runtime: &MeshRuntimeState<P>,
173    relay_transport: Option<&S>,
174    local_buses: &[SharedLocalNostrBus],
175    seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
176    seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
177    msg: SignalingMessage,
178    signaling_kind: u64,
179) -> Result<()>
180where
181    P: MeshSession + Clone + Send + Sync + 'static,
182    S: SignalingTransport + Send + Sync + 'static,
183{
184    if !signaling_enabled {
185        debug!(
186            "Skipping signaling message {} because signaling is disabled",
187            msg.msg_type()
188        );
189        return Ok(());
190    }
191
192    let event = create_signaling_event(keys, &msg, signaling_kind).await?;
193
194    for bus in local_buses {
195        if let Err(err) = bus.broadcast_event(&event).await {
196            debug!(
197                "Failed to broadcast signaling event over {} ({}): {}",
198                bus.source_name(),
199                msg.msg_type(),
200                err
201            );
202        }
203    }
204
205    let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
206    if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
207        runtime.record_mesh_duplicate_drop();
208        return Ok(());
209    }
210    if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
211        runtime.record_mesh_duplicate_drop();
212        return Ok(());
213    }
214
215    frame.sender_peer_id = my_peer_id.to_string();
216    let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
217    if forwarded > 0 {
218        runtime.record_mesh_forwarded(forwarded as u64);
219    }
220
221    if let Some(relay_transport) = relay_transport {
222        match tokio::time::timeout(
223            Duration::from_millis(500),
224            relay_transport.publish(msg.clone()),
225        )
226        .await
227        {
228            Ok(Ok(())) => {}
229            Ok(Err(err)) => {
230                debug!(
231                    "Failed to publish signaling message {} via relay transport: {}",
232                    msg.msg_type(),
233                    err
234                );
235            }
236            Err(_) => {
237                debug!(
238                    "Timed out publishing signaling message {} via relay transport",
239                    msg.msg_type()
240                );
241            }
242        }
243    }
244
245    Ok(())
246}
247
248pub async fn handle_peer_state_event<P, R, F>(
249    runtime: &MeshRuntimeState<P>,
250    event: PeerStateEvent,
251    shared_router: Option<&Arc<MeshRouter<R, F>>>,
252) where
253    P: MeshSession + Send + Sync + 'static,
254    R: SignalingTransport + 'static,
255    F: PeerLinkFactory + 'static,
256{
257    match event {
258        PeerStateEvent::Connected(peer_id) => {
259            let peer_key = peer_id.to_string();
260            let mut emit_hello = false;
261            let mut peers = runtime.peers.write().await;
262            if let Some(entry) = peers.get_mut(&peer_key) {
263                if entry.state != ConnectionState::Connected {
264                    info!("Peer {} connected (via state event)", peer_id.short());
265                    entry.state = ConnectionState::Connected;
266                    emit_hello = true;
267                    runtime
268                        .connected_count
269                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270                }
271            }
272            drop(peers);
273            if emit_hello {
274                if let Some(shared_router) = shared_router {
275                    let _ = shared_router.send_hello(Vec::new()).await;
276                }
277            }
278        }
279        PeerStateEvent::SignalHints(peer_id, signal_urls) => {
280            runtime
281                .record_known_peer_signal_urls(&peer_id.to_string(), &signal_urls, "peer-link")
282                .await;
283        }
284        PeerStateEvent::Failed(peer_id) => {
285            remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
286        }
287        PeerStateEvent::Disconnected(peer_id) => {
288            remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
289        }
290    }
291}
292
293pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
294where
295    P: MeshSession + Send + Sync + 'static,
296{
297    let mut peers = runtime.peers.write().await;
298    let mut connected_count = 0usize;
299    let mut to_remove = Vec::new();
300
301    for (key, entry) in peers.iter_mut() {
302        if let Some(ref peer) = entry.peer {
303            if peer.is_connected() {
304                if entry.state != ConnectionState::Connected {
305                    info!(
306                        "Peer {} is now connected (sync fallback)",
307                        entry.peer_id.short()
308                    );
309                    entry.state = ConnectionState::Connected;
310                }
311                connected_count += 1;
312            } else if entry.state == ConnectionState::Connected {
313                info!(
314                    "Removing disconnected peer {} after transport closed",
315                    entry.peer_id.short()
316                );
317                to_remove.push(key.clone());
318            } else if entry.state == ConnectionState::Connecting
319                && entry.last_seen.elapsed() > stale_timeout
320            {
321                info!(
322                    "Removing stale peer {} (stuck in Connecting for {:?})",
323                    entry.peer_id.short(),
324                    entry.last_seen.elapsed()
325                );
326                to_remove.push(key.clone());
327            }
328        } else if entry.state == ConnectionState::Discovered
329            && entry.last_seen.elapsed() > stale_timeout
330        {
331            debug!("Removing stale discovered peer {}", entry.peer_id.short());
332            to_remove.push(key.clone());
333        }
334    }
335
336    let mut removed_peers = Vec::new();
337    for key in to_remove {
338        if let Some(entry) = peers.remove(&key) {
339            removed_peers.push(entry);
340        }
341    }
342    drop(peers);
343
344    for entry in removed_peers {
345        if let Some(peer) = entry.peer {
346            let _ = peer.close().await;
347        }
348    }
349
350    runtime
351        .connected_count
352        .store(connected_count, std::sync::atomic::Ordering::Relaxed);
353}
354
355async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
356    let mut seen = seen.lock().await;
357    seen.insert_if_new(id)
358}
359
360async fn remove_peer_from_runtime<P, R, F>(
361    runtime: &MeshRuntimeState<P>,
362    shared_router: Option<&Arc<MeshRouter<R, F>>>,
363    peer_id: PeerId,
364    reason: &str,
365) where
366    P: MeshSession + Send + Sync + 'static,
367    R: SignalingTransport + 'static,
368    F: PeerLinkFactory + 'static,
369{
370    let peer_key = peer_id.to_string();
371    info!("Peer {} {} - removing from pool", peer_id.short(), reason);
372    let removed = {
373        let mut peers = runtime.peers.write().await;
374        peers.remove(&peer_key)
375    };
376    runtime.clear_peer_hash_get(&peer_key).await;
377    if let Some(entry) = removed {
378        if entry.state == ConnectionState::Connected {
379            runtime
380                .connected_count
381                .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
382        }
383        if let Some(peer) = entry.peer {
384            let _ = peer.close().await;
385        }
386    }
387    if let Some(shared_router) = shared_router {
388        if let Some(channel) = shared_router.remove_peer(&peer_key).await {
389            channel.close().await;
390        }
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use anyhow::Result as AnyResult;
398    use async_trait::async_trait;
399    use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
400    use std::collections::BTreeSet;
401    use std::sync::atomic::{AtomicBool, Ordering};
402    use std::time::Instant;
403
404    use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
405    use crate::types::{PeerHTLConfig, PeerPool};
406
407    #[derive(Clone)]
408    struct TestSession {
409        connected: bool,
410        close_delay: Duration,
411        closed: Arc<AtomicBool>,
412    }
413
414    #[async_trait]
415    impl MeshSession for TestSession {
416        fn is_ready(&self) -> bool {
417            true
418        }
419
420        fn is_connected(&self) -> bool {
421            self.connected
422        }
423
424        fn htl_config(&self) -> PeerHTLConfig {
425            PeerHTLConfig::from_flags(false, false)
426        }
427
428        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
429            Ok(None)
430        }
431
432        async fn query_nostr_events(
433            &self,
434            _filters: Vec<Filter>,
435            _timeout: Duration,
436        ) -> AnyResult<Vec<Event>> {
437            Ok(Vec::new())
438        }
439
440        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
441            Ok(())
442        }
443
444        async fn close(&self) -> AnyResult<()> {
445            if !self.close_delay.is_zero() {
446                tokio::time::sleep(self.close_delay).await;
447            }
448            self.closed.store(true, Ordering::Relaxed);
449            Ok(())
450        }
451    }
452
453    #[test]
454    fn can_track_source_peer_respects_optional_limits() {
455        let peer_id = PeerId::new("peer-a".to_string());
456        let peer_key = peer_id.to_string();
457        let mut peers = HashMap::new();
458        peers.insert(
459            peer_key.clone(),
460            MeshPeerEntry {
461                peer_id,
462                direction: PeerDirection::Outbound,
463                state: ConnectionState::Discovered,
464                last_seen: Instant::now(),
465                peer: None::<TestSession>,
466                pool: PeerPool::Other,
467                transport: PeerTransport::Direct,
468                signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
469                bytes_sent: 0,
470                bytes_received: 0,
471            },
472        );
473
474        assert!(can_track_source_peer("relay", "peer-b", &peers, None));
475        assert!(can_track_source_peer(
476            "wifi-aware",
477            &peer_key,
478            &peers,
479            Some(1)
480        ));
481        assert!(!can_track_source_peer(
482            "wifi-aware",
483            "peer-b",
484            &peers,
485            Some(1),
486        ));
487    }
488
489    #[tokio::test]
490    async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
491        let runtime = MeshRuntimeState::<TestSession>::new();
492        let stale_id = PeerId::new("peer-stale".to_string());
493        runtime.peers.write().await.insert(
494            stale_id.to_string(),
495            MeshPeerEntry {
496                peer_id: stale_id,
497                direction: PeerDirection::Outbound,
498                state: ConnectionState::Discovered,
499                last_seen: Instant::now() - Duration::from_secs(120),
500                peer: None,
501                pool: PeerPool::Other,
502                transport: PeerTransport::Direct,
503                signal_paths: BTreeSet::new(),
504                bytes_sent: 0,
505                bytes_received: 0,
506            },
507        );
508
509        let active_id = PeerId::new("peer-active".to_string());
510        runtime.peers.write().await.insert(
511            active_id.to_string(),
512            MeshPeerEntry {
513                peer_id: active_id.clone(),
514                direction: PeerDirection::Outbound,
515                state: ConnectionState::Connecting,
516                last_seen: Instant::now(),
517                peer: Some(TestSession {
518                    connected: true,
519                    close_delay: Duration::ZERO,
520                    closed: Arc::new(AtomicBool::new(false)),
521                }),
522                pool: PeerPool::Other,
523                transport: PeerTransport::Direct,
524                signal_paths: BTreeSet::new(),
525                bytes_sent: 0,
526                bytes_received: 0,
527            },
528        );
529
530        cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
531
532        let peers = runtime.peers.read().await;
533        assert!(!peers.contains_key("peer-stale"));
534        assert_eq!(
535            peers.get(&active_id.to_string()).unwrap().state,
536            ConnectionState::Connected
537        );
538        assert_eq!(
539            runtime
540                .connected_count
541                .load(std::sync::atomic::Ordering::Relaxed),
542            1
543        );
544    }
545
546    #[tokio::test]
547    async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
548        let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
549        let peer_id = PeerId::new("peer-a-pub".to_string());
550        runtime.peers.write().await.insert(
551            peer_id.to_string(),
552            MeshPeerEntry {
553                peer_id: peer_id.clone(),
554                direction: PeerDirection::Outbound,
555                state: ConnectionState::Connected,
556                last_seen: Instant::now(),
557                peer: Some(TestSession {
558                    connected: false,
559                    close_delay: Duration::from_millis(200),
560                    closed: Arc::new(AtomicBool::new(false)),
561                }),
562                pool: PeerPool::Other,
563                transport: PeerTransport::Bluetooth,
564                signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
565                bytes_sent: 0,
566                bytes_received: 0,
567            },
568        );
569
570        let runtime_for_task = runtime.clone();
571        let peer_id_for_task = peer_id.clone();
572        let cleanup_task = tokio::spawn(async move {
573            handle_peer_state_event::<
574                TestSession,
575                crate::mock::MockRelayTransport,
576                crate::mock::MockConnectionFactory,
577            >(
578                runtime_for_task.as_ref(),
579                PeerStateEvent::Failed(peer_id_for_task),
580                None,
581            )
582            .await;
583        });
584
585        tokio::time::sleep(Duration::from_millis(20)).await;
586
587        let remaining = tokio::time::timeout(Duration::from_millis(50), async {
588            runtime.peers.read().await.len()
589        })
590        .await
591        .expect("peer map read should not block on close");
592
593        assert_eq!(remaining, 0);
594        cleanup_task.await.expect("cleanup task");
595    }
596
597    #[tokio::test]
598    async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
599        let runtime = MeshRuntimeState::<TestSession>::new();
600        let closed = Arc::new(AtomicBool::new(false));
601        let peer_id = PeerId::new("peer-a".to_string());
602        runtime.peers.write().await.insert(
603            peer_id.to_string(),
604            MeshPeerEntry {
605                peer_id: peer_id.clone(),
606                direction: PeerDirection::Outbound,
607                state: ConnectionState::Connected,
608                last_seen: Instant::now(),
609                peer: Some(TestSession {
610                    connected: true,
611                    close_delay: Duration::ZERO,
612                    closed: closed.clone(),
613                }),
614                pool: PeerPool::Other,
615                transport: PeerTransport::Direct,
616                signal_paths: BTreeSet::new(),
617                bytes_sent: 0,
618                bytes_received: 0,
619            },
620        );
621        let keys = Keys::generate();
622        let event = EventBuilder::new(Kind::Custom(25050), "mesh")
623            .sign_with_keys(&keys)
624            .unwrap();
625        let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
626
627        let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
628        assert_eq!(forwarded, 1);
629        assert!(!closed.load(Ordering::Relaxed));
630    }
631}