1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13 can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14 PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23 Connected(PeerId),
24 Failed(PeerId),
25 Disconnected(PeerId),
26}
27
28pub fn can_track_source_peer<P>(
29 source: &str,
30 peer_key: &str,
31 peers: &HashMap<String, MeshPeerEntry<P>>,
32 max_peers: Option<usize>,
33) -> bool {
34 match max_peers {
35 Some(max_peers) => can_track_signal_path_peer(
36 PeerSignalPath::from_source_name(source),
37 max_peers,
38 peer_key,
39 peers,
40 ),
41 None => true,
42 }
43}
44
45pub async fn forward_mesh_frame_from_runtime<P>(
46 runtime: &MeshRuntimeState<P>,
47 frame: &MeshNostrFrame,
48 exclude_peer_id: Option<&str>,
49) -> usize
50where
51 P: MeshSession + Clone + Send + Sync + 'static,
52{
53 let peers = runtime.peers.read().await;
54 let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
55 .values()
56 .filter(|entry| entry.state == ConnectionState::Connected)
57 .filter_map(|entry| {
58 entry.peer.as_ref().map(|peer| {
59 (
60 entry.peer_id.to_string(),
61 Arc::new(peer.clone()) as Arc<dyn MeshSession>,
62 )
63 })
64 })
65 .collect();
66 drop(peers);
67
68 forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
69}
70
71pub async fn create_signaling_event(
72 keys: &Keys,
73 msg: &SignalingMessage,
74 signaling_kind: u64,
75) -> Result<Event> {
76 encode_signaling_event(
77 keys,
78 msg.peer_id(),
79 msg,
80 Kind::Ephemeral(signaling_kind as u16),
81 )
82 .map_err(|e| anyhow::anyhow!(e.to_string()))
83}
84
85pub async fn handle_signaling_event<P, R, F>(
86 signaling_enabled: bool,
87 my_peer_id: &PeerId,
88 keys: &Keys,
89 runtime: &MeshRuntimeState<P>,
90 source: &str,
91 source_max_peers: Option<usize>,
92 event: &Event,
93 shared_router: Option<&Arc<MeshRouter<R, F>>>,
94) -> Result<()>
95where
96 P: MeshSession + Send + Sync + 'static,
97 R: SignalingTransport + 'static,
98 F: PeerLinkFactory + 'static,
99{
100 if !signaling_enabled {
101 return Ok(());
102 }
103
104 let Some(msg) = decode_signaling_event(
105 event,
106 &my_peer_id.to_string(),
107 &keys.public_key().to_hex(),
108 keys,
109 ) else {
110 return Ok(());
111 };
112
113 handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
114}
115
116pub async fn handle_signaling_message<P, R, F>(
117 runtime: &MeshRuntimeState<P>,
118 source: &str,
119 source_max_peers: Option<usize>,
120 msg: SignalingMessage,
121 shared_router: Option<&Arc<MeshRouter<R, F>>>,
122) -> Result<()>
123where
124 P: MeshSession + Send + Sync + 'static,
125 R: SignalingTransport + 'static,
126 F: PeerLinkFactory + 'static,
127{
128 let Some(shared_router) = shared_router else {
129 return Ok(());
130 };
131
132 if matches!(
133 msg,
134 SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
135 ) {
136 let peers = runtime.peers.read().await;
137 if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
138 return Ok(());
139 }
140 }
141
142 debug!(
143 "Received {} from {} via {}",
144 msg.msg_type(),
145 msg.peer_id(),
146 source
147 );
148 let peer_id = msg.peer_id().to_string();
149 let peer_hash_get = match &msg {
150 SignalingMessage::Hello { hash_get, .. } => Some(*hash_get),
151 _ => None,
152 };
153 shared_router
154 .handle_message(msg)
155 .await
156 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
157 if let Some(hash_get) = peer_hash_get {
158 runtime.set_peer_hash_get(&peer_id, hash_get).await;
159 }
160 remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
161
162 Ok(())
163}
164
165pub async fn dispatch_signaling_message<P, S>(
166 signaling_enabled: bool,
167 keys: &Keys,
168 my_peer_id: &PeerId,
169 runtime: &MeshRuntimeState<P>,
170 relay_transport: Option<&S>,
171 local_buses: &[SharedLocalNostrBus],
172 seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
173 seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
174 msg: SignalingMessage,
175 signaling_kind: u64,
176) -> Result<()>
177where
178 P: MeshSession + Clone + Send + Sync + 'static,
179 S: SignalingTransport + Send + Sync + 'static,
180{
181 if !signaling_enabled {
182 debug!(
183 "Skipping signaling message {} because signaling is disabled",
184 msg.msg_type()
185 );
186 return Ok(());
187 }
188
189 if let Some(relay_transport) = relay_transport {
190 if let Err(err) = relay_transport.publish(msg.clone()).await {
191 debug!(
192 "Failed to publish signaling message {} via relay transport: {}",
193 msg.msg_type(),
194 err
195 );
196 }
197 }
198
199 let event = create_signaling_event(keys, &msg, signaling_kind).await?;
200
201 for bus in local_buses {
202 if let Err(err) = bus.broadcast_event(&event).await {
203 debug!(
204 "Failed to broadcast signaling event over {} ({}): {}",
205 bus.source_name(),
206 msg.msg_type(),
207 err
208 );
209 }
210 }
211
212 let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
213 if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
214 runtime.record_mesh_duplicate_drop();
215 return Ok(());
216 }
217 if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
218 runtime.record_mesh_duplicate_drop();
219 return Ok(());
220 }
221
222 frame.sender_peer_id = my_peer_id.to_string();
223 let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
224 if forwarded > 0 {
225 runtime.record_mesh_forwarded(forwarded as u64);
226 }
227
228 Ok(())
229}
230
231pub async fn handle_peer_state_event<P, R, F>(
232 runtime: &MeshRuntimeState<P>,
233 event: PeerStateEvent,
234 shared_router: Option<&Arc<MeshRouter<R, F>>>,
235) where
236 P: MeshSession + Send + Sync + 'static,
237 R: SignalingTransport + 'static,
238 F: PeerLinkFactory + 'static,
239{
240 match event {
241 PeerStateEvent::Connected(peer_id) => {
242 let peer_key = peer_id.to_string();
243 let mut emit_hello = false;
244 let mut peers = runtime.peers.write().await;
245 if let Some(entry) = peers.get_mut(&peer_key) {
246 if entry.state != ConnectionState::Connected {
247 info!("Peer {} connected (via state event)", peer_id.short());
248 entry.state = ConnectionState::Connected;
249 emit_hello = true;
250 runtime
251 .connected_count
252 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253 }
254 }
255 drop(peers);
256 if emit_hello {
257 if let Some(shared_router) = shared_router {
258 let _ = shared_router.send_hello(Vec::new()).await;
259 }
260 }
261 }
262 PeerStateEvent::Failed(peer_id) => {
263 remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
264 }
265 PeerStateEvent::Disconnected(peer_id) => {
266 remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
267 }
268 }
269}
270
271pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
272where
273 P: MeshSession + Send + Sync + 'static,
274{
275 let mut peers = runtime.peers.write().await;
276 let mut connected_count = 0usize;
277 let mut to_remove = Vec::new();
278
279 for (key, entry) in peers.iter_mut() {
280 if let Some(ref peer) = entry.peer {
281 if peer.is_connected() {
282 if entry.state != ConnectionState::Connected {
283 info!(
284 "Peer {} is now connected (sync fallback)",
285 entry.peer_id.short()
286 );
287 entry.state = ConnectionState::Connected;
288 }
289 connected_count += 1;
290 } else if entry.state == ConnectionState::Connected {
291 info!(
292 "Removing disconnected peer {} after transport closed",
293 entry.peer_id.short()
294 );
295 to_remove.push(key.clone());
296 } else if entry.state == ConnectionState::Connecting
297 && entry.last_seen.elapsed() > stale_timeout
298 {
299 info!(
300 "Removing stale peer {} (stuck in Connecting for {:?})",
301 entry.peer_id.short(),
302 entry.last_seen.elapsed()
303 );
304 to_remove.push(key.clone());
305 }
306 } else if entry.state == ConnectionState::Discovered
307 && entry.last_seen.elapsed() > stale_timeout
308 {
309 debug!("Removing stale discovered peer {}", entry.peer_id.short());
310 to_remove.push(key.clone());
311 }
312 }
313
314 let mut removed_peers = Vec::new();
315 for key in to_remove {
316 if let Some(entry) = peers.remove(&key) {
317 removed_peers.push(entry);
318 }
319 }
320 drop(peers);
321
322 for entry in removed_peers {
323 if let Some(peer) = entry.peer {
324 let _ = peer.close().await;
325 }
326 }
327
328 runtime
329 .connected_count
330 .store(connected_count, std::sync::atomic::Ordering::Relaxed);
331}
332
333async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
334 let mut seen = seen.lock().await;
335 seen.insert_if_new(id)
336}
337
338async fn remove_peer_from_runtime<P, R, F>(
339 runtime: &MeshRuntimeState<P>,
340 shared_router: Option<&Arc<MeshRouter<R, F>>>,
341 peer_id: PeerId,
342 reason: &str,
343) where
344 P: MeshSession + Send + Sync + 'static,
345 R: SignalingTransport + 'static,
346 F: PeerLinkFactory + 'static,
347{
348 let peer_key = peer_id.to_string();
349 info!("Peer {} {} - removing from pool", peer_id.short(), reason);
350 let removed = {
351 let mut peers = runtime.peers.write().await;
352 peers.remove(&peer_key)
353 };
354 runtime.clear_peer_hash_get(&peer_key).await;
355 if let Some(entry) = removed {
356 if entry.state == ConnectionState::Connected {
357 runtime
358 .connected_count
359 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
360 }
361 if let Some(peer) = entry.peer {
362 let _ = peer.close().await;
363 }
364 }
365 if let Some(shared_router) = shared_router {
366 if let Some(channel) = shared_router.remove_peer(&peer_key).await {
367 channel.close().await;
368 }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use anyhow::Result as AnyResult;
376 use async_trait::async_trait;
377 use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
378 use std::collections::BTreeSet;
379 use std::sync::atomic::{AtomicBool, Ordering};
380 use std::time::Instant;
381
382 use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
383 use crate::types::{PeerHTLConfig, PeerPool};
384
385 #[derive(Clone)]
386 struct TestSession {
387 connected: bool,
388 close_delay: Duration,
389 closed: Arc<AtomicBool>,
390 }
391
392 #[async_trait]
393 impl MeshSession for TestSession {
394 fn is_ready(&self) -> bool {
395 true
396 }
397
398 fn is_connected(&self) -> bool {
399 self.connected
400 }
401
402 fn htl_config(&self) -> PeerHTLConfig {
403 PeerHTLConfig::from_flags(false, false)
404 }
405
406 async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
407 Ok(None)
408 }
409
410 async fn query_nostr_events(
411 &self,
412 _filters: Vec<Filter>,
413 _timeout: Duration,
414 ) -> AnyResult<Vec<Event>> {
415 Ok(Vec::new())
416 }
417
418 async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
419 Ok(())
420 }
421
422 async fn close(&self) -> AnyResult<()> {
423 if !self.close_delay.is_zero() {
424 tokio::time::sleep(self.close_delay).await;
425 }
426 self.closed.store(true, Ordering::Relaxed);
427 Ok(())
428 }
429 }
430
431 #[test]
432 fn can_track_source_peer_respects_optional_limits() {
433 let peer_id = PeerId::new("peer-a".to_string());
434 let peer_key = peer_id.to_string();
435 let mut peers = HashMap::new();
436 peers.insert(
437 peer_key.clone(),
438 MeshPeerEntry {
439 peer_id,
440 direction: PeerDirection::Outbound,
441 state: ConnectionState::Discovered,
442 last_seen: Instant::now(),
443 peer: None::<TestSession>,
444 pool: PeerPool::Other,
445 transport: PeerTransport::WebRtc,
446 signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
447 bytes_sent: 0,
448 bytes_received: 0,
449 },
450 );
451
452 assert!(can_track_source_peer("relay", "peer-b", &peers, None));
453 assert!(can_track_source_peer(
454 "wifi-aware",
455 &peer_key,
456 &peers,
457 Some(1)
458 ));
459 assert!(!can_track_source_peer(
460 "wifi-aware",
461 "peer-b",
462 &peers,
463 Some(1),
464 ));
465 }
466
467 #[tokio::test]
468 async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
469 let runtime = MeshRuntimeState::<TestSession>::new();
470 let stale_id = PeerId::new("peer-stale".to_string());
471 runtime.peers.write().await.insert(
472 stale_id.to_string(),
473 MeshPeerEntry {
474 peer_id: stale_id,
475 direction: PeerDirection::Outbound,
476 state: ConnectionState::Discovered,
477 last_seen: Instant::now() - Duration::from_secs(120),
478 peer: None,
479 pool: PeerPool::Other,
480 transport: PeerTransport::WebRtc,
481 signal_paths: BTreeSet::new(),
482 bytes_sent: 0,
483 bytes_received: 0,
484 },
485 );
486
487 let active_id = PeerId::new("peer-active".to_string());
488 runtime.peers.write().await.insert(
489 active_id.to_string(),
490 MeshPeerEntry {
491 peer_id: active_id.clone(),
492 direction: PeerDirection::Outbound,
493 state: ConnectionState::Connecting,
494 last_seen: Instant::now(),
495 peer: Some(TestSession {
496 connected: true,
497 close_delay: Duration::ZERO,
498 closed: Arc::new(AtomicBool::new(false)),
499 }),
500 pool: PeerPool::Other,
501 transport: PeerTransport::WebRtc,
502 signal_paths: BTreeSet::new(),
503 bytes_sent: 0,
504 bytes_received: 0,
505 },
506 );
507
508 cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
509
510 let peers = runtime.peers.read().await;
511 assert!(!peers.contains_key("peer-stale"));
512 assert_eq!(
513 peers.get(&active_id.to_string()).unwrap().state,
514 ConnectionState::Connected
515 );
516 assert_eq!(
517 runtime
518 .connected_count
519 .load(std::sync::atomic::Ordering::Relaxed),
520 1
521 );
522 }
523
524 #[tokio::test]
525 async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
526 let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
527 let peer_id = PeerId::new("peer-a-pub".to_string());
528 runtime.peers.write().await.insert(
529 peer_id.to_string(),
530 MeshPeerEntry {
531 peer_id: peer_id.clone(),
532 direction: PeerDirection::Outbound,
533 state: ConnectionState::Connected,
534 last_seen: Instant::now(),
535 peer: Some(TestSession {
536 connected: false,
537 close_delay: Duration::from_millis(200),
538 closed: Arc::new(AtomicBool::new(false)),
539 }),
540 pool: PeerPool::Other,
541 transport: PeerTransport::Bluetooth,
542 signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
543 bytes_sent: 0,
544 bytes_received: 0,
545 },
546 );
547
548 let runtime_for_task = runtime.clone();
549 let peer_id_for_task = peer_id.clone();
550 let cleanup_task = tokio::spawn(async move {
551 handle_peer_state_event::<
552 TestSession,
553 crate::mock::MockRelayTransport,
554 crate::mock::MockConnectionFactory,
555 >(
556 runtime_for_task.as_ref(),
557 PeerStateEvent::Failed(peer_id_for_task),
558 None,
559 )
560 .await;
561 });
562
563 tokio::time::sleep(Duration::from_millis(20)).await;
564
565 let remaining = tokio::time::timeout(Duration::from_millis(50), async {
566 runtime.peers.read().await.len()
567 })
568 .await
569 .expect("peer map read should not block on close");
570
571 assert_eq!(remaining, 0);
572 cleanup_task.await.expect("cleanup task");
573 }
574
575 #[tokio::test]
576 async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
577 let runtime = MeshRuntimeState::<TestSession>::new();
578 let closed = Arc::new(AtomicBool::new(false));
579 let peer_id = PeerId::new("peer-a".to_string());
580 runtime.peers.write().await.insert(
581 peer_id.to_string(),
582 MeshPeerEntry {
583 peer_id: peer_id.clone(),
584 direction: PeerDirection::Outbound,
585 state: ConnectionState::Connected,
586 last_seen: Instant::now(),
587 peer: Some(TestSession {
588 connected: true,
589 close_delay: Duration::ZERO,
590 closed: closed.clone(),
591 }),
592 pool: PeerPool::Other,
593 transport: PeerTransport::WebRtc,
594 signal_paths: BTreeSet::new(),
595 bytes_sent: 0,
596 bytes_received: 0,
597 },
598 );
599 let keys = Keys::generate();
600 let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
601 .to_event(&keys)
602 .unwrap();
603 let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
604
605 let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
606 assert_eq!(forwarded, 1);
607 assert!(!closed.load(Ordering::Relaxed));
608 }
609}