1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17pub struct PeerEntry {
19 pub channel: Arc<dyn PeerLink>,
20 pub pool: PeerPool,
21 pub hash_get: bool,
22}
23
24pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
35 peer_id: String,
37 transport: Arc<R>,
39 conn_factory: Arc<F>,
41 peers: RwLock<HashMap<String, PeerEntry>>,
43 pending_offers: RwLock<HashMap<String, ()>>,
45 pools: PoolSettings,
47 peer_roots: RwLock<HashMap<String, Vec<String>>>,
49 peer_hash_get: RwLock<HashMap<String, bool>>,
51 classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
53 debug: bool,
55 hash_get_enabled: bool,
57}
58
59impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
60 pub fn new(
62 peer_id: String,
63 transport: Arc<R>,
64 conn_factory: Arc<F>,
65 pools: PoolSettings,
66 debug: bool,
67 ) -> Self {
68 Self {
69 peer_id,
70 transport,
71 conn_factory,
72 peers: RwLock::new(HashMap::new()),
73 pending_offers: RwLock::new(HashMap::new()),
74 pools,
75 peer_roots: RwLock::new(HashMap::new()),
76 peer_hash_get: RwLock::new(HashMap::new()),
77 classifier_tx: None,
78 debug,
79 hash_get_enabled: true,
80 }
81 }
82
83 pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
85 self.classifier_tx = Some(tx);
86 }
87
88 pub fn set_hash_get_enabled(&mut self, enabled: bool) {
89 self.hash_get_enabled = enabled;
90 }
91
92 pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
93 self.peer_hash_get
94 .write()
95 .await
96 .insert(peer_id.to_string(), enabled);
97 if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
98 entry.hash_get = enabled;
99 }
100 }
101
102 pub async fn peer_supports_hash_get(&self, peer_id: &str) -> bool {
103 self.peer_hash_get
104 .read()
105 .await
106 .get(peer_id)
107 .copied()
108 .unwrap_or(true)
109 }
110
111 pub async fn hash_get_peer_ids(&self) -> Vec<String> {
112 let peers = self.peers.read().await;
113 let peer_hash_get = self.peer_hash_get.read().await;
114 peers
115 .keys()
116 .filter(|peer_id| peer_hash_get.get(*peer_id).copied().unwrap_or(true))
117 .cloned()
118 .collect()
119 }
120
121 pub fn peer_id(&self) -> &str {
123 &self.peer_id
124 }
125
126 pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
128 let msg = SignalingMessage::Hello {
129 peer_id: self.peer_id.clone(),
130 roots,
131 hash_get: self.hash_get_enabled,
132 };
133 self.transport.publish(msg).await
134 }
135
136 async fn count_pools(&self) -> (usize, usize) {
138 let peers = self.peers.read().await;
139 let mut follows = 0;
140 let mut other = 0;
141 for entry in peers.values() {
142 match entry.pool {
143 PeerPool::Follows => follows += 1,
144 PeerPool::Other => other += 1,
145 }
146 }
147 (follows, other)
148 }
149
150 async fn classify_peer(&self, pubkey: &str) -> PeerPool {
152 if let Some(ref tx) = self.classifier_tx {
153 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
154 let request = ClassifyRequest {
155 pubkey: pubkey.to_string(),
156 response: response_tx,
157 };
158 if tx.send(request).await.is_ok() {
159 if let Ok(pool) = response_rx.await {
160 return pool;
161 }
162 }
163 }
164 PeerPool::Other
165 }
166
167 fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
169 match pool {
170 PeerPool::Follows => self.pools.follows.can_accept(follows),
171 PeerPool::Other => self.pools.other.can_accept(other),
172 }
173 }
174
175 fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
177 match pool {
178 PeerPool::Follows => self.pools.follows.needs_peers(follows),
179 PeerPool::Other => self.pools.other.needs_peers(other),
180 }
181 }
182
183 pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
187 match &msg {
188 SignalingMessage::Hello {
189 peer_id,
190 roots,
191 hash_get,
192 } => {
193 self.set_peer_hash_get(peer_id, *hash_get).await;
194 self.handle_hello(peer_id, roots, *hash_get).await
195 }
196 SignalingMessage::Offer {
197 peer_id,
198 target_peer_id,
199 sdp,
200 } => {
201 if target_peer_id == &self.peer_id {
202 self.handle_offer(peer_id, sdp).await
203 } else {
204 Ok(()) }
206 }
207 SignalingMessage::Answer {
208 peer_id,
209 target_peer_id,
210 sdp,
211 } => {
212 if target_peer_id == &self.peer_id {
213 self.handle_answer(peer_id, sdp).await
214 } else {
215 Ok(()) }
217 }
218 SignalingMessage::Candidate {
219 peer_id,
220 target_peer_id,
221 candidate,
222 sdp_m_line_index,
223 sdp_mid,
224 } => {
225 if target_peer_id == &self.peer_id {
226 self.conn_factory
227 .handle_candidate(
228 peer_id,
229 crate::types::IceCandidate {
230 candidate: candidate.clone(),
231 sdp_m_line_index: *sdp_m_line_index,
232 sdp_mid: sdp_mid.clone(),
233 },
234 )
235 .await
236 } else {
237 Ok(())
238 }
239 }
240 SignalingMessage::Candidates {
241 peer_id,
242 target_peer_id,
243 candidates,
244 } => {
245 if target_peer_id == &self.peer_id {
246 self.conn_factory
247 .handle_candidates(peer_id, candidates.clone())
248 .await
249 } else {
250 Ok(())
251 }
252 }
253 }
254 }
255
256 async fn handle_hello(
261 &self,
262 from_peer_id: &str,
263 roots: &[String],
264 hash_get: bool,
265 ) -> Result<(), TransportError> {
266 if from_peer_id == self.peer_id {
268 return Ok(());
269 }
270
271 self.peer_roots
272 .write()
273 .await
274 .insert(from_peer_id.to_string(), roots.to_vec());
275 if let Some(entry) = self.peers.write().await.get_mut(from_peer_id) {
276 entry.hash_get = hash_get;
277 return Ok(());
278 }
279
280 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
281 .map(|peer_id| peer_id.pubkey)
282 .unwrap_or_else(|| from_peer_id.to_string());
283
284 let pool = self.classify_peer(&peer_pubkey).await;
286
287 let (follows_count, other_count) = self.count_pools().await;
289
290 if !self.can_accept_peer(pool, follows_count, other_count) {
291 if self.debug {
292 println!(
293 "[Signaling] Ignoring hello from {} - {:?} pool full",
294 from_peer_id, pool
295 );
296 }
297 return Ok(());
298 }
299
300 if self.pool_needs_peers(pool, follows_count, other_count) {
303 if self.peers.read().await.contains_key(from_peer_id) {
305 return Ok(());
306 }
307 if self.pending_offers.read().await.contains_key(from_peer_id) {
308 return Ok(());
309 }
310
311 if self.debug {
312 println!(
313 "[Signaling] Sending offer to {} (pool: {:?})",
314 from_peer_id, pool
315 );
316 }
317
318 self.pending_offers
320 .write()
321 .await
322 .insert(from_peer_id.to_string(), ());
323
324 let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
326
327 self.peers.write().await.insert(
329 from_peer_id.to_string(),
330 PeerEntry {
331 channel,
332 pool,
333 hash_get,
334 },
335 );
336
337 let offer_msg = SignalingMessage::Offer {
339 peer_id: self.peer_id.clone(),
340 target_peer_id: from_peer_id.to_string(),
341 sdp,
342 };
343 self.transport.publish(offer_msg).await?;
344 }
345
346 Ok(())
347 }
348
349 async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
354 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
356 .map(|peer_id| peer_id.pubkey)
357 .unwrap_or_else(|| from_peer_id.to_string());
358
359 let pool = self.classify_peer(&peer_pubkey).await;
361 let (follows_count, other_count) = self.count_pools().await;
362
363 if !self.can_accept_peer(pool, follows_count, other_count) {
364 if self.debug {
365 println!(
366 "[Signaling] Ignoring offer from {} - {:?} pool full",
367 from_peer_id, pool
368 );
369 }
370 return Ok(());
371 }
372
373 let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
375 if have_pending {
376 let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
378
379 if we_are_polite {
380 self.pending_offers.write().await.remove(from_peer_id);
383 self.peers.write().await.remove(from_peer_id);
384
385 if self.debug {
386 println!(
387 "[Signaling] Collision with {} - we're polite, accepting their offer",
388 from_peer_id
389 );
390 }
391 } else {
392 if self.debug {
394 println!(
395 "[Signaling] Collision with {} - we're impolite, ignoring their offer",
396 from_peer_id
397 );
398 }
399 return Ok(());
400 }
401 }
402
403 if self.peers.read().await.contains_key(from_peer_id) {
405 return Ok(());
406 }
407
408 if self.debug {
409 println!("[Signaling] Accepting offer from {}", from_peer_id);
410 }
411
412 let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
414 let hash_get = self.peer_supports_hash_get(from_peer_id).await;
415
416 self.peers.write().await.insert(
418 from_peer_id.to_string(),
419 PeerEntry {
420 channel,
421 pool,
422 hash_get,
423 },
424 );
425
426 let answer_msg = SignalingMessage::Answer {
428 peer_id: self.peer_id.clone(),
429 target_peer_id: from_peer_id.to_string(),
430 sdp: answer_sdp,
431 };
432 self.transport.publish(answer_msg).await?;
433
434 Ok(())
435 }
436
437 async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
439 if self.debug {
440 println!("[Signaling] Received answer from {}", from_peer_id);
441 }
442
443 let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
445
446 Ok(())
450 }
451
452 pub async fn peer_count(&self) -> usize {
454 self.peers.read().await.len()
455 }
456
457 pub async fn peer_ids(&self) -> Vec<String> {
459 self.peers.read().await.keys().cloned().collect()
460 }
461
462 pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
464 self.peers
465 .read()
466 .await
467 .get(peer_id)
468 .map(|e| e.channel.clone())
469 }
470
471 pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
473 self.pending_offers.write().await.remove(peer_id);
474 self.peer_roots.write().await.remove(peer_id);
475 let _ = self.conn_factory.remove_peer(peer_id).await;
476 self.peers
477 .write()
478 .await
479 .remove(peer_id)
480 .map(|entry| entry.channel)
481 }
482
483 pub async fn needs_peers(&self) -> bool {
485 let (follows, other) = self.count_pools().await;
486 self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
487 }
488
489 pub async fn can_accept(&self) -> bool {
491 let (follows, other) = self.count_pools().await;
492 self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use async_trait::async_trait;
500 use std::sync::Arc;
501 use tokio::sync::Mutex;
502
503 use crate::types::{IceCandidate, PoolConfig, PoolSettings};
504
505 #[derive(Default)]
506 struct NoopTransport;
507
508 #[async_trait]
509 impl SignalingTransport for NoopTransport {
510 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
511 Ok(())
512 }
513
514 async fn disconnect(&self) {}
515
516 async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
517 Ok(())
518 }
519
520 async fn recv(&self) -> Option<SignalingMessage> {
521 None
522 }
523
524 fn try_recv(&self) -> Option<SignalingMessage> {
525 None
526 }
527
528 fn peer_id(&self) -> &str {
529 "local"
530 }
531 }
532
533 #[derive(Default)]
534 struct RecordingFactory {
535 candidates: Mutex<Vec<(String, IceCandidate)>>,
536 removed: Mutex<Vec<String>>,
537 }
538
539 #[async_trait]
540 impl PeerLinkFactory for RecordingFactory {
541 async fn create_offer(
542 &self,
543 _target_peer_id: &str,
544 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
545 Err(TransportError::ConnectionFailed(
546 "not used in this test".to_string(),
547 ))
548 }
549
550 async fn accept_offer(
551 &self,
552 _from_peer_id: &str,
553 _offer_sdp: &str,
554 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
555 Err(TransportError::ConnectionFailed(
556 "not used in this test".to_string(),
557 ))
558 }
559
560 async fn handle_answer(
561 &self,
562 _target_peer_id: &str,
563 _answer_sdp: &str,
564 ) -> Result<Arc<dyn PeerLink>, TransportError> {
565 Err(TransportError::ConnectionFailed(
566 "not used in this test".to_string(),
567 ))
568 }
569
570 async fn handle_candidate(
571 &self,
572 peer_id: &str,
573 candidate: IceCandidate,
574 ) -> Result<(), TransportError> {
575 self.candidates
576 .lock()
577 .await
578 .push((peer_id.to_string(), candidate));
579 Ok(())
580 }
581
582 async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
583 self.removed.lock().await.push(peer_id.to_string());
584 Ok(())
585 }
586 }
587
588 #[tokio::test]
589 async fn routes_targeted_candidates_to_factory() {
590 let router = MeshRouter::new(
591 "local".to_string(),
592 Arc::new(NoopTransport),
593 Arc::new(RecordingFactory::default()),
594 PoolSettings {
595 follows: PoolConfig::default(),
596 other: PoolConfig::default(),
597 },
598 false,
599 );
600
601 router
602 .handle_message(SignalingMessage::Candidate {
603 peer_id: "remote:peer".to_string(),
604 target_peer_id: "local".to_string(),
605 candidate: "candidate:1".to_string(),
606 sdp_m_line_index: Some(0),
607 sdp_mid: Some("data".to_string()),
608 })
609 .await
610 .expect("candidate should route");
611
612 let factory = router.conn_factory.clone();
613 let recorded = factory
614 .candidates
615 .lock()
616 .await
617 .iter()
618 .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
619 .collect::<Vec<_>>();
620
621 assert_eq!(
622 recorded,
623 vec![("remote:peer".to_string(), "candidate:1".to_string())]
624 );
625 }
626
627 #[tokio::test]
628 async fn remove_peer_cleans_factory_state() {
629 let factory = Arc::new(RecordingFactory::default());
630 let router = MeshRouter::new(
631 "local".to_string(),
632 Arc::new(NoopTransport),
633 factory.clone(),
634 PoolSettings {
635 follows: PoolConfig::default(),
636 other: PoolConfig::default(),
637 },
638 false,
639 );
640
641 let removed = router.remove_peer("remote:peer").await;
642 assert!(removed.is_none());
643 assert_eq!(
644 factory.removed.lock().await.as_slice(),
645 &["remote:peer".to_string()]
646 );
647 }
648}