1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13 can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14 PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23 Connected(PeerId),
24 SignalHints(PeerId, Vec<String>),
25 Failed(PeerId),
26 Disconnected(PeerId),
27}
28
29pub fn can_track_source_peer<P>(
30 source: &str,
31 peer_key: &str,
32 peers: &HashMap<String, MeshPeerEntry<P>>,
33 max_peers: Option<usize>,
34) -> bool {
35 match max_peers {
36 Some(max_peers) => can_track_signal_path_peer(
37 PeerSignalPath::from_source_name(source),
38 max_peers,
39 peer_key,
40 peers,
41 ),
42 None => true,
43 }
44}
45
46pub async fn forward_mesh_frame_from_runtime<P>(
47 runtime: &MeshRuntimeState<P>,
48 frame: &MeshNostrFrame,
49 exclude_peer_id: Option<&str>,
50) -> usize
51where
52 P: MeshSession + Clone + Send + Sync + 'static,
53{
54 let peers = runtime.peers.read().await;
55 let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
56 .values()
57 .filter(|entry| entry.state == ConnectionState::Connected)
58 .filter_map(|entry| {
59 entry.peer.as_ref().map(|peer| {
60 (
61 entry.peer_id.to_string(),
62 Arc::new(peer.clone()) as Arc<dyn MeshSession>,
63 )
64 })
65 })
66 .collect();
67 drop(peers);
68
69 forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
70}
71
72pub async fn create_signaling_event(
73 keys: &Keys,
74 msg: &SignalingMessage,
75 signaling_kind: u64,
76) -> Result<Event> {
77 encode_signaling_event(
78 keys,
79 msg.peer_id(),
80 msg,
81 Kind::Ephemeral(signaling_kind as u16),
82 )
83 .map_err(|e| anyhow::anyhow!(e.to_string()))
84}
85
86pub async fn handle_signaling_event<P, R, F>(
87 signaling_enabled: bool,
88 my_peer_id: &PeerId,
89 keys: &Keys,
90 runtime: &MeshRuntimeState<P>,
91 source: &str,
92 source_max_peers: Option<usize>,
93 event: &Event,
94 shared_router: Option<&Arc<MeshRouter<R, F>>>,
95) -> Result<()>
96where
97 P: MeshSession + Send + Sync + 'static,
98 R: SignalingTransport + 'static,
99 F: PeerLinkFactory + 'static,
100{
101 if !signaling_enabled {
102 return Ok(());
103 }
104
105 let Some(msg) = decode_signaling_event(
106 event,
107 &my_peer_id.to_string(),
108 &keys.public_key().to_hex(),
109 keys,
110 ) else {
111 return Ok(());
112 };
113
114 handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
115}
116
117pub async fn handle_signaling_message<P, R, F>(
118 runtime: &MeshRuntimeState<P>,
119 source: &str,
120 source_max_peers: Option<usize>,
121 msg: SignalingMessage,
122 shared_router: Option<&Arc<MeshRouter<R, F>>>,
123) -> Result<()>
124where
125 P: MeshSession + Send + Sync + 'static,
126 R: SignalingTransport + 'static,
127 F: PeerLinkFactory + 'static,
128{
129 let Some(shared_router) = shared_router else {
130 return Ok(());
131 };
132
133 if matches!(
134 msg,
135 SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
136 ) {
137 let peers = runtime.peers.read().await;
138 if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
139 return Ok(());
140 }
141 }
142
143 debug!(
144 "Received {} from {} via {}",
145 msg.msg_type(),
146 msg.peer_id(),
147 source
148 );
149 let peer_id = msg.peer_id().to_string();
150 let peer_hash_get = match &msg {
151 SignalingMessage::Hello { hash_get, .. } => Some(*hash_get),
152 _ => None,
153 };
154 shared_router
155 .handle_message(msg)
156 .await
157 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
158 if let Some(hash_get) = peer_hash_get {
159 runtime.set_peer_hash_get(&peer_id, hash_get).await;
160 }
161 remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
162
163 Ok(())
164}
165
166pub async fn dispatch_signaling_message<P, S>(
167 signaling_enabled: bool,
168 keys: &Keys,
169 my_peer_id: &PeerId,
170 runtime: &MeshRuntimeState<P>,
171 relay_transport: Option<&S>,
172 local_buses: &[SharedLocalNostrBus],
173 seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
174 seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
175 msg: SignalingMessage,
176 signaling_kind: u64,
177) -> Result<()>
178where
179 P: MeshSession + Clone + Send + Sync + 'static,
180 S: SignalingTransport + Send + Sync + 'static,
181{
182 if !signaling_enabled {
183 debug!(
184 "Skipping signaling message {} because signaling is disabled",
185 msg.msg_type()
186 );
187 return Ok(());
188 }
189
190 let event = create_signaling_event(keys, &msg, signaling_kind).await?;
191
192 for bus in local_buses {
193 if let Err(err) = bus.broadcast_event(&event).await {
194 debug!(
195 "Failed to broadcast signaling event over {} ({}): {}",
196 bus.source_name(),
197 msg.msg_type(),
198 err
199 );
200 }
201 }
202
203 let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
204 if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
205 runtime.record_mesh_duplicate_drop();
206 return Ok(());
207 }
208 if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
209 runtime.record_mesh_duplicate_drop();
210 return Ok(());
211 }
212
213 frame.sender_peer_id = my_peer_id.to_string();
214 let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
215 if forwarded > 0 {
216 runtime.record_mesh_forwarded(forwarded as u64);
217 }
218
219 if let Some(relay_transport) = relay_transport {
220 match tokio::time::timeout(
221 Duration::from_millis(500),
222 relay_transport.publish(msg.clone()),
223 )
224 .await
225 {
226 Ok(Ok(())) => {}
227 Ok(Err(err)) => {
228 debug!(
229 "Failed to publish signaling message {} via relay transport: {}",
230 msg.msg_type(),
231 err
232 );
233 }
234 Err(_) => {
235 debug!(
236 "Timed out publishing signaling message {} via relay transport",
237 msg.msg_type()
238 );
239 }
240 }
241 }
242
243 Ok(())
244}
245
246pub async fn handle_peer_state_event<P, R, F>(
247 runtime: &MeshRuntimeState<P>,
248 event: PeerStateEvent,
249 shared_router: Option<&Arc<MeshRouter<R, F>>>,
250) where
251 P: MeshSession + Send + Sync + 'static,
252 R: SignalingTransport + 'static,
253 F: PeerLinkFactory + 'static,
254{
255 match event {
256 PeerStateEvent::Connected(peer_id) => {
257 let peer_key = peer_id.to_string();
258 let mut emit_hello = false;
259 let mut peers = runtime.peers.write().await;
260 if let Some(entry) = peers.get_mut(&peer_key) {
261 if entry.state != ConnectionState::Connected {
262 info!("Peer {} connected (via state event)", peer_id.short());
263 entry.state = ConnectionState::Connected;
264 emit_hello = true;
265 runtime
266 .connected_count
267 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
268 }
269 }
270 drop(peers);
271 if emit_hello {
272 if let Some(shared_router) = shared_router {
273 let _ = shared_router.send_hello(Vec::new()).await;
274 }
275 }
276 }
277 PeerStateEvent::SignalHints(peer_id, signal_urls) => {
278 runtime
279 .record_known_peer_signal_urls(&peer_id.to_string(), &signal_urls, "peer-link")
280 .await;
281 }
282 PeerStateEvent::Failed(peer_id) => {
283 remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
284 }
285 PeerStateEvent::Disconnected(peer_id) => {
286 remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
287 }
288 }
289}
290
291pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
292where
293 P: MeshSession + Send + Sync + 'static,
294{
295 let mut peers = runtime.peers.write().await;
296 let mut connected_count = 0usize;
297 let mut to_remove = Vec::new();
298
299 for (key, entry) in peers.iter_mut() {
300 if let Some(ref peer) = entry.peer {
301 if peer.is_connected() {
302 if entry.state != ConnectionState::Connected {
303 info!(
304 "Peer {} is now connected (sync fallback)",
305 entry.peer_id.short()
306 );
307 entry.state = ConnectionState::Connected;
308 }
309 connected_count += 1;
310 } else if entry.state == ConnectionState::Connected {
311 info!(
312 "Removing disconnected peer {} after transport closed",
313 entry.peer_id.short()
314 );
315 to_remove.push(key.clone());
316 } else if entry.state == ConnectionState::Connecting
317 && entry.last_seen.elapsed() > stale_timeout
318 {
319 info!(
320 "Removing stale peer {} (stuck in Connecting for {:?})",
321 entry.peer_id.short(),
322 entry.last_seen.elapsed()
323 );
324 to_remove.push(key.clone());
325 }
326 } else if entry.state == ConnectionState::Discovered
327 && entry.last_seen.elapsed() > stale_timeout
328 {
329 debug!("Removing stale discovered peer {}", entry.peer_id.short());
330 to_remove.push(key.clone());
331 }
332 }
333
334 let mut removed_peers = Vec::new();
335 for key in to_remove {
336 if let Some(entry) = peers.remove(&key) {
337 removed_peers.push(entry);
338 }
339 }
340 drop(peers);
341
342 for entry in removed_peers {
343 if let Some(peer) = entry.peer {
344 let _ = peer.close().await;
345 }
346 }
347
348 runtime
349 .connected_count
350 .store(connected_count, std::sync::atomic::Ordering::Relaxed);
351}
352
353async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
354 let mut seen = seen.lock().await;
355 seen.insert_if_new(id)
356}
357
358async fn remove_peer_from_runtime<P, R, F>(
359 runtime: &MeshRuntimeState<P>,
360 shared_router: Option<&Arc<MeshRouter<R, F>>>,
361 peer_id: PeerId,
362 reason: &str,
363) where
364 P: MeshSession + Send + Sync + 'static,
365 R: SignalingTransport + 'static,
366 F: PeerLinkFactory + 'static,
367{
368 let peer_key = peer_id.to_string();
369 info!("Peer {} {} - removing from pool", peer_id.short(), reason);
370 let removed = {
371 let mut peers = runtime.peers.write().await;
372 peers.remove(&peer_key)
373 };
374 runtime.clear_peer_hash_get(&peer_key).await;
375 if let Some(entry) = removed {
376 if entry.state == ConnectionState::Connected {
377 runtime
378 .connected_count
379 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
380 }
381 if let Some(peer) = entry.peer {
382 let _ = peer.close().await;
383 }
384 }
385 if let Some(shared_router) = shared_router {
386 if let Some(channel) = shared_router.remove_peer(&peer_key).await {
387 channel.close().await;
388 }
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use anyhow::Result as AnyResult;
396 use async_trait::async_trait;
397 use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
398 use std::collections::BTreeSet;
399 use std::sync::atomic::{AtomicBool, Ordering};
400 use std::time::Instant;
401
402 use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
403 use crate::types::{PeerHTLConfig, PeerPool};
404
405 #[derive(Clone)]
406 struct TestSession {
407 connected: bool,
408 close_delay: Duration,
409 closed: Arc<AtomicBool>,
410 }
411
412 #[async_trait]
413 impl MeshSession for TestSession {
414 fn is_ready(&self) -> bool {
415 true
416 }
417
418 fn is_connected(&self) -> bool {
419 self.connected
420 }
421
422 fn htl_config(&self) -> PeerHTLConfig {
423 PeerHTLConfig::from_flags(false, false)
424 }
425
426 async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
427 Ok(None)
428 }
429
430 async fn query_nostr_events(
431 &self,
432 _filters: Vec<Filter>,
433 _timeout: Duration,
434 ) -> AnyResult<Vec<Event>> {
435 Ok(Vec::new())
436 }
437
438 async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
439 Ok(())
440 }
441
442 async fn close(&self) -> AnyResult<()> {
443 if !self.close_delay.is_zero() {
444 tokio::time::sleep(self.close_delay).await;
445 }
446 self.closed.store(true, Ordering::Relaxed);
447 Ok(())
448 }
449 }
450
451 #[test]
452 fn can_track_source_peer_respects_optional_limits() {
453 let peer_id = PeerId::new("peer-a".to_string());
454 let peer_key = peer_id.to_string();
455 let mut peers = HashMap::new();
456 peers.insert(
457 peer_key.clone(),
458 MeshPeerEntry {
459 peer_id,
460 direction: PeerDirection::Outbound,
461 state: ConnectionState::Discovered,
462 last_seen: Instant::now(),
463 peer: None::<TestSession>,
464 pool: PeerPool::Other,
465 transport: PeerTransport::WebRtc,
466 signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
467 bytes_sent: 0,
468 bytes_received: 0,
469 },
470 );
471
472 assert!(can_track_source_peer("relay", "peer-b", &peers, None));
473 assert!(can_track_source_peer(
474 "wifi-aware",
475 &peer_key,
476 &peers,
477 Some(1)
478 ));
479 assert!(!can_track_source_peer(
480 "wifi-aware",
481 "peer-b",
482 &peers,
483 Some(1),
484 ));
485 }
486
487 #[tokio::test]
488 async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
489 let runtime = MeshRuntimeState::<TestSession>::new();
490 let stale_id = PeerId::new("peer-stale".to_string());
491 runtime.peers.write().await.insert(
492 stale_id.to_string(),
493 MeshPeerEntry {
494 peer_id: stale_id,
495 direction: PeerDirection::Outbound,
496 state: ConnectionState::Discovered,
497 last_seen: Instant::now() - Duration::from_secs(120),
498 peer: None,
499 pool: PeerPool::Other,
500 transport: PeerTransport::WebRtc,
501 signal_paths: BTreeSet::new(),
502 bytes_sent: 0,
503 bytes_received: 0,
504 },
505 );
506
507 let active_id = PeerId::new("peer-active".to_string());
508 runtime.peers.write().await.insert(
509 active_id.to_string(),
510 MeshPeerEntry {
511 peer_id: active_id.clone(),
512 direction: PeerDirection::Outbound,
513 state: ConnectionState::Connecting,
514 last_seen: Instant::now(),
515 peer: Some(TestSession {
516 connected: true,
517 close_delay: Duration::ZERO,
518 closed: Arc::new(AtomicBool::new(false)),
519 }),
520 pool: PeerPool::Other,
521 transport: PeerTransport::WebRtc,
522 signal_paths: BTreeSet::new(),
523 bytes_sent: 0,
524 bytes_received: 0,
525 },
526 );
527
528 cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
529
530 let peers = runtime.peers.read().await;
531 assert!(!peers.contains_key("peer-stale"));
532 assert_eq!(
533 peers.get(&active_id.to_string()).unwrap().state,
534 ConnectionState::Connected
535 );
536 assert_eq!(
537 runtime
538 .connected_count
539 .load(std::sync::atomic::Ordering::Relaxed),
540 1
541 );
542 }
543
544 #[tokio::test]
545 async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
546 let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
547 let peer_id = PeerId::new("peer-a-pub".to_string());
548 runtime.peers.write().await.insert(
549 peer_id.to_string(),
550 MeshPeerEntry {
551 peer_id: peer_id.clone(),
552 direction: PeerDirection::Outbound,
553 state: ConnectionState::Connected,
554 last_seen: Instant::now(),
555 peer: Some(TestSession {
556 connected: false,
557 close_delay: Duration::from_millis(200),
558 closed: Arc::new(AtomicBool::new(false)),
559 }),
560 pool: PeerPool::Other,
561 transport: PeerTransport::Bluetooth,
562 signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
563 bytes_sent: 0,
564 bytes_received: 0,
565 },
566 );
567
568 let runtime_for_task = runtime.clone();
569 let peer_id_for_task = peer_id.clone();
570 let cleanup_task = tokio::spawn(async move {
571 handle_peer_state_event::<
572 TestSession,
573 crate::mock::MockRelayTransport,
574 crate::mock::MockConnectionFactory,
575 >(
576 runtime_for_task.as_ref(),
577 PeerStateEvent::Failed(peer_id_for_task),
578 None,
579 )
580 .await;
581 });
582
583 tokio::time::sleep(Duration::from_millis(20)).await;
584
585 let remaining = tokio::time::timeout(Duration::from_millis(50), async {
586 runtime.peers.read().await.len()
587 })
588 .await
589 .expect("peer map read should not block on close");
590
591 assert_eq!(remaining, 0);
592 cleanup_task.await.expect("cleanup task");
593 }
594
595 #[tokio::test]
596 async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
597 let runtime = MeshRuntimeState::<TestSession>::new();
598 let closed = Arc::new(AtomicBool::new(false));
599 let peer_id = PeerId::new("peer-a".to_string());
600 runtime.peers.write().await.insert(
601 peer_id.to_string(),
602 MeshPeerEntry {
603 peer_id: peer_id.clone(),
604 direction: PeerDirection::Outbound,
605 state: ConnectionState::Connected,
606 last_seen: Instant::now(),
607 peer: Some(TestSession {
608 connected: true,
609 close_delay: Duration::ZERO,
610 closed: closed.clone(),
611 }),
612 pool: PeerPool::Other,
613 transport: PeerTransport::WebRtc,
614 signal_paths: BTreeSet::new(),
615 bytes_sent: 0,
616 bytes_received: 0,
617 },
618 );
619 let keys = Keys::generate();
620 let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
621 .to_event(&keys)
622 .unwrap();
623 let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
624
625 let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
626 assert_eq!(forwarded, 1);
627 assert!(!closed.load(Ordering::Relaxed));
628 }
629}