1mod task;
2mod utils;
3
4use core::task::{Context, Poll};
5use futures::{FutureExt, StreamExt};
6use futures_timer::Delay;
7use libp2p::core::transport::ListenerId;
8use libp2p::core::{Endpoint, Multiaddr};
9use libp2p::swarm::{
10 self, dummy::ConnectionHandler as DummyConnectionHandler, NetworkBehaviour,
11};
12use libp2p::swarm::{
13 ConnectionDenied, ConnectionId, ExpiredListenAddr, NewListenAddr, THandler, THandlerInEvent,
14 ToSwarm,
15};
16use libp2p::PeerId;
17use std::collections::hash_map::Entry;
18use std::task::Waker;
19use std::time::Duration;
20use task::{ForwardingError, NatCommands, NatResult, NatType};
21
22use std::collections::{HashMap, VecDeque};
23
24#[cfg(not(any(feature = "tokio", feature = "async-std")))]
25compile_error!("Require tokio or async-std feature to be enabled");
26
27#[derive(Debug)]
28struct LocalListener {
29 pub addrs: HashMap<Multiaddr, Option<NatType>>,
30 pub external_addrs: Vec<Multiaddr>,
31 pub renewal: Option<Delay>,
32}
33
34#[allow(clippy::type_complexity)]
35pub struct Behaviour {
36 events: VecDeque<swarm::ToSwarm<<Self as NetworkBehaviour>::ToSwarm, THandlerInEvent<Self>>>,
37 nat_sender: futures::channel::mpsc::UnboundedSender<NatCommands>,
38 event_receiver: futures::channel::mpsc::Receiver<Result<NatResult, ForwardingError>>,
39 local_listeners: HashMap<ListenerId, LocalListener>,
40 disabled: bool,
41 waker: Option<Waker>,
42}
43
44impl Default for Behaviour {
45 fn default() -> Self {
46 Self::with_duration(Duration::from_secs(2 * 60))
47 }
48}
49
50impl Behaviour {
51 pub fn with_duration(duration: Duration) -> Self {
52 assert!(duration.as_secs() > 10);
53 let renewal = duration / 2;
54
55 let (nat_sender, result_rx) = task::port_forwarding_task(duration, renewal);
56 Self {
57 events: Default::default(),
58 nat_sender,
59 event_receiver: result_rx,
60 local_listeners: Default::default(),
61 disabled: false,
62 waker: None,
63 }
64 }
65
66 pub fn enable(&mut self) {
68 self.disabled = false;
69 for local_listener in self.local_listeners.values_mut() {
70 local_listener.renewal = Some(Delay::new(Duration::from_secs(10)));
71 }
72 if let Some(waker) = self.waker.take() {
73 waker.wake();
74 }
75 }
76
77 pub fn disable(&mut self) {
80 if self.disabled {
81 return;
82 }
83
84 self.disabled = true;
85
86 if self.external_addr().is_empty() {
88 return;
89 }
90
91 for (id, listener) in &mut self.local_listeners {
92 for (addr, nat_type) in &listener.addrs {
93 let Some(nat_type) = nat_type else {
94 continue;
95 };
96
97 let _ = self
98 .nat_sender
99 .clone()
100 .unbounded_send(NatCommands::DisableForwardPort(
101 *id,
102 addr.clone(),
103 *nat_type,
104 ));
105 }
106
107 for addr in listener.external_addrs.drain(..) {
110 self.events.push_back(ToSwarm::ExternalAddrExpired(addr));
111 }
112
113 listener.renewal = None;
114 }
115
116 if let Some(waker) = self.waker.take() {
117 waker.wake();
118 }
119 }
120
121 pub fn external_addr(&self) -> Vec<Multiaddr> {
123 self.local_listeners
124 .values()
125 .flat_map(|local| local.addrs.keys().cloned().collect::<Vec<_>>())
126 .collect::<Vec<_>>()
127 }
128}
129
130impl NetworkBehaviour for Behaviour {
131 type ConnectionHandler = DummyConnectionHandler;
132 type ToSwarm = void::Void;
133
134 fn handle_established_inbound_connection(
135 &mut self,
136 _: ConnectionId,
137 _: PeerId,
138 _: &Multiaddr,
139 _: &Multiaddr,
140 ) -> Result<THandler<Self>, ConnectionDenied> {
141 Ok(DummyConnectionHandler)
142 }
143
144 fn handle_established_outbound_connection(
145 &mut self,
146 _: ConnectionId,
147 _: PeerId,
148 _: &Multiaddr,
149 _: Endpoint,
150 ) -> Result<THandler<Self>, ConnectionDenied> {
151 Ok(DummyConnectionHandler)
152 }
153
154 fn on_connection_handler_event(
155 &mut self,
156 _: libp2p::PeerId,
157 _: swarm::ConnectionId,
158 _: swarm::THandlerOutEvent<Self>,
159 ) {
160 }
161
162 fn on_swarm_event(&mut self, event: swarm::FromSwarm) {
163 match event {
164 swarm::FromSwarm::NewListenAddr(NewListenAddr { listener_id, addr }) => {
165 if utils::multiaddr_to_socket_port(addr).is_none() {
167 return;
168 }
169
170 match self.local_listeners.entry(listener_id) {
171 Entry::Occupied(mut entry) => {
172 let listener = entry.get_mut();
173 if !listener.addrs.contains_key(addr) {
174 listener.addrs.insert(addr.clone(), None);
175 }
176 }
177 Entry::Vacant(entry) => {
178 entry.insert(LocalListener {
179 addrs: HashMap::from_iter([(addr.clone(), None)]),
180 external_addrs: vec![],
181 renewal: None,
182 });
183 }
184 };
185
186 let _ = self
187 .nat_sender
188 .clone()
189 .unbounded_send(NatCommands::ForwardPort(listener_id, addr.clone()));
190 }
191 swarm::FromSwarm::NewExternalAddrCandidate(_) => {}
192 swarm::FromSwarm::ExternalAddrExpired(_) => {}
193 swarm::FromSwarm::ExpiredListenAddr(ExpiredListenAddr { listener_id, addr }) => {
194 if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id) {
195 let listener = entry.get_mut();
196
197 let list = &mut listener.addrs;
198
199 if !list.contains_key(addr) {
200 return;
201 }
202
203 let nat_type = list.remove(addr).flatten();
204
205 if let Some(nat_type) = nat_type {
206 let _ = self.nat_sender.clone().unbounded_send(
207 NatCommands::DisableForwardPort(listener_id, addr.clone(), nat_type),
208 );
209 }
210
211 if list.is_empty() {
212 entry.remove();
213 }
214 }
215 }
216 _ => {}
217 }
218 }
219
220 fn poll(
221 &mut self,
222 cx: &mut Context,
223 ) -> Poll<swarm::ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
224 if let Some(event) = self.events.pop_front() {
225 return Poll::Ready(event);
226 }
227
228 for (id, local_listener) in &mut self.local_listeners {
229 if let Some(renewal) = local_listener.renewal.as_mut() {
230 if let Poll::Ready(()) = renewal.poll_unpin(cx) {
231 for addr in local_listener.addrs.keys() {
232 let _ = self
233 .nat_sender
234 .clone()
235 .unbounded_send(NatCommands::ForwardPort(*id, addr.clone()));
236 }
237 renewal.reset(Duration::from_secs(30));
238 }
239 }
240 }
241
242 loop {
243 match self.event_receiver.poll_next_unpin(cx) {
244 Poll::Ready(Some(result)) => match result {
245 Ok(NatResult::PortForwardingEnabled {
246 listener_id,
247 local_addr,
248 addr,
249 nat_type,
250 timer,
251 }) => {
252 if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
253 {
254 let listener = entry.get_mut();
255
256 listener
257 .addrs
258 .entry(local_addr)
259 .and_modify(|nty| *nty = Some(nat_type));
260
261 if !listener.external_addrs.contains(&addr) {
262 log::info!("Discovered {addr} as an external address.");
263 listener.external_addrs.push(addr.clone());
264 self.events
265 .push_back(ToSwarm::ExternalAddrConfirmed(addr.clone()));
266 }
267 listener.renewal = Some(timer);
268 }
269 }
270 Ok(NatResult::PortForwardingDisabled { listener_id }) => {
271 if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
272 {
273 let listener = entry.get_mut();
274
275 listener.addrs.values_mut().for_each(|nty| *nty = None);
276
277 if !listener.external_addrs.is_empty() {
278 for addr in listener.external_addrs.drain(..) {
279 self.events
280 .push_back(ToSwarm::ExternalAddrExpired(addr.clone()));
281 }
282 }
283 }
284 }
285 Err(ForwardingError::InvalidAddress {
286 listener_id,
287 address,
288 }) => {
289 if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
290 {
291 let listener = entry.get_mut();
292 log::debug!("Removing {address} from local listeners");
293 listener
294 .addrs
295 .retain(|local_addr, _| local_addr != &address);
296 listener.renewal = Some(Delay::new(Duration::from_secs(30)));
297 }
298 }
299 Err(ForwardingError::PortForwardingFailed { listener_id }) => {
300 if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
301 {
302 log::error!("Failed performing port forwarding");
303 let listener = entry.get_mut();
304 if !listener.external_addrs.is_empty() {
305 for addr in listener.external_addrs.drain(..) {
306 self.events.push_back(ToSwarm::ExternalAddrExpired(addr));
307 }
308 }
309 listener.renewal = Some(Delay::new(Duration::from_secs(30)));
310 }
311 }
312 Err(ForwardingError::Any {
313 listener_id,
314 error: e,
315 }) => {
316 log::error!("Error: {e}");
317 if let Entry::Occupied(mut entry) = self.local_listeners.entry(listener_id)
318 {
319 let listener = entry.get_mut();
320 listener.renewal = Some(Delay::new(Duration::from_secs(30)));
321 }
322 }
323 },
324 Poll::Ready(None) => return Poll::Pending,
325 Poll::Pending => break,
326 }
327 }
328
329 self.waker = Some(cx.waker().clone());
330
331 Poll::Pending
332 }
333}