Skip to main content

hashtree_network/
runtime_control.rs

1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13    can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14    PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23    Connected(PeerId),
24    Failed(PeerId),
25    Disconnected(PeerId),
26}
27
28pub fn can_track_source_peer<P>(
29    source: &str,
30    peer_key: &str,
31    peers: &HashMap<String, MeshPeerEntry<P>>,
32    max_peers: Option<usize>,
33) -> bool {
34    match max_peers {
35        Some(max_peers) => can_track_signal_path_peer(
36            PeerSignalPath::from_source_name(source),
37            max_peers,
38            peer_key,
39            peers,
40        ),
41        None => true,
42    }
43}
44
45pub async fn forward_mesh_frame_from_runtime<P>(
46    runtime: &MeshRuntimeState<P>,
47    frame: &MeshNostrFrame,
48    exclude_peer_id: Option<&str>,
49) -> usize
50where
51    P: MeshSession + Clone + Send + Sync + 'static,
52{
53    let peers = runtime.peers.read().await;
54    let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
55        .values()
56        .filter(|entry| entry.state == ConnectionState::Connected)
57        .filter_map(|entry| {
58            entry.peer.as_ref().map(|peer| {
59                (
60                    entry.peer_id.to_string(),
61                    Arc::new(peer.clone()) as Arc<dyn MeshSession>,
62                )
63            })
64        })
65        .collect();
66    drop(peers);
67
68    forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
69}
70
71pub async fn create_signaling_event(
72    keys: &Keys,
73    msg: &SignalingMessage,
74    signaling_kind: u64,
75) -> Result<Event> {
76    encode_signaling_event(
77        keys,
78        msg.peer_id(),
79        msg,
80        Kind::Ephemeral(signaling_kind as u16),
81    )
82    .map_err(|e| anyhow::anyhow!(e.to_string()))
83}
84
85pub async fn handle_signaling_event<P, R, F>(
86    signaling_enabled: bool,
87    my_peer_id: &PeerId,
88    keys: &Keys,
89    runtime: &MeshRuntimeState<P>,
90    source: &str,
91    source_max_peers: Option<usize>,
92    event: &Event,
93    shared_router: Option<&Arc<MeshRouter<R, F>>>,
94) -> Result<()>
95where
96    P: MeshSession + Send + Sync + 'static,
97    R: SignalingTransport + 'static,
98    F: PeerLinkFactory + 'static,
99{
100    if !signaling_enabled {
101        return Ok(());
102    }
103
104    let Some(msg) = decode_signaling_event(
105        event,
106        &my_peer_id.to_string(),
107        &keys.public_key().to_hex(),
108        keys,
109    ) else {
110        return Ok(());
111    };
112
113    handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
114}
115
116pub async fn handle_signaling_message<P, R, F>(
117    runtime: &MeshRuntimeState<P>,
118    source: &str,
119    source_max_peers: Option<usize>,
120    msg: SignalingMessage,
121    shared_router: Option<&Arc<MeshRouter<R, F>>>,
122) -> Result<()>
123where
124    P: MeshSession + Send + Sync + 'static,
125    R: SignalingTransport + 'static,
126    F: PeerLinkFactory + 'static,
127{
128    let Some(shared_router) = shared_router else {
129        return Ok(());
130    };
131
132    if matches!(
133        msg,
134        SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
135    ) {
136        let peers = runtime.peers.read().await;
137        if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
138            return Ok(());
139        }
140    }
141
142    debug!(
143        "Received {} from {} via {}",
144        msg.msg_type(),
145        msg.peer_id(),
146        source
147    );
148    let peer_id = msg.peer_id().to_string();
149    shared_router
150        .handle_message(msg)
151        .await
152        .map_err(|e| anyhow::anyhow!(e.to_string()))?;
153    remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
154
155    Ok(())
156}
157
158pub async fn dispatch_signaling_message<P, S>(
159    signaling_enabled: bool,
160    keys: &Keys,
161    my_peer_id: &PeerId,
162    runtime: &MeshRuntimeState<P>,
163    relay_transport: Option<&S>,
164    local_buses: &[SharedLocalNostrBus],
165    seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
166    seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
167    msg: SignalingMessage,
168    signaling_kind: u64,
169) -> Result<()>
170where
171    P: MeshSession + Clone + Send + Sync + 'static,
172    S: SignalingTransport + Send + Sync + 'static,
173{
174    if !signaling_enabled {
175        debug!(
176            "Skipping signaling message {} because signaling is disabled",
177            msg.msg_type()
178        );
179        return Ok(());
180    }
181
182    if let Some(relay_transport) = relay_transport {
183        if let Err(err) = relay_transport.publish(msg.clone()).await {
184            debug!(
185                "Failed to publish signaling message {} via relay transport: {}",
186                msg.msg_type(),
187                err
188            );
189        }
190    }
191
192    let event = create_signaling_event(keys, &msg, signaling_kind).await?;
193
194    for bus in local_buses {
195        if let Err(err) = bus.broadcast_event(&event).await {
196            debug!(
197                "Failed to broadcast signaling event over {} ({}): {}",
198                bus.source_name(),
199                msg.msg_type(),
200                err
201            );
202        }
203    }
204
205    let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
206    if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
207        runtime.record_mesh_duplicate_drop();
208        return Ok(());
209    }
210    if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
211        runtime.record_mesh_duplicate_drop();
212        return Ok(());
213    }
214
215    frame.sender_peer_id = my_peer_id.to_string();
216    let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
217    if forwarded > 0 {
218        runtime.record_mesh_forwarded(forwarded as u64);
219    }
220
221    Ok(())
222}
223
224pub async fn handle_peer_state_event<P, R, F>(
225    runtime: &MeshRuntimeState<P>,
226    event: PeerStateEvent,
227    shared_router: Option<&Arc<MeshRouter<R, F>>>,
228) where
229    P: MeshSession + Send + Sync + 'static,
230    R: SignalingTransport + 'static,
231    F: PeerLinkFactory + 'static,
232{
233    match event {
234        PeerStateEvent::Connected(peer_id) => {
235            let peer_key = peer_id.to_string();
236            let mut emit_hello = false;
237            let mut peers = runtime.peers.write().await;
238            if let Some(entry) = peers.get_mut(&peer_key) {
239                if entry.state != ConnectionState::Connected {
240                    info!("Peer {} connected (via state event)", peer_id.short());
241                    entry.state = ConnectionState::Connected;
242                    emit_hello = true;
243                    runtime
244                        .connected_count
245                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
246                }
247            }
248            drop(peers);
249            if emit_hello {
250                if let Some(shared_router) = shared_router {
251                    let _ = shared_router.send_hello(Vec::new()).await;
252                }
253            }
254        }
255        PeerStateEvent::Failed(peer_id) => {
256            remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
257        }
258        PeerStateEvent::Disconnected(peer_id) => {
259            remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
260        }
261    }
262}
263
264pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
265where
266    P: MeshSession + Send + Sync + 'static,
267{
268    let mut peers = runtime.peers.write().await;
269    let mut connected_count = 0usize;
270    let mut to_remove = Vec::new();
271
272    for (key, entry) in peers.iter_mut() {
273        if let Some(ref peer) = entry.peer {
274            if peer.is_connected() {
275                if entry.state != ConnectionState::Connected {
276                    info!(
277                        "Peer {} is now connected (sync fallback)",
278                        entry.peer_id.short()
279                    );
280                    entry.state = ConnectionState::Connected;
281                }
282                connected_count += 1;
283            } else if entry.state == ConnectionState::Connected {
284                info!(
285                    "Removing disconnected peer {} after transport closed",
286                    entry.peer_id.short()
287                );
288                to_remove.push(key.clone());
289            } else if entry.state == ConnectionState::Connecting
290                && entry.last_seen.elapsed() > stale_timeout
291            {
292                info!(
293                    "Removing stale peer {} (stuck in Connecting for {:?})",
294                    entry.peer_id.short(),
295                    entry.last_seen.elapsed()
296                );
297                to_remove.push(key.clone());
298            }
299        } else if entry.state == ConnectionState::Discovered
300            && entry.last_seen.elapsed() > stale_timeout
301        {
302            debug!("Removing stale discovered peer {}", entry.peer_id.short());
303            to_remove.push(key.clone());
304        }
305    }
306
307    let mut removed_peers = Vec::new();
308    for key in to_remove {
309        if let Some(entry) = peers.remove(&key) {
310            removed_peers.push(entry);
311        }
312    }
313    drop(peers);
314
315    for entry in removed_peers {
316        if let Some(peer) = entry.peer {
317            let _ = peer.close().await;
318        }
319    }
320
321    runtime
322        .connected_count
323        .store(connected_count, std::sync::atomic::Ordering::Relaxed);
324}
325
326async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
327    let mut seen = seen.lock().await;
328    seen.insert_if_new(id)
329}
330
331async fn remove_peer_from_runtime<P, R, F>(
332    runtime: &MeshRuntimeState<P>,
333    shared_router: Option<&Arc<MeshRouter<R, F>>>,
334    peer_id: PeerId,
335    reason: &str,
336) where
337    P: MeshSession + Send + Sync + 'static,
338    R: SignalingTransport + 'static,
339    F: PeerLinkFactory + 'static,
340{
341    let peer_key = peer_id.to_string();
342    info!("Peer {} {} - removing from pool", peer_id.short(), reason);
343    let removed = {
344        let mut peers = runtime.peers.write().await;
345        peers.remove(&peer_key)
346    };
347    if let Some(entry) = removed {
348        if entry.state == ConnectionState::Connected {
349            runtime
350                .connected_count
351                .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
352        }
353        if let Some(peer) = entry.peer {
354            let _ = peer.close().await;
355        }
356    }
357    if let Some(shared_router) = shared_router {
358        if let Some(channel) = shared_router.remove_peer(&peer_key).await {
359            channel.close().await;
360        }
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use anyhow::Result as AnyResult;
368    use async_trait::async_trait;
369    use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
370    use std::collections::BTreeSet;
371    use std::sync::atomic::{AtomicBool, Ordering};
372    use std::time::Instant;
373
374    use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
375    use crate::types::{PeerHTLConfig, PeerPool};
376
377    #[derive(Clone)]
378    struct TestSession {
379        connected: bool,
380        close_delay: Duration,
381        closed: Arc<AtomicBool>,
382    }
383
384    #[async_trait]
385    impl MeshSession for TestSession {
386        fn is_ready(&self) -> bool {
387            true
388        }
389
390        fn is_connected(&self) -> bool {
391            self.connected
392        }
393
394        fn htl_config(&self) -> PeerHTLConfig {
395            PeerHTLConfig::from_flags(false, false)
396        }
397
398        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
399            Ok(None)
400        }
401
402        async fn query_nostr_events(
403            &self,
404            _filters: Vec<Filter>,
405            _timeout: Duration,
406        ) -> AnyResult<Vec<Event>> {
407            Ok(Vec::new())
408        }
409
410        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
411            Ok(())
412        }
413
414        async fn close(&self) -> AnyResult<()> {
415            if !self.close_delay.is_zero() {
416                tokio::time::sleep(self.close_delay).await;
417            }
418            self.closed.store(true, Ordering::Relaxed);
419            Ok(())
420        }
421    }
422
423    #[test]
424    fn can_track_source_peer_respects_optional_limits() {
425        let peer_id = PeerId::new("peer-a".to_string());
426        let peer_key = peer_id.to_string();
427        let mut peers = HashMap::new();
428        peers.insert(
429            peer_key.clone(),
430            MeshPeerEntry {
431                peer_id,
432                direction: PeerDirection::Outbound,
433                state: ConnectionState::Discovered,
434                last_seen: Instant::now(),
435                peer: None::<TestSession>,
436                pool: PeerPool::Other,
437                transport: PeerTransport::WebRtc,
438                signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
439                bytes_sent: 0,
440                bytes_received: 0,
441            },
442        );
443
444        assert!(can_track_source_peer("relay", "peer-b", &peers, None));
445        assert!(can_track_source_peer(
446            "wifi-aware",
447            &peer_key,
448            &peers,
449            Some(1)
450        ));
451        assert!(!can_track_source_peer(
452            "wifi-aware",
453            "peer-b",
454            &peers,
455            Some(1),
456        ));
457    }
458
459    #[tokio::test]
460    async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
461        let runtime = MeshRuntimeState::<TestSession>::new();
462        let stale_id = PeerId::new("peer-stale".to_string());
463        runtime.peers.write().await.insert(
464            stale_id.to_string(),
465            MeshPeerEntry {
466                peer_id: stale_id,
467                direction: PeerDirection::Outbound,
468                state: ConnectionState::Discovered,
469                last_seen: Instant::now() - Duration::from_secs(120),
470                peer: None,
471                pool: PeerPool::Other,
472                transport: PeerTransport::WebRtc,
473                signal_paths: BTreeSet::new(),
474                bytes_sent: 0,
475                bytes_received: 0,
476            },
477        );
478
479        let active_id = PeerId::new("peer-active".to_string());
480        runtime.peers.write().await.insert(
481            active_id.to_string(),
482            MeshPeerEntry {
483                peer_id: active_id.clone(),
484                direction: PeerDirection::Outbound,
485                state: ConnectionState::Connecting,
486                last_seen: Instant::now(),
487                peer: Some(TestSession {
488                    connected: true,
489                    close_delay: Duration::ZERO,
490                    closed: Arc::new(AtomicBool::new(false)),
491                }),
492                pool: PeerPool::Other,
493                transport: PeerTransport::WebRtc,
494                signal_paths: BTreeSet::new(),
495                bytes_sent: 0,
496                bytes_received: 0,
497            },
498        );
499
500        cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
501
502        let peers = runtime.peers.read().await;
503        assert!(!peers.contains_key("peer-stale"));
504        assert_eq!(
505            peers.get(&active_id.to_string()).unwrap().state,
506            ConnectionState::Connected
507        );
508        assert_eq!(
509            runtime
510                .connected_count
511                .load(std::sync::atomic::Ordering::Relaxed),
512            1
513        );
514    }
515
516    #[tokio::test]
517    async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
518        let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
519        let peer_id = PeerId::new("peer-a-pub".to_string());
520        runtime.peers.write().await.insert(
521            peer_id.to_string(),
522            MeshPeerEntry {
523                peer_id: peer_id.clone(),
524                direction: PeerDirection::Outbound,
525                state: ConnectionState::Connected,
526                last_seen: Instant::now(),
527                peer: Some(TestSession {
528                    connected: false,
529                    close_delay: Duration::from_millis(200),
530                    closed: Arc::new(AtomicBool::new(false)),
531                }),
532                pool: PeerPool::Other,
533                transport: PeerTransport::Bluetooth,
534                signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
535                bytes_sent: 0,
536                bytes_received: 0,
537            },
538        );
539
540        let runtime_for_task = runtime.clone();
541        let peer_id_for_task = peer_id.clone();
542        let cleanup_task = tokio::spawn(async move {
543            handle_peer_state_event::<
544                TestSession,
545                crate::mock::MockRelayTransport,
546                crate::mock::MockConnectionFactory,
547            >(
548                runtime_for_task.as_ref(),
549                PeerStateEvent::Failed(peer_id_for_task),
550                None,
551            )
552            .await;
553        });
554
555        tokio::time::sleep(Duration::from_millis(20)).await;
556
557        let remaining = tokio::time::timeout(Duration::from_millis(50), async {
558            runtime.peers.read().await.len()
559        })
560        .await
561        .expect("peer map read should not block on close");
562
563        assert_eq!(remaining, 0);
564        cleanup_task.await.expect("cleanup task");
565    }
566
567    #[tokio::test]
568    async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
569        let runtime = MeshRuntimeState::<TestSession>::new();
570        let closed = Arc::new(AtomicBool::new(false));
571        let peer_id = PeerId::new("peer-a".to_string());
572        runtime.peers.write().await.insert(
573            peer_id.to_string(),
574            MeshPeerEntry {
575                peer_id: peer_id.clone(),
576                direction: PeerDirection::Outbound,
577                state: ConnectionState::Connected,
578                last_seen: Instant::now(),
579                peer: Some(TestSession {
580                    connected: true,
581                    close_delay: Duration::ZERO,
582                    closed: closed.clone(),
583                }),
584                pool: PeerPool::Other,
585                transport: PeerTransport::WebRtc,
586                signal_paths: BTreeSet::new(),
587                bytes_sent: 0,
588                bytes_received: 0,
589            },
590        );
591        let keys = Keys::generate();
592        let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
593            .to_event(&keys)
594            .unwrap();
595        let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
596
597        let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
598        assert_eq!(forwarded, 1);
599        assert!(!closed.load(Ordering::Relaxed));
600    }
601}