1mod iface;
22mod socket;
23mod timer;
24
25use std::{
26 cmp,
27 collections::{
28 hash_map::{Entry, HashMap},
29 VecDeque,
30 },
31 convert::Infallible,
32 fmt,
33 future::Future,
34 io,
35 net::IpAddr,
36 pin::Pin,
37 sync::{Arc, RwLock},
38 task::{Context, Poll},
39 time::Instant,
40};
41
42use futures::{channel::mpsc, Stream, StreamExt};
43use if_watch::IfEvent;
44use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
45use libp2p_identity::PeerId;
46use libp2p_swarm::{
47 behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour,
48 THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
49};
50use smallvec::SmallVec;
51
52use self::iface::InterfaceState;
53use crate::{
54 behaviour::{socket::AsyncSocket, timer::Builder},
55 Config,
56};
57
58pub trait Provider: 'static {
60 type Socket: AsyncSocket;
62 type Timer: Builder + Stream;
64 type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
66
67 type TaskHandle: Abort;
68
69 fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
71
72 #[track_caller]
73 fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
74}
75
76#[allow(unreachable_pub)] pub trait Abort {
78 fn abort(self);
79}
80
81#[cfg(feature = "tokio")]
83pub mod tokio {
84 use std::future::Future;
85
86 use if_watch::tokio::IfWatcher;
87 use tokio::task::JoinHandle;
88
89 use super::Provider;
90 use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
91
92 #[doc(hidden)]
93 pub enum Tokio {}
94
95 impl Provider for Tokio {
96 type Socket = TokioUdpSocket;
97 type Timer = TokioTimer;
98 type Watcher = IfWatcher;
99 type TaskHandle = JoinHandle<()>;
100
101 fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
102 IfWatcher::new()
103 }
104
105 fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
106 tokio::spawn(task)
107 }
108 }
109
110 impl Abort for JoinHandle<()> {
111 fn abort(self) {
112 JoinHandle::abort(&self)
113 }
114 }
115
116 pub type Behaviour = super::Behaviour<Tokio>;
117}
118
119#[derive(Debug)]
122pub struct Behaviour<P>
123where
124 P: Provider,
125{
126 config: Config,
128
129 if_watch: P::Watcher,
131
132 if_tasks: HashMap<IpAddr, P::TaskHandle>,
134
135 query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
136 query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
137
138 discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
143
144 closest_expiration: Option<P::Timer>,
148
149 listen_addresses: Arc<RwLock<ListenAddresses>>,
155
156 local_peer_id: PeerId,
157
158 pending_events: VecDeque<ToSwarm<Event, Infallible>>,
160}
161
162impl<P> Behaviour<P>
163where
164 P: Provider,
165{
166 pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
168 let (tx, rx) = mpsc::channel(10); Ok(Self {
171 config,
172 if_watch: P::new_watcher()?,
173 if_tasks: Default::default(),
174 query_response_receiver: rx,
175 query_response_sender: tx,
176 discovered_nodes: Default::default(),
177 closest_expiration: Default::default(),
178 listen_addresses: Default::default(),
179 local_peer_id,
180 pending_events: Default::default(),
181 })
182 }
183
184 #[deprecated(note = "Use `discovered_nodes` iterator instead.")]
186 pub fn has_node(&self, peer_id: &PeerId) -> bool {
187 self.discovered_nodes().any(|p| p == peer_id)
188 }
189
190 pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
192 self.discovered_nodes.iter().map(|(p, _, _)| p)
193 }
194
195 #[deprecated(note = "Unused API. Will be removed in the next release.")]
197 pub fn expire_node(&mut self, peer_id: &PeerId) {
198 let now = Instant::now();
199 for (peer, _addr, expires) in &mut self.discovered_nodes {
200 if peer == peer_id {
201 *expires = now;
202 }
203 }
204 self.closest_expiration = Some(P::Timer::at(now));
205 }
206}
207
208impl<P> NetworkBehaviour for Behaviour<P>
209where
210 P: Provider,
211{
212 type ConnectionHandler = dummy::ConnectionHandler;
213 type ToSwarm = Event;
214
215 fn handle_established_inbound_connection(
216 &mut self,
217 _: ConnectionId,
218 _: PeerId,
219 _: &Multiaddr,
220 _: &Multiaddr,
221 ) -> Result<THandler<Self>, ConnectionDenied> {
222 Ok(dummy::ConnectionHandler)
223 }
224
225 fn handle_pending_outbound_connection(
226 &mut self,
227 _connection_id: ConnectionId,
228 maybe_peer: Option<PeerId>,
229 _addresses: &[Multiaddr],
230 _effective_role: Endpoint,
231 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
232 let Some(peer_id) = maybe_peer else {
233 return Ok(vec![]);
234 };
235
236 Ok(self
237 .discovered_nodes
238 .iter()
239 .filter(|(peer, _, _)| peer == &peer_id)
240 .map(|(_, addr, _)| addr.clone())
241 .collect())
242 }
243
244 fn handle_established_outbound_connection(
245 &mut self,
246 _: ConnectionId,
247 _: PeerId,
248 _: &Multiaddr,
249 _: Endpoint,
250 _: PortUse,
251 ) -> Result<THandler<Self>, ConnectionDenied> {
252 Ok(dummy::ConnectionHandler)
253 }
254
255 fn on_connection_handler_event(
256 &mut self,
257 _: PeerId,
258 _: ConnectionId,
259 ev: THandlerOutEvent<Self>,
260 ) {
261 libp2p_core::util::unreachable(ev)
262 }
263
264 fn on_swarm_event(&mut self, event: FromSwarm) {
265 self.listen_addresses
266 .write()
267 .unwrap_or_else(|e| e.into_inner())
268 .on_swarm_event(&event);
269 }
270
271 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
272 fn poll(
273 &mut self,
274 cx: &mut Context<'_>,
275 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
276 loop {
277 if let Some(event) = self.pending_events.pop_front() {
279 return Poll::Ready(event);
280 }
281
282 while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
284 match event {
285 Ok(IfEvent::Up(inet)) => {
286 let addr = inet.addr();
287 if addr.is_loopback() {
288 continue;
289 }
290 if addr.is_ipv4() && self.config.enable_ipv6
291 || addr.is_ipv6() && !self.config.enable_ipv6
292 {
293 continue;
294 }
295 if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
296 match InterfaceState::<P::Socket, P::Timer>::new(
297 addr,
298 self.config.clone(),
299 self.local_peer_id,
300 self.listen_addresses.clone(),
301 self.query_response_sender.clone(),
302 ) {
303 Ok(iface_state) => {
304 e.insert(P::spawn(iface_state));
305 }
306 Err(err) => {
307 tracing::error!("failed to create `InterfaceState`: {}", err)
308 }
309 }
310 }
311 }
312 Ok(IfEvent::Down(inet)) => {
313 if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
314 tracing::info!(instance=%inet.addr(), "dropping instance");
315
316 handle.abort();
317 }
318 }
319 Err(err) => tracing::error!("if watch returned an error: {}", err),
320 }
321 }
322 let mut discovered = Vec::new();
324
325 while let Poll::Ready(Some((peer, addr, expiration))) =
326 self.query_response_receiver.poll_next_unpin(cx)
327 {
328 if let Some((_, _, cur_expires)) = self
329 .discovered_nodes
330 .iter_mut()
331 .find(|(p, a, _)| *p == peer && *a == addr)
332 {
333 *cur_expires = cmp::max(*cur_expires, expiration);
334 } else {
335 tracing::info!(%peer, address=%addr, "discovered peer on address");
336 self.discovered_nodes.push((peer, addr.clone(), expiration));
337 discovered.push((peer, addr.clone()));
338
339 self.pending_events
340 .push_back(ToSwarm::NewExternalAddrOfPeer {
341 peer_id: peer,
342 address: addr,
343 });
344 }
345 }
346
347 if !discovered.is_empty() {
348 let event = Event::Discovered(discovered);
349 self.pending_events
352 .push_front(ToSwarm::GenerateEvent(event));
353 continue;
354 }
355 let now = Instant::now();
357 let mut closest_expiration = None;
358 let mut expired = Vec::new();
359 self.discovered_nodes.retain(|(peer, addr, expiration)| {
360 if *expiration <= now {
361 tracing::info!(%peer, address=%addr, "expired peer on address");
362 expired.push((*peer, addr.clone()));
363 return false;
364 }
365 closest_expiration =
366 Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
367 true
368 });
369 if !expired.is_empty() {
370 let event = Event::Expired(expired);
371 self.pending_events.push_back(ToSwarm::GenerateEvent(event));
372 continue;
373 }
374 if let Some(closest_expiration) = closest_expiration {
375 let mut timer = P::Timer::at(closest_expiration);
376 let _ = Pin::new(&mut timer).poll_next(cx);
377
378 self.closest_expiration = Some(timer);
379 }
380
381 return Poll::Pending;
382 }
383 }
384}
385
386#[derive(Debug, Clone)]
388pub enum Event {
389 Discovered(Vec<(PeerId, Multiaddr)>),
391
392 Expired(Vec<(PeerId, Multiaddr)>),
397}