Skip to main content

hashtree_network/
signaling.rs

1//! Shared signaling logic for peer discovery and connection management.
2//!
3//! This module contains the core signaling logic used by both the default
4//! production mesh transport stack and simulation. It handles:
5//! - Hello broadcasts and discovery
6//! - Pool management (follows vs other peers)
7//! - Tie-breaking for connection initiation
8//! - Offer/answer flow coordination
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17/// Peer entry with pool classification and channel
18pub struct PeerEntry {
19    pub channel: Arc<dyn PeerLink>,
20    pub pool: PeerPool,
21}
22
23/// Mesh router handles peer discovery and negotiated link establishment.
24///
25/// This is the shared routing logic between production transports and simulation.
26/// It uses traits for signaling transport and negotiated link factories so the
27/// same router can drive Nostr websockets, LAN buses, BLE, WebRTC, or mocks.
28///
29/// Uses the standard concurrent-offer "perfect negotiation" pattern:
30/// - Both peers can send offers when they discover each other
31/// - On collision (both sent offers), "polite" peer backs off and accepts incoming
32/// - This ensures connections form even when one peer is satisfied but can accept
33pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
34    /// Our peer ID (pubkey format)
35    peer_id: String,
36    /// Relay transport for signaling
37    transport: Arc<R>,
38    /// Link factory for creating negotiated peer links
39    conn_factory: Arc<F>,
40    /// Connected peers
41    peers: RwLock<HashMap<String, PeerEntry>>,
42    /// Pending outbound offers (we sent offer, waiting for answer)
43    pending_offers: RwLock<HashMap<String, ()>>,
44    /// Pool settings
45    pools: PoolSettings,
46    /// Known peer roots (for future use)
47    peer_roots: RwLock<HashMap<String, Vec<String>>>,
48    /// Classifier channel (optional)
49    classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
50    /// Debug mode
51    debug: bool,
52    /// Whether local node accepts `hash_get` lookups.
53    hash_get_enabled: bool,
54}
55
56impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
57    /// Create a new mesh router.
58    pub fn new(
59        peer_id: String,
60        transport: Arc<R>,
61        conn_factory: Arc<F>,
62        pools: PoolSettings,
63        debug: bool,
64    ) -> Self {
65        Self {
66            peer_id,
67            transport,
68            conn_factory,
69            peers: RwLock::new(HashMap::new()),
70            pending_offers: RwLock::new(HashMap::new()),
71            pools,
72            peer_roots: RwLock::new(HashMap::new()),
73            classifier_tx: None,
74            debug,
75            hash_get_enabled: true,
76        }
77    }
78
79    /// Set classifier for peer pool assignment
80    pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
81        self.classifier_tx = Some(tx);
82    }
83
84    pub fn set_hash_get_enabled(&mut self, enabled: bool) {
85        self.hash_get_enabled = enabled;
86    }
87
88    /// Get our peer ID
89    pub fn peer_id(&self) -> &str {
90        &self.peer_id
91    }
92
93    /// Send hello broadcast
94    pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
95        let msg = SignalingMessage::Hello {
96            peer_id: self.peer_id.clone(),
97            roots,
98            hash_get: self.hash_get_enabled,
99        };
100        self.transport.publish(msg).await
101    }
102
103    /// Count peers by pool
104    async fn count_pools(&self) -> (usize, usize) {
105        let peers = self.peers.read().await;
106        let mut follows = 0;
107        let mut other = 0;
108        for entry in peers.values() {
109            match entry.pool {
110                PeerPool::Follows => follows += 1,
111                PeerPool::Other => other += 1,
112            }
113        }
114        (follows, other)
115    }
116
117    /// Classify a peer by pubkey
118    async fn classify_peer(&self, pubkey: &str) -> PeerPool {
119        if let Some(ref tx) = self.classifier_tx {
120            let (response_tx, response_rx) = tokio::sync::oneshot::channel();
121            let request = ClassifyRequest {
122                pubkey: pubkey.to_string(),
123                response: response_tx,
124            };
125            if tx.send(request).await.is_ok() {
126                if let Ok(pool) = response_rx.await {
127                    return pool;
128                }
129            }
130        }
131        PeerPool::Other
132    }
133
134    /// Check if we can accept a peer in a given pool
135    fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
136        match pool {
137            PeerPool::Follows => self.pools.follows.can_accept(follows),
138            PeerPool::Other => self.pools.other.can_accept(other),
139        }
140    }
141
142    /// Check if a pool needs more peers
143    fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
144        match pool {
145            PeerPool::Follows => self.pools.follows.needs_peers(follows),
146            PeerPool::Other => self.pools.other.needs_peers(other),
147        }
148    }
149
150    /// Handle incoming signaling message
151    ///
152    /// This is the core signaling logic shared between production and simulation.
153    pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
154        match &msg {
155            SignalingMessage::Hello { peer_id, roots, .. } => {
156                self.handle_hello(peer_id, roots).await
157            }
158            SignalingMessage::Offer {
159                peer_id,
160                target_peer_id,
161                sdp,
162            } => {
163                if target_peer_id == &self.peer_id {
164                    self.handle_offer(peer_id, sdp).await
165                } else {
166                    Ok(()) // Not for us
167                }
168            }
169            SignalingMessage::Answer {
170                peer_id,
171                target_peer_id,
172                sdp,
173            } => {
174                if target_peer_id == &self.peer_id {
175                    self.handle_answer(peer_id, sdp).await
176                } else {
177                    Ok(()) // Not for us
178                }
179            }
180            SignalingMessage::Candidate {
181                peer_id,
182                target_peer_id,
183                candidate,
184                sdp_m_line_index,
185                sdp_mid,
186            } => {
187                if target_peer_id == &self.peer_id {
188                    self.conn_factory
189                        .handle_candidate(
190                            peer_id,
191                            crate::types::IceCandidate {
192                                candidate: candidate.clone(),
193                                sdp_m_line_index: *sdp_m_line_index,
194                                sdp_mid: sdp_mid.clone(),
195                            },
196                        )
197                        .await
198                } else {
199                    Ok(())
200                }
201            }
202            SignalingMessage::Candidates {
203                peer_id,
204                target_peer_id,
205                candidates,
206            } => {
207                if target_peer_id == &self.peer_id {
208                    self.conn_factory
209                        .handle_candidates(peer_id, candidates.clone())
210                        .await
211                } else {
212                    Ok(())
213                }
214            }
215        }
216    }
217
218    /// Handle hello message using the shared concurrent-offer negotiation flow.
219    ///
220    /// With perfect negotiation, we send an offer if we need peers.
221    /// No tie-breaking here - collisions are handled in handle_offer.
222    async fn handle_hello(
223        &self,
224        from_peer_id: &str,
225        roots: &[String],
226    ) -> Result<(), TransportError> {
227        // Ignore our own hello
228        if from_peer_id == self.peer_id {
229            return Ok(());
230        }
231
232        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
233            .map(|peer_id| peer_id.pubkey)
234            .unwrap_or_else(|| from_peer_id.to_string());
235
236        // Classify the peer
237        let pool = self.classify_peer(&peer_pubkey).await;
238
239        // Check pool limits
240        let (follows_count, other_count) = self.count_pools().await;
241
242        if !self.can_accept_peer(pool, follows_count, other_count) {
243            if self.debug {
244                println!(
245                    "[Signaling] Ignoring hello from {} - {:?} pool full",
246                    from_peer_id, pool
247                );
248            }
249            return Ok(());
250        }
251
252        // Store peer roots
253        self.peer_roots
254            .write()
255            .await
256            .insert(from_peer_id.to_string(), roots.to_vec());
257
258        // Shared perfect negotiation: send offer if we NEED more peers
259        // Both sides may send offers - collision handled in handle_offer
260        if self.pool_needs_peers(pool, follows_count, other_count) {
261            // Check if already connected or pending
262            if self.peers.read().await.contains_key(from_peer_id) {
263                return Ok(());
264            }
265            if self.pending_offers.read().await.contains_key(from_peer_id) {
266                return Ok(());
267            }
268
269            if self.debug {
270                println!(
271                    "[Signaling] Sending offer to {} (pool: {:?})",
272                    from_peer_id, pool
273                );
274            }
275
276            // Mark as pending before creating offer
277            self.pending_offers
278                .write()
279                .await
280                .insert(from_peer_id.to_string(), ());
281
282            // Create offer
283            let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
284
285            // Add peer (will be confirmed when we get answer)
286            self.peers
287                .write()
288                .await
289                .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
290
291            // Send offer
292            let offer_msg = SignalingMessage::Offer {
293                peer_id: self.peer_id.clone(),
294                target_peer_id: from_peer_id.to_string(),
295                sdp,
296            };
297            self.transport.publish(offer_msg).await?;
298        }
299
300        Ok(())
301    }
302
303    /// Handle offer message in the shared concurrent-offer negotiation flow.
304    ///
305    /// Handles offer collision: if we also sent an offer to this peer,
306    /// the "polite" peer (lower ID) backs off and accepts the incoming offer.
307    async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
308        // Extract pubkey
309        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
310            .map(|peer_id| peer_id.pubkey)
311            .unwrap_or_else(|| from_peer_id.to_string());
312
313        // Classify and check limits
314        let pool = self.classify_peer(&peer_pubkey).await;
315        let (follows_count, other_count) = self.count_pools().await;
316
317        if !self.can_accept_peer(pool, follows_count, other_count) {
318            if self.debug {
319                println!(
320                    "[Signaling] Ignoring offer from {} - {:?} pool full",
321                    from_peer_id, pool
322                );
323            }
324            return Ok(());
325        }
326
327        // Check for offer collision (we also sent an offer to them)
328        let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
329        if have_pending {
330            // Collision! Use polite/impolite pattern
331            let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
332
333            if we_are_polite {
334                // We're polite - back off, accept their offer
335                // Remove our pending offer and peer entry
336                self.pending_offers.write().await.remove(from_peer_id);
337                self.peers.write().await.remove(from_peer_id);
338
339                if self.debug {
340                    println!(
341                        "[Signaling] Collision with {} - we're polite, accepting their offer",
342                        from_peer_id
343                    );
344                }
345            } else {
346                // We're impolite - ignore their offer, wait for answer to ours
347                if self.debug {
348                    println!(
349                        "[Signaling] Collision with {} - we're impolite, ignoring their offer",
350                        from_peer_id
351                    );
352                }
353                return Ok(());
354            }
355        }
356
357        // Check if already connected (no collision case)
358        if self.peers.read().await.contains_key(from_peer_id) {
359            return Ok(());
360        }
361
362        if self.debug {
363            println!("[Signaling] Accepting offer from {}", from_peer_id);
364        }
365
366        // Accept offer
367        let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
368
369        // Add peer
370        self.peers
371            .write()
372            .await
373            .insert(from_peer_id.to_string(), PeerEntry { channel, pool });
374
375        // Send answer
376        let answer_msg = SignalingMessage::Answer {
377            peer_id: self.peer_id.clone(),
378            target_peer_id: from_peer_id.to_string(),
379            sdp: answer_sdp,
380        };
381        self.transport.publish(answer_msg).await?;
382
383        Ok(())
384    }
385
386    /// Handle answer message
387    async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
388        if self.debug {
389            println!("[Signaling] Received answer from {}", from_peer_id);
390        }
391
392        // Complete connection
393        let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
394
395        // Peer should already be in our map from when we sent the offer
396        // The channel returned here is the same one we stored
397
398        Ok(())
399    }
400
401    /// Get connected peer count
402    pub async fn peer_count(&self) -> usize {
403        self.peers.read().await.len()
404    }
405
406    /// Get peer IDs
407    pub async fn peer_ids(&self) -> Vec<String> {
408        self.peers.read().await.keys().cloned().collect()
409    }
410
411    /// Get a peer's channel
412    pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
413        self.peers
414            .read()
415            .await
416            .get(peer_id)
417            .map(|e| e.channel.clone())
418    }
419
420    /// Remove a peer and any pending offer state.
421    pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
422        self.pending_offers.write().await.remove(peer_id);
423        self.peer_roots.write().await.remove(peer_id);
424        let _ = self.conn_factory.remove_peer(peer_id).await;
425        self.peers
426            .write()
427            .await
428            .remove(peer_id)
429            .map(|entry| entry.channel)
430    }
431
432    /// Check if we need more peers (below satisfied in any pool)
433    pub async fn needs_peers(&self) -> bool {
434        let (follows, other) = self.count_pools().await;
435        self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
436    }
437
438    /// Check if we can accept more peers (below max in any pool)
439    pub async fn can_accept(&self) -> bool {
440        let (follows, other) = self.count_pools().await;
441        self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use async_trait::async_trait;
449    use std::sync::Arc;
450    use tokio::sync::Mutex;
451
452    use crate::types::{IceCandidate, PoolConfig, PoolSettings};
453
454    #[derive(Default)]
455    struct NoopTransport;
456
457    #[async_trait]
458    impl SignalingTransport for NoopTransport {
459        async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
460            Ok(())
461        }
462
463        async fn disconnect(&self) {}
464
465        async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
466            Ok(())
467        }
468
469        async fn recv(&self) -> Option<SignalingMessage> {
470            None
471        }
472
473        fn try_recv(&self) -> Option<SignalingMessage> {
474            None
475        }
476
477        fn peer_id(&self) -> &str {
478            "local"
479        }
480    }
481
482    #[derive(Default)]
483    struct RecordingFactory {
484        candidates: Mutex<Vec<(String, IceCandidate)>>,
485        removed: Mutex<Vec<String>>,
486    }
487
488    #[async_trait]
489    impl PeerLinkFactory for RecordingFactory {
490        async fn create_offer(
491            &self,
492            _target_peer_id: &str,
493        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
494            Err(TransportError::ConnectionFailed(
495                "not used in this test".to_string(),
496            ))
497        }
498
499        async fn accept_offer(
500            &self,
501            _from_peer_id: &str,
502            _offer_sdp: &str,
503        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
504            Err(TransportError::ConnectionFailed(
505                "not used in this test".to_string(),
506            ))
507        }
508
509        async fn handle_answer(
510            &self,
511            _target_peer_id: &str,
512            _answer_sdp: &str,
513        ) -> Result<Arc<dyn PeerLink>, TransportError> {
514            Err(TransportError::ConnectionFailed(
515                "not used in this test".to_string(),
516            ))
517        }
518
519        async fn handle_candidate(
520            &self,
521            peer_id: &str,
522            candidate: IceCandidate,
523        ) -> Result<(), TransportError> {
524            self.candidates
525                .lock()
526                .await
527                .push((peer_id.to_string(), candidate));
528            Ok(())
529        }
530
531        async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
532            self.removed.lock().await.push(peer_id.to_string());
533            Ok(())
534        }
535    }
536
537    #[tokio::test]
538    async fn routes_targeted_candidates_to_factory() {
539        let router = MeshRouter::new(
540            "local".to_string(),
541            Arc::new(NoopTransport),
542            Arc::new(RecordingFactory::default()),
543            PoolSettings {
544                follows: PoolConfig::default(),
545                other: PoolConfig::default(),
546            },
547            false,
548        );
549
550        router
551            .handle_message(SignalingMessage::Candidate {
552                peer_id: "remote:peer".to_string(),
553                target_peer_id: "local".to_string(),
554                candidate: "candidate:1".to_string(),
555                sdp_m_line_index: Some(0),
556                sdp_mid: Some("data".to_string()),
557            })
558            .await
559            .expect("candidate should route");
560
561        let factory = router.conn_factory.clone();
562        let recorded = factory
563            .candidates
564            .lock()
565            .await
566            .iter()
567            .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
568            .collect::<Vec<_>>();
569
570        assert_eq!(
571            recorded,
572            vec![("remote:peer".to_string(), "candidate:1".to_string())]
573        );
574    }
575
576    #[tokio::test]
577    async fn remove_peer_cleans_factory_state() {
578        let factory = Arc::new(RecordingFactory::default());
579        let router = MeshRouter::new(
580            "local".to_string(),
581            Arc::new(NoopTransport),
582            factory.clone(),
583            PoolSettings {
584                follows: PoolConfig::default(),
585                other: PoolConfig::default(),
586            },
587            false,
588        );
589
590        let removed = router.remove_peer("remote:peer").await;
591        assert!(removed.is_none());
592        assert_eq!(
593            factory.removed.lock().await.as_slice(),
594            &["remote:peer".to_string()]
595        );
596    }
597}