1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17pub struct PeerEntry {
19 pub channel: Arc<dyn PeerLink>,
20 pub pool: PeerPool,
21 pub hash_get: bool,
22}
23
24pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
35 peer_id: String,
37 transport: Arc<R>,
39 conn_factory: Arc<F>,
41 peers: RwLock<HashMap<String, PeerEntry>>,
43 pending_offers: RwLock<HashMap<String, ()>>,
45 pools: PoolSettings,
47 peer_roots: RwLock<HashMap<String, Vec<String>>>,
49 peer_hash_get: RwLock<HashMap<String, bool>>,
51 classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
53 debug: bool,
55 hash_get_enabled: bool,
57}
58
59impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
60 pub fn new(
62 peer_id: String,
63 transport: Arc<R>,
64 conn_factory: Arc<F>,
65 pools: PoolSettings,
66 debug: bool,
67 ) -> Self {
68 Self {
69 peer_id,
70 transport,
71 conn_factory,
72 peers: RwLock::new(HashMap::new()),
73 pending_offers: RwLock::new(HashMap::new()),
74 pools,
75 peer_roots: RwLock::new(HashMap::new()),
76 peer_hash_get: RwLock::new(HashMap::new()),
77 classifier_tx: None,
78 debug,
79 hash_get_enabled: true,
80 }
81 }
82
83 pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
85 self.classifier_tx = Some(tx);
86 }
87
88 pub fn set_hash_get_enabled(&mut self, enabled: bool) {
89 self.hash_get_enabled = enabled;
90 }
91
92 pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
93 self.peer_hash_get
94 .write()
95 .await
96 .insert(peer_id.to_string(), enabled);
97 if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
98 entry.hash_get = enabled;
99 }
100 }
101
102 pub async fn peer_supports_hash_get(&self, peer_id: &str) -> bool {
103 self.peer_hash_get
104 .read()
105 .await
106 .get(peer_id)
107 .copied()
108 .unwrap_or(true)
109 }
110
111 pub async fn hash_get_peer_ids(&self) -> Vec<String> {
112 let peers = self.peers.read().await;
113 let peer_hash_get = self.peer_hash_get.read().await;
114 peers
115 .keys()
116 .filter(|peer_id| peer_hash_get.get(*peer_id).copied().unwrap_or(true))
117 .cloned()
118 .collect()
119 }
120
121 pub fn peer_id(&self) -> &str {
123 &self.peer_id
124 }
125
126 pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
128 let msg = SignalingMessage::Hello {
129 peer_id: self.peer_id.clone(),
130 roots,
131 hash_get: self.hash_get_enabled,
132 };
133 self.transport.publish(msg).await
134 }
135
136 async fn count_pools(&self) -> (usize, usize) {
138 let peers = self.peers.read().await;
139 let mut follows = 0;
140 let mut other = 0;
141 for entry in peers.values() {
142 match entry.pool {
143 PeerPool::Follows => follows += 1,
144 PeerPool::Other => other += 1,
145 }
146 }
147 (follows, other)
148 }
149
150 async fn classify_peer(&self, pubkey: &str) -> PeerPool {
152 if let Some(ref tx) = self.classifier_tx {
153 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
154 let request = ClassifyRequest {
155 pubkey: pubkey.to_string(),
156 response: response_tx,
157 };
158 if tx.send(request).await.is_ok() {
159 if let Ok(pool) = response_rx.await {
160 return pool;
161 }
162 }
163 }
164 PeerPool::Other
165 }
166
167 fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
169 match pool {
170 PeerPool::Follows => self.pools.follows.can_accept(follows),
171 PeerPool::Other => self.pools.other.can_accept(other),
172 }
173 }
174
175 fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
177 match pool {
178 PeerPool::Follows => self.pools.follows.needs_peers(follows),
179 PeerPool::Other => self.pools.other.needs_peers(other),
180 }
181 }
182
183 pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
187 match &msg {
188 SignalingMessage::Hello {
189 peer_id,
190 roots,
191 hash_get,
192 ..
193 } => {
194 self.set_peer_hash_get(peer_id, *hash_get).await;
195 self.handle_hello(peer_id, roots, *hash_get).await
196 }
197 SignalingMessage::Offer {
198 peer_id,
199 target_peer_id,
200 sdp,
201 } => {
202 if target_peer_id == &self.peer_id {
203 self.handle_offer(peer_id, sdp).await
204 } else {
205 Ok(()) }
207 }
208 SignalingMessage::Answer {
209 peer_id,
210 target_peer_id,
211 sdp,
212 } => {
213 if target_peer_id == &self.peer_id {
214 self.handle_answer(peer_id, sdp).await
215 } else {
216 Ok(()) }
218 }
219 SignalingMessage::Candidate {
220 peer_id,
221 target_peer_id,
222 candidate,
223 sdp_m_line_index,
224 sdp_mid,
225 } => {
226 if target_peer_id == &self.peer_id {
227 self.conn_factory
228 .handle_candidate(
229 peer_id,
230 crate::types::IceCandidate {
231 candidate: candidate.clone(),
232 sdp_m_line_index: *sdp_m_line_index,
233 sdp_mid: sdp_mid.clone(),
234 },
235 )
236 .await
237 } else {
238 Ok(())
239 }
240 }
241 SignalingMessage::Candidates {
242 peer_id,
243 target_peer_id,
244 candidates,
245 } => {
246 if target_peer_id == &self.peer_id {
247 self.conn_factory
248 .handle_candidates(peer_id, candidates.clone())
249 .await
250 } else {
251 Ok(())
252 }
253 }
254 }
255 }
256
257 async fn handle_hello(
262 &self,
263 from_peer_id: &str,
264 roots: &[String],
265 hash_get: bool,
266 ) -> Result<(), TransportError> {
267 if from_peer_id == self.peer_id {
269 return Ok(());
270 }
271
272 self.peer_roots
273 .write()
274 .await
275 .insert(from_peer_id.to_string(), roots.to_vec());
276 if let Some(entry) = self.peers.write().await.get_mut(from_peer_id) {
277 entry.hash_get = hash_get;
278 return Ok(());
279 }
280
281 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
282 .map(|peer_id| peer_id.pubkey)
283 .unwrap_or_else(|| from_peer_id.to_string());
284
285 let pool = self.classify_peer(&peer_pubkey).await;
287
288 let (follows_count, other_count) = self.count_pools().await;
290
291 if !self.can_accept_peer(pool, follows_count, other_count) {
292 if self.debug {
293 println!(
294 "[Signaling] Ignoring hello from {} - {:?} pool full",
295 from_peer_id, pool
296 );
297 }
298 return Ok(());
299 }
300
301 if self.pool_needs_peers(pool, follows_count, other_count) {
304 if self.peers.read().await.contains_key(from_peer_id) {
306 return Ok(());
307 }
308 if self.pending_offers.read().await.contains_key(from_peer_id) {
309 return Ok(());
310 }
311
312 if self.debug {
313 println!(
314 "[Signaling] Sending offer to {} (pool: {:?})",
315 from_peer_id, pool
316 );
317 }
318
319 self.pending_offers
321 .write()
322 .await
323 .insert(from_peer_id.to_string(), ());
324
325 let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
327
328 self.peers.write().await.insert(
330 from_peer_id.to_string(),
331 PeerEntry {
332 channel,
333 pool,
334 hash_get,
335 },
336 );
337
338 let offer_msg = SignalingMessage::Offer {
340 peer_id: self.peer_id.clone(),
341 target_peer_id: from_peer_id.to_string(),
342 sdp,
343 };
344 self.transport.publish(offer_msg).await?;
345 }
346
347 Ok(())
348 }
349
350 async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
355 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
357 .map(|peer_id| peer_id.pubkey)
358 .unwrap_or_else(|| from_peer_id.to_string());
359
360 let pool = self.classify_peer(&peer_pubkey).await;
362 let (follows_count, other_count) = self.count_pools().await;
363
364 if !self.can_accept_peer(pool, follows_count, other_count) {
365 if self.debug {
366 println!(
367 "[Signaling] Ignoring offer from {} - {:?} pool full",
368 from_peer_id, pool
369 );
370 }
371 return Ok(());
372 }
373
374 let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
376 if have_pending {
377 let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
379
380 if we_are_polite {
381 self.pending_offers.write().await.remove(from_peer_id);
384 self.peers.write().await.remove(from_peer_id);
385
386 if self.debug {
387 println!(
388 "[Signaling] Collision with {} - we're polite, accepting their offer",
389 from_peer_id
390 );
391 }
392 } else {
393 if self.debug {
395 println!(
396 "[Signaling] Collision with {} - we're impolite, ignoring their offer",
397 from_peer_id
398 );
399 }
400 return Ok(());
401 }
402 }
403
404 if self.peers.read().await.contains_key(from_peer_id) {
406 return Ok(());
407 }
408
409 if self.debug {
410 println!("[Signaling] Accepting offer from {}", from_peer_id);
411 }
412
413 let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
415 let hash_get = self.peer_supports_hash_get(from_peer_id).await;
416
417 self.peers.write().await.insert(
419 from_peer_id.to_string(),
420 PeerEntry {
421 channel,
422 pool,
423 hash_get,
424 },
425 );
426
427 let answer_msg = SignalingMessage::Answer {
429 peer_id: self.peer_id.clone(),
430 target_peer_id: from_peer_id.to_string(),
431 sdp: answer_sdp,
432 };
433 self.transport.publish(answer_msg).await?;
434
435 Ok(())
436 }
437
438 async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
440 if self.debug {
441 println!("[Signaling] Received answer from {}", from_peer_id);
442 }
443
444 let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
446
447 Ok(())
451 }
452
453 pub async fn peer_count(&self) -> usize {
455 self.peers.read().await.len()
456 }
457
458 pub async fn peer_ids(&self) -> Vec<String> {
460 let mut peer_ids = self.peers.read().await.keys().cloned().collect::<Vec<_>>();
461 peer_ids.sort();
462 peer_ids
463 }
464
465 pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
467 self.peers
468 .read()
469 .await
470 .get(peer_id)
471 .map(|e| e.channel.clone())
472 }
473
474 pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
476 self.pending_offers.write().await.remove(peer_id);
477 self.peer_roots.write().await.remove(peer_id);
478 let _ = self.conn_factory.remove_peer(peer_id).await;
479 self.peers
480 .write()
481 .await
482 .remove(peer_id)
483 .map(|entry| entry.channel)
484 }
485
486 pub async fn needs_peers(&self) -> bool {
488 let (follows, other) = self.count_pools().await;
489 self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
490 }
491
492 pub async fn can_accept(&self) -> bool {
494 let (follows, other) = self.count_pools().await;
495 self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use async_trait::async_trait;
503 use std::sync::Arc;
504 use tokio::sync::Mutex;
505
506 use crate::types::{IceCandidate, PoolConfig, PoolSettings};
507
508 #[derive(Default)]
509 struct NoopTransport;
510
511 #[async_trait]
512 impl SignalingTransport for NoopTransport {
513 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
514 Ok(())
515 }
516
517 async fn disconnect(&self) {}
518
519 async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
520 Ok(())
521 }
522
523 async fn recv(&self) -> Option<SignalingMessage> {
524 None
525 }
526
527 fn try_recv(&self) -> Option<SignalingMessage> {
528 None
529 }
530
531 fn peer_id(&self) -> &str {
532 "local"
533 }
534 }
535
536 #[derive(Default)]
537 struct RecordingFactory {
538 candidates: Mutex<Vec<(String, IceCandidate)>>,
539 removed: Mutex<Vec<String>>,
540 }
541
542 #[async_trait]
543 impl PeerLinkFactory for RecordingFactory {
544 async fn create_offer(
545 &self,
546 _target_peer_id: &str,
547 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
548 Err(TransportError::ConnectionFailed(
549 "not used in this test".to_string(),
550 ))
551 }
552
553 async fn accept_offer(
554 &self,
555 _from_peer_id: &str,
556 _offer_sdp: &str,
557 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
558 Err(TransportError::ConnectionFailed(
559 "not used in this test".to_string(),
560 ))
561 }
562
563 async fn handle_answer(
564 &self,
565 _target_peer_id: &str,
566 _answer_sdp: &str,
567 ) -> Result<Arc<dyn PeerLink>, TransportError> {
568 Err(TransportError::ConnectionFailed(
569 "not used in this test".to_string(),
570 ))
571 }
572
573 async fn handle_candidate(
574 &self,
575 peer_id: &str,
576 candidate: IceCandidate,
577 ) -> Result<(), TransportError> {
578 self.candidates
579 .lock()
580 .await
581 .push((peer_id.to_string(), candidate));
582 Ok(())
583 }
584
585 async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
586 self.removed.lock().await.push(peer_id.to_string());
587 Ok(())
588 }
589 }
590
591 #[tokio::test]
592 async fn routes_targeted_candidates_to_factory() {
593 let router = MeshRouter::new(
594 "local".to_string(),
595 Arc::new(NoopTransport),
596 Arc::new(RecordingFactory::default()),
597 PoolSettings {
598 follows: PoolConfig::default(),
599 other: PoolConfig::default(),
600 },
601 false,
602 );
603
604 router
605 .handle_message(SignalingMessage::Candidate {
606 peer_id: "remote:peer".to_string(),
607 target_peer_id: "local".to_string(),
608 candidate: "candidate:1".to_string(),
609 sdp_m_line_index: Some(0),
610 sdp_mid: Some("data".to_string()),
611 })
612 .await
613 .expect("candidate should route");
614
615 let factory = router.conn_factory.clone();
616 let recorded = factory
617 .candidates
618 .lock()
619 .await
620 .iter()
621 .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
622 .collect::<Vec<_>>();
623
624 assert_eq!(
625 recorded,
626 vec![("remote:peer".to_string(), "candidate:1".to_string())]
627 );
628 }
629
630 #[tokio::test]
631 async fn remove_peer_cleans_factory_state() {
632 let factory = Arc::new(RecordingFactory::default());
633 let router = MeshRouter::new(
634 "local".to_string(),
635 Arc::new(NoopTransport),
636 factory.clone(),
637 PoolSettings {
638 follows: PoolConfig::default(),
639 other: PoolConfig::default(),
640 },
641 false,
642 );
643
644 let removed = router.remove_peer("remote:peer").await;
645 assert!(removed.is_none());
646 assert_eq!(
647 factory.removed.lock().await.as_slice(),
648 &["remote:peer".to_string()]
649 );
650 }
651}