1use bytes::Bytes;
2use dashmap::DashMap;
3use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
4use futures::StreamExt;
5
6use iroh::NodeId;
7use iroh_gossip::{
8 net::{Event, Gossip, GossipEvent, GossipReceiver, GossipSender},
9 proto::TopicId,
10};
11
12use serde::{Deserialize, Serialize};
13
14use std::sync::Arc;
15use thiserror::Error;
16
17use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
18use tokio::time::{Duration, Instant, sleep};
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct Node {
23 pub name: String,
24 pub node_id: NodeId,
25 pub count: u32,
26}
27
28#[derive(Debug, Clone)]
29pub struct NodeInfo {
30 pub node_id: NodeId,
31 pub last_seen: Instant,
32}
33
34#[derive(Debug, Clone, Deserialize, Serialize)]
35struct SignedMessage {
36 from: VerifyingKey,
37 data: Bytes,
38 signature: Signature,
39}
40
41impl SignedMessage {
42 pub fn sign_and_encode(secret_key: &SigningKey, node: &Node) -> Result<Bytes> {
43 let data: Bytes = postcard::to_stdvec(node)
44 .map_err(|e| GossipDiscoveryError::Serialization(e.to_string()))?
45 .into();
46 let signature = secret_key.sign(&data);
47 let from: VerifyingKey = secret_key.verifying_key();
48
49 let signed_message = Self {
50 from,
51 data,
52 signature,
53 };
54
55 let encoded = postcard::to_stdvec(&signed_message)
56 .map_err(|e| GossipDiscoveryError::Serialization(e.to_string()))?;
57 Ok(encoded.into())
58 }
59
60 pub fn verify_and_decode(bytes: &[u8]) -> Result<(VerifyingKey, Node)> {
61 let signed_message: Self = postcard::from_bytes(bytes)
62 .map_err(|e| GossipDiscoveryError::Deserialization(e.to_string()))?;
63 let key: VerifyingKey = signed_message.from;
64
65 key.verify(&signed_message.data, &signed_message.signature)
66 .map_err(|e| GossipDiscoveryError::SignatureVerification(e.to_string()))?;
67
68 let node: Node = postcard::from_bytes(&signed_message.data)
69 .map_err(|e| GossipDiscoveryError::Deserialization(e.to_string()))?;
70 Ok((signed_message.from, node))
71 }
72}
73
74#[derive(Error, Debug)]
75pub enum GossipDiscoveryError {
76 #[error("Gossip error: {0}")]
77 Gossip(#[from] iroh_gossip::net::Error),
78 #[error("Channel send error")]
79 ChannelSend,
80 #[error("Serialization error: {0}")]
81 Serialization(String),
82 #[error("Deserialization error: {0}")]
83 Deserialization(String),
84 #[error("Signature verification error: {0}")]
85 SignatureVerification(String),
86 #[error("NodeId mismatch: expected {expected}, got {actual}")]
87 NodeIdMismatch { expected: NodeId, actual: NodeId },
88}
89
90pub type Result<T> = std::result::Result<T, GossipDiscoveryError>;
91
92pub struct GossipDiscoveryBuilder {
93 expiration_timeout: Option<Duration>,
94}
95
96impl GossipDiscoveryBuilder {
97 pub fn new() -> Self {
98 Self {
99 expiration_timeout: None,
100 }
101 }
102
103 pub fn with_expiration_timeout(mut self, timeout: Duration) -> Self {
104 self.expiration_timeout = Some(timeout);
105 self
106 }
107
108 pub async fn build_with_peers(
109 self,
110 gossip: Gossip,
111 topic_id: TopicId,
112 peers: Vec<NodeId>,
113 endpoint: &iroh::Endpoint,
114 ) -> Result<(GossipDiscoverySender, GossipDiscoveryReceiver)> {
115 info!("Attempting to subscribe to gossip topic");
118 let (sender, receiver) = gossip.subscribe(topic_id, peers)?.split();
119 info!("Subscribed to gossip topic");
120
121 let (peer_tx, peer_rx) = tokio::sync::mpsc::unbounded_channel();
122 let neighbor_map = Arc::new(DashMap::new());
123
124 let node_secret = endpoint.secret_key();
127 let secret_key_bytes = node_secret.to_bytes();
128 let secret_key = SigningKey::from_bytes(&secret_key_bytes);
129 let discovery_sender = GossipDiscoverySender { peer_rx, sender, secret_key };
130
131 let expiration_timeout = self.expiration_timeout.unwrap_or(Duration::from_secs(30));
132
133 let discovery_receiver = GossipDiscoveryReceiver {
134 neighbor_map: Arc::clone(&neighbor_map),
135 peer_tx,
136 receiver,
137 expiration_timeout,
138 };
139
140 GossipDiscoveryReceiver::start_cleanup_task(neighbor_map, expiration_timeout);
142
143 Ok((discovery_sender, discovery_receiver))
144 }
145}
146
147pub struct GossipDiscoverySender {
148 pub peer_rx: UnboundedReceiver<NodeId>,
149 pub sender: GossipSender,
150 pub secret_key: SigningKey,
151}
152
153impl GossipDiscoverySender {
154 pub async fn add_peers(&mut self, peers: Vec<NodeId>) -> Result<()> {
156 if !peers.is_empty() {
157 info!(peer_count = peers.len(), "Adding external peers to gossip network");
158 self.sender.join_peers(peers).await?;
159 }
160 Ok(())
161 }
162
163 pub async fn add_peer(&mut self, peer: NodeId) -> Result<()> {
165 self.add_peers(vec![peer]).await
166 }
167
168 pub async fn gossip(&mut self, node: Node, update_rate: Duration) -> Result<()> {
169 let mut i = node.count;
170
171 loop {
172 match self.peer_rx.try_recv() {
174 Ok(peer) => {
175 info!(%peer, "Joining new peer");
176 if let Err(e) = self.sender.join_peers(vec![peer]).await {
177 error!(%e, "Failed to join peer");
178 }
179 }
180 Err(_) => {}
181 }
182
183 let update_node = Node {
184 name: node.name.clone(),
185 node_id: node.node_id,
186 count: i,
187 };
188
189 let bytes = SignedMessage::sign_and_encode(&self.secret_key, &update_node)?;
191
192 if let Err(e) = self.sender.broadcast(bytes).await {
193 error!(%e, "Failed to broadcast");
194 }
195
196 i += 1;
197 sleep(update_rate).await;
198 }
199 }
200}
201
202pub struct GossipDiscoveryReceiver {
203 pub neighbor_map: Arc<DashMap<String, NodeInfo>>,
204 pub peer_tx: UnboundedSender<NodeId>,
205 pub receiver: GossipReceiver,
206 pub expiration_timeout: Duration,
207}
208
209impl GossipDiscoveryReceiver {
210 pub async fn update_map(&mut self) -> Result<()> {
211 while let Some(res) = self.receiver.next().await {
212 match res {
213 Ok(Event::Gossip(GossipEvent::Received(msg))) => {
214 let (verifying_key, value) = match SignedMessage::verify_and_decode(&msg.content) {
216 Ok(result) => result,
217 Err(e) => {
218 warn!(%e, "Failed to verify message signature, ignoring");
219 continue;
220 }
221 };
222
223 let expected_node_id = NodeId::from(verifying_key);
225 if value.node_id != expected_node_id {
226 warn!(
227 claimed_node_id = %value.node_id,
228 actual_node_id = %expected_node_id,
229 "NodeId spoofing attempt detected, ignoring message"
230 );
231 continue;
232 }
233
234 let is_new_peer = !self.neighbor_map.contains_key(&value.name);
235
236 if is_new_peer {
237 self.peer_tx
239 .send(value.node_id)
240 .map_err(|_| GossipDiscoveryError::ChannelSend)?;
241 info!(name = %value.name, node_id = %value.node_id, "Discovered new peer");
242 }
243
244 self.neighbor_map.insert(
245 value.name.clone(),
246 NodeInfo {
247 node_id: value.node_id,
248 last_seen: Instant::now(),
249 },
250 );
251 debug!(peer_count = self.neighbor_map.len(), "Address book updated");
252 }
253 Ok(_) => {}
254 Err(e) => {
255 error!(%e, "Error receiving gossip");
256 }
257 }
258 }
259 Ok(())
260 }
261
262 pub fn get_neighbors(&self) -> Vec<(String, NodeId)> {
263 self.neighbor_map
264 .iter()
265 .map(|entry| (entry.key().clone(), entry.value().node_id))
266 .collect()
267 }
268
269 pub fn cleanup_expired_nodes(&self) -> usize {
270 let now = Instant::now();
271 let mut expired_count = 0;
272
273 let expired_nodes: Vec<String> = self
275 .neighbor_map
276 .iter()
277 .filter_map(|entry| {
278 if now.duration_since(entry.value().last_seen) > self.expiration_timeout {
279 Some(entry.key().clone())
280 } else {
281 None
282 }
283 })
284 .collect();
285
286 for node_name in expired_nodes {
288 if let Some((_, node_info)) = self.neighbor_map.remove(&node_name) {
289 info!(name = %node_name, node_id = %node_info.node_id, "Expired node");
290 expired_count += 1;
291 }
292 }
293
294 expired_count
295 }
296
297 pub fn start_cleanup_task(
298 neighbor_map: Arc<DashMap<String, NodeInfo>>,
299 expiration_timeout: Duration,
300 ) {
301 let cleanup_interval = expiration_timeout / 3; tokio::spawn(async move {
304 loop {
305 sleep(cleanup_interval).await;
306
307 let now = Instant::now();
308 let mut expired_count = 0;
309
310 let expired_nodes: Vec<String> = neighbor_map
312 .iter()
313 .filter_map(|entry| {
314 if now.duration_since(entry.value().last_seen) > expiration_timeout {
315 Some(entry.key().clone())
316 } else {
317 None
318 }
319 })
320 .collect();
321
322 for node_name in expired_nodes {
324 if let Some((_, node_info)) = neighbor_map.remove(&node_name) {
325 info!(name = %node_name, node_id = %node_info.node_id, "Expired node");
326 expired_count += 1;
327 }
328 }
329
330 if expired_count > 0 {
331 info!(count = expired_count, "Cleaned up expired nodes");
332 }
333 }
334 });
335 }
336}