1use crate::{
4 error::{ChaincraftError, Result},
5 network::{PeerId, PeerInfo},
6};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::net::SocketAddr;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::time::{Duration, Instant};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub enum DiscoveryMessage {
17 Announce {
19 node_id: PeerId,
20 socket_addr: SocketAddr,
21 timestamp: u64,
22 },
23 PeerRequest {
25 requester_id: PeerId,
26 max_peers: usize,
27 },
28 PeerResponse { peers: Vec<PeerAnnouncement> },
30 Ping { sender_id: PeerId, timestamp: u64 },
32 Pong {
34 responder_id: PeerId,
35 timestamp: u64,
36 },
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct PeerAnnouncement {
42 pub node_id: PeerId,
43 pub socket_addr: SocketAddr,
44 pub last_seen: u64,
45 pub announced_at: u64,
46}
47
48#[derive(Debug, Clone)]
50pub struct DiscoveryConfig {
51 pub max_peers: usize,
53 pub ping_interval: u64,
55 pub peer_timeout: u64,
57 pub announce_interval: u64,
59 pub enabled: bool,
61}
62
63impl Default for DiscoveryConfig {
64 fn default() -> Self {
65 Self {
66 max_peers: 50,
67 ping_interval: 30,
68 peer_timeout: 120,
69 announce_interval: 60,
70 enabled: true,
71 }
72 }
73}
74
75pub struct DiscoveryManager {
77 node_id: PeerId,
79 socket_addr: SocketAddr,
81 peers: Arc<RwLock<HashMap<PeerId, PeerAnnouncement>>>,
83 connected_peers: Arc<RwLock<HashSet<PeerId>>>,
85 config: DiscoveryConfig,
87 last_announce: Arc<RwLock<Option<Instant>>>,
89}
90
91impl DiscoveryManager {
92 pub fn new(node_id: PeerId, socket_addr: SocketAddr, config: DiscoveryConfig) -> Self {
94 Self {
95 node_id,
96 socket_addr,
97 peers: Arc::new(RwLock::new(HashMap::new())),
98 connected_peers: Arc::new(RwLock::new(HashSet::new())),
99 config,
100 last_announce: Arc::new(RwLock::new(None)),
101 }
102 }
103
104 pub async fn add_peer(&self, peer_info: PeerInfo) -> Result<()> {
106 let now = std::time::SystemTime::now()
107 .duration_since(std::time::UNIX_EPOCH)
108 .unwrap()
109 .as_secs();
110
111 let announcement = PeerAnnouncement {
112 node_id: peer_info.id.clone(),
113 socket_addr: peer_info.address,
114 last_seen: now,
115 announced_at: now,
116 };
117
118 let mut peers = self.peers.write().await;
119 peers.insert(peer_info.id, announcement);
120
121 if peers.len() > self.config.max_peers {
123 let mut peer_last_seen: Vec<(PeerId, u64)> = peers
124 .iter()
125 .map(|(id, ann)| (id.clone(), ann.last_seen))
126 .collect();
127
128 peer_last_seen.sort_by_key(|(_, last_seen)| *last_seen);
129
130 let oldest_peers: Vec<PeerId> = peer_last_seen
131 .into_iter()
132 .take(peers.len() - self.config.max_peers)
133 .map(|(id, _)| id)
134 .collect();
135
136 for peer_id in oldest_peers {
137 peers.remove(&peer_id);
138 }
139 }
140
141 Ok(())
142 }
143
144 pub async fn remove_peer(&self, peer_id: &PeerId) -> Result<()> {
146 let mut peers = self.peers.write().await;
147 peers.remove(peer_id);
148
149 let mut connected = self.connected_peers.write().await;
150 connected.remove(peer_id);
151
152 Ok(())
153 }
154
155 pub async fn mark_connected(&self, peer_id: &PeerId) -> Result<()> {
157 let mut connected = self.connected_peers.write().await;
158 connected.insert(peer_id.clone());
159
160 let now = std::time::SystemTime::now()
162 .duration_since(std::time::UNIX_EPOCH)
163 .unwrap()
164 .as_secs();
165
166 let mut peers = self.peers.write().await;
167 if let Some(peer) = peers.get_mut(peer_id) {
168 peer.last_seen = now;
169 }
170
171 Ok(())
172 }
173
174 pub async fn mark_disconnected(&self, peer_id: &PeerId) -> Result<()> {
176 let mut connected = self.connected_peers.write().await;
177 connected.remove(peer_id);
178 Ok(())
179 }
180
181 pub async fn get_peers(&self) -> Vec<PeerAnnouncement> {
183 let peers = self.peers.read().await;
184 peers.values().cloned().collect()
185 }
186
187 pub async fn get_connected_peers(&self) -> Vec<PeerId> {
189 let connected = self.connected_peers.read().await;
190 connected.iter().cloned().collect()
191 }
192
193 pub async fn get_peers_for_discovery(
195 &self,
196 requester_id: &PeerId,
197 max_peers: usize,
198 ) -> Vec<PeerAnnouncement> {
199 let peers = self.peers.read().await;
200 let connected = self.connected_peers.read().await;
201
202 peers
203 .values()
204 .filter(|peer| &peer.node_id != requester_id && !connected.contains(&peer.node_id))
205 .take(max_peers)
206 .cloned()
207 .collect()
208 }
209
210 pub async fn handle_message(
212 &self,
213 message: DiscoveryMessage,
214 sender_addr: SocketAddr,
215 ) -> Result<Option<DiscoveryMessage>> {
216 match message {
217 DiscoveryMessage::Announce {
218 node_id,
219 socket_addr,
220 timestamp: _,
221 } => {
222 let peer_info = PeerInfo::new(node_id, socket_addr);
224 self.add_peer(peer_info).await?;
225 Ok(None)
226 },
227
228 DiscoveryMessage::PeerRequest {
229 requester_id,
230 max_peers,
231 } => {
232 let peers = self.get_peers_for_discovery(&requester_id, max_peers).await;
234 Ok(Some(DiscoveryMessage::PeerResponse { peers }))
235 },
236
237 DiscoveryMessage::PeerResponse { peers } => {
238 for peer_announcement in peers {
240 let peer_info =
241 PeerInfo::new(peer_announcement.node_id, peer_announcement.socket_addr);
242 self.add_peer(peer_info).await?;
243 }
244 Ok(None)
245 },
246
247 DiscoveryMessage::Ping {
248 sender_id,
249 timestamp: _,
250 } => {
251 let now = std::time::SystemTime::now()
253 .duration_since(std::time::UNIX_EPOCH)
254 .unwrap()
255 .as_secs();
256 Ok(Some(DiscoveryMessage::Pong {
257 responder_id: self.node_id.clone(),
258 timestamp: now,
259 }))
260 },
261
262 DiscoveryMessage::Pong {
263 responder_id,
264 timestamp: _,
265 } => {
266 self.mark_connected(&responder_id).await?;
268 Ok(None)
269 },
270 }
271 }
272
273 pub fn create_announcement(&self) -> DiscoveryMessage {
275 let now = std::time::SystemTime::now()
276 .duration_since(std::time::UNIX_EPOCH)
277 .unwrap()
278 .as_secs();
279
280 DiscoveryMessage::Announce {
281 node_id: self.node_id.clone(),
282 socket_addr: self.socket_addr,
283 timestamp: now,
284 }
285 }
286
287 pub fn create_peer_request(&self, max_peers: usize) -> DiscoveryMessage {
289 DiscoveryMessage::PeerRequest {
290 requester_id: self.node_id.clone(),
291 max_peers,
292 }
293 }
294
295 pub fn create_ping(&self) -> DiscoveryMessage {
297 let now = std::time::SystemTime::now()
298 .duration_since(std::time::UNIX_EPOCH)
299 .unwrap()
300 .as_secs();
301
302 DiscoveryMessage::Ping {
303 sender_id: self.node_id.clone(),
304 timestamp: now,
305 }
306 }
307
308 pub async fn should_announce(&self) -> bool {
310 if !self.config.enabled {
311 return false;
312 }
313
314 let last_announce = self.last_announce.read().await;
315 match *last_announce {
316 None => true,
317 Some(last) => {
318 let elapsed = last.elapsed();
319 elapsed >= Duration::from_secs(self.config.announce_interval)
320 },
321 }
322 }
323
324 pub async fn update_last_announce(&self) {
326 let mut last_announce = self.last_announce.write().await;
327 *last_announce = Some(Instant::now());
328 }
329
330 pub async fn cleanup_old_peers(&self) -> Result<()> {
332 let now = std::time::SystemTime::now()
333 .duration_since(std::time::UNIX_EPOCH)
334 .unwrap()
335 .as_secs();
336
337 let mut peers = self.peers.write().await;
338 let mut connected = self.connected_peers.write().await;
339
340 let timeout_threshold = now - self.config.peer_timeout;
341 let old_peers: Vec<PeerId> = peers
342 .iter()
343 .filter(|(_, peer)| peer.last_seen < timeout_threshold)
344 .map(|(id, _)| id.clone())
345 .collect();
346
347 for peer_id in old_peers {
348 peers.remove(&peer_id);
349 connected.remove(&peer_id);
350 }
351
352 Ok(())
353 }
354
355 pub async fn get_stats(&self) -> DiscoveryStats {
357 let peers = self.peers.read().await;
358 let connected = self.connected_peers.read().await;
359
360 DiscoveryStats {
361 total_known_peers: peers.len(),
362 connected_peers: connected.len(),
363 max_peers: self.config.max_peers,
364 }
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct DiscoveryStats {
371 pub total_known_peers: usize,
372 pub connected_peers: usize,
373 pub max_peers: usize,
374}
375
376trait SortedByKey<T> {
378 fn sorted_by_key<K, F>(self, f: F) -> Vec<T>
379 where
380 F: FnMut(&T) -> K,
381 K: Ord;
382}
383
384impl<T> SortedByKey<T> for Vec<T> {
385 fn sorted_by_key<K, F>(mut self, f: F) -> Vec<T>
386 where
387 F: FnMut(&T) -> K,
388 K: Ord,
389 {
390 self.sort_by_key(f);
391 self
392 }
393}