Skip to main content

hashtree_network/
runtime_control.rs

1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13    can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14    PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23    Connected(PeerId),
24    Failed(PeerId),
25    Disconnected(PeerId),
26}
27
28pub fn can_track_source_peer<P>(
29    source: &str,
30    peer_key: &str,
31    peers: &HashMap<String, MeshPeerEntry<P>>,
32    max_peers: Option<usize>,
33) -> bool {
34    match max_peers {
35        Some(max_peers) => can_track_signal_path_peer(
36            PeerSignalPath::from_source_name(source),
37            max_peers,
38            peer_key,
39            peers,
40        ),
41        None => true,
42    }
43}
44
45pub async fn forward_mesh_frame_from_runtime<P>(
46    runtime: &MeshRuntimeState<P>,
47    frame: &MeshNostrFrame,
48    exclude_peer_id: Option<&str>,
49) -> usize
50where
51    P: MeshSession + Clone + Send + Sync + 'static,
52{
53    let peers = runtime.peers.read().await;
54    let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
55        .values()
56        .filter(|entry| entry.state == ConnectionState::Connected)
57        .filter_map(|entry| {
58            entry.peer.as_ref().map(|peer| {
59                (
60                    entry.peer_id.to_string(),
61                    Arc::new(peer.clone()) as Arc<dyn MeshSession>,
62                )
63            })
64        })
65        .collect();
66    drop(peers);
67
68    forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
69}
70
71pub async fn create_signaling_event(
72    keys: &Keys,
73    msg: &SignalingMessage,
74    signaling_kind: u64,
75) -> Result<Event> {
76    encode_signaling_event(
77        keys,
78        msg.peer_id(),
79        msg,
80        Kind::Ephemeral(signaling_kind as u16),
81    )
82    .map_err(|e| anyhow::anyhow!(e.to_string()))
83}
84
85pub async fn handle_signaling_event<P, R, F>(
86    signaling_enabled: bool,
87    my_peer_id: &PeerId,
88    keys: &Keys,
89    runtime: &MeshRuntimeState<P>,
90    source: &str,
91    source_max_peers: Option<usize>,
92    event: &Event,
93    shared_router: Option<&Arc<MeshRouter<R, F>>>,
94) -> Result<()>
95where
96    P: MeshSession + Send + Sync + 'static,
97    R: SignalingTransport + 'static,
98    F: PeerLinkFactory + 'static,
99{
100    if !signaling_enabled {
101        return Ok(());
102    }
103
104    let Some(msg) = decode_signaling_event(
105        event,
106        &my_peer_id.to_string(),
107        &keys.public_key().to_hex(),
108        keys,
109    ) else {
110        return Ok(());
111    };
112
113    handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
114}
115
116pub async fn handle_signaling_message<P, R, F>(
117    runtime: &MeshRuntimeState<P>,
118    source: &str,
119    source_max_peers: Option<usize>,
120    msg: SignalingMessage,
121    shared_router: Option<&Arc<MeshRouter<R, F>>>,
122) -> Result<()>
123where
124    P: MeshSession + Send + Sync + 'static,
125    R: SignalingTransport + 'static,
126    F: PeerLinkFactory + 'static,
127{
128    let Some(shared_router) = shared_router else {
129        return Ok(());
130    };
131
132    if matches!(
133        msg,
134        SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
135    ) {
136        let peers = runtime.peers.read().await;
137        if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
138            return Ok(());
139        }
140    }
141
142    debug!(
143        "Received {} from {} via {}",
144        msg.msg_type(),
145        msg.peer_id(),
146        source
147    );
148    let peer_id = msg.peer_id().to_string();
149    let peer_hash_get = match &msg {
150        SignalingMessage::Hello { hash_get, .. } => Some(*hash_get),
151        _ => None,
152    };
153    shared_router
154        .handle_message(msg)
155        .await
156        .map_err(|e| anyhow::anyhow!(e.to_string()))?;
157    if let Some(hash_get) = peer_hash_get {
158        runtime.set_peer_hash_get(&peer_id, hash_get).await;
159    }
160    remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
161
162    Ok(())
163}
164
165pub async fn dispatch_signaling_message<P, S>(
166    signaling_enabled: bool,
167    keys: &Keys,
168    my_peer_id: &PeerId,
169    runtime: &MeshRuntimeState<P>,
170    relay_transport: Option<&S>,
171    local_buses: &[SharedLocalNostrBus],
172    seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
173    seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
174    msg: SignalingMessage,
175    signaling_kind: u64,
176) -> Result<()>
177where
178    P: MeshSession + Clone + Send + Sync + 'static,
179    S: SignalingTransport + Send + Sync + 'static,
180{
181    if !signaling_enabled {
182        debug!(
183            "Skipping signaling message {} because signaling is disabled",
184            msg.msg_type()
185        );
186        return Ok(());
187    }
188
189    if let Some(relay_transport) = relay_transport {
190        if let Err(err) = relay_transport.publish(msg.clone()).await {
191            debug!(
192                "Failed to publish signaling message {} via relay transport: {}",
193                msg.msg_type(),
194                err
195            );
196        }
197    }
198
199    let event = create_signaling_event(keys, &msg, signaling_kind).await?;
200
201    for bus in local_buses {
202        if let Err(err) = bus.broadcast_event(&event).await {
203            debug!(
204                "Failed to broadcast signaling event over {} ({}): {}",
205                bus.source_name(),
206                msg.msg_type(),
207                err
208            );
209        }
210    }
211
212    let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
213    if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
214        runtime.record_mesh_duplicate_drop();
215        return Ok(());
216    }
217    if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
218        runtime.record_mesh_duplicate_drop();
219        return Ok(());
220    }
221
222    frame.sender_peer_id = my_peer_id.to_string();
223    let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
224    if forwarded > 0 {
225        runtime.record_mesh_forwarded(forwarded as u64);
226    }
227
228    Ok(())
229}
230
231pub async fn handle_peer_state_event<P, R, F>(
232    runtime: &MeshRuntimeState<P>,
233    event: PeerStateEvent,
234    shared_router: Option<&Arc<MeshRouter<R, F>>>,
235) where
236    P: MeshSession + Send + Sync + 'static,
237    R: SignalingTransport + 'static,
238    F: PeerLinkFactory + 'static,
239{
240    match event {
241        PeerStateEvent::Connected(peer_id) => {
242            let peer_key = peer_id.to_string();
243            let mut emit_hello = false;
244            let mut peers = runtime.peers.write().await;
245            if let Some(entry) = peers.get_mut(&peer_key) {
246                if entry.state != ConnectionState::Connected {
247                    info!("Peer {} connected (via state event)", peer_id.short());
248                    entry.state = ConnectionState::Connected;
249                    emit_hello = true;
250                    runtime
251                        .connected_count
252                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253                }
254            }
255            drop(peers);
256            if emit_hello {
257                if let Some(shared_router) = shared_router {
258                    let _ = shared_router.send_hello(Vec::new()).await;
259                }
260            }
261        }
262        PeerStateEvent::Failed(peer_id) => {
263            remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
264        }
265        PeerStateEvent::Disconnected(peer_id) => {
266            remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
267        }
268    }
269}
270
271pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
272where
273    P: MeshSession + Send + Sync + 'static,
274{
275    let mut peers = runtime.peers.write().await;
276    let mut connected_count = 0usize;
277    let mut to_remove = Vec::new();
278
279    for (key, entry) in peers.iter_mut() {
280        if let Some(ref peer) = entry.peer {
281            if peer.is_connected() {
282                if entry.state != ConnectionState::Connected {
283                    info!(
284                        "Peer {} is now connected (sync fallback)",
285                        entry.peer_id.short()
286                    );
287                    entry.state = ConnectionState::Connected;
288                }
289                connected_count += 1;
290            } else if entry.state == ConnectionState::Connected {
291                info!(
292                    "Removing disconnected peer {} after transport closed",
293                    entry.peer_id.short()
294                );
295                to_remove.push(key.clone());
296            } else if entry.state == ConnectionState::Connecting
297                && entry.last_seen.elapsed() > stale_timeout
298            {
299                info!(
300                    "Removing stale peer {} (stuck in Connecting for {:?})",
301                    entry.peer_id.short(),
302                    entry.last_seen.elapsed()
303                );
304                to_remove.push(key.clone());
305            }
306        } else if entry.state == ConnectionState::Discovered
307            && entry.last_seen.elapsed() > stale_timeout
308        {
309            debug!("Removing stale discovered peer {}", entry.peer_id.short());
310            to_remove.push(key.clone());
311        }
312    }
313
314    let mut removed_peers = Vec::new();
315    for key in to_remove {
316        if let Some(entry) = peers.remove(&key) {
317            removed_peers.push(entry);
318        }
319    }
320    drop(peers);
321
322    for entry in removed_peers {
323        if let Some(peer) = entry.peer {
324            let _ = peer.close().await;
325        }
326    }
327
328    runtime
329        .connected_count
330        .store(connected_count, std::sync::atomic::Ordering::Relaxed);
331}
332
333async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
334    let mut seen = seen.lock().await;
335    seen.insert_if_new(id)
336}
337
338async fn remove_peer_from_runtime<P, R, F>(
339    runtime: &MeshRuntimeState<P>,
340    shared_router: Option<&Arc<MeshRouter<R, F>>>,
341    peer_id: PeerId,
342    reason: &str,
343) where
344    P: MeshSession + Send + Sync + 'static,
345    R: SignalingTransport + 'static,
346    F: PeerLinkFactory + 'static,
347{
348    let peer_key = peer_id.to_string();
349    info!("Peer {} {} - removing from pool", peer_id.short(), reason);
350    let removed = {
351        let mut peers = runtime.peers.write().await;
352        peers.remove(&peer_key)
353    };
354    runtime.clear_peer_hash_get(&peer_key).await;
355    if let Some(entry) = removed {
356        if entry.state == ConnectionState::Connected {
357            runtime
358                .connected_count
359                .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
360        }
361        if let Some(peer) = entry.peer {
362            let _ = peer.close().await;
363        }
364    }
365    if let Some(shared_router) = shared_router {
366        if let Some(channel) = shared_router.remove_peer(&peer_key).await {
367            channel.close().await;
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use anyhow::Result as AnyResult;
376    use async_trait::async_trait;
377    use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
378    use std::collections::BTreeSet;
379    use std::sync::atomic::{AtomicBool, Ordering};
380    use std::time::Instant;
381
382    use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
383    use crate::types::{PeerHTLConfig, PeerPool};
384
385    #[derive(Clone)]
386    struct TestSession {
387        connected: bool,
388        close_delay: Duration,
389        closed: Arc<AtomicBool>,
390    }
391
392    #[async_trait]
393    impl MeshSession for TestSession {
394        fn is_ready(&self) -> bool {
395            true
396        }
397
398        fn is_connected(&self) -> bool {
399            self.connected
400        }
401
402        fn htl_config(&self) -> PeerHTLConfig {
403            PeerHTLConfig::from_flags(false, false)
404        }
405
406        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
407            Ok(None)
408        }
409
410        async fn query_nostr_events(
411            &self,
412            _filters: Vec<Filter>,
413            _timeout: Duration,
414        ) -> AnyResult<Vec<Event>> {
415            Ok(Vec::new())
416        }
417
418        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
419            Ok(())
420        }
421
422        async fn close(&self) -> AnyResult<()> {
423            if !self.close_delay.is_zero() {
424                tokio::time::sleep(self.close_delay).await;
425            }
426            self.closed.store(true, Ordering::Relaxed);
427            Ok(())
428        }
429    }
430
431    #[test]
432    fn can_track_source_peer_respects_optional_limits() {
433        let peer_id = PeerId::new("peer-a".to_string());
434        let peer_key = peer_id.to_string();
435        let mut peers = HashMap::new();
436        peers.insert(
437            peer_key.clone(),
438            MeshPeerEntry {
439                peer_id,
440                direction: PeerDirection::Outbound,
441                state: ConnectionState::Discovered,
442                last_seen: Instant::now(),
443                peer: None::<TestSession>,
444                pool: PeerPool::Other,
445                transport: PeerTransport::WebRtc,
446                signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
447                bytes_sent: 0,
448                bytes_received: 0,
449            },
450        );
451
452        assert!(can_track_source_peer("relay", "peer-b", &peers, None));
453        assert!(can_track_source_peer(
454            "wifi-aware",
455            &peer_key,
456            &peers,
457            Some(1)
458        ));
459        assert!(!can_track_source_peer(
460            "wifi-aware",
461            "peer-b",
462            &peers,
463            Some(1),
464        ));
465    }
466
467    #[tokio::test]
468    async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
469        let runtime = MeshRuntimeState::<TestSession>::new();
470        let stale_id = PeerId::new("peer-stale".to_string());
471        runtime.peers.write().await.insert(
472            stale_id.to_string(),
473            MeshPeerEntry {
474                peer_id: stale_id,
475                direction: PeerDirection::Outbound,
476                state: ConnectionState::Discovered,
477                last_seen: Instant::now() - Duration::from_secs(120),
478                peer: None,
479                pool: PeerPool::Other,
480                transport: PeerTransport::WebRtc,
481                signal_paths: BTreeSet::new(),
482                bytes_sent: 0,
483                bytes_received: 0,
484            },
485        );
486
487        let active_id = PeerId::new("peer-active".to_string());
488        runtime.peers.write().await.insert(
489            active_id.to_string(),
490            MeshPeerEntry {
491                peer_id: active_id.clone(),
492                direction: PeerDirection::Outbound,
493                state: ConnectionState::Connecting,
494                last_seen: Instant::now(),
495                peer: Some(TestSession {
496                    connected: true,
497                    close_delay: Duration::ZERO,
498                    closed: Arc::new(AtomicBool::new(false)),
499                }),
500                pool: PeerPool::Other,
501                transport: PeerTransport::WebRtc,
502                signal_paths: BTreeSet::new(),
503                bytes_sent: 0,
504                bytes_received: 0,
505            },
506        );
507
508        cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
509
510        let peers = runtime.peers.read().await;
511        assert!(!peers.contains_key("peer-stale"));
512        assert_eq!(
513            peers.get(&active_id.to_string()).unwrap().state,
514            ConnectionState::Connected
515        );
516        assert_eq!(
517            runtime
518                .connected_count
519                .load(std::sync::atomic::Ordering::Relaxed),
520            1
521        );
522    }
523
524    #[tokio::test]
525    async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
526        let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
527        let peer_id = PeerId::new("peer-a-pub".to_string());
528        runtime.peers.write().await.insert(
529            peer_id.to_string(),
530            MeshPeerEntry {
531                peer_id: peer_id.clone(),
532                direction: PeerDirection::Outbound,
533                state: ConnectionState::Connected,
534                last_seen: Instant::now(),
535                peer: Some(TestSession {
536                    connected: false,
537                    close_delay: Duration::from_millis(200),
538                    closed: Arc::new(AtomicBool::new(false)),
539                }),
540                pool: PeerPool::Other,
541                transport: PeerTransport::Bluetooth,
542                signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
543                bytes_sent: 0,
544                bytes_received: 0,
545            },
546        );
547
548        let runtime_for_task = runtime.clone();
549        let peer_id_for_task = peer_id.clone();
550        let cleanup_task = tokio::spawn(async move {
551            handle_peer_state_event::<
552                TestSession,
553                crate::mock::MockRelayTransport,
554                crate::mock::MockConnectionFactory,
555            >(
556                runtime_for_task.as_ref(),
557                PeerStateEvent::Failed(peer_id_for_task),
558                None,
559            )
560            .await;
561        });
562
563        tokio::time::sleep(Duration::from_millis(20)).await;
564
565        let remaining = tokio::time::timeout(Duration::from_millis(50), async {
566            runtime.peers.read().await.len()
567        })
568        .await
569        .expect("peer map read should not block on close");
570
571        assert_eq!(remaining, 0);
572        cleanup_task.await.expect("cleanup task");
573    }
574
575    #[tokio::test]
576    async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
577        let runtime = MeshRuntimeState::<TestSession>::new();
578        let closed = Arc::new(AtomicBool::new(false));
579        let peer_id = PeerId::new("peer-a".to_string());
580        runtime.peers.write().await.insert(
581            peer_id.to_string(),
582            MeshPeerEntry {
583                peer_id: peer_id.clone(),
584                direction: PeerDirection::Outbound,
585                state: ConnectionState::Connected,
586                last_seen: Instant::now(),
587                peer: Some(TestSession {
588                    connected: true,
589                    close_delay: Duration::ZERO,
590                    closed: closed.clone(),
591                }),
592                pool: PeerPool::Other,
593                transport: PeerTransport::WebRtc,
594                signal_paths: BTreeSet::new(),
595                bytes_sent: 0,
596                bytes_received: 0,
597            },
598        );
599        let keys = Keys::generate();
600        let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
601            .to_event(&keys)
602            .unwrap();
603        let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
604
605        let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
606        assert_eq!(forwarded, 1);
607        assert!(!closed.load(Ordering::Relaxed));
608    }
609}