1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13 can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14 PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23 Connected(PeerId),
24 SignalHints(PeerId, Vec<String>),
25 Failed(PeerId),
26 Disconnected(PeerId),
27}
28
29pub fn can_track_source_peer<P>(
30 source: &str,
31 peer_key: &str,
32 peers: &HashMap<String, MeshPeerEntry<P>>,
33 max_peers: Option<usize>,
34) -> bool {
35 match max_peers {
36 Some(max_peers) => can_track_signal_path_peer(
37 PeerSignalPath::from_source_name(source),
38 max_peers,
39 peer_key,
40 peers,
41 ),
42 None => true,
43 }
44}
45
46pub async fn forward_mesh_frame_from_runtime<P>(
47 runtime: &MeshRuntimeState<P>,
48 frame: &MeshNostrFrame,
49 exclude_peer_id: Option<&str>,
50) -> usize
51where
52 P: MeshSession + Clone + Send + Sync + 'static,
53{
54 let peers = runtime.peers.read().await;
55 let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
56 .values()
57 .filter(|entry| entry.state == ConnectionState::Connected)
58 .filter_map(|entry| {
59 entry.peer.as_ref().map(|peer| {
60 (
61 entry.peer_id.to_string(),
62 Arc::new(peer.clone()) as Arc<dyn MeshSession>,
63 )
64 })
65 })
66 .collect();
67 drop(peers);
68
69 forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
70}
71
72pub async fn create_signaling_event(
73 keys: &Keys,
74 msg: &SignalingMessage,
75 signaling_kind: u64,
76) -> Result<Event> {
77 encode_signaling_event(
78 keys,
79 msg.peer_id(),
80 msg,
81 Kind::Ephemeral(signaling_kind as u16),
82 )
83 .map_err(|e| anyhow::anyhow!(e.to_string()))
84}
85
86#[allow(clippy::too_many_arguments)]
87pub async fn handle_signaling_event<P, R, F>(
88 signaling_enabled: bool,
89 my_peer_id: &PeerId,
90 keys: &Keys,
91 runtime: &MeshRuntimeState<P>,
92 source: &str,
93 source_max_peers: Option<usize>,
94 event: &Event,
95 shared_router: Option<&Arc<MeshRouter<R, F>>>,
96) -> Result<()>
97where
98 P: MeshSession + Send + Sync + 'static,
99 R: SignalingTransport + 'static,
100 F: PeerLinkFactory + 'static,
101{
102 if !signaling_enabled {
103 return Ok(());
104 }
105
106 let Some(msg) = decode_signaling_event(
107 event,
108 &my_peer_id.to_string(),
109 &keys.public_key().to_hex(),
110 keys,
111 ) else {
112 return Ok(());
113 };
114
115 handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
116}
117
118pub async fn handle_signaling_message<P, R, F>(
119 runtime: &MeshRuntimeState<P>,
120 source: &str,
121 source_max_peers: Option<usize>,
122 msg: SignalingMessage,
123 shared_router: Option<&Arc<MeshRouter<R, F>>>,
124) -> Result<()>
125where
126 P: MeshSession + Send + Sync + 'static,
127 R: SignalingTransport + 'static,
128 F: PeerLinkFactory + 'static,
129{
130 let Some(shared_router) = shared_router else {
131 return Ok(());
132 };
133
134 if matches!(
135 msg,
136 SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
137 ) {
138 let peers = runtime.peers.read().await;
139 if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
140 return Ok(());
141 }
142 }
143
144 debug!(
145 "Received {} from {} via {}",
146 msg.msg_type(),
147 msg.peer_id(),
148 source
149 );
150 let peer_id = msg.peer_id().to_string();
151 let peer_hash_get = match &msg {
152 SignalingMessage::Hello { hash_get, .. } => Some(*hash_get),
153 _ => None,
154 };
155 shared_router
156 .handle_message(msg)
157 .await
158 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
159 if let Some(hash_get) = peer_hash_get {
160 runtime.set_peer_hash_get(&peer_id, hash_get).await;
161 }
162 remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
163
164 Ok(())
165}
166
167#[allow(clippy::too_many_arguments)]
168pub async fn dispatch_signaling_message<P, S>(
169 signaling_enabled: bool,
170 keys: &Keys,
171 my_peer_id: &PeerId,
172 runtime: &MeshRuntimeState<P>,
173 relay_transport: Option<&S>,
174 local_buses: &[SharedLocalNostrBus],
175 seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
176 seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
177 msg: SignalingMessage,
178 signaling_kind: u64,
179) -> Result<()>
180where
181 P: MeshSession + Clone + Send + Sync + 'static,
182 S: SignalingTransport + Send + Sync + 'static,
183{
184 if !signaling_enabled {
185 debug!(
186 "Skipping signaling message {} because signaling is disabled",
187 msg.msg_type()
188 );
189 return Ok(());
190 }
191
192 let event = create_signaling_event(keys, &msg, signaling_kind).await?;
193
194 for bus in local_buses {
195 if let Err(err) = bus.broadcast_event(&event).await {
196 debug!(
197 "Failed to broadcast signaling event over {} ({}): {}",
198 bus.source_name(),
199 msg.msg_type(),
200 err
201 );
202 }
203 }
204
205 let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
206 if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
207 runtime.record_mesh_duplicate_drop();
208 return Ok(());
209 }
210 if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
211 runtime.record_mesh_duplicate_drop();
212 return Ok(());
213 }
214
215 frame.sender_peer_id = my_peer_id.to_string();
216 let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
217 if forwarded > 0 {
218 runtime.record_mesh_forwarded(forwarded as u64);
219 }
220
221 if let Some(relay_transport) = relay_transport {
222 match tokio::time::timeout(
223 Duration::from_millis(500),
224 relay_transport.publish(msg.clone()),
225 )
226 .await
227 {
228 Ok(Ok(())) => {}
229 Ok(Err(err)) => {
230 debug!(
231 "Failed to publish signaling message {} via relay transport: {}",
232 msg.msg_type(),
233 err
234 );
235 }
236 Err(_) => {
237 debug!(
238 "Timed out publishing signaling message {} via relay transport",
239 msg.msg_type()
240 );
241 }
242 }
243 }
244
245 Ok(())
246}
247
248pub async fn handle_peer_state_event<P, R, F>(
249 runtime: &MeshRuntimeState<P>,
250 event: PeerStateEvent,
251 shared_router: Option<&Arc<MeshRouter<R, F>>>,
252) where
253 P: MeshSession + Send + Sync + 'static,
254 R: SignalingTransport + 'static,
255 F: PeerLinkFactory + 'static,
256{
257 match event {
258 PeerStateEvent::Connected(peer_id) => {
259 let peer_key = peer_id.to_string();
260 let mut emit_hello = false;
261 let mut peers = runtime.peers.write().await;
262 if let Some(entry) = peers.get_mut(&peer_key) {
263 if entry.state != ConnectionState::Connected {
264 info!("Peer {} connected (via state event)", peer_id.short());
265 entry.state = ConnectionState::Connected;
266 emit_hello = true;
267 runtime
268 .connected_count
269 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270 }
271 }
272 drop(peers);
273 if emit_hello {
274 if let Some(shared_router) = shared_router {
275 let _ = shared_router.send_hello(Vec::new()).await;
276 }
277 }
278 }
279 PeerStateEvent::SignalHints(peer_id, signal_urls) => {
280 runtime
281 .record_known_peer_signal_urls(&peer_id.to_string(), &signal_urls, "peer-link")
282 .await;
283 }
284 PeerStateEvent::Failed(peer_id) => {
285 remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
286 }
287 PeerStateEvent::Disconnected(peer_id) => {
288 remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
289 }
290 }
291}
292
293pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
294where
295 P: MeshSession + Send + Sync + 'static,
296{
297 let mut peers = runtime.peers.write().await;
298 let mut connected_count = 0usize;
299 let mut to_remove = Vec::new();
300
301 for (key, entry) in peers.iter_mut() {
302 if let Some(ref peer) = entry.peer {
303 if peer.is_connected() {
304 if entry.state != ConnectionState::Connected {
305 info!(
306 "Peer {} is now connected (sync fallback)",
307 entry.peer_id.short()
308 );
309 entry.state = ConnectionState::Connected;
310 }
311 connected_count += 1;
312 } else if entry.state == ConnectionState::Connected {
313 info!(
314 "Removing disconnected peer {} after transport closed",
315 entry.peer_id.short()
316 );
317 to_remove.push(key.clone());
318 } else if entry.state == ConnectionState::Connecting
319 && entry.last_seen.elapsed() > stale_timeout
320 {
321 info!(
322 "Removing stale peer {} (stuck in Connecting for {:?})",
323 entry.peer_id.short(),
324 entry.last_seen.elapsed()
325 );
326 to_remove.push(key.clone());
327 }
328 } else if entry.state == ConnectionState::Discovered
329 && entry.last_seen.elapsed() > stale_timeout
330 {
331 debug!("Removing stale discovered peer {}", entry.peer_id.short());
332 to_remove.push(key.clone());
333 }
334 }
335
336 let mut removed_peers = Vec::new();
337 for key in to_remove {
338 if let Some(entry) = peers.remove(&key) {
339 removed_peers.push(entry);
340 }
341 }
342 drop(peers);
343
344 for entry in removed_peers {
345 if let Some(peer) = entry.peer {
346 let _ = peer.close().await;
347 }
348 }
349
350 runtime
351 .connected_count
352 .store(connected_count, std::sync::atomic::Ordering::Relaxed);
353}
354
355async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
356 let mut seen = seen.lock().await;
357 seen.insert_if_new(id)
358}
359
360async fn remove_peer_from_runtime<P, R, F>(
361 runtime: &MeshRuntimeState<P>,
362 shared_router: Option<&Arc<MeshRouter<R, F>>>,
363 peer_id: PeerId,
364 reason: &str,
365) where
366 P: MeshSession + Send + Sync + 'static,
367 R: SignalingTransport + 'static,
368 F: PeerLinkFactory + 'static,
369{
370 let peer_key = peer_id.to_string();
371 info!("Peer {} {} - removing from pool", peer_id.short(), reason);
372 let removed = {
373 let mut peers = runtime.peers.write().await;
374 peers.remove(&peer_key)
375 };
376 runtime.clear_peer_hash_get(&peer_key).await;
377 if let Some(entry) = removed {
378 if entry.state == ConnectionState::Connected {
379 runtime
380 .connected_count
381 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
382 }
383 if let Some(peer) = entry.peer {
384 let _ = peer.close().await;
385 }
386 }
387 if let Some(shared_router) = shared_router {
388 if let Some(channel) = shared_router.remove_peer(&peer_key).await {
389 channel.close().await;
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use anyhow::Result as AnyResult;
398 use async_trait::async_trait;
399 use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
400 use std::collections::BTreeSet;
401 use std::sync::atomic::{AtomicBool, Ordering};
402 use std::time::Instant;
403
404 use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
405 use crate::types::{PeerHTLConfig, PeerPool};
406
407 #[derive(Clone)]
408 struct TestSession {
409 connected: bool,
410 close_delay: Duration,
411 closed: Arc<AtomicBool>,
412 }
413
414 #[async_trait]
415 impl MeshSession for TestSession {
416 fn is_ready(&self) -> bool {
417 true
418 }
419
420 fn is_connected(&self) -> bool {
421 self.connected
422 }
423
424 fn htl_config(&self) -> PeerHTLConfig {
425 PeerHTLConfig::from_flags(false, false)
426 }
427
428 async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
429 Ok(None)
430 }
431
432 async fn query_nostr_events(
433 &self,
434 _filters: Vec<Filter>,
435 _timeout: Duration,
436 ) -> AnyResult<Vec<Event>> {
437 Ok(Vec::new())
438 }
439
440 async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
441 Ok(())
442 }
443
444 async fn close(&self) -> AnyResult<()> {
445 if !self.close_delay.is_zero() {
446 tokio::time::sleep(self.close_delay).await;
447 }
448 self.closed.store(true, Ordering::Relaxed);
449 Ok(())
450 }
451 }
452
453 #[test]
454 fn can_track_source_peer_respects_optional_limits() {
455 let peer_id = PeerId::new("peer-a".to_string());
456 let peer_key = peer_id.to_string();
457 let mut peers = HashMap::new();
458 peers.insert(
459 peer_key.clone(),
460 MeshPeerEntry {
461 peer_id,
462 direction: PeerDirection::Outbound,
463 state: ConnectionState::Discovered,
464 last_seen: Instant::now(),
465 peer: None::<TestSession>,
466 pool: PeerPool::Other,
467 transport: PeerTransport::WebRtc,
468 signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
469 bytes_sent: 0,
470 bytes_received: 0,
471 },
472 );
473
474 assert!(can_track_source_peer("relay", "peer-b", &peers, None));
475 assert!(can_track_source_peer(
476 "wifi-aware",
477 &peer_key,
478 &peers,
479 Some(1)
480 ));
481 assert!(!can_track_source_peer(
482 "wifi-aware",
483 "peer-b",
484 &peers,
485 Some(1),
486 ));
487 }
488
489 #[tokio::test]
490 async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
491 let runtime = MeshRuntimeState::<TestSession>::new();
492 let stale_id = PeerId::new("peer-stale".to_string());
493 runtime.peers.write().await.insert(
494 stale_id.to_string(),
495 MeshPeerEntry {
496 peer_id: stale_id,
497 direction: PeerDirection::Outbound,
498 state: ConnectionState::Discovered,
499 last_seen: Instant::now() - Duration::from_secs(120),
500 peer: None,
501 pool: PeerPool::Other,
502 transport: PeerTransport::WebRtc,
503 signal_paths: BTreeSet::new(),
504 bytes_sent: 0,
505 bytes_received: 0,
506 },
507 );
508
509 let active_id = PeerId::new("peer-active".to_string());
510 runtime.peers.write().await.insert(
511 active_id.to_string(),
512 MeshPeerEntry {
513 peer_id: active_id.clone(),
514 direction: PeerDirection::Outbound,
515 state: ConnectionState::Connecting,
516 last_seen: Instant::now(),
517 peer: Some(TestSession {
518 connected: true,
519 close_delay: Duration::ZERO,
520 closed: Arc::new(AtomicBool::new(false)),
521 }),
522 pool: PeerPool::Other,
523 transport: PeerTransport::WebRtc,
524 signal_paths: BTreeSet::new(),
525 bytes_sent: 0,
526 bytes_received: 0,
527 },
528 );
529
530 cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
531
532 let peers = runtime.peers.read().await;
533 assert!(!peers.contains_key("peer-stale"));
534 assert_eq!(
535 peers.get(&active_id.to_string()).unwrap().state,
536 ConnectionState::Connected
537 );
538 assert_eq!(
539 runtime
540 .connected_count
541 .load(std::sync::atomic::Ordering::Relaxed),
542 1
543 );
544 }
545
546 #[tokio::test]
547 async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
548 let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
549 let peer_id = PeerId::new("peer-a-pub".to_string());
550 runtime.peers.write().await.insert(
551 peer_id.to_string(),
552 MeshPeerEntry {
553 peer_id: peer_id.clone(),
554 direction: PeerDirection::Outbound,
555 state: ConnectionState::Connected,
556 last_seen: Instant::now(),
557 peer: Some(TestSession {
558 connected: false,
559 close_delay: Duration::from_millis(200),
560 closed: Arc::new(AtomicBool::new(false)),
561 }),
562 pool: PeerPool::Other,
563 transport: PeerTransport::Bluetooth,
564 signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
565 bytes_sent: 0,
566 bytes_received: 0,
567 },
568 );
569
570 let runtime_for_task = runtime.clone();
571 let peer_id_for_task = peer_id.clone();
572 let cleanup_task = tokio::spawn(async move {
573 handle_peer_state_event::<
574 TestSession,
575 crate::mock::MockRelayTransport,
576 crate::mock::MockConnectionFactory,
577 >(
578 runtime_for_task.as_ref(),
579 PeerStateEvent::Failed(peer_id_for_task),
580 None,
581 )
582 .await;
583 });
584
585 tokio::time::sleep(Duration::from_millis(20)).await;
586
587 let remaining = tokio::time::timeout(Duration::from_millis(50), async {
588 runtime.peers.read().await.len()
589 })
590 .await
591 .expect("peer map read should not block on close");
592
593 assert_eq!(remaining, 0);
594 cleanup_task.await.expect("cleanup task");
595 }
596
597 #[tokio::test]
598 async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
599 let runtime = MeshRuntimeState::<TestSession>::new();
600 let closed = Arc::new(AtomicBool::new(false));
601 let peer_id = PeerId::new("peer-a".to_string());
602 runtime.peers.write().await.insert(
603 peer_id.to_string(),
604 MeshPeerEntry {
605 peer_id: peer_id.clone(),
606 direction: PeerDirection::Outbound,
607 state: ConnectionState::Connected,
608 last_seen: Instant::now(),
609 peer: Some(TestSession {
610 connected: true,
611 close_delay: Duration::ZERO,
612 closed: closed.clone(),
613 }),
614 pool: PeerPool::Other,
615 transport: PeerTransport::WebRtc,
616 signal_paths: BTreeSet::new(),
617 bytes_sent: 0,
618 bytes_received: 0,
619 },
620 );
621 let keys = Keys::generate();
622 let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
623 .to_event(&keys)
624 .unwrap();
625 let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
626
627 let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
628 assert_eq!(forwarded, 1);
629 assert!(!closed.load(Ordering::Relaxed));
630 }
631}