Skip to main content

hashtree_network/
signaling.rs

1//! Shared signaling logic for peer discovery and connection management.
2//!
3//! This module contains the core signaling logic used by both the default
4//! production mesh transport stack and simulation. It handles:
5//! - Hello broadcasts and discovery
6//! - Pool management (follows vs other peers)
7//! - Tie-breaking for connection initiation
8//! - Offer/answer flow coordination
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17/// Peer entry with pool classification and channel
18pub struct PeerEntry {
19    pub channel: Arc<dyn PeerLink>,
20    pub pool: PeerPool,
21    pub hash_get: bool,
22}
23
24/// Mesh router handles peer discovery and negotiated link establishment.
25///
26/// This is the shared routing logic between production transports and simulation.
27/// It uses traits for signaling transport and negotiated link factories so the
28/// same router can drive Nostr websockets, LAN buses, BLE, WebRTC, or mocks.
29///
30/// Uses the standard concurrent-offer "perfect negotiation" pattern:
31/// - Both peers can send offers when they discover each other
32/// - On collision (both sent offers), "polite" peer backs off and accepts incoming
33/// - This ensures connections form even when one peer is satisfied but can accept
34pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
35    /// Our peer ID (pubkey format)
36    peer_id: String,
37    /// Relay transport for signaling
38    transport: Arc<R>,
39    /// Link factory for creating negotiated peer links
40    conn_factory: Arc<F>,
41    /// Connected peers
42    peers: RwLock<HashMap<String, PeerEntry>>,
43    /// Pending outbound offers (we sent offer, waiting for answer)
44    pending_offers: RwLock<HashMap<String, ()>>,
45    /// Pool settings
46    pools: PoolSettings,
47    /// Known peer roots (for future use)
48    peer_roots: RwLock<HashMap<String, Vec<String>>>,
49    /// Last advertised `hash_get` capability per peer.
50    peer_hash_get: RwLock<HashMap<String, bool>>,
51    /// Classifier channel (optional)
52    classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
53    /// Debug mode
54    debug: bool,
55    /// Whether local node accepts `hash_get` lookups.
56    hash_get_enabled: bool,
57}
58
59impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
60    /// Create a new mesh router.
61    pub fn new(
62        peer_id: String,
63        transport: Arc<R>,
64        conn_factory: Arc<F>,
65        pools: PoolSettings,
66        debug: bool,
67    ) -> Self {
68        Self {
69            peer_id,
70            transport,
71            conn_factory,
72            peers: RwLock::new(HashMap::new()),
73            pending_offers: RwLock::new(HashMap::new()),
74            pools,
75            peer_roots: RwLock::new(HashMap::new()),
76            peer_hash_get: RwLock::new(HashMap::new()),
77            classifier_tx: None,
78            debug,
79            hash_get_enabled: true,
80        }
81    }
82
83    /// Set classifier for peer pool assignment
84    pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
85        self.classifier_tx = Some(tx);
86    }
87
88    pub fn set_hash_get_enabled(&mut self, enabled: bool) {
89        self.hash_get_enabled = enabled;
90    }
91
92    pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
93        self.peer_hash_get
94            .write()
95            .await
96            .insert(peer_id.to_string(), enabled);
97        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
98            entry.hash_get = enabled;
99        }
100    }
101
102    pub async fn peer_supports_hash_get(&self, peer_id: &str) -> bool {
103        self.peer_hash_get
104            .read()
105            .await
106            .get(peer_id)
107            .copied()
108            .unwrap_or(true)
109    }
110
111    pub async fn hash_get_peer_ids(&self) -> Vec<String> {
112        let peers = self.peers.read().await;
113        let peer_hash_get = self.peer_hash_get.read().await;
114        peers
115            .keys()
116            .filter(|peer_id| peer_hash_get.get(*peer_id).copied().unwrap_or(true))
117            .cloned()
118            .collect()
119    }
120
121    /// Get our peer ID
122    pub fn peer_id(&self) -> &str {
123        &self.peer_id
124    }
125
126    /// Send hello broadcast
127    pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
128        let msg = SignalingMessage::Hello {
129            peer_id: self.peer_id.clone(),
130            roots,
131            hash_get: self.hash_get_enabled,
132        };
133        self.transport.publish(msg).await
134    }
135
136    /// Count peers by pool
137    async fn count_pools(&self) -> (usize, usize) {
138        let peers = self.peers.read().await;
139        let mut follows = 0;
140        let mut other = 0;
141        for entry in peers.values() {
142            match entry.pool {
143                PeerPool::Follows => follows += 1,
144                PeerPool::Other => other += 1,
145            }
146        }
147        (follows, other)
148    }
149
150    /// Classify a peer by pubkey
151    async fn classify_peer(&self, pubkey: &str) -> PeerPool {
152        if let Some(ref tx) = self.classifier_tx {
153            let (response_tx, response_rx) = tokio::sync::oneshot::channel();
154            let request = ClassifyRequest {
155                pubkey: pubkey.to_string(),
156                response: response_tx,
157            };
158            if tx.send(request).await.is_ok() {
159                if let Ok(pool) = response_rx.await {
160                    return pool;
161                }
162            }
163        }
164        PeerPool::Other
165    }
166
167    /// Check if we can accept a peer in a given pool
168    fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
169        match pool {
170            PeerPool::Follows => self.pools.follows.can_accept(follows),
171            PeerPool::Other => self.pools.other.can_accept(other),
172        }
173    }
174
175    /// Check if a pool needs more peers
176    fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
177        match pool {
178            PeerPool::Follows => self.pools.follows.needs_peers(follows),
179            PeerPool::Other => self.pools.other.needs_peers(other),
180        }
181    }
182
183    /// Handle incoming signaling message
184    ///
185    /// This is the core signaling logic shared between production and simulation.
186    pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
187        match &msg {
188            SignalingMessage::Hello {
189                peer_id,
190                roots,
191                hash_get,
192                ..
193            } => {
194                self.set_peer_hash_get(peer_id, *hash_get).await;
195                self.handle_hello(peer_id, roots, *hash_get).await
196            }
197            SignalingMessage::Offer {
198                peer_id,
199                target_peer_id,
200                sdp,
201            } => {
202                if target_peer_id == &self.peer_id {
203                    self.handle_offer(peer_id, sdp).await
204                } else {
205                    Ok(()) // Not for us
206                }
207            }
208            SignalingMessage::Answer {
209                peer_id,
210                target_peer_id,
211                sdp,
212            } => {
213                if target_peer_id == &self.peer_id {
214                    self.handle_answer(peer_id, sdp).await
215                } else {
216                    Ok(()) // Not for us
217                }
218            }
219            SignalingMessage::Candidate {
220                peer_id,
221                target_peer_id,
222                candidate,
223                sdp_m_line_index,
224                sdp_mid,
225            } => {
226                if target_peer_id == &self.peer_id {
227                    self.conn_factory
228                        .handle_candidate(
229                            peer_id,
230                            crate::types::IceCandidate {
231                                candidate: candidate.clone(),
232                                sdp_m_line_index: *sdp_m_line_index,
233                                sdp_mid: sdp_mid.clone(),
234                            },
235                        )
236                        .await
237                } else {
238                    Ok(())
239                }
240            }
241            SignalingMessage::Candidates {
242                peer_id,
243                target_peer_id,
244                candidates,
245            } => {
246                if target_peer_id == &self.peer_id {
247                    self.conn_factory
248                        .handle_candidates(peer_id, candidates.clone())
249                        .await
250                } else {
251                    Ok(())
252                }
253            }
254        }
255    }
256
257    /// Handle hello message using the shared concurrent-offer negotiation flow.
258    ///
259    /// With perfect negotiation, we send an offer if we need peers.
260    /// No tie-breaking here - collisions are handled in handle_offer.
261    async fn handle_hello(
262        &self,
263        from_peer_id: &str,
264        roots: &[String],
265        hash_get: bool,
266    ) -> Result<(), TransportError> {
267        // Ignore our own hello
268        if from_peer_id == self.peer_id {
269            return Ok(());
270        }
271
272        self.peer_roots
273            .write()
274            .await
275            .insert(from_peer_id.to_string(), roots.to_vec());
276        if let Some(entry) = self.peers.write().await.get_mut(from_peer_id) {
277            entry.hash_get = hash_get;
278            return Ok(());
279        }
280
281        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
282            .map(|peer_id| peer_id.pubkey)
283            .unwrap_or_else(|| from_peer_id.to_string());
284
285        // Classify the peer
286        let pool = self.classify_peer(&peer_pubkey).await;
287
288        // Check pool limits
289        let (follows_count, other_count) = self.count_pools().await;
290
291        if !self.can_accept_peer(pool, follows_count, other_count) {
292            if self.debug {
293                println!(
294                    "[Signaling] Ignoring hello from {} - {:?} pool full",
295                    from_peer_id, pool
296                );
297            }
298            return Ok(());
299        }
300
301        // Shared perfect negotiation: send offer if we NEED more peers
302        // Both sides may send offers - collision handled in handle_offer
303        if self.pool_needs_peers(pool, follows_count, other_count) {
304            // Check if already connected or pending
305            if self.peers.read().await.contains_key(from_peer_id) {
306                return Ok(());
307            }
308            if self.pending_offers.read().await.contains_key(from_peer_id) {
309                return Ok(());
310            }
311
312            if self.debug {
313                println!(
314                    "[Signaling] Sending offer to {} (pool: {:?})",
315                    from_peer_id, pool
316                );
317            }
318
319            // Mark as pending before creating offer
320            self.pending_offers
321                .write()
322                .await
323                .insert(from_peer_id.to_string(), ());
324
325            // Create offer
326            let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
327
328            // Add peer (will be confirmed when we get answer)
329            self.peers.write().await.insert(
330                from_peer_id.to_string(),
331                PeerEntry {
332                    channel,
333                    pool,
334                    hash_get,
335                },
336            );
337
338            // Send offer
339            let offer_msg = SignalingMessage::Offer {
340                peer_id: self.peer_id.clone(),
341                target_peer_id: from_peer_id.to_string(),
342                sdp,
343            };
344            self.transport.publish(offer_msg).await?;
345        }
346
347        Ok(())
348    }
349
350    /// Handle offer message in the shared concurrent-offer negotiation flow.
351    ///
352    /// Handles offer collision: if we also sent an offer to this peer,
353    /// the "polite" peer (lower ID) backs off and accepts the incoming offer.
354    async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
355        // Extract pubkey
356        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
357            .map(|peer_id| peer_id.pubkey)
358            .unwrap_or_else(|| from_peer_id.to_string());
359
360        // Classify and check limits
361        let pool = self.classify_peer(&peer_pubkey).await;
362        let (follows_count, other_count) = self.count_pools().await;
363
364        if !self.can_accept_peer(pool, follows_count, other_count) {
365            if self.debug {
366                println!(
367                    "[Signaling] Ignoring offer from {} - {:?} pool full",
368                    from_peer_id, pool
369                );
370            }
371            return Ok(());
372        }
373
374        // Check for offer collision (we also sent an offer to them)
375        let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
376        if have_pending {
377            // Collision! Use polite/impolite pattern
378            let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
379
380            if we_are_polite {
381                // We're polite - back off, accept their offer
382                // Remove our pending offer and peer entry
383                self.pending_offers.write().await.remove(from_peer_id);
384                self.peers.write().await.remove(from_peer_id);
385
386                if self.debug {
387                    println!(
388                        "[Signaling] Collision with {} - we're polite, accepting their offer",
389                        from_peer_id
390                    );
391                }
392            } else {
393                // We're impolite - ignore their offer, wait for answer to ours
394                if self.debug {
395                    println!(
396                        "[Signaling] Collision with {} - we're impolite, ignoring their offer",
397                        from_peer_id
398                    );
399                }
400                return Ok(());
401            }
402        }
403
404        // Check if already connected (no collision case)
405        if self.peers.read().await.contains_key(from_peer_id) {
406            return Ok(());
407        }
408
409        if self.debug {
410            println!("[Signaling] Accepting offer from {}", from_peer_id);
411        }
412
413        // Accept offer
414        let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
415        let hash_get = self.peer_supports_hash_get(from_peer_id).await;
416
417        // Add peer
418        self.peers.write().await.insert(
419            from_peer_id.to_string(),
420            PeerEntry {
421                channel,
422                pool,
423                hash_get,
424            },
425        );
426
427        // Send answer
428        let answer_msg = SignalingMessage::Answer {
429            peer_id: self.peer_id.clone(),
430            target_peer_id: from_peer_id.to_string(),
431            sdp: answer_sdp,
432        };
433        self.transport.publish(answer_msg).await?;
434
435        Ok(())
436    }
437
438    /// Handle answer message
439    async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
440        if self.debug {
441            println!("[Signaling] Received answer from {}", from_peer_id);
442        }
443
444        // Complete connection
445        let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
446
447        // Peer should already be in our map from when we sent the offer
448        // The channel returned here is the same one we stored
449
450        Ok(())
451    }
452
453    /// Get connected peer count
454    pub async fn peer_count(&self) -> usize {
455        self.peers.read().await.len()
456    }
457
458    /// Get peer IDs
459    pub async fn peer_ids(&self) -> Vec<String> {
460        self.peers.read().await.keys().cloned().collect()
461    }
462
463    /// Get a peer's channel
464    pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
465        self.peers
466            .read()
467            .await
468            .get(peer_id)
469            .map(|e| e.channel.clone())
470    }
471
472    /// Remove a peer and any pending offer state.
473    pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
474        self.pending_offers.write().await.remove(peer_id);
475        self.peer_roots.write().await.remove(peer_id);
476        let _ = self.conn_factory.remove_peer(peer_id).await;
477        self.peers
478            .write()
479            .await
480            .remove(peer_id)
481            .map(|entry| entry.channel)
482    }
483
484    /// Check if we need more peers (below satisfied in any pool)
485    pub async fn needs_peers(&self) -> bool {
486        let (follows, other) = self.count_pools().await;
487        self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
488    }
489
490    /// Check if we can accept more peers (below max in any pool)
491    pub async fn can_accept(&self) -> bool {
492        let (follows, other) = self.count_pools().await;
493        self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use async_trait::async_trait;
501    use std::sync::Arc;
502    use tokio::sync::Mutex;
503
504    use crate::types::{IceCandidate, PoolConfig, PoolSettings};
505
506    #[derive(Default)]
507    struct NoopTransport;
508
509    #[async_trait]
510    impl SignalingTransport for NoopTransport {
511        async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
512            Ok(())
513        }
514
515        async fn disconnect(&self) {}
516
517        async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
518            Ok(())
519        }
520
521        async fn recv(&self) -> Option<SignalingMessage> {
522            None
523        }
524
525        fn try_recv(&self) -> Option<SignalingMessage> {
526            None
527        }
528
529        fn peer_id(&self) -> &str {
530            "local"
531        }
532    }
533
534    #[derive(Default)]
535    struct RecordingFactory {
536        candidates: Mutex<Vec<(String, IceCandidate)>>,
537        removed: Mutex<Vec<String>>,
538    }
539
540    #[async_trait]
541    impl PeerLinkFactory for RecordingFactory {
542        async fn create_offer(
543            &self,
544            _target_peer_id: &str,
545        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
546            Err(TransportError::ConnectionFailed(
547                "not used in this test".to_string(),
548            ))
549        }
550
551        async fn accept_offer(
552            &self,
553            _from_peer_id: &str,
554            _offer_sdp: &str,
555        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
556            Err(TransportError::ConnectionFailed(
557                "not used in this test".to_string(),
558            ))
559        }
560
561        async fn handle_answer(
562            &self,
563            _target_peer_id: &str,
564            _answer_sdp: &str,
565        ) -> Result<Arc<dyn PeerLink>, TransportError> {
566            Err(TransportError::ConnectionFailed(
567                "not used in this test".to_string(),
568            ))
569        }
570
571        async fn handle_candidate(
572            &self,
573            peer_id: &str,
574            candidate: IceCandidate,
575        ) -> Result<(), TransportError> {
576            self.candidates
577                .lock()
578                .await
579                .push((peer_id.to_string(), candidate));
580            Ok(())
581        }
582
583        async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
584            self.removed.lock().await.push(peer_id.to_string());
585            Ok(())
586        }
587    }
588
589    #[tokio::test]
590    async fn routes_targeted_candidates_to_factory() {
591        let router = MeshRouter::new(
592            "local".to_string(),
593            Arc::new(NoopTransport),
594            Arc::new(RecordingFactory::default()),
595            PoolSettings {
596                follows: PoolConfig::default(),
597                other: PoolConfig::default(),
598            },
599            false,
600        );
601
602        router
603            .handle_message(SignalingMessage::Candidate {
604                peer_id: "remote:peer".to_string(),
605                target_peer_id: "local".to_string(),
606                candidate: "candidate:1".to_string(),
607                sdp_m_line_index: Some(0),
608                sdp_mid: Some("data".to_string()),
609            })
610            .await
611            .expect("candidate should route");
612
613        let factory = router.conn_factory.clone();
614        let recorded = factory
615            .candidates
616            .lock()
617            .await
618            .iter()
619            .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
620            .collect::<Vec<_>>();
621
622        assert_eq!(
623            recorded,
624            vec![("remote:peer".to_string(), "candidate:1".to_string())]
625        );
626    }
627
628    #[tokio::test]
629    async fn remove_peer_cleans_factory_state() {
630        let factory = Arc::new(RecordingFactory::default());
631        let router = MeshRouter::new(
632            "local".to_string(),
633            Arc::new(NoopTransport),
634            factory.clone(),
635            PoolSettings {
636                follows: PoolConfig::default(),
637                other: PoolConfig::default(),
638            },
639            false,
640        );
641
642        let removed = router.remove_peer("remote:peer").await;
643        assert!(removed.is_none());
644        assert_eq!(
645            factory.removed.lock().await.as_slice(),
646            &["remote:peer".to_string()]
647        );
648    }
649}