1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17pub struct PeerEntry {
19 pub channel: Arc<dyn PeerLink>,
20 pub pool: PeerPool,
21}
22
23pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
34 peer_id: String,
36 transport: Arc<R>,
38 conn_factory: Arc<F>,
40 peers: RwLock<HashMap<String, PeerEntry>>,
42 pending_offers: RwLock<HashMap<String, ()>>,
44 pools: PoolSettings,
46 peer_roots: RwLock<HashMap<String, Vec<String>>>,
48 classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
50 debug: bool,
52 hash_get_enabled: bool,
54}
55
56impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
57 pub fn new(
59 peer_id: String,
60 transport: Arc<R>,
61 conn_factory: Arc<F>,
62 pools: PoolSettings,
63 debug: bool,
64 ) -> Self {
65 Self {
66 peer_id,
67 transport,
68 conn_factory,
69 peers: RwLock::new(HashMap::new()),
70 pending_offers: RwLock::new(HashMap::new()),
71 pools,
72 peer_roots: RwLock::new(HashMap::new()),
73 classifier_tx: None,
74 debug,
75 hash_get_enabled: true,
76 }
77 }
78
79 pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
81 self.classifier_tx = Some(tx);
82 }
83
84 pub fn set_hash_get_enabled(&mut self, enabled: bool) {
85 self.hash_get_enabled = enabled;
86 }
87
88 pub fn peer_id(&self) -> &str {
90 &self.peer_id
91 }
92
93 pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
95 let msg = SignalingMessage::Hello {
96 peer_id: self.peer_id.clone(),
97 roots,
98 hash_get: self.hash_get_enabled,
99 };
100 self.transport.publish(msg).await
101 }
102
103 async fn count_pools(&self) -> (usize, usize) {
105 let peers = self.peers.read().await;
106 let mut follows = 0;
107 let mut other = 0;
108 for entry in peers.values() {
109 match entry.pool {
110 PeerPool::Follows => follows += 1,
111 PeerPool::Other => other += 1,
112 }
113 }
114 (follows, other)
115 }
116
117 async fn classify_peer(&self, pubkey: &str) -> PeerPool {
119 if let Some(ref tx) = self.classifier_tx {
120 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
121 let request = ClassifyRequest {
122 pubkey: pubkey.to_string(),
123 response: response_tx,
124 };
125 if tx.send(request).await.is_ok() {
126 if let Ok(pool) = response_rx.await {
127 return pool;
128 }
129 }
130 }
131 PeerPool::Other
132 }
133
134 fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
136 match pool {
137 PeerPool::Follows => self.pools.follows.can_accept(follows),
138 PeerPool::Other => self.pools.other.can_accept(other),
139 }
140 }
141
142 fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
144 match pool {
145 PeerPool::Follows => self.pools.follows.needs_peers(follows),
146 PeerPool::Other => self.pools.other.needs_peers(other),
147 }
148 }
149
150 pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
154 match &msg {
155 SignalingMessage::Hello { peer_id, roots, .. } => {
156 self.handle_hello(peer_id, roots).await
157 }
158 SignalingMessage::Offer {
159 peer_id,
160 target_peer_id,
161 sdp,
162 } => {
163 if target_peer_id == &self.peer_id {
164 self.handle_offer(peer_id, sdp).await
165 } else {
166 Ok(()) }
168 }
169 SignalingMessage::Answer {
170 peer_id,
171 target_peer_id,
172 sdp,
173 } => {
174 if target_peer_id == &self.peer_id {
175 self.handle_answer(peer_id, sdp).await
176 } else {
177 Ok(()) }
179 }
180 SignalingMessage::Candidate {
181 peer_id,
182 target_peer_id,
183 candidate,
184 sdp_m_line_index,
185 sdp_mid,
186 } => {
187 if target_peer_id == &self.peer_id {
188 self.conn_factory
189 .handle_candidate(
190 peer_id,
191 crate::types::IceCandidate {
192 candidate: candidate.clone(),
193 sdp_m_line_index: *sdp_m_line_index,
194 sdp_mid: sdp_mid.clone(),
195 },
196 )
197 .await
198 } else {
199 Ok(())
200 }
201 }
202 SignalingMessage::Candidates {
203 peer_id,
204 target_peer_id,
205 candidates,
206 } => {
207 if target_peer_id == &self.peer_id {
208 self.conn_factory
209 .handle_candidates(peer_id, candidates.clone())
210 .await
211 } else {
212 Ok(())
213 }
214 }
215 }
216 }
217
218 async fn handle_hello(
223 &self,
224 from_peer_id: &str,
225 roots: &[String],
226 ) -> Result<(), TransportError> {
227 if from_peer_id == self.peer_id {
229 return Ok(());
230 }
231
232 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
233 .map(|peer_id| peer_id.pubkey)
234 .unwrap_or_else(|| from_peer_id.to_string());
235
236 let pool = self.classify_peer(&peer_pubkey).await;
238
239 let (follows_count, other_count) = self.count_pools().await;
241
242 if !self.can_accept_peer(pool, follows_count, other_count) {
243 if self.debug {
244 println!(
245 "[Signaling] Ignoring hello from {} - {:?} pool full",
246 from_peer_id, pool
247 );
248 }
249 return Ok(());
250 }
251
252 self.peer_roots
254 .write()
255 .await
256 .insert(from_peer_id.to_string(), roots.to_vec());
257
258 if self.pool_needs_peers(pool, follows_count, other_count) {
261 if self.peers.read().await.contains_key(from_peer_id) {
263 return Ok(());
264 }
265 if self.pending_offers.read().await.contains_key(from_peer_id) {
266 return Ok(());
267 }
268
269 if self.debug {
270 println!(
271 "[Signaling] Sending offer to {} (pool: {:?})",
272 from_peer_id, pool
273 );
274 }
275
276 self.pending_offers
278 .write()
279 .await
280 .insert(from_peer_id.to_string(), ());
281
282 let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
284
285 self.peers
287 .write()
288 .await
289 .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
290
291 let offer_msg = SignalingMessage::Offer {
293 peer_id: self.peer_id.clone(),
294 target_peer_id: from_peer_id.to_string(),
295 sdp,
296 };
297 self.transport.publish(offer_msg).await?;
298 }
299
300 Ok(())
301 }
302
303 async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
308 let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
310 .map(|peer_id| peer_id.pubkey)
311 .unwrap_or_else(|| from_peer_id.to_string());
312
313 let pool = self.classify_peer(&peer_pubkey).await;
315 let (follows_count, other_count) = self.count_pools().await;
316
317 if !self.can_accept_peer(pool, follows_count, other_count) {
318 if self.debug {
319 println!(
320 "[Signaling] Ignoring offer from {} - {:?} pool full",
321 from_peer_id, pool
322 );
323 }
324 return Ok(());
325 }
326
327 let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
329 if have_pending {
330 let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
332
333 if we_are_polite {
334 self.pending_offers.write().await.remove(from_peer_id);
337 self.peers.write().await.remove(from_peer_id);
338
339 if self.debug {
340 println!(
341 "[Signaling] Collision with {} - we're polite, accepting their offer",
342 from_peer_id
343 );
344 }
345 } else {
346 if self.debug {
348 println!(
349 "[Signaling] Collision with {} - we're impolite, ignoring their offer",
350 from_peer_id
351 );
352 }
353 return Ok(());
354 }
355 }
356
357 if self.peers.read().await.contains_key(from_peer_id) {
359 return Ok(());
360 }
361
362 if self.debug {
363 println!("[Signaling] Accepting offer from {}", from_peer_id);
364 }
365
366 let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
368
369 self.peers
371 .write()
372 .await
373 .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
374
375 let answer_msg = SignalingMessage::Answer {
377 peer_id: self.peer_id.clone(),
378 target_peer_id: from_peer_id.to_string(),
379 sdp: answer_sdp,
380 };
381 self.transport.publish(answer_msg).await?;
382
383 Ok(())
384 }
385
386 async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
388 if self.debug {
389 println!("[Signaling] Received answer from {}", from_peer_id);
390 }
391
392 let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
394
395 Ok(())
399 }
400
401 pub async fn peer_count(&self) -> usize {
403 self.peers.read().await.len()
404 }
405
406 pub async fn peer_ids(&self) -> Vec<String> {
408 self.peers.read().await.keys().cloned().collect()
409 }
410
411 pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
413 self.peers
414 .read()
415 .await
416 .get(peer_id)
417 .map(|e| e.channel.clone())
418 }
419
420 pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
422 self.pending_offers.write().await.remove(peer_id);
423 self.peer_roots.write().await.remove(peer_id);
424 let _ = self.conn_factory.remove_peer(peer_id).await;
425 self.peers
426 .write()
427 .await
428 .remove(peer_id)
429 .map(|entry| entry.channel)
430 }
431
432 pub async fn needs_peers(&self) -> bool {
434 let (follows, other) = self.count_pools().await;
435 self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
436 }
437
438 pub async fn can_accept(&self) -> bool {
440 let (follows, other) = self.count_pools().await;
441 self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use async_trait::async_trait;
449 use std::sync::Arc;
450 use tokio::sync::Mutex;
451
452 use crate::types::{IceCandidate, PoolConfig, PoolSettings};
453
454 #[derive(Default)]
455 struct NoopTransport;
456
457 #[async_trait]
458 impl SignalingTransport for NoopTransport {
459 async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
460 Ok(())
461 }
462
463 async fn disconnect(&self) {}
464
465 async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
466 Ok(())
467 }
468
469 async fn recv(&self) -> Option<SignalingMessage> {
470 None
471 }
472
473 fn try_recv(&self) -> Option<SignalingMessage> {
474 None
475 }
476
477 fn peer_id(&self) -> &str {
478 "local"
479 }
480 }
481
482 #[derive(Default)]
483 struct RecordingFactory {
484 candidates: Mutex<Vec<(String, IceCandidate)>>,
485 removed: Mutex<Vec<String>>,
486 }
487
488 #[async_trait]
489 impl PeerLinkFactory for RecordingFactory {
490 async fn create_offer(
491 &self,
492 _target_peer_id: &str,
493 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
494 Err(TransportError::ConnectionFailed(
495 "not used in this test".to_string(),
496 ))
497 }
498
499 async fn accept_offer(
500 &self,
501 _from_peer_id: &str,
502 _offer_sdp: &str,
503 ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
504 Err(TransportError::ConnectionFailed(
505 "not used in this test".to_string(),
506 ))
507 }
508
509 async fn handle_answer(
510 &self,
511 _target_peer_id: &str,
512 _answer_sdp: &str,
513 ) -> Result<Arc<dyn PeerLink>, TransportError> {
514 Err(TransportError::ConnectionFailed(
515 "not used in this test".to_string(),
516 ))
517 }
518
519 async fn handle_candidate(
520 &self,
521 peer_id: &str,
522 candidate: IceCandidate,
523 ) -> Result<(), TransportError> {
524 self.candidates
525 .lock()
526 .await
527 .push((peer_id.to_string(), candidate));
528 Ok(())
529 }
530
531 async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
532 self.removed.lock().await.push(peer_id.to_string());
533 Ok(())
534 }
535 }
536
537 #[tokio::test]
538 async fn routes_targeted_candidates_to_factory() {
539 let router = MeshRouter::new(
540 "local".to_string(),
541 Arc::new(NoopTransport),
542 Arc::new(RecordingFactory::default()),
543 PoolSettings {
544 follows: PoolConfig::default(),
545 other: PoolConfig::default(),
546 },
547 false,
548 );
549
550 router
551 .handle_message(SignalingMessage::Candidate {
552 peer_id: "remote:peer".to_string(),
553 target_peer_id: "local".to_string(),
554 candidate: "candidate:1".to_string(),
555 sdp_m_line_index: Some(0),
556 sdp_mid: Some("data".to_string()),
557 })
558 .await
559 .expect("candidate should route");
560
561 let factory = router.conn_factory.clone();
562 let recorded = factory
563 .candidates
564 .lock()
565 .await
566 .iter()
567 .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
568 .collect::<Vec<_>>();
569
570 assert_eq!(
571 recorded,
572 vec![("remote:peer".to_string(), "candidate:1".to_string())]
573 );
574 }
575
576 #[tokio::test]
577 async fn remove_peer_cleans_factory_state() {
578 let factory = Arc::new(RecordingFactory::default());
579 let router = MeshRouter::new(
580 "local".to_string(),
581 Arc::new(NoopTransport),
582 factory.clone(),
583 PoolSettings {
584 follows: PoolConfig::default(),
585 other: PoolConfig::default(),
586 },
587 false,
588 );
589
590 let removed = router.remove_peer("remote:peer").await;
591 assert!(removed.is_none());
592 assert_eq!(
593 factory.removed.lock().await.as_slice(),
594 &["remote:peer".to_string()]
595 );
596 }
597}