1use anyhow::Result;
2use nostr_sdk::nostr::{Event, Keys, Kind};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::Mutex;
7use tracing::{debug, info};
8
9use crate::local_bus::SharedLocalNostrBus;
10use crate::mesh_session::{forward_mesh_frame_to_sessions, MeshSession};
11use crate::nostr::{decode_signaling_event, encode_signaling_event};
12use crate::runtime_peer::{
13 can_track_signal_path_peer, remember_peer_signal_path, ConnectionState, MeshPeerEntry,
14 PeerSignalPath,
15};
16use crate::runtime_state::MeshRuntimeState;
17use crate::signaling::MeshRouter;
18use crate::transport::{PeerLinkFactory, SignalingTransport};
19use crate::types::{MeshNostrFrame, PeerId, SignalingMessage, TimedSeenSet, MESH_DEFAULT_HTL};
20
21#[derive(Debug, Clone)]
22pub enum PeerStateEvent {
23 Connected(PeerId),
24 Failed(PeerId),
25 Disconnected(PeerId),
26}
27
28pub fn can_track_source_peer<P>(
29 source: &str,
30 peer_key: &str,
31 peers: &HashMap<String, MeshPeerEntry<P>>,
32 max_peers: Option<usize>,
33) -> bool {
34 match max_peers {
35 Some(max_peers) => can_track_signal_path_peer(
36 PeerSignalPath::from_source_name(source),
37 max_peers,
38 peer_key,
39 peers,
40 ),
41 None => true,
42 }
43}
44
45pub async fn forward_mesh_frame_from_runtime<P>(
46 runtime: &MeshRuntimeState<P>,
47 frame: &MeshNostrFrame,
48 exclude_peer_id: Option<&str>,
49) -> usize
50where
51 P: MeshSession + Clone + Send + Sync + 'static,
52{
53 let peers = runtime.peers.read().await;
54 let peer_refs: Vec<(String, Arc<dyn MeshSession>)> = peers
55 .values()
56 .filter(|entry| entry.state == ConnectionState::Connected)
57 .filter_map(|entry| {
58 entry.peer.as_ref().map(|peer| {
59 (
60 entry.peer_id.to_string(),
61 Arc::new(peer.clone()) as Arc<dyn MeshSession>,
62 )
63 })
64 })
65 .collect();
66 drop(peers);
67
68 forward_mesh_frame_to_sessions(peer_refs, frame, exclude_peer_id).await
69}
70
71pub async fn create_signaling_event(
72 keys: &Keys,
73 msg: &SignalingMessage,
74 signaling_kind: u64,
75) -> Result<Event> {
76 encode_signaling_event(
77 keys,
78 msg.peer_id(),
79 msg,
80 Kind::Ephemeral(signaling_kind as u16),
81 )
82 .map_err(|e| anyhow::anyhow!(e.to_string()))
83}
84
85pub async fn handle_signaling_event<P, R, F>(
86 signaling_enabled: bool,
87 my_peer_id: &PeerId,
88 keys: &Keys,
89 runtime: &MeshRuntimeState<P>,
90 source: &str,
91 source_max_peers: Option<usize>,
92 event: &Event,
93 shared_router: Option<&Arc<MeshRouter<R, F>>>,
94) -> Result<()>
95where
96 P: MeshSession + Send + Sync + 'static,
97 R: SignalingTransport + 'static,
98 F: PeerLinkFactory + 'static,
99{
100 if !signaling_enabled {
101 return Ok(());
102 }
103
104 let Some(msg) = decode_signaling_event(
105 event,
106 &my_peer_id.to_string(),
107 &keys.public_key().to_hex(),
108 keys,
109 ) else {
110 return Ok(());
111 };
112
113 handle_signaling_message(runtime, source, source_max_peers, msg, shared_router).await
114}
115
116pub async fn handle_signaling_message<P, R, F>(
117 runtime: &MeshRuntimeState<P>,
118 source: &str,
119 source_max_peers: Option<usize>,
120 msg: SignalingMessage,
121 shared_router: Option<&Arc<MeshRouter<R, F>>>,
122) -> Result<()>
123where
124 P: MeshSession + Send + Sync + 'static,
125 R: SignalingTransport + 'static,
126 F: PeerLinkFactory + 'static,
127{
128 let Some(shared_router) = shared_router else {
129 return Ok(());
130 };
131
132 if matches!(
133 msg,
134 SignalingMessage::Hello { .. } | SignalingMessage::Offer { .. }
135 ) {
136 let peers = runtime.peers.read().await;
137 if !can_track_source_peer(source, msg.peer_id(), &peers, source_max_peers) {
138 return Ok(());
139 }
140 }
141
142 debug!(
143 "Received {} from {} via {}",
144 msg.msg_type(),
145 msg.peer_id(),
146 source
147 );
148 let peer_id = msg.peer_id().to_string();
149 shared_router
150 .handle_message(msg)
151 .await
152 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
153 remember_peer_signal_path(runtime.peers.as_ref(), &peer_id, source).await;
154
155 Ok(())
156}
157
158pub async fn dispatch_signaling_message<P, S>(
159 signaling_enabled: bool,
160 keys: &Keys,
161 my_peer_id: &PeerId,
162 runtime: &MeshRuntimeState<P>,
163 relay_transport: Option<&S>,
164 local_buses: &[SharedLocalNostrBus],
165 seen_frame_ids: &Arc<Mutex<TimedSeenSet>>,
166 seen_event_ids: &Arc<Mutex<TimedSeenSet>>,
167 msg: SignalingMessage,
168 signaling_kind: u64,
169) -> Result<()>
170where
171 P: MeshSession + Clone + Send + Sync + 'static,
172 S: SignalingTransport + Send + Sync + 'static,
173{
174 if !signaling_enabled {
175 debug!(
176 "Skipping signaling message {} because signaling is disabled",
177 msg.msg_type()
178 );
179 return Ok(());
180 }
181
182 if let Some(relay_transport) = relay_transport {
183 if let Err(err) = relay_transport.publish(msg.clone()).await {
184 debug!(
185 "Failed to publish signaling message {} via relay transport: {}",
186 msg.msg_type(),
187 err
188 );
189 }
190 }
191
192 let event = create_signaling_event(keys, &msg, signaling_kind).await?;
193
194 for bus in local_buses {
195 if let Err(err) = bus.broadcast_event(&event).await {
196 debug!(
197 "Failed to broadcast signaling event over {} ({}): {}",
198 bus.source_name(),
199 msg.msg_type(),
200 err
201 );
202 }
203 }
204
205 let mut frame = MeshNostrFrame::new_event(event, &my_peer_id.to_string(), MESH_DEFAULT_HTL);
206 if !mark_seen(seen_frame_ids, frame.frame_id.clone()).await {
207 runtime.record_mesh_duplicate_drop();
208 return Ok(());
209 }
210 if !mark_seen(seen_event_ids, frame.event().id.to_hex()).await {
211 runtime.record_mesh_duplicate_drop();
212 return Ok(());
213 }
214
215 frame.sender_peer_id = my_peer_id.to_string();
216 let forwarded = forward_mesh_frame_from_runtime(runtime, &frame, None).await;
217 if forwarded > 0 {
218 runtime.record_mesh_forwarded(forwarded as u64);
219 }
220
221 Ok(())
222}
223
224pub async fn handle_peer_state_event<P, R, F>(
225 runtime: &MeshRuntimeState<P>,
226 event: PeerStateEvent,
227 shared_router: Option<&Arc<MeshRouter<R, F>>>,
228) where
229 P: MeshSession + Send + Sync + 'static,
230 R: SignalingTransport + 'static,
231 F: PeerLinkFactory + 'static,
232{
233 match event {
234 PeerStateEvent::Connected(peer_id) => {
235 let peer_key = peer_id.to_string();
236 let mut emit_hello = false;
237 let mut peers = runtime.peers.write().await;
238 if let Some(entry) = peers.get_mut(&peer_key) {
239 if entry.state != ConnectionState::Connected {
240 info!("Peer {} connected (via state event)", peer_id.short());
241 entry.state = ConnectionState::Connected;
242 emit_hello = true;
243 runtime
244 .connected_count
245 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
246 }
247 }
248 drop(peers);
249 if emit_hello {
250 if let Some(shared_router) = shared_router {
251 let _ = shared_router.send_hello(Vec::new()).await;
252 }
253 }
254 }
255 PeerStateEvent::Failed(peer_id) => {
256 remove_peer_from_runtime(runtime, shared_router, peer_id, "connection failed").await;
257 }
258 PeerStateEvent::Disconnected(peer_id) => {
259 remove_peer_from_runtime(runtime, shared_router, peer_id, "disconnected").await;
260 }
261 }
262}
263
264pub async fn cleanup_stale_peers<P>(runtime: &MeshRuntimeState<P>, stale_timeout: Duration)
265where
266 P: MeshSession + Send + Sync + 'static,
267{
268 let mut peers = runtime.peers.write().await;
269 let mut connected_count = 0usize;
270 let mut to_remove = Vec::new();
271
272 for (key, entry) in peers.iter_mut() {
273 if let Some(ref peer) = entry.peer {
274 if peer.is_connected() {
275 if entry.state != ConnectionState::Connected {
276 info!(
277 "Peer {} is now connected (sync fallback)",
278 entry.peer_id.short()
279 );
280 entry.state = ConnectionState::Connected;
281 }
282 connected_count += 1;
283 } else if entry.state == ConnectionState::Connected {
284 info!(
285 "Removing disconnected peer {} after transport closed",
286 entry.peer_id.short()
287 );
288 to_remove.push(key.clone());
289 } else if entry.state == ConnectionState::Connecting
290 && entry.last_seen.elapsed() > stale_timeout
291 {
292 info!(
293 "Removing stale peer {} (stuck in Connecting for {:?})",
294 entry.peer_id.short(),
295 entry.last_seen.elapsed()
296 );
297 to_remove.push(key.clone());
298 }
299 } else if entry.state == ConnectionState::Discovered
300 && entry.last_seen.elapsed() > stale_timeout
301 {
302 debug!("Removing stale discovered peer {}", entry.peer_id.short());
303 to_remove.push(key.clone());
304 }
305 }
306
307 let mut removed_peers = Vec::new();
308 for key in to_remove {
309 if let Some(entry) = peers.remove(&key) {
310 removed_peers.push(entry);
311 }
312 }
313 drop(peers);
314
315 for entry in removed_peers {
316 if let Some(peer) = entry.peer {
317 let _ = peer.close().await;
318 }
319 }
320
321 runtime
322 .connected_count
323 .store(connected_count, std::sync::atomic::Ordering::Relaxed);
324}
325
326async fn mark_seen(seen: &Arc<Mutex<TimedSeenSet>>, id: String) -> bool {
327 let mut seen = seen.lock().await;
328 seen.insert_if_new(id)
329}
330
331async fn remove_peer_from_runtime<P, R, F>(
332 runtime: &MeshRuntimeState<P>,
333 shared_router: Option<&Arc<MeshRouter<R, F>>>,
334 peer_id: PeerId,
335 reason: &str,
336) where
337 P: MeshSession + Send + Sync + 'static,
338 R: SignalingTransport + 'static,
339 F: PeerLinkFactory + 'static,
340{
341 let peer_key = peer_id.to_string();
342 info!("Peer {} {} - removing from pool", peer_id.short(), reason);
343 let removed = {
344 let mut peers = runtime.peers.write().await;
345 peers.remove(&peer_key)
346 };
347 if let Some(entry) = removed {
348 if entry.state == ConnectionState::Connected {
349 runtime
350 .connected_count
351 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
352 }
353 if let Some(peer) = entry.peer {
354 let _ = peer.close().await;
355 }
356 }
357 if let Some(shared_router) = shared_router {
358 if let Some(channel) = shared_router.remove_peer(&peer_key).await {
359 channel.close().await;
360 }
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use anyhow::Result as AnyResult;
368 use async_trait::async_trait;
369 use nostr_sdk::nostr::{EventBuilder, Filter, Kind};
370 use std::collections::BTreeSet;
371 use std::sync::atomic::{AtomicBool, Ordering};
372 use std::time::Instant;
373
374 use crate::runtime_peer::{MeshPeerEntry, PeerDirection, PeerTransport};
375 use crate::types::{PeerHTLConfig, PeerPool};
376
377 #[derive(Clone)]
378 struct TestSession {
379 connected: bool,
380 close_delay: Duration,
381 closed: Arc<AtomicBool>,
382 }
383
384 #[async_trait]
385 impl MeshSession for TestSession {
386 fn is_ready(&self) -> bool {
387 true
388 }
389
390 fn is_connected(&self) -> bool {
391 self.connected
392 }
393
394 fn htl_config(&self) -> PeerHTLConfig {
395 PeerHTLConfig::from_flags(false, false)
396 }
397
398 async fn request(&self, _hash_hex: &str, _timeout: Duration) -> AnyResult<Option<Vec<u8>>> {
399 Ok(None)
400 }
401
402 async fn query_nostr_events(
403 &self,
404 _filters: Vec<Filter>,
405 _timeout: Duration,
406 ) -> AnyResult<Vec<Event>> {
407 Ok(Vec::new())
408 }
409
410 async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> AnyResult<()> {
411 Ok(())
412 }
413
414 async fn close(&self) -> AnyResult<()> {
415 if !self.close_delay.is_zero() {
416 tokio::time::sleep(self.close_delay).await;
417 }
418 self.closed.store(true, Ordering::Relaxed);
419 Ok(())
420 }
421 }
422
423 #[test]
424 fn can_track_source_peer_respects_optional_limits() {
425 let peer_id = PeerId::new("peer-a".to_string());
426 let peer_key = peer_id.to_string();
427 let mut peers = HashMap::new();
428 peers.insert(
429 peer_key.clone(),
430 MeshPeerEntry {
431 peer_id,
432 direction: PeerDirection::Outbound,
433 state: ConnectionState::Discovered,
434 last_seen: Instant::now(),
435 peer: None::<TestSession>,
436 pool: PeerPool::Other,
437 transport: PeerTransport::WebRtc,
438 signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
439 bytes_sent: 0,
440 bytes_received: 0,
441 },
442 );
443
444 assert!(can_track_source_peer("relay", "peer-b", &peers, None));
445 assert!(can_track_source_peer(
446 "wifi-aware",
447 &peer_key,
448 &peers,
449 Some(1)
450 ));
451 assert!(!can_track_source_peer(
452 "wifi-aware",
453 "peer-b",
454 &peers,
455 Some(1),
456 ));
457 }
458
459 #[tokio::test]
460 async fn cleanup_stale_peers_removes_stale_entries_and_syncs_connected_count() {
461 let runtime = MeshRuntimeState::<TestSession>::new();
462 let stale_id = PeerId::new("peer-stale".to_string());
463 runtime.peers.write().await.insert(
464 stale_id.to_string(),
465 MeshPeerEntry {
466 peer_id: stale_id,
467 direction: PeerDirection::Outbound,
468 state: ConnectionState::Discovered,
469 last_seen: Instant::now() - Duration::from_secs(120),
470 peer: None,
471 pool: PeerPool::Other,
472 transport: PeerTransport::WebRtc,
473 signal_paths: BTreeSet::new(),
474 bytes_sent: 0,
475 bytes_received: 0,
476 },
477 );
478
479 let active_id = PeerId::new("peer-active".to_string());
480 runtime.peers.write().await.insert(
481 active_id.to_string(),
482 MeshPeerEntry {
483 peer_id: active_id.clone(),
484 direction: PeerDirection::Outbound,
485 state: ConnectionState::Connecting,
486 last_seen: Instant::now(),
487 peer: Some(TestSession {
488 connected: true,
489 close_delay: Duration::ZERO,
490 closed: Arc::new(AtomicBool::new(false)),
491 }),
492 pool: PeerPool::Other,
493 transport: PeerTransport::WebRtc,
494 signal_paths: BTreeSet::new(),
495 bytes_sent: 0,
496 bytes_received: 0,
497 },
498 );
499
500 cleanup_stale_peers(&runtime, Duration::from_secs(60)).await;
501
502 let peers = runtime.peers.read().await;
503 assert!(!peers.contains_key("peer-stale"));
504 assert_eq!(
505 peers.get(&active_id.to_string()).unwrap().state,
506 ConnectionState::Connected
507 );
508 assert_eq!(
509 runtime
510 .connected_count
511 .load(std::sync::atomic::Ordering::Relaxed),
512 1
513 );
514 }
515
516 #[tokio::test]
517 async fn handle_peer_state_event_does_not_hold_peer_map_lock_while_closing() {
518 let runtime = Arc::new(MeshRuntimeState::<TestSession>::new());
519 let peer_id = PeerId::new("peer-a-pub".to_string());
520 runtime.peers.write().await.insert(
521 peer_id.to_string(),
522 MeshPeerEntry {
523 peer_id: peer_id.clone(),
524 direction: PeerDirection::Outbound,
525 state: ConnectionState::Connected,
526 last_seen: Instant::now(),
527 peer: Some(TestSession {
528 connected: false,
529 close_delay: Duration::from_millis(200),
530 closed: Arc::new(AtomicBool::new(false)),
531 }),
532 pool: PeerPool::Other,
533 transport: PeerTransport::Bluetooth,
534 signal_paths: BTreeSet::from([PeerSignalPath::Bluetooth]),
535 bytes_sent: 0,
536 bytes_received: 0,
537 },
538 );
539
540 let runtime_for_task = runtime.clone();
541 let peer_id_for_task = peer_id.clone();
542 let cleanup_task = tokio::spawn(async move {
543 handle_peer_state_event::<
544 TestSession,
545 crate::mock::MockRelayTransport,
546 crate::mock::MockConnectionFactory,
547 >(
548 runtime_for_task.as_ref(),
549 PeerStateEvent::Failed(peer_id_for_task),
550 None,
551 )
552 .await;
553 });
554
555 tokio::time::sleep(Duration::from_millis(20)).await;
556
557 let remaining = tokio::time::timeout(Duration::from_millis(50), async {
558 runtime.peers.read().await.len()
559 })
560 .await
561 .expect("peer map read should not block on close");
562
563 assert_eq!(remaining, 0);
564 cleanup_task.await.expect("cleanup task");
565 }
566
567 #[tokio::test]
568 async fn forward_mesh_frame_from_runtime_sends_to_connected_peers() {
569 let runtime = MeshRuntimeState::<TestSession>::new();
570 let closed = Arc::new(AtomicBool::new(false));
571 let peer_id = PeerId::new("peer-a".to_string());
572 runtime.peers.write().await.insert(
573 peer_id.to_string(),
574 MeshPeerEntry {
575 peer_id: peer_id.clone(),
576 direction: PeerDirection::Outbound,
577 state: ConnectionState::Connected,
578 last_seen: Instant::now(),
579 peer: Some(TestSession {
580 connected: true,
581 close_delay: Duration::ZERO,
582 closed: closed.clone(),
583 }),
584 pool: PeerPool::Other,
585 transport: PeerTransport::WebRtc,
586 signal_paths: BTreeSet::new(),
587 bytes_sent: 0,
588 bytes_received: 0,
589 },
590 );
591 let keys = Keys::generate();
592 let event = EventBuilder::new(Kind::Custom(25050), "mesh", [])
593 .to_event(&keys)
594 .unwrap();
595 let frame = MeshNostrFrame::new_event_with_id(event, "sender", "frame-1", 4);
596
597 let forwarded = forward_mesh_frame_from_runtime(&runtime, &frame, None).await;
598 assert_eq!(forwarded, 1);
599 assert!(!closed.load(Ordering::Relaxed));
600 }
601}