Skip to main content

hashtree_network/
signaling.rs

1//! Shared signaling logic for peer discovery and connection management.
2//!
3//! This module contains the core signaling logic used by both the default
4//! production mesh transport stack and simulation. It handles:
5//! - Hello broadcasts and discovery
6//! - Pool management (follows vs other peers)
7//! - Tie-breaking for connection initiation
8//! - Offer/answer flow coordination
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17/// Peer entry with pool classification and channel
18pub struct PeerEntry {
19    pub channel: Arc<dyn PeerLink>,
20    pub pool: PeerPool,
21    pub hash_get: bool,
22}
23
24/// Mesh router handles peer discovery and negotiated link establishment.
25///
26/// This is the shared routing logic between production transports and simulation.
27/// It uses traits for signaling transport and negotiated link factories so the
28/// same router can drive Nostr websockets, LAN buses, BLE, WebRTC, or mocks.
29///
30/// Uses the standard concurrent-offer "perfect negotiation" pattern:
31/// - Both peers can send offers when they discover each other
32/// - On collision (both sent offers), "polite" peer backs off and accepts incoming
33/// - This ensures connections form even when one peer is satisfied but can accept
34pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
35    /// Our peer ID (pubkey format)
36    peer_id: String,
37    /// Relay transport for signaling
38    transport: Arc<R>,
39    /// Link factory for creating negotiated peer links
40    conn_factory: Arc<F>,
41    /// Connected peers
42    peers: RwLock<HashMap<String, PeerEntry>>,
43    /// Pending outbound offers (we sent offer, waiting for answer)
44    pending_offers: RwLock<HashMap<String, ()>>,
45    /// Pool settings
46    pools: PoolSettings,
47    /// Known peer roots (for future use)
48    peer_roots: RwLock<HashMap<String, Vec<String>>>,
49    /// Last advertised `hash_get` capability per peer.
50    peer_hash_get: RwLock<HashMap<String, bool>>,
51    /// Classifier channel (optional)
52    classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
53    /// Debug mode
54    debug: bool,
55    /// Whether local node accepts `hash_get` lookups.
56    hash_get_enabled: bool,
57}
58
59impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
60    /// Create a new mesh router.
61    pub fn new(
62        peer_id: String,
63        transport: Arc<R>,
64        conn_factory: Arc<F>,
65        pools: PoolSettings,
66        debug: bool,
67    ) -> Self {
68        Self {
69            peer_id,
70            transport,
71            conn_factory,
72            peers: RwLock::new(HashMap::new()),
73            pending_offers: RwLock::new(HashMap::new()),
74            pools,
75            peer_roots: RwLock::new(HashMap::new()),
76            peer_hash_get: RwLock::new(HashMap::new()),
77            classifier_tx: None,
78            debug,
79            hash_get_enabled: true,
80        }
81    }
82
83    /// Set classifier for peer pool assignment
84    pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
85        self.classifier_tx = Some(tx);
86    }
87
88    pub fn set_hash_get_enabled(&mut self, enabled: bool) {
89        self.hash_get_enabled = enabled;
90    }
91
92    pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
93        self.peer_hash_get
94            .write()
95            .await
96            .insert(peer_id.to_string(), enabled);
97        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
98            entry.hash_get = enabled;
99        }
100    }
101
102    pub async fn peer_supports_hash_get(&self, peer_id: &str) -> bool {
103        self.peer_hash_get
104            .read()
105            .await
106            .get(peer_id)
107            .copied()
108            .unwrap_or(true)
109    }
110
111    pub async fn hash_get_peer_ids(&self) -> Vec<String> {
112        let peers = self.peers.read().await;
113        let peer_hash_get = self.peer_hash_get.read().await;
114        peers
115            .keys()
116            .filter(|peer_id| peer_hash_get.get(*peer_id).copied().unwrap_or(true))
117            .cloned()
118            .collect()
119    }
120
121    /// Get our peer ID
122    pub fn peer_id(&self) -> &str {
123        &self.peer_id
124    }
125
126    /// Send hello broadcast
127    pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
128        let msg = SignalingMessage::Hello {
129            peer_id: self.peer_id.clone(),
130            roots,
131            hash_get: self.hash_get_enabled,
132        };
133        self.transport.publish(msg).await
134    }
135
136    /// Count peers by pool
137    async fn count_pools(&self) -> (usize, usize) {
138        let peers = self.peers.read().await;
139        let mut follows = 0;
140        let mut other = 0;
141        for entry in peers.values() {
142            match entry.pool {
143                PeerPool::Follows => follows += 1,
144                PeerPool::Other => other += 1,
145            }
146        }
147        (follows, other)
148    }
149
150    /// Classify a peer by pubkey
151    async fn classify_peer(&self, pubkey: &str) -> PeerPool {
152        if let Some(ref tx) = self.classifier_tx {
153            let (response_tx, response_rx) = tokio::sync::oneshot::channel();
154            let request = ClassifyRequest {
155                pubkey: pubkey.to_string(),
156                response: response_tx,
157            };
158            if tx.send(request).await.is_ok() {
159                if let Ok(pool) = response_rx.await {
160                    return pool;
161                }
162            }
163        }
164        PeerPool::Other
165    }
166
167    /// Check if we can accept a peer in a given pool
168    fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
169        match pool {
170            PeerPool::Follows => self.pools.follows.can_accept(follows),
171            PeerPool::Other => self.pools.other.can_accept(other),
172        }
173    }
174
175    /// Check if a pool needs more peers
176    fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
177        match pool {
178            PeerPool::Follows => self.pools.follows.needs_peers(follows),
179            PeerPool::Other => self.pools.other.needs_peers(other),
180        }
181    }
182
183    /// Handle incoming signaling message
184    ///
185    /// This is the core signaling logic shared between production and simulation.
186    pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
187        match &msg {
188            SignalingMessage::Hello {
189                peer_id,
190                roots,
191                hash_get,
192            } => {
193                self.set_peer_hash_get(peer_id, *hash_get).await;
194                self.handle_hello(peer_id, roots, *hash_get).await
195            }
196            SignalingMessage::Offer {
197                peer_id,
198                target_peer_id,
199                sdp,
200            } => {
201                if target_peer_id == &self.peer_id {
202                    self.handle_offer(peer_id, sdp).await
203                } else {
204                    Ok(()) // Not for us
205                }
206            }
207            SignalingMessage::Answer {
208                peer_id,
209                target_peer_id,
210                sdp,
211            } => {
212                if target_peer_id == &self.peer_id {
213                    self.handle_answer(peer_id, sdp).await
214                } else {
215                    Ok(()) // Not for us
216                }
217            }
218            SignalingMessage::Candidate {
219                peer_id,
220                target_peer_id,
221                candidate,
222                sdp_m_line_index,
223                sdp_mid,
224            } => {
225                if target_peer_id == &self.peer_id {
226                    self.conn_factory
227                        .handle_candidate(
228                            peer_id,
229                            crate::types::IceCandidate {
230                                candidate: candidate.clone(),
231                                sdp_m_line_index: *sdp_m_line_index,
232                                sdp_mid: sdp_mid.clone(),
233                            },
234                        )
235                        .await
236                } else {
237                    Ok(())
238                }
239            }
240            SignalingMessage::Candidates {
241                peer_id,
242                target_peer_id,
243                candidates,
244            } => {
245                if target_peer_id == &self.peer_id {
246                    self.conn_factory
247                        .handle_candidates(peer_id, candidates.clone())
248                        .await
249                } else {
250                    Ok(())
251                }
252            }
253        }
254    }
255
256    /// Handle hello message using the shared concurrent-offer negotiation flow.
257    ///
258    /// With perfect negotiation, we send an offer if we need peers.
259    /// No tie-breaking here - collisions are handled in handle_offer.
260    async fn handle_hello(
261        &self,
262        from_peer_id: &str,
263        roots: &[String],
264        hash_get: bool,
265    ) -> Result<(), TransportError> {
266        // Ignore our own hello
267        if from_peer_id == self.peer_id {
268            return Ok(());
269        }
270
271        self.peer_roots
272            .write()
273            .await
274            .insert(from_peer_id.to_string(), roots.to_vec());
275        if let Some(entry) = self.peers.write().await.get_mut(from_peer_id) {
276            entry.hash_get = hash_get;
277            return Ok(());
278        }
279
280        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
281            .map(|peer_id| peer_id.pubkey)
282            .unwrap_or_else(|| from_peer_id.to_string());
283
284        // Classify the peer
285        let pool = self.classify_peer(&peer_pubkey).await;
286
287        // Check pool limits
288        let (follows_count, other_count) = self.count_pools().await;
289
290        if !self.can_accept_peer(pool, follows_count, other_count) {
291            if self.debug {
292                println!(
293                    "[Signaling] Ignoring hello from {} - {:?} pool full",
294                    from_peer_id, pool
295                );
296            }
297            return Ok(());
298        }
299
300        // Shared perfect negotiation: send offer if we NEED more peers
301        // Both sides may send offers - collision handled in handle_offer
302        if self.pool_needs_peers(pool, follows_count, other_count) {
303            // Check if already connected or pending
304            if self.peers.read().await.contains_key(from_peer_id) {
305                return Ok(());
306            }
307            if self.pending_offers.read().await.contains_key(from_peer_id) {
308                return Ok(());
309            }
310
311            if self.debug {
312                println!(
313                    "[Signaling] Sending offer to {} (pool: {:?})",
314                    from_peer_id, pool
315                );
316            }
317
318            // Mark as pending before creating offer
319            self.pending_offers
320                .write()
321                .await
322                .insert(from_peer_id.to_string(), ());
323
324            // Create offer
325            let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
326
327            // Add peer (will be confirmed when we get answer)
328            self.peers.write().await.insert(
329                from_peer_id.to_string(),
330                PeerEntry {
331                    channel,
332                    pool,
333                    hash_get,
334                },
335            );
336
337            // Send offer
338            let offer_msg = SignalingMessage::Offer {
339                peer_id: self.peer_id.clone(),
340                target_peer_id: from_peer_id.to_string(),
341                sdp,
342            };
343            self.transport.publish(offer_msg).await?;
344        }
345
346        Ok(())
347    }
348
349    /// Handle offer message in the shared concurrent-offer negotiation flow.
350    ///
351    /// Handles offer collision: if we also sent an offer to this peer,
352    /// the "polite" peer (lower ID) backs off and accepts the incoming offer.
353    async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
354        // Extract pubkey
355        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
356            .map(|peer_id| peer_id.pubkey)
357            .unwrap_or_else(|| from_peer_id.to_string());
358
359        // Classify and check limits
360        let pool = self.classify_peer(&peer_pubkey).await;
361        let (follows_count, other_count) = self.count_pools().await;
362
363        if !self.can_accept_peer(pool, follows_count, other_count) {
364            if self.debug {
365                println!(
366                    "[Signaling] Ignoring offer from {} - {:?} pool full",
367                    from_peer_id, pool
368                );
369            }
370            return Ok(());
371        }
372
373        // Check for offer collision (we also sent an offer to them)
374        let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
375        if have_pending {
376            // Collision! Use polite/impolite pattern
377            let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
378
379            if we_are_polite {
380                // We're polite - back off, accept their offer
381                // Remove our pending offer and peer entry
382                self.pending_offers.write().await.remove(from_peer_id);
383                self.peers.write().await.remove(from_peer_id);
384
385                if self.debug {
386                    println!(
387                        "[Signaling] Collision with {} - we're polite, accepting their offer",
388                        from_peer_id
389                    );
390                }
391            } else {
392                // We're impolite - ignore their offer, wait for answer to ours
393                if self.debug {
394                    println!(
395                        "[Signaling] Collision with {} - we're impolite, ignoring their offer",
396                        from_peer_id
397                    );
398                }
399                return Ok(());
400            }
401        }
402
403        // Check if already connected (no collision case)
404        if self.peers.read().await.contains_key(from_peer_id) {
405            return Ok(());
406        }
407
408        if self.debug {
409            println!("[Signaling] Accepting offer from {}", from_peer_id);
410        }
411
412        // Accept offer
413        let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
414        let hash_get = self.peer_supports_hash_get(from_peer_id).await;
415
416        // Add peer
417        self.peers.write().await.insert(
418            from_peer_id.to_string(),
419            PeerEntry {
420                channel,
421                pool,
422                hash_get,
423            },
424        );
425
426        // Send answer
427        let answer_msg = SignalingMessage::Answer {
428            peer_id: self.peer_id.clone(),
429            target_peer_id: from_peer_id.to_string(),
430            sdp: answer_sdp,
431        };
432        self.transport.publish(answer_msg).await?;
433
434        Ok(())
435    }
436
437    /// Handle answer message
438    async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
439        if self.debug {
440            println!("[Signaling] Received answer from {}", from_peer_id);
441        }
442
443        // Complete connection
444        let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
445
446        // Peer should already be in our map from when we sent the offer
447        // The channel returned here is the same one we stored
448
449        Ok(())
450    }
451
452    /// Get connected peer count
453    pub async fn peer_count(&self) -> usize {
454        self.peers.read().await.len()
455    }
456
457    /// Get peer IDs
458    pub async fn peer_ids(&self) -> Vec<String> {
459        self.peers.read().await.keys().cloned().collect()
460    }
461
462    /// Get a peer's channel
463    pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
464        self.peers
465            .read()
466            .await
467            .get(peer_id)
468            .map(|e| e.channel.clone())
469    }
470
471    /// Remove a peer and any pending offer state.
472    pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
473        self.pending_offers.write().await.remove(peer_id);
474        self.peer_roots.write().await.remove(peer_id);
475        let _ = self.conn_factory.remove_peer(peer_id).await;
476        self.peers
477            .write()
478            .await
479            .remove(peer_id)
480            .map(|entry| entry.channel)
481    }
482
483    /// Check if we need more peers (below satisfied in any pool)
484    pub async fn needs_peers(&self) -> bool {
485        let (follows, other) = self.count_pools().await;
486        self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
487    }
488
489    /// Check if we can accept more peers (below max in any pool)
490    pub async fn can_accept(&self) -> bool {
491        let (follows, other) = self.count_pools().await;
492        self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use async_trait::async_trait;
500    use std::sync::Arc;
501    use tokio::sync::Mutex;
502
503    use crate::types::{IceCandidate, PoolConfig, PoolSettings};
504
505    #[derive(Default)]
506    struct NoopTransport;
507
508    #[async_trait]
509    impl SignalingTransport for NoopTransport {
510        async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
511            Ok(())
512        }
513
514        async fn disconnect(&self) {}
515
516        async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
517            Ok(())
518        }
519
520        async fn recv(&self) -> Option<SignalingMessage> {
521            None
522        }
523
524        fn try_recv(&self) -> Option<SignalingMessage> {
525            None
526        }
527
528        fn peer_id(&self) -> &str {
529            "local"
530        }
531    }
532
533    #[derive(Default)]
534    struct RecordingFactory {
535        candidates: Mutex<Vec<(String, IceCandidate)>>,
536        removed: Mutex<Vec<String>>,
537    }
538
539    #[async_trait]
540    impl PeerLinkFactory for RecordingFactory {
541        async fn create_offer(
542            &self,
543            _target_peer_id: &str,
544        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
545            Err(TransportError::ConnectionFailed(
546                "not used in this test".to_string(),
547            ))
548        }
549
550        async fn accept_offer(
551            &self,
552            _from_peer_id: &str,
553            _offer_sdp: &str,
554        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
555            Err(TransportError::ConnectionFailed(
556                "not used in this test".to_string(),
557            ))
558        }
559
560        async fn handle_answer(
561            &self,
562            _target_peer_id: &str,
563            _answer_sdp: &str,
564        ) -> Result<Arc<dyn PeerLink>, TransportError> {
565            Err(TransportError::ConnectionFailed(
566                "not used in this test".to_string(),
567            ))
568        }
569
570        async fn handle_candidate(
571            &self,
572            peer_id: &str,
573            candidate: IceCandidate,
574        ) -> Result<(), TransportError> {
575            self.candidates
576                .lock()
577                .await
578                .push((peer_id.to_string(), candidate));
579            Ok(())
580        }
581
582        async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
583            self.removed.lock().await.push(peer_id.to_string());
584            Ok(())
585        }
586    }
587
588    #[tokio::test]
589    async fn routes_targeted_candidates_to_factory() {
590        let router = MeshRouter::new(
591            "local".to_string(),
592            Arc::new(NoopTransport),
593            Arc::new(RecordingFactory::default()),
594            PoolSettings {
595                follows: PoolConfig::default(),
596                other: PoolConfig::default(),
597            },
598            false,
599        );
600
601        router
602            .handle_message(SignalingMessage::Candidate {
603                peer_id: "remote:peer".to_string(),
604                target_peer_id: "local".to_string(),
605                candidate: "candidate:1".to_string(),
606                sdp_m_line_index: Some(0),
607                sdp_mid: Some("data".to_string()),
608            })
609            .await
610            .expect("candidate should route");
611
612        let factory = router.conn_factory.clone();
613        let recorded = factory
614            .candidates
615            .lock()
616            .await
617            .iter()
618            .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
619            .collect::<Vec<_>>();
620
621        assert_eq!(
622            recorded,
623            vec![("remote:peer".to_string(), "candidate:1".to_string())]
624        );
625    }
626
627    #[tokio::test]
628    async fn remove_peer_cleans_factory_state() {
629        let factory = Arc::new(RecordingFactory::default());
630        let router = MeshRouter::new(
631            "local".to_string(),
632            Arc::new(NoopTransport),
633            factory.clone(),
634            PoolSettings {
635                follows: PoolConfig::default(),
636                other: PoolConfig::default(),
637            },
638            false,
639        );
640
641        let removed = router.remove_peer("remote:peer").await;
642        assert!(removed.is_none());
643        assert_eq!(
644            factory.removed.lock().await.as_slice(),
645            &["remote:peer".to_string()]
646        );
647    }
648}