Skip to main content

hashtree_network/
signaling.rs

1//! Shared signaling logic for peer discovery and connection management.
2//!
3//! This module contains the core signaling logic used by both the default
4//! production mesh transport stack and simulation. It handles:
5//! - Hello broadcasts and discovery
6//! - Pool management (follows vs other peers)
7//! - Tie-breaking for connection initiation
8//! - Offer/answer flow coordination
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17/// Peer entry with pool classification and channel
18pub struct PeerEntry {
19    pub channel: Arc<dyn PeerLink>,
20    pub pool: PeerPool,
21}
22
23/// Mesh router handles peer discovery and negotiated link establishment.
24///
25/// This is the shared routing logic between production transports and simulation.
26/// It uses traits for signaling transport and negotiated link factories so the
27/// same router can drive Nostr websockets, LAN buses, BLE, WebRTC, or mocks.
28///
29/// Uses the standard concurrent-offer "perfect negotiation" pattern:
30/// - Both peers can send offers when they discover each other
31/// - On collision (both sent offers), "polite" peer backs off and accepts incoming
32/// - This ensures connections form even when one peer is satisfied but can accept
33pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
34    /// Our peer ID (pubkey format)
35    peer_id: String,
36    /// Relay transport for signaling
37    transport: Arc<R>,
38    /// Link factory for creating negotiated peer links
39    conn_factory: Arc<F>,
40    /// Connected peers
41    peers: RwLock<HashMap<String, PeerEntry>>,
42    /// Pending outbound offers (we sent offer, waiting for answer)
43    pending_offers: RwLock<HashMap<String, ()>>,
44    /// Pool settings
45    pools: PoolSettings,
46    /// Known peer roots (for future use)
47    peer_roots: RwLock<HashMap<String, Vec<String>>>,
48    /// Classifier channel (optional)
49    classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
50    /// Debug mode
51    debug: bool,
52}
53
54impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
55    /// Create a new mesh router.
56    pub fn new(
57        peer_id: String,
58        transport: Arc<R>,
59        conn_factory: Arc<F>,
60        pools: PoolSettings,
61        debug: bool,
62    ) -> Self {
63        Self {
64            peer_id,
65            transport,
66            conn_factory,
67            peers: RwLock::new(HashMap::new()),
68            pending_offers: RwLock::new(HashMap::new()),
69            pools,
70            peer_roots: RwLock::new(HashMap::new()),
71            classifier_tx: None,
72            debug,
73        }
74    }
75
76    /// Set classifier for peer pool assignment
77    pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
78        self.classifier_tx = Some(tx);
79    }
80
81    /// Get our peer ID
82    pub fn peer_id(&self) -> &str {
83        &self.peer_id
84    }
85
86    /// Send hello broadcast
87    pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
88        let msg = SignalingMessage::Hello {
89            peer_id: self.peer_id.clone(),
90            roots,
91        };
92        self.transport.publish(msg).await
93    }
94
95    /// Count peers by pool
96    async fn count_pools(&self) -> (usize, usize) {
97        let peers = self.peers.read().await;
98        let mut follows = 0;
99        let mut other = 0;
100        for entry in peers.values() {
101            match entry.pool {
102                PeerPool::Follows => follows += 1,
103                PeerPool::Other => other += 1,
104            }
105        }
106        (follows, other)
107    }
108
109    /// Classify a peer by pubkey
110    async fn classify_peer(&self, pubkey: &str) -> PeerPool {
111        if let Some(ref tx) = self.classifier_tx {
112            let (response_tx, response_rx) = tokio::sync::oneshot::channel();
113            let request = ClassifyRequest {
114                pubkey: pubkey.to_string(),
115                response: response_tx,
116            };
117            if tx.send(request).await.is_ok() {
118                if let Ok(pool) = response_rx.await {
119                    return pool;
120                }
121            }
122        }
123        PeerPool::Other
124    }
125
126    /// Check if we can accept a peer in a given pool
127    fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
128        match pool {
129            PeerPool::Follows => self.pools.follows.can_accept(follows),
130            PeerPool::Other => self.pools.other.can_accept(other),
131        }
132    }
133
134    /// Check if a pool needs more peers
135    fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
136        match pool {
137            PeerPool::Follows => self.pools.follows.needs_peers(follows),
138            PeerPool::Other => self.pools.other.needs_peers(other),
139        }
140    }
141
142    /// Handle incoming signaling message
143    ///
144    /// This is the core signaling logic shared between production and simulation.
145    pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
146        match &msg {
147            SignalingMessage::Hello { peer_id, roots } => self.handle_hello(peer_id, roots).await,
148            SignalingMessage::Offer {
149                peer_id,
150                target_peer_id,
151                sdp,
152            } => {
153                if target_peer_id == &self.peer_id {
154                    self.handle_offer(peer_id, sdp).await
155                } else {
156                    Ok(()) // Not for us
157                }
158            }
159            SignalingMessage::Answer {
160                peer_id,
161                target_peer_id,
162                sdp,
163            } => {
164                if target_peer_id == &self.peer_id {
165                    self.handle_answer(peer_id, sdp).await
166                } else {
167                    Ok(()) // Not for us
168                }
169            }
170            SignalingMessage::Candidate {
171                peer_id,
172                target_peer_id,
173                candidate,
174                sdp_m_line_index,
175                sdp_mid,
176            } => {
177                if target_peer_id == &self.peer_id {
178                    self.conn_factory
179                        .handle_candidate(
180                            peer_id,
181                            crate::types::IceCandidate {
182                                candidate: candidate.clone(),
183                                sdp_m_line_index: *sdp_m_line_index,
184                                sdp_mid: sdp_mid.clone(),
185                            },
186                        )
187                        .await
188                } else {
189                    Ok(())
190                }
191            }
192            SignalingMessage::Candidates {
193                peer_id,
194                target_peer_id,
195                candidates,
196            } => {
197                if target_peer_id == &self.peer_id {
198                    self.conn_factory
199                        .handle_candidates(peer_id, candidates.clone())
200                        .await
201                } else {
202                    Ok(())
203                }
204            }
205        }
206    }
207
208    /// Handle hello message using the shared concurrent-offer negotiation flow.
209    ///
210    /// With perfect negotiation, we send an offer if we need peers.
211    /// No tie-breaking here - collisions are handled in handle_offer.
212    async fn handle_hello(
213        &self,
214        from_peer_id: &str,
215        roots: &[String],
216    ) -> Result<(), TransportError> {
217        // Ignore our own hello
218        if from_peer_id == self.peer_id {
219            return Ok(());
220        }
221
222        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
223            .map(|peer_id| peer_id.pubkey)
224            .unwrap_or_else(|| from_peer_id.to_string());
225
226        // Classify the peer
227        let pool = self.classify_peer(&peer_pubkey).await;
228
229        // Check pool limits
230        let (follows_count, other_count) = self.count_pools().await;
231
232        if !self.can_accept_peer(pool, follows_count, other_count) {
233            if self.debug {
234                println!(
235                    "[Signaling] Ignoring hello from {} - {:?} pool full",
236                    from_peer_id, pool
237                );
238            }
239            return Ok(());
240        }
241
242        // Store peer roots
243        self.peer_roots
244            .write()
245            .await
246            .insert(from_peer_id.to_string(), roots.to_vec());
247
248        // Shared perfect negotiation: send offer if we NEED more peers
249        // Both sides may send offers - collision handled in handle_offer
250        if self.pool_needs_peers(pool, follows_count, other_count) {
251            // Check if already connected or pending
252            if self.peers.read().await.contains_key(from_peer_id) {
253                return Ok(());
254            }
255            if self.pending_offers.read().await.contains_key(from_peer_id) {
256                return Ok(());
257            }
258
259            if self.debug {
260                println!(
261                    "[Signaling] Sending offer to {} (pool: {:?})",
262                    from_peer_id, pool
263                );
264            }
265
266            // Mark as pending before creating offer
267            self.pending_offers
268                .write()
269                .await
270                .insert(from_peer_id.to_string(), ());
271
272            // Create offer
273            let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
274
275            // Add peer (will be confirmed when we get answer)
276            self.peers
277                .write()
278                .await
279                .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
280
281            // Send offer
282            let offer_msg = SignalingMessage::Offer {
283                peer_id: self.peer_id.clone(),
284                target_peer_id: from_peer_id.to_string(),
285                sdp,
286            };
287            self.transport.publish(offer_msg).await?;
288        }
289
290        Ok(())
291    }
292
293    /// Handle offer message in the shared concurrent-offer negotiation flow.
294    ///
295    /// Handles offer collision: if we also sent an offer to this peer,
296    /// the "polite" peer (lower ID) backs off and accepts the incoming offer.
297    async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
298        // Extract pubkey
299        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
300            .map(|peer_id| peer_id.pubkey)
301            .unwrap_or_else(|| from_peer_id.to_string());
302
303        // Classify and check limits
304        let pool = self.classify_peer(&peer_pubkey).await;
305        let (follows_count, other_count) = self.count_pools().await;
306
307        if !self.can_accept_peer(pool, follows_count, other_count) {
308            if self.debug {
309                println!(
310                    "[Signaling] Ignoring offer from {} - {:?} pool full",
311                    from_peer_id, pool
312                );
313            }
314            return Ok(());
315        }
316
317        // Check for offer collision (we also sent an offer to them)
318        let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
319        if have_pending {
320            // Collision! Use polite/impolite pattern
321            let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
322
323            if we_are_polite {
324                // We're polite - back off, accept their offer
325                // Remove our pending offer and peer entry
326                self.pending_offers.write().await.remove(from_peer_id);
327                self.peers.write().await.remove(from_peer_id);
328
329                if self.debug {
330                    println!(
331                        "[Signaling] Collision with {} - we're polite, accepting their offer",
332                        from_peer_id
333                    );
334                }
335            } else {
336                // We're impolite - ignore their offer, wait for answer to ours
337                if self.debug {
338                    println!(
339                        "[Signaling] Collision with {} - we're impolite, ignoring their offer",
340                        from_peer_id
341                    );
342                }
343                return Ok(());
344            }
345        }
346
347        // Check if already connected (no collision case)
348        if self.peers.read().await.contains_key(from_peer_id) {
349            return Ok(());
350        }
351
352        if self.debug {
353            println!("[Signaling] Accepting offer from {}", from_peer_id);
354        }
355
356        // Accept offer
357        let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
358
359        // Add peer
360        self.peers
361            .write()
362            .await
363            .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
364
365        // Send answer
366        let answer_msg = SignalingMessage::Answer {
367            peer_id: self.peer_id.clone(),
368            target_peer_id: from_peer_id.to_string(),
369            sdp: answer_sdp,
370        };
371        self.transport.publish(answer_msg).await?;
372
373        Ok(())
374    }
375
376    /// Handle answer message
377    async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
378        if self.debug {
379            println!("[Signaling] Received answer from {}", from_peer_id);
380        }
381
382        // Complete connection
383        let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
384
385        // Peer should already be in our map from when we sent the offer
386        // The channel returned here is the same one we stored
387
388        Ok(())
389    }
390
391    /// Get connected peer count
392    pub async fn peer_count(&self) -> usize {
393        self.peers.read().await.len()
394    }
395
396    /// Get peer IDs
397    pub async fn peer_ids(&self) -> Vec<String> {
398        self.peers.read().await.keys().cloned().collect()
399    }
400
401    /// Get a peer's channel
402    pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
403        self.peers
404            .read()
405            .await
406            .get(peer_id)
407            .map(|e| e.channel.clone())
408    }
409
410    /// Remove a peer and any pending offer state.
411    pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
412        self.pending_offers.write().await.remove(peer_id);
413        self.peer_roots.write().await.remove(peer_id);
414        let _ = self.conn_factory.remove_peer(peer_id).await;
415        self.peers
416            .write()
417            .await
418            .remove(peer_id)
419            .map(|entry| entry.channel)
420    }
421
422    /// Check if we need more peers (below satisfied in any pool)
423    pub async fn needs_peers(&self) -> bool {
424        let (follows, other) = self.count_pools().await;
425        self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
426    }
427
428    /// Check if we can accept more peers (below max in any pool)
429    pub async fn can_accept(&self) -> bool {
430        let (follows, other) = self.count_pools().await;
431        self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use async_trait::async_trait;
439    use std::sync::Arc;
440    use tokio::sync::Mutex;
441
442    use crate::types::{IceCandidate, PoolConfig, PoolSettings};
443
444    #[derive(Default)]
445    struct NoopTransport;
446
447    #[async_trait]
448    impl SignalingTransport for NoopTransport {
449        async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
450            Ok(())
451        }
452
453        async fn disconnect(&self) {}
454
455        async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
456            Ok(())
457        }
458
459        async fn recv(&self) -> Option<SignalingMessage> {
460            None
461        }
462
463        fn try_recv(&self) -> Option<SignalingMessage> {
464            None
465        }
466
467        fn peer_id(&self) -> &str {
468            "local"
469        }
470    }
471
472    #[derive(Default)]
473    struct RecordingFactory {
474        candidates: Mutex<Vec<(String, IceCandidate)>>,
475        removed: Mutex<Vec<String>>,
476    }
477
478    #[async_trait]
479    impl PeerLinkFactory for RecordingFactory {
480        async fn create_offer(
481            &self,
482            _target_peer_id: &str,
483        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
484            Err(TransportError::ConnectionFailed(
485                "not used in this test".to_string(),
486            ))
487        }
488
489        async fn accept_offer(
490            &self,
491            _from_peer_id: &str,
492            _offer_sdp: &str,
493        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
494            Err(TransportError::ConnectionFailed(
495                "not used in this test".to_string(),
496            ))
497        }
498
499        async fn handle_answer(
500            &self,
501            _target_peer_id: &str,
502            _answer_sdp: &str,
503        ) -> Result<Arc<dyn PeerLink>, TransportError> {
504            Err(TransportError::ConnectionFailed(
505                "not used in this test".to_string(),
506            ))
507        }
508
509        async fn handle_candidate(
510            &self,
511            peer_id: &str,
512            candidate: IceCandidate,
513        ) -> Result<(), TransportError> {
514            self.candidates
515                .lock()
516                .await
517                .push((peer_id.to_string(), candidate));
518            Ok(())
519        }
520
521        async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
522            self.removed.lock().await.push(peer_id.to_string());
523            Ok(())
524        }
525    }
526
527    #[tokio::test]
528    async fn routes_targeted_candidates_to_factory() {
529        let router = MeshRouter::new(
530            "local".to_string(),
531            Arc::new(NoopTransport),
532            Arc::new(RecordingFactory::default()),
533            PoolSettings {
534                follows: PoolConfig::default(),
535                other: PoolConfig::default(),
536            },
537            false,
538        );
539
540        router
541            .handle_message(SignalingMessage::Candidate {
542                peer_id: "remote:peer".to_string(),
543                target_peer_id: "local".to_string(),
544                candidate: "candidate:1".to_string(),
545                sdp_m_line_index: Some(0),
546                sdp_mid: Some("data".to_string()),
547            })
548            .await
549            .expect("candidate should route");
550
551        let factory = router.conn_factory.clone();
552        let recorded = factory
553            .candidates
554            .lock()
555            .await
556            .iter()
557            .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
558            .collect::<Vec<_>>();
559
560        assert_eq!(
561            recorded,
562            vec![("remote:peer".to_string(), "candidate:1".to_string())]
563        );
564    }
565
566    #[tokio::test]
567    async fn remove_peer_cleans_factory_state() {
568        let factory = Arc::new(RecordingFactory::default());
569        let router = MeshRouter::new(
570            "local".to_string(),
571            Arc::new(NoopTransport),
572            factory.clone(),
573            PoolSettings {
574                follows: PoolConfig::default(),
575                other: PoolConfig::default(),
576            },
577            false,
578        );
579
580        let removed = router.remove_peer("remote:peer").await;
581        assert!(removed.is_none());
582        assert_eq!(
583            factory.removed.lock().await.as_slice(),
584            &["remote:peer".to_string()]
585        );
586    }
587}