Skip to main content

arc_malachitebft_discovery/
lib.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, error, info, warn};
4
5use malachitebft_metrics::Registry;
6
7use libp2p::core::SignedEnvelope;
8use libp2p::{identify, kad, request_response, swarm::ConnectionId, Multiaddr, PeerId, Swarm};
9
10mod behaviour;
11pub use behaviour::*;
12
13mod dial;
14use dial::DialData;
15
16pub mod config;
17pub use config::Config;
18
19mod controller;
20use controller::Controller;
21
22mod handlers;
23use handlers::selection::selector::Selector;
24
25mod metrics;
26use metrics::Metrics;
27
28mod rate_limiter;
29use rate_limiter::DiscoveryRateLimiter;
30
31mod request;
32
33pub mod util;
34
35#[derive(Debug, PartialEq)]
36enum State {
37    Bootstrapping,
38    Extending(usize), // Target number of peers
39    Idle,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ConnectionDirection {
44    /// Outbound connection (we dialed the peer)
45    Outbound,
46    /// Inbound connection (the peer dialed us)
47    Inbound,
48}
49
50impl ConnectionDirection {
51    pub fn as_str(&self) -> &'static str {
52        match self {
53            Self::Outbound => "outbound",
54            Self::Inbound => "inbound",
55        }
56    }
57}
58
59/// Information about an established connection
60#[derive(Debug, Clone)]
61pub struct ConnectionInfo {
62    pub direction: ConnectionDirection,
63    pub remote_addr: Multiaddr,
64}
65
66#[derive(Debug, PartialEq)]
67enum OutboundState {
68    Pending,
69    Confirmed,
70}
71
72#[derive(Debug)]
73pub struct Discovery<C>
74where
75    C: DiscoveryClient,
76{
77    config: Config,
78    state: State,
79
80    selector: Box<dyn Selector<C>>,
81
82    bootstrap_nodes: Vec<(Option<PeerId>, Vec<Multiaddr>)>,
83    discovered_peers: HashMap<PeerId, identify::Info>,
84    /// Signed peer records received from peers (cryptographically verified)
85    signed_peer_records: HashMap<PeerId, SignedEnvelope>,
86    active_connections: HashMap<PeerId, Vec<ConnectionId>>,
87    /// Track connection info (direction and remote address) per connection
88    pub connections: HashMap<ConnectionId, ConnectionInfo>,
89    outbound_peers: HashMap<PeerId, OutboundState>,
90    inbound_peers: HashSet<PeerId>,
91
92    /// Rate limiter for peers requests
93    rate_limiter: DiscoveryRateLimiter,
94
95    pub controller: Controller,
96    metrics: Metrics,
97}
98
99impl<C> Discovery<C>
100where
101    C: DiscoveryClient,
102{
103    pub fn new(config: Config, bootstrap_nodes: Vec<Multiaddr>, registry: &mut Registry) -> Self {
104        info!(
105            "Discovery is {}",
106            if config.enabled {
107                "enabled"
108            } else {
109                "disabled"
110            }
111        );
112
113        // Warn if discovery is enabled with persistent_peers_only
114        if config.enabled && config.persistent_peers_only {
115            warn!(
116                "Discovery is enabled with persistent_peers_only mode. \
117                 Discovered peers will be rejected unless they are in the persistent_peers list. \
118                 Consider disabling discovery for a pure persistent-peers-only setup."
119            );
120        }
121
122        let state = if config.enabled && bootstrap_nodes.is_empty() {
123            warn!("No bootstrap nodes provided");
124            info!("Discovery found 0 peers in 0ms");
125            State::Idle
126        } else if config.enabled {
127            match config.bootstrap_protocol {
128                config::BootstrapProtocol::Kademlia => {
129                    debug!("Using Kademlia bootstrap");
130
131                    State::Bootstrapping
132                }
133
134                config::BootstrapProtocol::Full => {
135                    debug!("Using full bootstrap");
136
137                    State::Extending(config.num_outbound_peers)
138                }
139            }
140        } else {
141            State::Idle
142        };
143
144        Self {
145            config,
146            state,
147
148            selector: Discovery::get_selector(
149                config.enabled,
150                config.bootstrap_protocol,
151                config.selector,
152            ),
153
154            bootstrap_nodes: bootstrap_nodes
155                .clone()
156                .into_iter()
157                .map(|addr| (None, vec![addr]))
158                .collect(),
159            discovered_peers: HashMap::new(),
160            signed_peer_records: HashMap::new(),
161            active_connections: HashMap::new(),
162            connections: HashMap::new(),
163            outbound_peers: HashMap::new(),
164            inbound_peers: HashSet::new(),
165
166            rate_limiter: DiscoveryRateLimiter::default(),
167
168            controller: Controller::new(),
169            metrics: Metrics::new(registry, !config.enabled || bootstrap_nodes.is_empty()),
170        }
171    }
172
173    pub fn is_enabled(&self) -> bool {
174        self.config.enabled
175    }
176
177    /// Check if a peer connection is outbound
178    pub fn is_outbound_peer(&self, peer_id: &PeerId) -> bool {
179        self.outbound_peers.contains_key(peer_id)
180    }
181
182    /// Check if a peer connection is inbound
183    pub fn is_inbound_peer(&self, peer_id: &PeerId) -> bool {
184        self.inbound_peers.contains(peer_id)
185    }
186
187    /// Check if a peer is a persistent peer (in the bootstrap_nodes list)
188    pub fn is_persistent_peer(&self, peer_id: &PeerId) -> bool {
189        // XXX: The assumption here is bootstrap_nodes is a list of persistent peers.
190        self.bootstrap_nodes
191            .iter()
192            .any(|(maybe_peer_id, _)| maybe_peer_id == &Some(*peer_id))
193    }
194
195    pub fn on_network_event(
196        &mut self,
197        swarm: &mut Swarm<C>,
198        network_event: behaviour::NetworkEvent,
199    ) {
200        match network_event {
201            behaviour::NetworkEvent::Kademlia(kad::Event::OutboundQueryProgressed {
202                result,
203                step,
204                ..
205            }) => match result {
206                kad::QueryResult::Bootstrap(Ok(_)) => {
207                    if step.last && self.state == State::Bootstrapping {
208                        debug!("Discovery bootstrap successful");
209
210                        self.handle_successful_bootstrap(swarm);
211                    }
212                }
213
214                kad::QueryResult::Bootstrap(Err(error)) => {
215                    error!("Discovery bootstrap failed: {error}");
216
217                    if self.state == State::Bootstrapping {
218                        self.handle_failed_bootstrap();
219                    }
220                }
221
222                _ => {}
223            },
224
225            behaviour::NetworkEvent::Kademlia(_) => {}
226
227            behaviour::NetworkEvent::RequestResponse(event) => {
228                match event {
229                    request_response::Event::Message {
230                        peer,
231                        connection_id,
232                        message:
233                            request_response::Message::Request {
234                                request, channel, ..
235                            },
236                    } => match request {
237                        behaviour::Request::Peers(signed_records) => {
238                            debug!(
239                                peer_id = %peer, %connection_id,
240                                count = signed_records.len(),
241                                "Received peers request"
242                            );
243
244                            self.handle_peers_request(swarm, peer, channel, signed_records);
245                        }
246
247                        behaviour::Request::Connect() => {
248                            debug!(peer_id = %peer, %connection_id, "Received connect request");
249
250                            self.handle_connect_request(swarm, channel, peer);
251                        }
252                    },
253
254                    request_response::Event::Message {
255                        peer,
256                        connection_id,
257                        message:
258                            request_response::Message::Response {
259                                response,
260                                request_id,
261                                ..
262                            },
263                    } => match response {
264                        behaviour::Response::Peers(signed_records) => {
265                            debug!(
266                                %peer, %connection_id,
267                                count = signed_records.len(),
268                                "Received peers response"
269                            );
270
271                            self.handle_peers_response(swarm, request_id, signed_records);
272                        }
273
274                        behaviour::Response::Connect(accepted) => {
275                            debug!(%peer, %connection_id, accepted, "Received connect response");
276
277                            self.handle_connect_response(swarm, request_id, peer, accepted);
278                        }
279                    },
280
281                    request_response::Event::OutboundFailure {
282                        peer,
283                        request_id,
284                        connection_id,
285                        error,
286                    } => {
287                        error!(%peer, %connection_id, "Outbound request to failed: {error}");
288
289                        if self.controller.peers_request.is_in_progress(&request_id) {
290                            self.handle_failed_peers_request(swarm, request_id);
291                        } else if self.controller.connect_request.is_in_progress(&request_id) {
292                            self.handle_failed_connect_request(swarm, request_id);
293                        } else {
294                            // This should not happen
295                            error!(%peer, %connection_id, "Unknown outbound request failure");
296                        }
297                    }
298
299                    _ => {}
300                }
301            }
302        }
303    }
304
305    /// Add a bootstrap node for persistent peer management
306    pub fn add_bootstrap_node(&mut self, addr: Multiaddr) {
307        // Check if this address already exists in bootstrap nodes
308        if self
309            .bootstrap_nodes
310            .iter()
311            .any(|(_, addrs)| addrs.contains(&addr))
312        {
313            info!("Bootstrap node already exists: {addr}");
314            return;
315        }
316
317        // Extract peer_id from multiaddr if present
318        let peer_id = addr.iter().find_map(|protocol| {
319            if let libp2p::multiaddr::Protocol::P2p(peer_id) = protocol {
320                Some(peer_id)
321            } else {
322                None
323            }
324        });
325
326        // Add to bootstrap_nodes list
327        self.bootstrap_nodes.push((peer_id, vec![addr]));
328
329        info!(
330            "Added bootstrap node, total: {}",
331            self.bootstrap_nodes.len()
332        );
333    }
334
335    /// Remove a bootstrap node for persistent peer management
336    pub fn remove_bootstrap_node(&mut self, addr: &Multiaddr) -> bool {
337        // Find matching bootstrap node by comparing addresses
338        let pos = self
339            .bootstrap_nodes
340            .iter()
341            .position(|(_, addrs)| addrs.iter().any(|a| a == addr));
342
343        if let Some(index) = pos {
344            self.bootstrap_nodes.remove(index);
345            info!(
346                "Removed bootstrap node, remaining: {}",
347                self.bootstrap_nodes.len()
348            );
349            true
350        } else {
351            warn!("Bootstrap node not found for removal: {}", addr);
352            false
353        }
354    }
355
356    /// Get the peer_id associated with a bootstrap node address.
357    ///
358    /// This is useful when the peer_id is discovered when we successfully connect, via the TLS/noise handshake
359    pub fn get_peer_id_for_addr(&self, addr: &Multiaddr) -> Option<PeerId> {
360        self.bootstrap_nodes
361            .iter()
362            .find(|(_, addrs)| addrs.iter().any(|a| a == addr))
363            .and_then(|(peer_id, _)| *peer_id)
364    }
365
366    /// Cancel any in-progress dial attempts for a given address and/or peer_id
367    ///
368    /// This is useful when removing a persistent peer to ensure we don't continue
369    /// trying to dial them after they've been removed.
370    pub fn cancel_dial_attempts(&mut self, addr: &Multiaddr, peer_id: Option<PeerId>) {
371        use controller::PeerData;
372
373        // Cancel dial attempts for the address
374        let addr_without_p2p = util::strip_peer_id_from_multiaddr(addr);
375        self.controller
376            .dial
377            .remove_done_on(&PeerData::Multiaddr(addr_without_p2p));
378
379        // Cancel dial attempts for the peer_id if present
380        if let Some(peer_id) = peer_id {
381            self.controller
382                .dial
383                .remove_done_on(&PeerData::PeerId(peer_id));
384        }
385    }
386}