1use std::collections::{BTreeSet, HashMap};
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::time::Instant;
5
6use tokio::sync::RwLock;
7
8use crate::mesh_session::MeshSession;
9use crate::types::{PeerId, PeerPool, PoolSettings};
10
11pub type PeerClassifier = Arc<dyn Fn(&str) -> PeerPool + Send + Sync>;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub enum PeerTransport {
16 WebRtc,
17 Bluetooth,
18}
19
20impl PeerTransport {
21 pub const fn as_str(self) -> &'static str {
22 match self {
23 PeerTransport::WebRtc => "webrtc",
24 PeerTransport::Bluetooth => "bluetooth",
25 }
26 }
27}
28
29impl std::fmt::Display for PeerTransport {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.write_str((*self).as_str())
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
37pub enum PeerSignalPath {
38 Relay,
39 Multicast,
40 WifiAware,
41 Bluetooth,
42}
43
44impl PeerSignalPath {
45 pub const fn as_str(self) -> &'static str {
46 match self {
47 PeerSignalPath::Relay => "relay",
48 PeerSignalPath::Multicast => "multicast",
49 PeerSignalPath::WifiAware => "wifi-aware",
50 PeerSignalPath::Bluetooth => "bluetooth",
51 }
52 }
53
54 pub fn from_source_name(source: &str) -> Self {
55 match source {
56 "multicast" => PeerSignalPath::Multicast,
57 "wifi-aware" => PeerSignalPath::WifiAware,
58 "bluetooth" => PeerSignalPath::Bluetooth,
59 _ => PeerSignalPath::Relay,
60 }
61 }
62}
63
64impl std::fmt::Display for PeerSignalPath {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.write_str((*self).as_str())
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PeerDirection {
73 Inbound,
74 Outbound,
75}
76
77impl std::fmt::Display for PeerDirection {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 match self {
80 PeerDirection::Inbound => write!(f, "inbound"),
81 PeerDirection::Outbound => write!(f, "outbound"),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum ConnectionState {
89 Discovered,
90 Connecting,
91 Connected,
92 Failed,
93}
94
95impl std::fmt::Display for ConnectionState {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 match self {
98 ConnectionState::Discovered => write!(f, "discovered"),
99 ConnectionState::Connecting => write!(f, "connecting"),
100 ConnectionState::Connected => write!(f, "connected"),
101 ConnectionState::Failed => write!(f, "failed"),
102 }
103 }
104}
105
106pub struct MeshPeerEntry<P> {
108 pub peer_id: PeerId,
109 pub direction: PeerDirection,
110 pub state: ConnectionState,
111 pub last_seen: Instant,
112 pub peer: Option<P>,
113 pub pool: PeerPool,
114 pub transport: PeerTransport,
115 pub signal_paths: BTreeSet<PeerSignalPath>,
116 pub bytes_sent: u64,
117 pub bytes_received: u64,
118}
119
120pub async fn remember_peer_signal_path<P>(
121 peers: &RwLock<HashMap<String, MeshPeerEntry<P>>>,
122 peer_id: &str,
123 source: &str,
124) {
125 if let Some(entry) = peers.write().await.get_mut(peer_id) {
126 entry
127 .signal_paths
128 .insert(PeerSignalPath::from_source_name(source));
129 }
130}
131
132pub fn can_track_signal_path_peer<P>(
133 signal_path: PeerSignalPath,
134 max_peers: usize,
135 peer_key: &str,
136 peers: &HashMap<String, MeshPeerEntry<P>>,
137) -> bool {
138 if peers.contains_key(peer_key) {
139 return true;
140 }
141 if max_peers == 0 {
142 return false;
143 }
144 peers
145 .values()
146 .filter(|entry| {
147 entry.signal_paths.contains(&signal_path) && entry.state != ConnectionState::Failed
148 })
149 .count()
150 < max_peers
151}
152
153#[derive(Clone)]
155pub struct TransportPeerRegistrar<P> {
156 peers: Arc<RwLock<HashMap<String, MeshPeerEntry<P>>>>,
157 connected_count: Arc<AtomicUsize>,
158 peer_classifier: PeerClassifier,
159 pools: PoolSettings,
160 transport: PeerTransport,
161 signal_path: PeerSignalPath,
162 max_transport_peers: usize,
163}
164
165impl<P> TransportPeerRegistrar<P>
166where
167 P: MeshSession + Send + Sync + 'static,
168{
169 pub fn new(
170 peers: Arc<RwLock<HashMap<String, MeshPeerEntry<P>>>>,
171 connected_count: Arc<AtomicUsize>,
172 peer_classifier: PeerClassifier,
173 pools: PoolSettings,
174 transport: PeerTransport,
175 signal_path: PeerSignalPath,
176 max_transport_peers: usize,
177 ) -> Self {
178 Self {
179 peers,
180 connected_count,
181 peer_classifier,
182 pools,
183 transport,
184 signal_path,
185 max_transport_peers,
186 }
187 }
188
189 async fn pool_counts(&self) -> (usize, usize) {
190 let peers = self.peers.read().await;
191 let mut follows = 0usize;
192 let mut other = 0usize;
193 for entry in peers.values() {
194 if entry.state != ConnectionState::Connected {
195 continue;
196 }
197 match entry.pool {
198 PeerPool::Follows => follows += 1,
199 PeerPool::Other => other += 1,
200 }
201 }
202 (follows, other)
203 }
204
205 async fn transport_peer_count(&self, peer_key: &str) -> usize {
206 let peers = self.peers.read().await;
207 peers
208 .values()
209 .filter(|entry| entry.transport == self.transport)
210 .filter(|entry| entry.state == ConnectionState::Connected)
211 .filter(|entry| entry.peer_id.to_string() != peer_key)
212 .count()
213 }
214
215 pub async fn register_connected_peer(
216 &self,
217 peer_id: PeerId,
218 direction: PeerDirection,
219 peer: P,
220 ) -> bool {
221 let peer_key = peer_id.to_string();
222 let pool = (self.peer_classifier)(&peer_id.pubkey);
223 let (follows, other) = self.pool_counts().await;
224 let can_accept_pool = match pool {
225 PeerPool::Follows => follows < self.pools.follows.max_connections,
226 PeerPool::Other => other < self.pools.other.max_connections,
227 };
228 if !can_accept_pool {
229 return false;
230 }
231
232 if self.max_transport_peers == 0
233 || self.transport_peer_count(&peer_key).await >= self.max_transport_peers
234 {
235 return false;
236 }
237
238 let mut peers = self.peers.write().await;
239 let duplicate_keys = peers
240 .iter()
241 .filter(|(key, entry)| {
242 key.as_str() != peer_key
243 && entry.transport == self.transport
244 && entry.peer_id.pubkey == peer_id.pubkey
245 })
246 .map(|(key, _)| key.clone())
247 .collect::<Vec<_>>();
248 let was_connected = peers
249 .get(&peer_key)
250 .map(|entry| entry.state == ConnectionState::Connected)
251 .unwrap_or(false);
252 let replaced = peers.insert(
253 peer_key,
254 MeshPeerEntry {
255 peer_id,
256 direction,
257 state: ConnectionState::Connected,
258 last_seen: Instant::now(),
259 peer: Some(peer),
260 pool,
261 transport: self.transport,
262 signal_paths: BTreeSet::from([self.signal_path]),
263 bytes_sent: 0,
264 bytes_received: 0,
265 },
266 );
267 let removed_duplicates = duplicate_keys
268 .into_iter()
269 .filter_map(|key| peers.remove(&key))
270 .collect::<Vec<_>>();
271 drop(peers);
272
273 if let Some(previous) = replaced.and_then(|entry| entry.peer) {
274 let _ = previous.close().await;
275 }
276 for duplicate in &removed_duplicates {
277 if let Some(peer) = duplicate.peer.as_ref() {
278 let _ = peer.close().await;
279 }
280 }
281
282 let removed_connected_duplicates = removed_duplicates
283 .iter()
284 .filter(|entry| entry.state == ConnectionState::Connected)
285 .count() as isize;
286 let connected_delta =
287 1isize - if was_connected { 1 } else { 0 } - removed_connected_duplicates;
288 if connected_delta > 0 {
289 self.connected_count
290 .fetch_add(connected_delta as usize, Ordering::Relaxed);
291 } else if connected_delta < 0 {
292 self.connected_count
293 .fetch_sub((-connected_delta) as usize, Ordering::Relaxed);
294 }
295 true
296 }
297
298 pub async fn unregister_peer(&self, peer_id: &PeerId) {
299 let peer_key = peer_id.to_string();
300 let removed = self.peers.write().await.remove(&peer_key);
301 self.finish_unregister(removed).await;
302 }
303
304 pub async fn unregister_peer_if<F>(&self, peer_id: &PeerId, predicate: F)
305 where
306 F: FnOnce(&P) -> bool + Send,
307 {
308 let peer_key = peer_id.to_string();
309 let removed = {
310 let mut peers = self.peers.write().await;
311 let matches_current = peers
312 .get(&peer_key)
313 .and_then(|entry| entry.peer.as_ref())
314 .map(predicate)
315 .unwrap_or(false);
316 if matches_current {
317 peers.remove(&peer_key)
318 } else {
319 None
320 }
321 };
322 self.finish_unregister(removed).await;
323 }
324
325 async fn finish_unregister(&self, removed: Option<MeshPeerEntry<P>>) {
326 if let Some(entry) = removed {
327 if entry.state == ConnectionState::Connected {
328 self.connected_count.fetch_sub(1, Ordering::Relaxed);
329 }
330 if let Some(peer) = entry.peer {
331 let _ = peer.close().await;
332 }
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use anyhow::Result;
341 use async_trait::async_trait;
342 use nostr_sdk::nostr::{Event, Filter};
343 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
344 use std::time::Duration;
345
346 use crate::types::{MeshNostrFrame, PeerHTLConfig, PoolConfig};
347
348 struct TestSession {
349 closed: AtomicBool,
350 }
351
352 impl TestSession {
353 fn new() -> Self {
354 Self {
355 closed: AtomicBool::new(false),
356 }
357 }
358
359 fn is_closed(&self) -> bool {
360 self.closed.load(AtomicOrdering::Relaxed)
361 }
362 }
363
364 #[async_trait]
365 impl MeshSession for Arc<TestSession> {
366 fn is_ready(&self) -> bool {
367 true
368 }
369
370 fn is_connected(&self) -> bool {
371 true
372 }
373
374 fn htl_config(&self) -> PeerHTLConfig {
375 PeerHTLConfig::from_flags(false, false)
376 }
377
378 async fn request(&self, _hash_hex: &str, _timeout: Duration) -> Result<Option<Vec<u8>>> {
379 Ok(None)
380 }
381
382 async fn query_nostr_events(
383 &self,
384 _filters: Vec<Filter>,
385 _timeout: Duration,
386 ) -> Result<Vec<Event>> {
387 Ok(Vec::new())
388 }
389
390 async fn send_mesh_frame_text(&self, _frame: &MeshNostrFrame) -> Result<()> {
391 Ok(())
392 }
393
394 async fn close(&self) -> Result<()> {
395 self.closed.store(true, AtomicOrdering::Relaxed);
396 Ok(())
397 }
398 }
399
400 fn test_pools() -> PoolSettings {
401 PoolSettings {
402 follows: PoolConfig {
403 max_connections: 4,
404 satisfied_connections: 0,
405 },
406 other: PoolConfig {
407 max_connections: 4,
408 satisfied_connections: 0,
409 },
410 }
411 }
412
413 fn test_registrar() -> (
414 TransportPeerRegistrar<Arc<TestSession>>,
415 Arc<RwLock<HashMap<String, MeshPeerEntry<Arc<TestSession>>>>>,
416 Arc<AtomicUsize>,
417 ) {
418 let peers = Arc::new(RwLock::new(HashMap::new()));
419 let connected_count = Arc::new(AtomicUsize::new(0));
420 let registrar = TransportPeerRegistrar::new(
421 peers.clone(),
422 connected_count.clone(),
423 Arc::new(|_| PeerPool::Other),
424 test_pools(),
425 PeerTransport::Bluetooth,
426 PeerSignalPath::Bluetooth,
427 2,
428 );
429 (registrar, peers, connected_count)
430 }
431
432 #[tokio::test]
433 async fn register_connected_peer_closes_replaced_session() {
434 let (registrar, _peers, _connected_count) = test_registrar();
435 let peer_id = PeerId::new("peer-pub".to_string());
436 let first = Arc::new(TestSession::new());
437 let second = Arc::new(TestSession::new());
438
439 assert!(
440 registrar
441 .register_connected_peer(peer_id.clone(), PeerDirection::Outbound, first.clone())
442 .await
443 );
444 assert!(
445 registrar
446 .register_connected_peer(peer_id, PeerDirection::Outbound, second)
447 .await
448 );
449
450 assert!(first.is_closed());
451 }
452
453 #[tokio::test]
454 async fn register_connected_peer_replaces_existing_transport_session_for_same_pubkey() {
455 let (registrar, peers, connected_count) = test_registrar();
456 let first_peer_id = PeerId::new("peer-pub".to_string());
457 let second_peer_id = PeerId::new("peer-pub".to_string());
458 let first = Arc::new(TestSession::new());
459 let second = Arc::new(TestSession::new());
460
461 assert!(
462 registrar
463 .register_connected_peer(
464 first_peer_id.clone(),
465 PeerDirection::Outbound,
466 first.clone(),
467 )
468 .await
469 );
470 assert!(
471 registrar
472 .register_connected_peer(second_peer_id.clone(), PeerDirection::Outbound, second,)
473 .await
474 );
475
476 assert!(first.is_closed());
477 let peers = peers.read().await;
478 assert!(peers.contains_key(&second_peer_id.to_string()));
479 assert_eq!(peers.len(), 1);
480 assert_eq!(connected_count.load(Ordering::Relaxed), 1);
481 }
482
483 #[tokio::test]
484 async fn unregister_peer_if_respects_current_predicate() {
485 let (registrar, peers, connected_count) = test_registrar();
486 let peer_id = PeerId::new("peer-pub".to_string());
487 let session = Arc::new(TestSession::new());
488
489 assert!(
490 registrar
491 .register_connected_peer(peer_id.clone(), PeerDirection::Outbound, session.clone(),)
492 .await
493 );
494 registrar
495 .unregister_peer_if(&peer_id, |current| Arc::ptr_eq(current, &session))
496 .await;
497
498 assert!(session.is_closed());
499 assert!(!peers.read().await.contains_key(&peer_id.to_string()));
500 assert_eq!(connected_count.load(Ordering::Relaxed), 0);
501 }
502
503 #[test]
504 fn can_track_signal_path_peer_enforces_limit() {
505 let existing_peer = PeerId::new("peer-a".to_string());
506 let existing_key = existing_peer.to_string();
507 let mut peers = HashMap::new();
508 peers.insert(
509 existing_key.clone(),
510 MeshPeerEntry::<Arc<TestSession>> {
511 peer_id: existing_peer,
512 direction: PeerDirection::Outbound,
513 state: ConnectionState::Discovered,
514 last_seen: Instant::now(),
515 peer: None,
516 pool: PeerPool::Other,
517 transport: PeerTransport::WebRtc,
518 signal_paths: BTreeSet::from([PeerSignalPath::WifiAware]),
519 bytes_sent: 0,
520 bytes_received: 0,
521 },
522 );
523
524 assert!(can_track_signal_path_peer(
525 PeerSignalPath::WifiAware,
526 1,
527 &existing_key,
528 &peers
529 ));
530 assert!(!can_track_signal_path_peer(
531 PeerSignalPath::WifiAware,
532 1,
533 "peer-b",
534 &peers
535 ));
536 assert!(can_track_signal_path_peer(
537 PeerSignalPath::Relay,
538 1,
539 "peer-c",
540 &peers
541 ));
542 }
543}