libp2p_nat/
lib.rs

1mod task;
2mod utils;
3
4use core::task::{Context, Poll};
5use futures::{FutureExt, StreamExt};
6use futures_timer::Delay;
7use libp2p::core::transport::ListenerId;
8use libp2p::core::{Endpoint, Multiaddr};
9use libp2p::swarm::{
10    self, dummy::ConnectionHandler as DummyConnectionHandler, NetworkBehaviour,
11};
12use libp2p::swarm::{
13    ConnectionDenied, ConnectionId, ExpiredListenAddr, NewListenAddr, THandler, THandlerInEvent,
14    ToSwarm,
15};
16use libp2p::PeerId;
17use std::collections::hash_map::Entry;
18use std::task::Waker;
19use std::time::Duration;
20use task::{ForwardingError, NatCommands, NatResult, NatType};
21
22use std::collections::{HashMap, VecDeque};
23
24#[cfg(not(any(feature = "tokio", feature = "async-std")))]
25compile_error!("Require tokio or async-std feature to be enabled");
26
27#[derive(Debug)]
28struct LocalListener {
29    pub addrs: HashMap<Multiaddr, Option<NatType>>,
30    pub external_addrs: Vec<Multiaddr>,
31    pub renewal: Option<Delay>,
32}
33
34#[allow(clippy::type_complexity)]
35pub struct Behaviour {
36    events: VecDeque<swarm::ToSwarm<<Self as NetworkBehaviour>::ToSwarm, THandlerInEvent<Self>>>,
37    nat_sender: futures::channel::mpsc::UnboundedSender<NatCommands>,
38    event_receiver: futures::channel::mpsc::Receiver<Result<NatResult, ForwardingError>>,
39    local_listeners: HashMap<ListenerId, LocalListener>,
40    disabled: bool,
41    waker: Option<Waker>,
42}
43
44impl Default for Behaviour {
45    fn default() -> Self {
46        Self::with_duration(Duration::from_secs(2 * 60))
47    }
48}
49
50impl Behaviour {
51    pub fn with_duration(duration: Duration) -> Self {
52        assert!(duration.as_secs() > 10);
53        let renewal = duration / 2;
54
55        let (nat_sender, result_rx) = task::port_forwarding_task(duration, renewal);
56        Self {
57            events: Default::default(),
58            nat_sender,
59            event_receiver: result_rx,
60            local_listeners: Default::default(),
61            disabled: false,
62            waker: None,
63        }
64    }
65
66    /// Enables port forwarding
67    pub fn enable(&mut self) {
68        self.disabled = false;
69        for local_listener in self.local_listeners.values_mut() {
70            local_listener.renewal = Some(Delay::new(Duration::from_secs(10)));
71        }
72        if let Some(waker) = self.waker.take() {
73            waker.wake();
74        }
75    }
76
77    /// Disable port forwarding
78    /// Note: This does not remove the current lease but instead will not allow them to be renewed
79    pub fn disable(&mut self) {
80        if self.disabled {
81            return;
82        }
83
84        self.disabled = true;
85
86        // No need to continue if there are no external addresses
87        if self.external_addr().is_empty() {
88            return;
89        }
90
91        for (id, listener) in &mut self.local_listeners {
92            for (addr, nat_type) in &listener.addrs {
93                let Some(nat_type) = nat_type else {
94                    continue;
95                };
96
97                let _ = self
98                    .nat_sender
99                    .clone()
100                    .unbounded_send(NatCommands::DisableForwardPort(
101                        *id,
102                        addr.clone(),
103                        *nat_type,
104                    ));
105            }
106
107            // Notify swarm about the external addresses expiring
108            // Regardless of if we successfully disable port forwarding in the background task
109            for addr in listener.external_addrs.drain(..) {
110                self.events.push_back(ToSwarm::ExternalAddrExpired(addr));
111            }
112
113            listener.renewal = None;
114        }
115
116        if let Some(waker) = self.waker.take() {
117            waker.wake();
118        }
119    }
120
121    /// Gets external addresses
122    pub fn external_addr(&self) -> Vec<Multiaddr> {
123        self.local_listeners
124            .values()
125            .flat_map(|local| local.addrs.keys().cloned().collect::<Vec<_>>())
126            .collect::<Vec<_>>()
127    }
128}
129
130impl NetworkBehaviour for Behaviour {
131    type ConnectionHandler = DummyConnectionHandler;
132    type ToSwarm = void::Void;
133
134    fn handle_established_inbound_connection(
135        &mut self,
136        _: ConnectionId,
137        _: PeerId,
138        _: &Multiaddr,
139        _: &Multiaddr,
140    ) -> Result<THandler<Self>, ConnectionDenied> {
141        Ok(DummyConnectionHandler)
142    }
143
144    fn handle_established_outbound_connection(
145        &mut self,
146        _: ConnectionId,
147        _: PeerId,
148        _: &Multiaddr,
149        _: Endpoint,
150    ) -> Result<THandler<Self>, ConnectionDenied> {
151        Ok(DummyConnectionHandler)
152    }
153
154    fn on_connection_handler_event(
155        &mut self,
156        _: libp2p::PeerId,
157        _: swarm::ConnectionId,
158        _: swarm::THandlerOutEvent<Self>,
159    ) {
160    }
161
162    fn on_swarm_event(&mut self, event: swarm::FromSwarm) {
163        match event {
164            swarm::FromSwarm::NewListenAddr(NewListenAddr { listener_id, addr }) => {
165                // Used to make sure we only obtain private ips
166                if utils::multiaddr_to_socket_port(addr).is_none() {
167                    return;
168                }
169
170                match self.local_listeners.entry(listener_id) {
171                    Entry::Occupied(mut entry) => {
172                        let listener = entry.get_mut();
173                        if !listener.addrs.contains_key(addr) {
174                            listener.addrs.insert(addr.clone(), None);
175                        }
176                    }
177                    Entry::Vacant(entry) => {
178                        entry.insert(LocalListener {
179                            addrs: HashMap::from_iter([(addr.clone(), None)]),
180                            external_addrs: vec![],
181                            renewal: None,
182                        });
183                    }
184                };
185
186                let _ = self
187                    .nat_sender
188                    .clone()
189                    .unbounded_send(NatCommands::ForwardPort(listener_id, addr.clone()));
190            }
191            swarm::FromSwarm::NewExternalAddrCandidate(_) => {}
192            swarm::FromSwarm::ExternalAddrExpired(_) => {}
193            swarm::FromSwarm::ExpiredListenAddr(ExpiredListenAddr { listener_id, addr }) => {
194                if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id) {
195                    let listener = entry.get_mut();
196
197                    let list = &mut listener.addrs;
198
199                    if !list.contains_key(addr) {
200                        return;
201                    }
202
203                    let nat_type = list.remove(addr).flatten();
204
205                    if let Some(nat_type) = nat_type {
206                        let _ = self.nat_sender.clone().unbounded_send(
207                            NatCommands::DisableForwardPort(listener_id, addr.clone(), nat_type),
208                        );
209                    }
210
211                    if list.is_empty() {
212                        entry.remove();
213                    }
214                }
215            }
216            _ => {}
217        }
218    }
219
220    fn poll(
221        &mut self,
222        cx: &mut Context,
223    ) -> Poll<swarm::ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
224        if let Some(event) = self.events.pop_front() {
225            return Poll::Ready(event);
226        }
227
228        for (id, local_listener) in &mut self.local_listeners {
229            if let Some(renewal) = local_listener.renewal.as_mut() {
230                if let Poll::Ready(()) = renewal.poll_unpin(cx) {
231                    for addr in local_listener.addrs.keys() {
232                        let _ = self
233                            .nat_sender
234                            .clone()
235                            .unbounded_send(NatCommands::ForwardPort(*id, addr.clone()));
236                    }
237                    renewal.reset(Duration::from_secs(30));
238                }
239            }
240        }
241
242        loop {
243            match self.event_receiver.poll_next_unpin(cx) {
244                Poll::Ready(Some(result)) => match result {
245                    Ok(NatResult::PortForwardingEnabled {
246                        listener_id,
247                        local_addr,
248                        addr,
249                        nat_type,
250                        timer,
251                    }) => {
252                        if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
253                        {
254                            let listener = entry.get_mut();
255
256                            listener
257                                .addrs
258                                .entry(local_addr)
259                                .and_modify(|nty| *nty = Some(nat_type));
260
261                            if !listener.external_addrs.contains(&addr) {
262                                log::info!("Discovered {addr} as an external address.");
263                                listener.external_addrs.push(addr.clone());
264                                self.events
265                                    .push_back(ToSwarm::ExternalAddrConfirmed(addr.clone()));
266                            }
267                            listener.renewal = Some(timer);
268                        }
269                    }
270                    Ok(NatResult::PortForwardingDisabled { listener_id }) => {
271                        if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
272                        {
273                            let listener = entry.get_mut();
274
275                            listener.addrs.values_mut().for_each(|nty| *nty = None);
276
277                            if !listener.external_addrs.is_empty() {
278                                for addr in listener.external_addrs.drain(..) {
279                                    self.events
280                                        .push_back(ToSwarm::ExternalAddrExpired(addr.clone()));
281                                }
282                            }
283                        }
284                    }
285                    Err(ForwardingError::InvalidAddress {
286                        listener_id,
287                        address,
288                    }) => {
289                        if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
290                        {
291                            let listener = entry.get_mut();
292                            log::debug!("Removing {address} from local listeners");
293                            listener
294                                .addrs
295                                .retain(|local_addr, _| local_addr != &address);
296                            listener.renewal = Some(Delay::new(Duration::from_secs(30)));
297                        }
298                    }
299                    Err(ForwardingError::PortForwardingFailed { listener_id }) => {
300                        if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
301                        {
302                            log::error!("Failed performing port forwarding");
303                            let listener = entry.get_mut();
304                            if !listener.external_addrs.is_empty() {
305                                for addr in listener.external_addrs.drain(..) {
306                                    self.events.push_back(ToSwarm::ExternalAddrExpired(addr));
307                                }
308                            }
309                            listener.renewal = Some(Delay::new(Duration::from_secs(30)));
310                        }
311                    }
312                    Err(ForwardingError::Any {
313                        listener_id,
314                        error: e,
315                    }) => {
316                        log::error!("Error: {e}");
317                        if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
318                        {
319                            let listener = entry.get_mut();
320                            listener.renewal = Some(Delay::new(Duration::from_secs(30)));
321                        }
322                    }
323                },
324                Poll::Ready(None) => return Poll::Pending,
325                Poll::Pending => break,
326            }
327        }
328
329        self.waker = Some(cx.waker().clone());
330
331        Poll::Pending
332    }
333}