Skip to main content

hashtree_network/
signaling.rs

1//! Shared signaling logic for peer discovery and connection management.
2//!
3//! This module contains the core signaling logic used by both the default
4//! production mesh transport stack and simulation. It handles:
5//! - Hello broadcasts and discovery
6//! - Pool management (follows vs other peers)
7//! - Tie-breaking for connection initiation
8//! - Offer/answer flow coordination
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::transport::{PeerLink, PeerLinkFactory, SignalingTransport, TransportError};
15use crate::types::{is_polite_peer, ClassifyRequest, PeerPool, PoolSettings, SignalingMessage};
16
17/// Peer entry with pool classification and channel
18pub struct PeerEntry {
19    pub channel: Arc<dyn PeerLink>,
20    pub pool: PeerPool,
21    pub hash_get: bool,
22}
23
24/// Mesh router handles peer discovery and negotiated link establishment.
25///
26/// This is the shared routing logic between production transports and simulation.
27/// It uses traits for signaling transport and negotiated link factories so the
28/// same router can drive Nostr websockets, LAN buses, BLE, WebRTC, or mocks.
29///
30/// Uses the standard concurrent-offer "perfect negotiation" pattern:
31/// - Both peers can send offers when they discover each other
32/// - On collision (both sent offers), "polite" peer backs off and accepts incoming
33/// - This ensures connections form even when one peer is satisfied but can accept
34pub struct MeshRouter<R: SignalingTransport, F: PeerLinkFactory> {
35    /// Our peer ID (pubkey format)
36    peer_id: String,
37    /// Relay transport for signaling
38    transport: Arc<R>,
39    /// Link factory for creating negotiated peer links
40    conn_factory: Arc<F>,
41    /// Connected peers
42    peers: RwLock<HashMap<String, PeerEntry>>,
43    /// Pending outbound offers (we sent offer, waiting for answer)
44    pending_offers: RwLock<HashMap<String, ()>>,
45    /// Pool settings
46    pools: PoolSettings,
47    /// Known peer roots (for future use)
48    peer_roots: RwLock<HashMap<String, Vec<String>>>,
49    /// Last advertised `hash_get` capability per peer.
50    peer_hash_get: RwLock<HashMap<String, bool>>,
51    /// Classifier channel (optional)
52    classifier_tx: Option<tokio::sync::mpsc::Sender<ClassifyRequest>>,
53    /// Debug mode
54    debug: bool,
55    /// Whether local node accepts `hash_get` lookups.
56    hash_get_enabled: bool,
57}
58
59impl<R: SignalingTransport + 'static, F: PeerLinkFactory + 'static> MeshRouter<R, F> {
60    /// Create a new mesh router.
61    pub fn new(
62        peer_id: String,
63        transport: Arc<R>,
64        conn_factory: Arc<F>,
65        pools: PoolSettings,
66        debug: bool,
67    ) -> Self {
68        Self {
69            peer_id,
70            transport,
71            conn_factory,
72            peers: RwLock::new(HashMap::new()),
73            pending_offers: RwLock::new(HashMap::new()),
74            pools,
75            peer_roots: RwLock::new(HashMap::new()),
76            peer_hash_get: RwLock::new(HashMap::new()),
77            classifier_tx: None,
78            debug,
79            hash_get_enabled: true,
80        }
81    }
82
83    /// Set classifier for peer pool assignment
84    pub fn set_classifier(&mut self, tx: tokio::sync::mpsc::Sender<ClassifyRequest>) {
85        self.classifier_tx = Some(tx);
86    }
87
88    pub fn set_hash_get_enabled(&mut self, enabled: bool) {
89        self.hash_get_enabled = enabled;
90    }
91
92    pub async fn set_peer_hash_get(&self, peer_id: &str, enabled: bool) {
93        self.peer_hash_get
94            .write()
95            .await
96            .insert(peer_id.to_string(), enabled);
97        if let Some(entry) = self.peers.write().await.get_mut(peer_id) {
98            entry.hash_get = enabled;
99        }
100    }
101
102    pub async fn peer_supports_hash_get(&self, peer_id: &str) -> bool {
103        self.peer_hash_get
104            .read()
105            .await
106            .get(peer_id)
107            .copied()
108            .unwrap_or(true)
109    }
110
111    pub async fn hash_get_peer_ids(&self) -> Vec<String> {
112        let peers = self.peers.read().await;
113        let peer_hash_get = self.peer_hash_get.read().await;
114        peers
115            .keys()
116            .filter(|peer_id| peer_hash_get.get(*peer_id).copied().unwrap_or(true))
117            .cloned()
118            .collect()
119    }
120
121    /// Get our peer ID
122    pub fn peer_id(&self) -> &str {
123        &self.peer_id
124    }
125
126    /// Send hello broadcast
127    pub async fn send_hello(&self, roots: Vec<String>) -> Result<(), TransportError> {
128        let msg = SignalingMessage::Hello {
129            peer_id: self.peer_id.clone(),
130            roots,
131            hash_get: self.hash_get_enabled,
132        };
133        self.transport.publish(msg).await
134    }
135
136    /// Count peers by pool
137    async fn count_pools(&self) -> (usize, usize) {
138        let peers = self.peers.read().await;
139        let mut follows = 0;
140        let mut other = 0;
141        for entry in peers.values() {
142            match entry.pool {
143                PeerPool::Follows => follows += 1,
144                PeerPool::Other => other += 1,
145            }
146        }
147        (follows, other)
148    }
149
150    /// Classify a peer by pubkey
151    async fn classify_peer(&self, pubkey: &str) -> PeerPool {
152        if let Some(ref tx) = self.classifier_tx {
153            let (response_tx, response_rx) = tokio::sync::oneshot::channel();
154            let request = ClassifyRequest {
155                pubkey: pubkey.to_string(),
156                response: response_tx,
157            };
158            if tx.send(request).await.is_ok() {
159                if let Ok(pool) = response_rx.await {
160                    return pool;
161                }
162            }
163        }
164        PeerPool::Other
165    }
166
167    /// Check if we can accept a peer in a given pool
168    fn can_accept_peer(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
169        match pool {
170            PeerPool::Follows => self.pools.follows.can_accept(follows),
171            PeerPool::Other => self.pools.other.can_accept(other),
172        }
173    }
174
175    /// Check if a pool needs more peers
176    fn pool_needs_peers(&self, pool: PeerPool, follows: usize, other: usize) -> bool {
177        match pool {
178            PeerPool::Follows => self.pools.follows.needs_peers(follows),
179            PeerPool::Other => self.pools.other.needs_peers(other),
180        }
181    }
182
183    /// Handle incoming signaling message
184    ///
185    /// This is the core signaling logic shared between production and simulation.
186    pub async fn handle_message(&self, msg: SignalingMessage) -> Result<(), TransportError> {
187        match &msg {
188            SignalingMessage::Hello {
189                peer_id,
190                roots,
191                hash_get,
192                ..
193            } => {
194                self.set_peer_hash_get(peer_id, *hash_get).await;
195                self.handle_hello(peer_id, roots, *hash_get).await
196            }
197            SignalingMessage::Offer {
198                peer_id,
199                target_peer_id,
200                sdp,
201            } => {
202                if target_peer_id == &self.peer_id {
203                    self.handle_offer(peer_id, sdp).await
204                } else {
205                    Ok(()) // Not for us
206                }
207            }
208            SignalingMessage::Answer {
209                peer_id,
210                target_peer_id,
211                sdp,
212            } => {
213                if target_peer_id == &self.peer_id {
214                    self.handle_answer(peer_id, sdp).await
215                } else {
216                    Ok(()) // Not for us
217                }
218            }
219            SignalingMessage::Candidate {
220                peer_id,
221                target_peer_id,
222                candidate,
223                sdp_m_line_index,
224                sdp_mid,
225            } => {
226                if target_peer_id == &self.peer_id {
227                    self.conn_factory
228                        .handle_candidate(
229                            peer_id,
230                            crate::types::IceCandidate {
231                                candidate: candidate.clone(),
232                                sdp_m_line_index: *sdp_m_line_index,
233                                sdp_mid: sdp_mid.clone(),
234                            },
235                        )
236                        .await
237                } else {
238                    Ok(())
239                }
240            }
241            SignalingMessage::Candidates {
242                peer_id,
243                target_peer_id,
244                candidates,
245            } => {
246                if target_peer_id == &self.peer_id {
247                    self.conn_factory
248                        .handle_candidates(peer_id, candidates.clone())
249                        .await
250                } else {
251                    Ok(())
252                }
253            }
254        }
255    }
256
257    /// Handle hello message using the shared concurrent-offer negotiation flow.
258    ///
259    /// With perfect negotiation, we send an offer if we need peers.
260    /// No tie-breaking here - collisions are handled in handle_offer.
261    async fn handle_hello(
262        &self,
263        from_peer_id: &str,
264        roots: &[String],
265        hash_get: bool,
266    ) -> Result<(), TransportError> {
267        // Ignore our own hello
268        if from_peer_id == self.peer_id {
269            return Ok(());
270        }
271
272        self.peer_roots
273            .write()
274            .await
275            .insert(from_peer_id.to_string(), roots.to_vec());
276        if let Some(entry) = self.peers.write().await.get_mut(from_peer_id) {
277            entry.hash_get = hash_get;
278            return Ok(());
279        }
280
281        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
282            .map(|peer_id| peer_id.pubkey)
283            .unwrap_or_else(|| from_peer_id.to_string());
284
285        // Classify the peer
286        let pool = self.classify_peer(&peer_pubkey).await;
287
288        // Check pool limits
289        let (follows_count, other_count) = self.count_pools().await;
290
291        if !self.can_accept_peer(pool, follows_count, other_count) {
292            if self.debug {
293                println!(
294                    "[Signaling] Ignoring hello from {} - {:?} pool full",
295                    from_peer_id, pool
296                );
297            }
298            return Ok(());
299        }
300
301        // Shared perfect negotiation: send offer if we NEED more peers
302        // Both sides may send offers - collision handled in handle_offer
303        if self.pool_needs_peers(pool, follows_count, other_count) {
304            // Check if already connected or pending
305            if self.peers.read().await.contains_key(from_peer_id) {
306                return Ok(());
307            }
308            if self.pending_offers.read().await.contains_key(from_peer_id) {
309                return Ok(());
310            }
311
312            if self.debug {
313                println!(
314                    "[Signaling] Sending offer to {} (pool: {:?})",
315                    from_peer_id, pool
316                );
317            }
318
319            // Mark as pending before creating offer
320            self.pending_offers
321                .write()
322                .await
323                .insert(from_peer_id.to_string(), ());
324
325            // Create offer
326            let (channel, sdp) = self.conn_factory.create_offer(from_peer_id).await?;
327
328            // Add peer (will be confirmed when we get answer)
329            self.peers.write().await.insert(
330                from_peer_id.to_string(),
331                PeerEntry {
332                    channel,
333                    pool,
334                    hash_get,
335                },
336            );
337
338            // Send offer
339            let offer_msg = SignalingMessage::Offer {
340                peer_id: self.peer_id.clone(),
341                target_peer_id: from_peer_id.to_string(),
342                sdp,
343            };
344            self.transport.publish(offer_msg).await?;
345        }
346
347        Ok(())
348    }
349
350    /// Handle offer message in the shared concurrent-offer negotiation flow.
351    ///
352    /// Handles offer collision: if we also sent an offer to this peer,
353    /// the "polite" peer (lower ID) backs off and accepts the incoming offer.
354    async fn handle_offer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
355        // Extract pubkey
356        let peer_pubkey = crate::types::PeerId::from_peer_string(from_peer_id)
357            .map(|peer_id| peer_id.pubkey)
358            .unwrap_or_else(|| from_peer_id.to_string());
359
360        // Classify and check limits
361        let pool = self.classify_peer(&peer_pubkey).await;
362        let (follows_count, other_count) = self.count_pools().await;
363
364        if !self.can_accept_peer(pool, follows_count, other_count) {
365            if self.debug {
366                println!(
367                    "[Signaling] Ignoring offer from {} - {:?} pool full",
368                    from_peer_id, pool
369                );
370            }
371            return Ok(());
372        }
373
374        // Check for offer collision (we also sent an offer to them)
375        let have_pending = self.pending_offers.read().await.contains_key(from_peer_id);
376        if have_pending {
377            // Collision! Use polite/impolite pattern
378            let we_are_polite = is_polite_peer(&self.peer_id, from_peer_id);
379
380            if we_are_polite {
381                // We're polite - back off, accept their offer
382                // Remove our pending offer and peer entry
383                self.pending_offers.write().await.remove(from_peer_id);
384                self.peers.write().await.remove(from_peer_id);
385
386                if self.debug {
387                    println!(
388                        "[Signaling] Collision with {} - we're polite, accepting their offer",
389                        from_peer_id
390                    );
391                }
392            } else {
393                // We're impolite - ignore their offer, wait for answer to ours
394                if self.debug {
395                    println!(
396                        "[Signaling] Collision with {} - we're impolite, ignoring their offer",
397                        from_peer_id
398                    );
399                }
400                return Ok(());
401            }
402        }
403
404        // Check if already connected (no collision case)
405        if self.peers.read().await.contains_key(from_peer_id) {
406            return Ok(());
407        }
408
409        if self.debug {
410            println!("[Signaling] Accepting offer from {}", from_peer_id);
411        }
412
413        // Accept offer
414        let (channel, answer_sdp) = self.conn_factory.accept_offer(from_peer_id, sdp).await?;
415        let hash_get = self.peer_supports_hash_get(from_peer_id).await;
416
417        // Add peer
418        self.peers.write().await.insert(
419            from_peer_id.to_string(),
420            PeerEntry {
421                channel,
422                pool,
423                hash_get,
424            },
425        );
426
427        // Send answer
428        let answer_msg = SignalingMessage::Answer {
429            peer_id: self.peer_id.clone(),
430            target_peer_id: from_peer_id.to_string(),
431            sdp: answer_sdp,
432        };
433        self.transport.publish(answer_msg).await?;
434
435        Ok(())
436    }
437
438    /// Handle answer message
439    async fn handle_answer(&self, from_peer_id: &str, sdp: &str) -> Result<(), TransportError> {
440        if self.debug {
441            println!("[Signaling] Received answer from {}", from_peer_id);
442        }
443
444        // Complete connection
445        let _channel = self.conn_factory.handle_answer(from_peer_id, sdp).await?;
446
447        // Peer should already be in our map from when we sent the offer
448        // The channel returned here is the same one we stored
449
450        Ok(())
451    }
452
453    /// Get connected peer count
454    pub async fn peer_count(&self) -> usize {
455        self.peers.read().await.len()
456    }
457
458    /// Get peer IDs
459    pub async fn peer_ids(&self) -> Vec<String> {
460        let mut peer_ids = self.peers.read().await.keys().cloned().collect::<Vec<_>>();
461        peer_ids.sort();
462        peer_ids
463    }
464
465    /// Get a peer's channel
466    pub async fn get_channel(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
467        self.peers
468            .read()
469            .await
470            .get(peer_id)
471            .map(|e| e.channel.clone())
472    }
473
474    /// Remove a peer and any pending offer state.
475    pub async fn remove_peer(&self, peer_id: &str) -> Option<Arc<dyn PeerLink>> {
476        self.pending_offers.write().await.remove(peer_id);
477        self.peer_roots.write().await.remove(peer_id);
478        let _ = self.conn_factory.remove_peer(peer_id).await;
479        self.peers
480            .write()
481            .await
482            .remove(peer_id)
483            .map(|entry| entry.channel)
484    }
485
486    /// Check if we need more peers (below satisfied in any pool)
487    pub async fn needs_peers(&self) -> bool {
488        let (follows, other) = self.count_pools().await;
489        self.pools.follows.needs_peers(follows) || self.pools.other.needs_peers(other)
490    }
491
492    /// Check if we can accept more peers (below max in any pool)
493    pub async fn can_accept(&self) -> bool {
494        let (follows, other) = self.count_pools().await;
495        self.pools.follows.can_accept(follows) || self.pools.other.can_accept(other)
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use async_trait::async_trait;
503    use std::sync::Arc;
504    use tokio::sync::Mutex;
505
506    use crate::types::{IceCandidate, PoolConfig, PoolSettings};
507
508    #[derive(Default)]
509    struct NoopTransport;
510
511    #[async_trait]
512    impl SignalingTransport for NoopTransport {
513        async fn connect(&self, _relays: &[String]) -> Result<(), TransportError> {
514            Ok(())
515        }
516
517        async fn disconnect(&self) {}
518
519        async fn publish(&self, _msg: SignalingMessage) -> Result<(), TransportError> {
520            Ok(())
521        }
522
523        async fn recv(&self) -> Option<SignalingMessage> {
524            None
525        }
526
527        fn try_recv(&self) -> Option<SignalingMessage> {
528            None
529        }
530
531        fn peer_id(&self) -> &str {
532            "local"
533        }
534    }
535
536    #[derive(Default)]
537    struct RecordingFactory {
538        candidates: Mutex<Vec<(String, IceCandidate)>>,
539        removed: Mutex<Vec<String>>,
540    }
541
542    #[async_trait]
543    impl PeerLinkFactory for RecordingFactory {
544        async fn create_offer(
545            &self,
546            _target_peer_id: &str,
547        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
548            Err(TransportError::ConnectionFailed(
549                "not used in this test".to_string(),
550            ))
551        }
552
553        async fn accept_offer(
554            &self,
555            _from_peer_id: &str,
556            _offer_sdp: &str,
557        ) -> Result<(Arc<dyn PeerLink>, String), TransportError> {
558            Err(TransportError::ConnectionFailed(
559                "not used in this test".to_string(),
560            ))
561        }
562
563        async fn handle_answer(
564            &self,
565            _target_peer_id: &str,
566            _answer_sdp: &str,
567        ) -> Result<Arc<dyn PeerLink>, TransportError> {
568            Err(TransportError::ConnectionFailed(
569                "not used in this test".to_string(),
570            ))
571        }
572
573        async fn handle_candidate(
574            &self,
575            peer_id: &str,
576            candidate: IceCandidate,
577        ) -> Result<(), TransportError> {
578            self.candidates
579                .lock()
580                .await
581                .push((peer_id.to_string(), candidate));
582            Ok(())
583        }
584
585        async fn remove_peer(&self, peer_id: &str) -> Result<(), TransportError> {
586            self.removed.lock().await.push(peer_id.to_string());
587            Ok(())
588        }
589    }
590
591    #[tokio::test]
592    async fn routes_targeted_candidates_to_factory() {
593        let router = MeshRouter::new(
594            "local".to_string(),
595            Arc::new(NoopTransport),
596            Arc::new(RecordingFactory::default()),
597            PoolSettings {
598                follows: PoolConfig::default(),
599                other: PoolConfig::default(),
600            },
601            false,
602        );
603
604        router
605            .handle_message(SignalingMessage::Candidate {
606                peer_id: "remote:peer".to_string(),
607                target_peer_id: "local".to_string(),
608                candidate: "candidate:1".to_string(),
609                sdp_m_line_index: Some(0),
610                sdp_mid: Some("data".to_string()),
611            })
612            .await
613            .expect("candidate should route");
614
615        let factory = router.conn_factory.clone();
616        let recorded = factory
617            .candidates
618            .lock()
619            .await
620            .iter()
621            .map(|(peer_id, candidate)| (peer_id.clone(), candidate.candidate.clone()))
622            .collect::<Vec<_>>();
623
624        assert_eq!(
625            recorded,
626            vec![("remote:peer".to_string(), "candidate:1".to_string())]
627        );
628    }
629
630    #[tokio::test]
631    async fn remove_peer_cleans_factory_state() {
632        let factory = Arc::new(RecordingFactory::default());
633        let router = MeshRouter::new(
634            "local".to_string(),
635            Arc::new(NoopTransport),
636            factory.clone(),
637            PoolSettings {
638                follows: PoolConfig::default(),
639                other: PoolConfig::default(),
640            },
641            false,
642        );
643
644        let removed = router.remove_peer("remote:peer").await;
645        assert!(removed.is_none());
646        assert_eq!(
647            factory.removed.lock().await.as_slice(),
648            &["remote:peer".to_string()]
649        );
650    }
651}