Skip to main content

hashtree_network/
runtime_control.rs

1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13    can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14    PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23    Connected(PeerId),
24    SignalHints(PeerId, Vec<String>),
25    Failed(PeerId),
26    Disconnected(PeerId),
27}
28
29pub fn can_track_source_peer<P>(
30    source: &str,
31    peer_key: &str,
32    peers: &HashMap<String, MeshPeerEntry<P>>,
33    max_peers: Option<usize>,
34) -> bool {
35    match max_peers {
36        Some(max_peers) => can_track_signal_path_peer(
37            PeerSignalPath::from_source_name(source),
38            max_peers,
39            peer_key,
40            peers,
41        ),
42        None => true,
43    }
44}
45
46pub async fn forward_mesh_frame_from_runtime<P>(
47    runtime: &MeshRuntimeState<P>,
48    frame: &MeshNostrFrame,
49    exclude_peer_id: Option<&str>,
50) -> usize
51where
52    P: MeshSession + Clone + Send + Sync + 'static,
53{
54    let peers = runtime.peers.read().await;
55    let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
56        .values()
57        .filter(|entry| entry.state == ConnectionState::Connected)
58        .filter_map(|entry| {
59            entry.peer.as_ref().map(|peer| {
60                (
61                    entry.peer_id.to_string(),
62                    Arc::new(peer.clone()) as Arc<dyn MeshSession>,
63                )
64            })
65        })
66        .collect();
67    drop(peers);
68
69    forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
70}
71
72pub async fn create_signaling_event(
73    keys: &Keys,
74    msg: &SignalingMessage,
75    signaling_kind: u64,
76) -> Result<Event> {
77    encode_signaling_event(
78        keys,
79        msg.peer_id(),
80        msg,
81        Kind::Ephemeral(signaling_kind as u16),
82    )
83    .map_err(|e| anyhow::anyhow!(e.to_string()))
84}
85
86pub async fn handle_signaling_event<P, R, F>(
87    signaling_enabled: bool,
88    my_peer_id: &PeerId,
89    keys: &Keys,
90    runtime: &MeshRuntimeState<P>,
91    source: &str,
92    source_max_peers: Option<usize>,
93    event: &Event,
94    shared_router: Option<&Arc<MeshRouter<R, F>>>,
95) -> Result<()>
96where
97    P: MeshSession + Send + Sync + 'static,
98    R: SignalingTransport + 'static,
99    F: PeerLinkFactory + 'static,
100{
101    if !signaling_enabled {
102        return Ok(());
103    }
104
105    let Some(msg) = decode_signaling_event(
106        event,
107        &my_peer_id.to_string(),
108        &keys.public_key().to_hex(),
109        keys,
110    ) else {
111        return Ok(());
112    };
113
114    handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
115}
116
117pub async fn handle_signaling_message<P, R, F>(
118    runtime: &MeshRuntimeState<P>,
119    source: &str,
120    source_max_peers: Option<usize>,
121    msg: SignalingMessage,
122    shared_router: Option<&Arc<MeshRouter<R, F>>>,
123) -> Result<()>
124where
125    P: MeshSession + Send + Sync + 'static,
126    R: SignalingTransport + 'static,
127    F: PeerLinkFactory + 'static,
128{
129    let Some(shared_router) = shared_router else {
130        return Ok(());
131    };
132
133    if matches!(
134        msg,
135        SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
136    ) {
137        let peers = runtime.peers.read().await;
138        if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
139            return Ok(());
140        }
141    }
142
143    debug!(
144        "Received {} from {} via {}",
145        msg.msg_type(),
146        msg.peer_id(),
147        source
148    );
149    let peer_id = msg.peer_id().to_string();
150    let peer_hash_get = match &msg {
151        SignalingMessage::Hello { hash_get, .. } => Some(*hash_get),
152        _ => None,
153    };
154    shared_router
155        .handle_message(msg)
156        .await
157        .map_err(|e| anyhow::anyhow!(e.to_string()))?;
158    if let Some(hash_get) = peer_hash_get {
159        runtime.set_peer_hash_get(&peer_id, hash_get).await;
160    }
161    remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
162
163    Ok(())
164}
165
166pub async fn dispatch_signaling_message<P, S>(
167    signaling_enabled: bool,
168    keys: &Keys,
169    my_peer_id: &PeerId,
170    runtime: &MeshRuntimeState<P>,
171    relay_transport: Option<&S>,
172    local_buses: &[SharedLocalNostrBus],
173    seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
174    seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
175    msg: SignalingMessage,
176    signaling_kind: u64,
177) -> Result<()>
178where
179    P: MeshSession + Clone + Send + Sync + 'static,
180    S: SignalingTransport + Send + Sync + 'static,
181{
182    if !signaling_enabled {
183        debug!(
184            "Skipping signaling message {} because signaling is disabled",
185            msg.msg_type()
186        );
187        return Ok(());
188    }
189
190    let event = create_signaling_event(keys, &msg, signaling_kind).await?;
191
192    for bus in local_buses {
193        if let Err(err) = bus.broadcast_event(&event).await {
194            debug!(
195                "Failed to broadcast signaling event over {} ({}): {}",
196                bus.source_name(),
197                msg.msg_type(),
198                err
199            );
200        }
201    }
202
203    let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
204    if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
205        runtime.record_mesh_duplicate_drop();
206        return Ok(());
207    }
208    if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
209        runtime.record_mesh_duplicate_drop();
210        return Ok(());
211    }
212
213    frame.sender_peer_id = my_peer_id.to_string();
214    let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
215    if forwarded > 0 {
216        runtime.record_mesh_forwarded(forwarded as u64);
217    }
218
219    if let Some(relay_transport) = relay_transport {
220        match tokio::time::timeout(
221            Duration::from_millis(500),
222            relay_transport.publish(msg.clone()),
223        )
224        .await
225        {
226            Ok(Ok(())) => {}
227            Ok(Err(err)) => {
228                debug!(
229                    "Failed to publish signaling message {} via relay transport: {}",
230                    msg.msg_type(),
231                    err
232                );
233            }
234            Err(_) => {
235                debug!(
236                    "Timed out publishing signaling message {} via relay transport",
237                    msg.msg_type()
238                );
239            }
240        }
241    }
242
243    Ok(())
244}
245
246pub async fn handle_peer_state_event<P, R, F>(
247    runtime: &MeshRuntimeState<P>,
248    event: PeerStateEvent,
249    shared_router: Option<&Arc<MeshRouter<R, F>>>,
250) where
251    P: MeshSession + Send + Sync + 'static,
252    R: SignalingTransport + 'static,
253    F: PeerLinkFactory + 'static,
254{
255    match event {
256        PeerStateEvent::Connected(peer_id) => {
257            let peer_key = peer_id.to_string();
258            let mut emit_hello = false;
259            let mut peers = runtime.peers.write().await;
260            if let Some(entry) = peers.get_mut(&peer_key) {
261                if entry.state != ConnectionState::Connected {
262                    info!("Peer {} connected (via state event)", peer_id.short());
263                    entry.state = ConnectionState::Connected;
264                    emit_hello = true;
265                    runtime
266                        .connected_count
267                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
268                }
269            }
270            drop(peers);
271            if emit_hello {
272                if let Some(shared_router) = shared_router {
273                    let _ = shared_router.send_hello(Vec::new()).await;
274                }
275            }
276        }
277        PeerStateEvent::SignalHints(peer_id, signal_urls) => {
278            runtime
279                .record_known_peer_signal_urls(&peer_id.to_string(), &signal_urls, "peer-link")
280                .await;
281        }
282        PeerStateEvent::Failed(peer_id) => {
283            remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
284        }
285        PeerStateEvent::Disconnected(peer_id) => {
286            remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
287        }
288    }
289}
290
291pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
292where
293    P: MeshSession + Send + Sync + 'static,
294{
295    let mut peers = runtime.peers.write().await;
296    let mut connected_count = 0usize;
297    let mut to_remove = Vec::new();
298
299    for (key, entry) in peers.iter_mut() {
300        if let Some(ref peer) = entry.peer {
301            if peer.is_connected() {
302                if entry.state != ConnectionState::Connected {
303                    info!(
304                        "Peer {} is now connected (sync fallback)",
305                        entry.peer_id.short()
306                    );
307                    entry.state = ConnectionState::Connected;
308                }
309                connected_count += 1;
310            } else if entry.state == ConnectionState::Connected {
311                info!(
312                    "Removing disconnected peer {} after transport closed",
313                    entry.peer_id.short()
314                );
315                to_remove.push(key.clone());
316            } else if entry.state == ConnectionState::Connecting
317                && entry.last_seen.elapsed() > stale_timeout
318            {
319                info!(
320                    "Removing stale peer {} (stuck in Connecting for {:?})",
321                    entry.peer_id.short(),
322                    entry.last_seen.elapsed()
323                );
324                to_remove.push(key.clone());
325            }
326        } else if entry.state == ConnectionState::Discovered
327            && entry.last_seen.elapsed() > stale_timeout
328        {
329            debug!("Removing stale discovered peer {}", entry.peer_id.short());
330            to_remove.push(key.clone());
331        }
332    }
333
334    let mut removed_peers = Vec::new();
335    for key in to_remove {
336        if let Some(entry) = peers.remove(&key) {
337            removed_peers.push(entry);
338        }
339    }
340    drop(peers);
341
342    for entry in removed_peers {
343        if let Some(peer) = entry.peer {
344            let _ = peer.close().await;
345        }
346    }
347
348    runtime
349        .connected_count
350        .store(connected_count, std::sync::atomic::Ordering::Relaxed);
351}
352
353async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
354    let mut seen = seen.lock().await;
355    seen.insert_if_new(id)
356}
357
358async fn remove_peer_from_runtime<P, R, F>(
359    runtime: &MeshRuntimeState<P>,
360    shared_router: Option<&Arc<MeshRouter<R, F>>>,
361    peer_id: PeerId,
362    reason: &str,
363) where
364    P: MeshSession + Send + Sync + 'static,
365    R: SignalingTransport + 'static,
366    F: PeerLinkFactory + 'static,
367{
368    let peer_key = peer_id.to_string();
369    info!("Peer {} {} - removing from pool", peer_id.short(), reason);
370    let removed = {
371        let mut peers = runtime.peers.write().await;
372        peers.remove(&peer_key)
373    };
374    runtime.clear_peer_hash_get(&peer_key).await;
375    if let Some(entry) = removed {
376        if entry.state == ConnectionState::Connected {
377            runtime
378                .connected_count
379                .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
380        }
381        if let Some(peer) = entry.peer {
382            let _ = peer.close().await;
383        }
384    }
385    if let Some(shared_router) = shared_router {
386        if let Some(channel) = shared_router.remove_peer(&peer_key).await {
387            channel.close().await;
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use anyhow::Result as AnyResult;
396    use async_trait::async_trait;
397    use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
398    use std::collections::BTreeSet;
399    use std::sync::atomic::{AtomicBool, Ordering};
400    use std::time::Instant;
401
402    use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
403    use crate::types::{PeerHTLConfig, PeerPool};
404
405    #[derive(Clone)]
406    struct TestSession {
407        connected: bool,
408        close_delay: Duration,
409        closed: Arc<AtomicBool>,
410    }
411
412    #[async_trait]
413    impl MeshSession for TestSession {
414        fn is_ready(&self) -> bool {
415            true
416        }
417
418        fn is_connected(&self) -> bool {
419            self.connected
420        }
421
422        fn htl_config(&self) -> PeerHTLConfig {
423            PeerHTLConfig::from_flags(false, false)
424        }
425
426        async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
427            Ok(None)
428        }
429
430        async fn query_nostr_events(
431            &self,
432            _filters: Vec<Filter>,
433            _timeout: Duration,
434        ) -> AnyResult<Vec<Event>> {
435            Ok(Vec::new())
436        }
437
438        async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
439            Ok(())
440        }
441
442        async fn close(&self) -> AnyResult<()> {
443            if !self.close_delay.is_zero() {
444                tokio::time::sleep(self.close_delay).await;
445            }
446            self.closed.store(true, Ordering::Relaxed);
447            Ok(())
448        }
449    }
450
451    #[test]
452    fn can_track_source_peer_respects_optional_limits() {
453        let peer_id = PeerId::new("peer-a".to_string());
454        let peer_key = peer_id.to_string();
455        let mut peers = HashMap::new();
456        peers.insert(
457            peer_key.clone(),
458            MeshPeerEntry {
459                peer_id,
460                direction: PeerDirection::Outbound,
461                state: ConnectionState::Discovered,
462                last_seen: Instant::now(),
463                peer: None::<TestSession>,
464                pool: PeerPool::Other,
465                transport: PeerTransport::WebRtc,
466                signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
467                bytes_sent: 0,
468                bytes_received: 0,
469            },
470        );
471
472        assert!(can_track_source_peer("relay", "peer-b", &peers, None));
473        assert!(can_track_source_peer(
474            "wifi-aware",
475            &peer_key,
476            &peers,
477            Some(1)
478        ));
479        assert!(!can_track_source_peer(
480            "wifi-aware",
481            "peer-b",
482            &peers,
483            Some(1),
484        ));
485    }
486
487    #[tokio::test]
488    async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
489        let runtime = MeshRuntimeState::<TestSession>::new();
490        let stale_id = PeerId::new("peer-stale".to_string());
491        runtime.peers.write().await.insert(
492            stale_id.to_string(),
493            MeshPeerEntry {
494                peer_id: stale_id,
495                direction: PeerDirection::Outbound,
496                state: ConnectionState::Discovered,
497                last_seen: Instant::now() - Duration::from_secs(120),
498                peer: None,
499                pool: PeerPool::Other,
500                transport: PeerTransport::WebRtc,
501                signal_paths: BTreeSet::new(),
502                bytes_sent: 0,
503                bytes_received: 0,
504            },
505        );
506
507        let active_id = PeerId::new("peer-active".to_string());
508        runtime.peers.write().await.insert(
509            active_id.to_string(),
510            MeshPeerEntry {
511                peer_id: active_id.clone(),
512                direction: PeerDirection::Outbound,
513                state: ConnectionState::Connecting,
514                last_seen: Instant::now(),
515                peer: Some(TestSession {
516                    connected: true,
517                    close_delay: Duration::ZERO,
518                    closed: Arc::new(AtomicBool::new(false)),
519                }),
520                pool: PeerPool::Other,
521                transport: PeerTransport::WebRtc,
522                signal_paths: BTreeSet::new(),
523                bytes_sent: 0,
524                bytes_received: 0,
525            },
526        );
527
528        cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
529
530        let peers = runtime.peers.read().await;
531        assert!(!peers.contains_key("peer-stale"));
532        assert_eq!(
533            peers.get(&active_id.to_string()).unwrap().state,
534            ConnectionState::Connected
535        );
536        assert_eq!(
537            runtime
538                .connected_count
539                .load(std::sync::atomic::Ordering::Relaxed),
540            1
541        );
542    }
543
544    #[tokio::test]
545    async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
546        let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
547        let peer_id = PeerId::new("peer-a-pub".to_string());
548        runtime.peers.write().await.insert(
549            peer_id.to_string(),
550            MeshPeerEntry {
551                peer_id: peer_id.clone(),
552                direction: PeerDirection::Outbound,
553                state: ConnectionState::Connected,
554                last_seen: Instant::now(),
555                peer: Some(TestSession {
556                    connected: false,
557                    close_delay: Duration::from_millis(200),
558                    closed: Arc::new(AtomicBool::new(false)),
559                }),
560                pool: PeerPool::Other,
561                transport: PeerTransport::Bluetooth,
562                signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
563                bytes_sent: 0,
564                bytes_received: 0,
565            },
566        );
567
568        let runtime_for_task = runtime.clone();
569        let peer_id_for_task = peer_id.clone();
570        let cleanup_task = tokio::spawn(async move {
571            handle_peer_state_event::<
572                TestSession,
573                crate::mock::MockRelayTransport,
574                crate::mock::MockConnectionFactory,
575            >(
576                runtime_for_task.as_ref(),
577                PeerStateEvent::Failed(peer_id_for_task),
578                None,
579            )
580            .await;
581        });
582
583        tokio::time::sleep(Duration::from_millis(20)).await;
584
585        let remaining = tokio::time::timeout(Duration::from_millis(50), async {
586            runtime.peers.read().await.len()
587        })
588        .await
589        .expect("peer map read should not block on close");
590
591        assert_eq!(remaining, 0);
592        cleanup_task.await.expect("cleanup task");
593    }
594
595    #[tokio::test]
596    async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
597        let runtime = MeshRuntimeState::<TestSession>::new();
598        let closed = Arc::new(AtomicBool::new(false));
599        let peer_id = PeerId::new("peer-a".to_string());
600        runtime.peers.write().await.insert(
601            peer_id.to_string(),
602            MeshPeerEntry {
603                peer_id: peer_id.clone(),
604                direction: PeerDirection::Outbound,
605                state: ConnectionState::Connected,
606                last_seen: Instant::now(),
607                peer: Some(TestSession {
608                    connected: true,
609                    close_delay: Duration::ZERO,
610                    closed: closed.clone(),
611                }),
612                pool: PeerPool::Other,
613                transport: PeerTransport::WebRtc,
614                signal_paths: BTreeSet::new(),
615                bytes_sent: 0,
616                bytes_received: 0,
617            },
618        );
619        let keys = Keys::generate();
620        let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
621            .to_event(&keys)
622            .unwrap();
623        let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
624
625        let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
626        assert_eq!(forwarded, 1);
627        assert!(!closed.load(Ordering::Relaxed));
628    }
629}