1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17pub struct PeerEntry {
19 pub channel: Arc<dyn PeerLink>,
20 pub pool: PeerPool,
21}
22
23pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
34 peer_id: String,
36 transport: Arc<R>,
38 conn_factory: Arc<F>,
40 peers: RwLock<HashMap<String, PeerEntry>>,
42 pending_offers: RwLock<HashMap<String, ()>>,
44 pools: PoolSettings,
46 peer_roots: RwLock<HashMap<String, Vec<String>>>,
48 classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
50 debug: bool,
52}
53
54impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
55 pub fn new(
57 peer_id: String,
58 transport: Arc<R>,
59 conn_factory: Arc<F>,
60 pools: PoolSettings,
61 debug: bool,
62 ) -> Self {
63 Self {
64 peer_id,
65 transport,
66 conn_factory,
67 peers: RwLock::new(HashMap::new()),
68 pending_offers: RwLock::new(HashMap::new()),
69 pools,
70 peer_roots: RwLock::new(HashMap::new()),
71 classifier_tx: None,
72 debug,
73 }
74 }
75
76 pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
78 self.classifier_tx = Some(tx);
79 }
80
81 pub fn peer_id(&self) -> &str {
83 &self.peer_id
84 }
85
86 pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
88 let msg = SignalingMessage::Hello {
89 peer_id: self.peer_id.clone(),
90 roots,
91 };
92 self.transport.publish(msg).await
93 }
94
95 async fn count_pools(&self) -> (usize, usize) {
97 let peers = self.peers.read().await;
98 let mut follows = 0;
99 let mut other = 0;
100 for entry in peers.values() {
101 match entry.pool {
102 PeerPool::Follows => follows += 1,
103 PeerPool::Other => other += 1,
104 }
105 }
106 (follows, other)
107 }
108
109 async fn classify_peer(&self, pubkey: &str) -> PeerPool {
111 if let Some(ref tx) = self.classifier_tx {
112 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
113 let request = ClassifyRequest {
114 pubkey: pubkey.to_string(),
115 response: response_tx,
116 };
117 if tx.send(request).await.is_ok() {
118 if let Ok(pool) = response_rx.await {
119 return pool;
120 }
121 }
122 }
123 PeerPool::Other
124 }
125
126 fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
128 match pool {
129 PeerPool::Follows => self.pools.follows.can_accept(follows),
130 PeerPool::Other => self.pools.other.can_accept(other),
131 }
132 }
133
134 fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
136 match pool {
137 PeerPool::Follows => self.pools.follows.needs_peers(follows),
138 PeerPool::Other => self.pools.other.needs_peers(other),
139 }
140 }
141
142 pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
146 match &msg {
147 SignalingMessage::Hello { peer_id, roots } => self.handle_hello(peer_id, roots).await,
148 SignalingMessage::Offer {
149 peer_id,
150 target_peer_id,
151 sdp,
152 } => {
153 if target_peer_id == &self.peer_id {
154 self.handle_offer(peer_id, sdp).await
155 } else {
156 Ok(()) }
158 }
159 SignalingMessage::Answer {
160 peer_id,
161 target_peer_id,
162 sdp,
163 } => {
164 if target_peer_id == &self.peer_id {
165 self.handle_answer(peer_id, sdp).await
166 } else {
167 Ok(()) }
169 }
170 SignalingMessage::Candidate {
171 peer_id,
172 target_peer_id,
173 candidate,
174 sdp_m_line_index,
175 sdp_mid,
176 } => {
177 if target_peer_id == &self.peer_id {
178 self.conn_factory
179 .handle_candidate(
180 peer_id,
181 crate::types::IceCandidate {
182 candidate: candidate.clone(),
183 sdp_m_line_index: *sdp_m_line_index,
184 sdp_mid: sdp_mid.clone(),
185 },
186 )
187 .await
188 } else {
189 Ok(())
190 }
191 }
192 SignalingMessage::Candidates {
193 peer_id,
194 target_peer_id,
195 candidates,
196 } => {
197 if target_peer_id == &self.peer_id {
198 self.conn_factory
199 .handle_candidates(peer_id, candidates.clone())
200 .await
201 } else {
202 Ok(())
203 }
204 }
205 }
206 }
207
208 async fn handle_hello(
213 &self,
214 from_peer_id: &str,
215 roots: &[String],
216 ) -> Result<(), TransportError> {
217 if from_peer_id == self.peer_id {
219 return Ok(());
220 }
221
222 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
223 .map(|peer_id| peer_id.pubkey)
224 .unwrap_or_else(|| from_peer_id.to_string());
225
226 let pool = self.classify_peer(&peer_pubkey).await;
228
229 let (follows_count, other_count) = self.count_pools().await;
231
232 if !self.can_accept_peer(pool, follows_count, other_count) {
233 if self.debug {
234 println!(
235 "[Signaling] Ignoring hello from {} - {:?} pool full",
236 from_peer_id, pool
237 );
238 }
239 return Ok(());
240 }
241
242 self.peer_roots
244 .write()
245 .await
246 .insert(from_peer_id.to_string(), roots.to_vec());
247
248 if self.pool_needs_peers(pool, follows_count, other_count) {
251 if self.peers.read().await.contains_key(from_peer_id) {
253 return Ok(());
254 }
255 if self.pending_offers.read().await.contains_key(from_peer_id) {
256 return Ok(());
257 }
258
259 if self.debug {
260 println!(
261 "[Signaling] Sending offer to {} (pool: {:?})",
262 from_peer_id, pool
263 );
264 }
265
266 self.pending_offers
268 .write()
269 .await
270 .insert(from_peer_id.to_string(), ());
271
272 let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
274
275 self.peers
277 .write()
278 .await
279 .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
280
281 let offer_msg = SignalingMessage::Offer {
283 peer_id: self.peer_id.clone(),
284 target_peer_id: from_peer_id.to_string(),
285 sdp,
286 };
287 self.transport.publish(offer_msg).await?;
288 }
289
290 Ok(())
291 }
292
293 async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
298 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
300 .map(|peer_id| peer_id.pubkey)
301 .unwrap_or_else(|| from_peer_id.to_string());
302
303 let pool = self.classify_peer(&peer_pubkey).await;
305 let (follows_count, other_count) = self.count_pools().await;
306
307 if !self.can_accept_peer(pool, follows_count, other_count) {
308 if self.debug {
309 println!(
310 "[Signaling] Ignoring offer from {} - {:?} pool full",
311 from_peer_id, pool
312 );
313 }
314 return Ok(());
315 }
316
317 let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
319 if have_pending {
320 let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
322
323 if we_are_polite {
324 self.pending_offers.write().await.remove(from_peer_id);
327 self.peers.write().await.remove(from_peer_id);
328
329 if self.debug {
330 println!(
331 "[Signaling] Collision with {} - we're polite, accepting their offer",
332 from_peer_id
333 );
334 }
335 } else {
336 if self.debug {
338 println!(
339 "[Signaling] Collision with {} - we're impolite, ignoring their offer",
340 from_peer_id
341 );
342 }
343 return Ok(());
344 }
345 }
346
347 if self.peers.read().await.contains_key(from_peer_id) {
349 return Ok(());
350 }
351
352 if self.debug {
353 println!("[Signaling] Accepting offer from {}", from_peer_id);
354 }
355
356 let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
358
359 self.peers
361 .write()
362 .await
363 .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
364
365 let answer_msg = SignalingMessage::Answer {
367 peer_id: self.peer_id.clone(),
368 target_peer_id: from_peer_id.to_string(),
369 sdp: answer_sdp,
370 };
371 self.transport.publish(answer_msg).await?;
372
373 Ok(())
374 }
375
376 async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
378 if self.debug {
379 println!("[Signaling] Received answer from {}", from_peer_id);
380 }
381
382 let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
384
385 Ok(())
389 }
390
391 pub async fn peer_count(&self) -> usize {
393 self.peers.read().await.len()
394 }
395
396 pub async fn peer_ids(&self) -> Vec<String> {
398 self.peers.read().await.keys().cloned().collect()
399 }
400
401 pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
403 self.peers
404 .read()
405 .await
406 .get(peer_id)
407 .map(|e| e.channel.clone())
408 }
409
410 pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
412 self.pending_offers.write().await.remove(peer_id);
413 self.peer_roots.write().await.remove(peer_id);
414 let _ = self.conn_factory.remove_peer(peer_id).await;
415 self.peers
416 .write()
417 .await
418 .remove(peer_id)
419 .map(|entry| entry.channel)
420 }
421
422 pub async fn needs_peers(&self) -> bool {
424 let (follows, other) = self.count_pools().await;
425 self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
426 }
427
428 pub async fn can_accept(&self) -> bool {
430 let (follows, other) = self.count_pools().await;
431 self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use async_trait::async_trait;
439 use std::sync::Arc;
440 use tokio::sync::Mutex;
441
442 use crate::types::{IceCandidate, PoolConfig, PoolSettings};
443
444 #[derive(Default)]
445 struct NoopTransport;
446
447 #[async_trait]
448 impl SignalingTransport for NoopTransport {
449 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
450 Ok(())
451 }
452
453 async fn disconnect(&self) {}
454
455 async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
456 Ok(())
457 }
458
459 async fn recv(&self) -> Option<SignalingMessage> {
460 None
461 }
462
463 fn try_recv(&self) -> Option<SignalingMessage> {
464 None
465 }
466
467 fn peer_id(&self) -> &str {
468 "local"
469 }
470 }
471
472 #[derive(Default)]
473 struct RecordingFactory {
474 candidates: Mutex<Vec<(String, IceCandidate)>>,
475 removed: Mutex<Vec<String>>,
476 }
477
478 #[async_trait]
479 impl PeerLinkFactory for RecordingFactory {
480 async fn create_offer(
481 &self,
482 _target_peer_id: &str,
483 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
484 Err(TransportError::ConnectionFailed(
485 "not used in this test".to_string(),
486 ))
487 }
488
489 async fn accept_offer(
490 &self,
491 _from_peer_id: &str,
492 _offer_sdp: &str,
493 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
494 Err(TransportError::ConnectionFailed(
495 "not used in this test".to_string(),
496 ))
497 }
498
499 async fn handle_answer(
500 &self,
501 _target_peer_id: &str,
502 _answer_sdp: &str,
503 ) -> Result<Arc<dyn PeerLink>, TransportError> {
504 Err(TransportError::ConnectionFailed(
505 "not used in this test".to_string(),
506 ))
507 }
508
509 async fn handle_candidate(
510 &self,
511 peer_id: &str,
512 candidate: IceCandidate,
513 ) -> Result<(), TransportError> {
514 self.candidates
515 .lock()
516 .await
517 .push((peer_id.to_string(), candidate));
518 Ok(())
519 }
520
521 async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
522 self.removed.lock().await.push(peer_id.to_string());
523 Ok(())
524 }
525 }
526
527 #[tokio::test]
528 async fn routes_targeted_candidates_to_factory() {
529 let router = MeshRouter::new(
530 "local".to_string(),
531 Arc::new(NoopTransport),
532 Arc::new(RecordingFactory::default()),
533 PoolSettings {
534 follows: PoolConfig::default(),
535 other: PoolConfig::default(),
536 },
537 false,
538 );
539
540 router
541 .handle_message(SignalingMessage::Candidate {
542 peer_id: "remote:peer".to_string(),
543 target_peer_id: "local".to_string(),
544 candidate: "candidate:1".to_string(),
545 sdp_m_line_index: Some(0),
546 sdp_mid: Some("data".to_string()),
547 })
548 .await
549 .expect("candidate should route");
550
551 let factory = router.conn_factory.clone();
552 let recorded = factory
553 .candidates
554 .lock()
555 .await
556 .iter()
557 .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
558 .collect::<Vec<_>>();
559
560 assert_eq!(
561 recorded,
562 vec![("remote:peer".to_string(), "candidate:1".to_string())]
563 );
564 }
565
566 #[tokio::test]
567 async fn remove_peer_cleans_factory_state() {
568 let factory = Arc::new(RecordingFactory::default());
569 let router = MeshRouter::new(
570 "local".to_string(),
571 Arc::new(NoopTransport),
572 factory.clone(),
573 PoolSettings {
574 follows: PoolConfig::default(),
575 other: PoolConfig::default(),
576 },
577 false,
578 );
579
580 let removed = router.remove_peer("remote:peer").await;
581 assert!(removed.is_none());
582 assert_eq!(
583 factory.removed.lock().await.as_slice(),
584 &["remote:peer".to_string()]
585 );
586 }
587}