iroh_gossip_discovery/
lib.rs

1use bytes::Bytes;
2use dashmap::DashMap;
3use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
4use futures::StreamExt;
5
6use iroh::NodeId;
7use iroh_gossip::{
8    net::{Event, Gossip, GossipEvent, GossipReceiver, GossipSender},
9    proto::TopicId,
10};
11
12use serde::{Deserialize, Serialize};
13
14use std::sync::Arc;
15use thiserror::Error;
16
17use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
18use tokio::time::{Duration, Instant, sleep};
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct Node {
23    pub name: String,
24    pub node_id: NodeId,
25    pub count: u32,
26}
27
28#[derive(Debug, Clone)]
29pub struct NodeInfo {
30    pub node_id: NodeId,
31    pub last_seen: Instant,
32}
33
34#[derive(Debug, Clone, Deserialize, Serialize)]
35struct SignedMessage {
36    from: VerifyingKey,
37    data: Bytes,
38    signature: Signature,
39}
40
41impl SignedMessage {
42    pub fn sign_and_encode(secret_key: &SigningKey, node: &Node) -> Result<Bytes> {
43        let data: Bytes = postcard::to_stdvec(node)
44            .map_err(|e| GossipDiscoveryError::Serialization(e.to_string()))?
45            .into();
46        let signature = secret_key.sign(&data);
47        let from: VerifyingKey = secret_key.verifying_key();
48        
49        let signed_message = Self {
50            from,
51            data,
52            signature,
53        };
54        
55        let encoded = postcard::to_stdvec(&signed_message)
56            .map_err(|e| GossipDiscoveryError::Serialization(e.to_string()))?;
57        Ok(encoded.into())
58    }
59    
60    pub fn verify_and_decode(bytes: &[u8]) -> Result<(VerifyingKey, Node)> {
61        let signed_message: Self = postcard::from_bytes(bytes)
62            .map_err(|e| GossipDiscoveryError::Deserialization(e.to_string()))?;
63        let key: VerifyingKey = signed_message.from;
64        
65        key.verify(&signed_message.data, &signed_message.signature)
66            .map_err(|e| GossipDiscoveryError::SignatureVerification(e.to_string()))?;
67        
68        let node: Node = postcard::from_bytes(&signed_message.data)
69            .map_err(|e| GossipDiscoveryError::Deserialization(e.to_string()))?;
70        Ok((signed_message.from, node))
71    }
72}
73
74#[derive(Error, Debug)]
75pub enum GossipDiscoveryError {
76    #[error("Gossip error: {0}")]
77    Gossip(#[from] iroh_gossip::net::Error),
78    #[error("Channel send error")]
79    ChannelSend,
80    #[error("Serialization error: {0}")]
81    Serialization(String),
82    #[error("Deserialization error: {0}")]
83    Deserialization(String),
84    #[error("Signature verification error: {0}")]
85    SignatureVerification(String),
86    #[error("NodeId mismatch: expected {expected}, got {actual}")]
87    NodeIdMismatch { expected: NodeId, actual: NodeId },
88}
89
90pub type Result<T> = std::result::Result<T, GossipDiscoveryError>;
91
92pub struct GossipDiscoveryBuilder {
93    expiration_timeout: Option<Duration>,
94}
95
96impl GossipDiscoveryBuilder {
97    pub fn new() -> Self {
98        Self {
99            expiration_timeout: None,
100        }
101    }
102
103    pub fn with_expiration_timeout(mut self, timeout: Duration) -> Self {
104        self.expiration_timeout = Some(timeout);
105        self
106    }
107
108    pub async fn build_with_peers(
109        self,
110        gossip: Gossip,
111        topic_id: TopicId,
112        peers: Vec<NodeId>,
113        endpoint: &iroh::Endpoint,
114    ) -> Result<(GossipDiscoverySender, GossipDiscoveryReceiver)> {
115        // - First node (empty peers): use subscribe() only  
116        // - Other nodes (with peers): use subscribe_and_join()
117        info!("Attempting to subscribe to gossip topic");
118        let (sender, receiver) = gossip.subscribe(topic_id, peers)?.split();
119        info!("Subscribed to gossip topic");
120
121        let (peer_tx, peer_rx) = tokio::sync::mpsc::unbounded_channel();
122        let neighbor_map = Arc::new(DashMap::new());
123
124        // Derive a secret key from the endpoint's node secret key
125        // This ensures the signing key corresponds to the node's identity
126        let node_secret = endpoint.secret_key();
127        let secret_key_bytes = node_secret.to_bytes();
128        let secret_key = SigningKey::from_bytes(&secret_key_bytes);
129        let discovery_sender = GossipDiscoverySender { peer_rx, sender, secret_key };
130
131        let expiration_timeout = self.expiration_timeout.unwrap_or(Duration::from_secs(30));
132
133        let discovery_receiver = GossipDiscoveryReceiver {
134            neighbor_map: Arc::clone(&neighbor_map),
135            peer_tx,
136            receiver,
137            expiration_timeout,
138        };
139
140        // Start the cleanup task
141        GossipDiscoveryReceiver::start_cleanup_task(neighbor_map, expiration_timeout);
142
143        Ok((discovery_sender, discovery_receiver))
144    }
145}
146
147pub struct GossipDiscoverySender {
148    pub peer_rx: UnboundedReceiver<NodeId>,
149    pub sender: GossipSender,
150    pub secret_key: SigningKey,
151}
152
153impl GossipDiscoverySender {
154    /// Add external peers to the gossip network
155    pub async fn add_peers(&mut self, peers: Vec<NodeId>) -> Result<()> {
156        if !peers.is_empty() {
157            info!(peer_count = peers.len(), "Adding external peers to gossip network");
158            self.sender.join_peers(peers).await?;
159        }
160        Ok(())
161    }
162
163    /// Add a single external peer to the gossip network  
164    pub async fn add_peer(&mut self, peer: NodeId) -> Result<()> {
165        self.add_peers(vec![peer]).await
166    }
167
168    pub async fn gossip(&mut self, node: Node, update_rate: Duration) -> Result<()> {
169        let mut i = node.count;
170
171        loop {
172            // Check for new peers to join
173            match self.peer_rx.try_recv() {
174                Ok(peer) => {
175                    info!(%peer, "Joining new peer");
176                    if let Err(e) = self.sender.join_peers(vec![peer]).await {
177                        error!(%e, "Failed to join peer");
178                    }
179                }
180                Err(_) => {}
181            }
182
183            let update_node = Node {
184                name: node.name.clone(),
185                node_id: node.node_id,
186                count: i,
187            };
188
189            // Sign and encode the message
190            let bytes = SignedMessage::sign_and_encode(&self.secret_key, &update_node)?;
191
192            if let Err(e) = self.sender.broadcast(bytes).await {
193                error!(%e, "Failed to broadcast");
194            }
195
196            i += 1;
197            sleep(update_rate).await;
198        }
199    }
200}
201
202pub struct GossipDiscoveryReceiver {
203    pub neighbor_map: Arc<DashMap<String, NodeInfo>>,
204    pub peer_tx: UnboundedSender<NodeId>,
205    pub receiver: GossipReceiver,
206    pub expiration_timeout: Duration,
207}
208
209impl GossipDiscoveryReceiver {
210    pub async fn update_map(&mut self) -> Result<()> {
211        while let Some(res) = self.receiver.next().await {
212            match res {
213                Ok(Event::Gossip(GossipEvent::Received(msg))) => {
214                    // Verify and decode the signed message
215                    let (verifying_key, value) = match SignedMessage::verify_and_decode(&msg.content) {
216                        Ok(result) => result,
217                        Err(e) => {
218                            warn!(%e, "Failed to verify message signature, ignoring");
219                            continue;
220                        }
221                    };
222
223                    // Verify that the claimed node_id matches the public key
224                    let expected_node_id = NodeId::from(verifying_key);
225                    if value.node_id != expected_node_id {
226                        warn!(
227                            claimed_node_id = %value.node_id,
228                            actual_node_id = %expected_node_id,
229                            "NodeId spoofing attempt detected, ignoring message"
230                        );
231                        continue;
232                    }
233
234                    let is_new_peer = !self.neighbor_map.contains_key(&value.name);
235
236                    if is_new_peer {
237                        // Send new peer to sender for joining
238                        self.peer_tx
239                            .send(value.node_id)
240                            .map_err(|_| GossipDiscoveryError::ChannelSend)?;
241                        info!(name = %value.name, node_id = %value.node_id, "Discovered new peer");
242                    }
243
244                    self.neighbor_map.insert(
245                        value.name.clone(),
246                        NodeInfo {
247                            node_id: value.node_id,
248                            last_seen: Instant::now(),
249                        },
250                    );
251                    debug!(peer_count = self.neighbor_map.len(), "Address book updated");
252                }
253                Ok(_) => {}
254                Err(e) => {
255                    error!(%e, "Error receiving gossip");
256                }
257            }
258        }
259        Ok(())
260    }
261
262    pub fn get_neighbors(&self) -> Vec<(String, NodeId)> {
263        self.neighbor_map
264            .iter()
265            .map(|entry| (entry.key().clone(), entry.value().node_id))
266            .collect()
267    }
268
269    pub fn cleanup_expired_nodes(&self) -> usize {
270        let now = Instant::now();
271        let mut expired_count = 0;
272
273        // Collect expired node names first to avoid holding locks
274        let expired_nodes: Vec<String> = self
275            .neighbor_map
276            .iter()
277            .filter_map(|entry| {
278                if now.duration_since(entry.value().last_seen) > self.expiration_timeout {
279                    Some(entry.key().clone())
280                } else {
281                    None
282                }
283            })
284            .collect();
285
286        // Remove expired nodes
287        for node_name in expired_nodes {
288            if let Some((_, node_info)) = self.neighbor_map.remove(&node_name) {
289                info!(name = %node_name, node_id = %node_info.node_id, "Expired node");
290                expired_count += 1;
291            }
292        }
293
294        expired_count
295    }
296
297    pub fn start_cleanup_task(
298        neighbor_map: Arc<DashMap<String, NodeInfo>>,
299        expiration_timeout: Duration,
300    ) {
301        let cleanup_interval = expiration_timeout / 3; // Check every 1/3 of timeout period
302
303        tokio::spawn(async move {
304            loop {
305                sleep(cleanup_interval).await;
306
307                let now = Instant::now();
308                let mut expired_count = 0;
309
310                // Collect expired node names first to avoid holding locks
311                let expired_nodes: Vec<String> = neighbor_map
312                    .iter()
313                    .filter_map(|entry| {
314                        if now.duration_since(entry.value().last_seen) > expiration_timeout {
315                            Some(entry.key().clone())
316                        } else {
317                            None
318                        }
319                    })
320                    .collect();
321
322                // Remove expired nodes
323                for node_name in expired_nodes {
324                    if let Some((_, node_info)) = neighbor_map.remove(&node_name) {
325                        info!(name = %node_name, node_id = %node_info.node_id, "Expired node");
326                        expired_count += 1;
327                    }
328                }
329
330                if expired_count > 0 {
331                    info!(count = expired_count, "Cleaned up expired nodes");
332                }
333            }
334        });
335    }
336}