memberlist_net/
lib.rs

1//! [`memberlist`](https://crates.io/crates/memberlist)'s [`Transport`] layer based on TCP and UDP.
2#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/memberlist/main/art/logo_72x72.png")]
3#![allow(clippy::type_complexity)]
4#![deny(missing_docs, warnings)]
5#![forbid(unsafe_code)]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7#![cfg_attr(docsrs, allow(unused_attributes))]
8
9use std::{
10  net::SocketAddr,
11  sync::{
12    Arc,
13    atomic::{AtomicBool, AtomicUsize, Ordering},
14  },
15};
16
17use agnostic::{
18  AsyncSpawner, Runtime, RuntimeLite,
19  net::{Net, UdpSocket},
20};
21use atomic_refcell::AtomicRefCell;
22use futures::{StreamExt, stream::FuturesUnordered};
23use memberlist_core::proto::{Data, Payload, SmallVec};
24pub use memberlist_core::{
25  proto::{CIDRsPolicy, Label, LabelError},
26  transport::*,
27};
28
29mod options;
30pub use options::*;
31
32mod promised_processor;
33use promised_processor::*;
34mod packet_processor;
35use packet_processor::*;
36
37/// Errors for the net transport.
38pub mod error;
39use error::*;
40
41/// Abstract the [`StremLayer`](crate::stream_layer::StreamLayer) for [`NetTransport`].
42pub mod stream_layer;
43use stream_layer::*;
44
45/// Re-exports [`nodecraft`]'s address resolver.
46pub mod resolver {
47  #[cfg(feature = "dns")]
48  pub use nodecraft::resolver::dns;
49  pub use nodecraft::resolver::{address, socket_addr};
50}
51
52/// Exports unit tests.
53#[cfg(any(test, feature = "test"))]
54#[cfg_attr(docsrs, doc(cfg(feature = "test")))]
55pub mod tests;
56
57/// A large buffer size that we attempt to set UDP
58/// sockets to in order to handle a large volume of messages.
59const DEFAULT_UDP_RECV_BUF_SIZE: usize = 2 * 1024 * 1024;
60
61#[cfg(feature = "tokio")]
62/// [`NetTransport`] based on [`tokio`](https://crates.io/crates/tokio).
63pub type TokioNetTransport<I, A, S> = NetTransport<I, A, S, agnostic::tokio::TokioRuntime>;
64
65#[cfg(feature = "smol")]
66/// [`NetTransport`] based on [`smol`](https://crates.io/crates/smol).
67pub type SmolNetTransport<I, A, S> = NetTransport<I, A, S, agnostic::smol::SmolRuntime>;
68
69/// The net transport based on TCP/TLS and UDP
70pub struct NetTransport<I, A, S, R>
71where
72  I: Id,
73  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
74  S: StreamLayer<Runtime = R>,
75  R: Runtime,
76{
77  opts: Arc<Options<I, A>>,
78  advertise_addr: A::ResolvedAddress,
79  local_addr: A::Address,
80  packet_rx: PacketSubscriber<A::ResolvedAddress, R::Instant>,
81  stream_rx: StreamSubscriber<A::ResolvedAddress, S::Stream>,
82  num_v4_sockets: usize,
83  v4_round_robin: AtomicUsize,
84  v4_sockets: AtomicRefCell<SmallVec<Arc<<<A::Runtime as Runtime>::Net as Net>::UdpSocket>>>,
85  num_v6_sockets: usize,
86  v6_round_robin: AtomicUsize,
87  v6_sockets: AtomicRefCell<SmallVec<Arc<<<A::Runtime as Runtime>::Net as Net>::UdpSocket>>>,
88  stream_layer: Arc<S>,
89  handles: AtomicRefCell<FuturesUnordered<<R::Spawner as AsyncSpawner>::JoinHandle<()>>>,
90  resolver: Arc<A>,
91  shutdown_tx: async_channel::Sender<()>,
92}
93
94impl<I, A, S, R> NetTransport<I, A, S, R>
95where
96  I: Id + Data + Send + Sync + 'static,
97  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
98  A::Address: Send + Sync + 'static,
99  A::ResolvedAddress: Data,
100  S: StreamLayer<Runtime = R>,
101  R: Runtime,
102{
103  fn find_advertise_addr_index(addrs: &[SocketAddr]) -> usize {
104    for (i, addr) in addrs.iter().enumerate() {
105      if !addr.ip().is_unspecified() {
106        return i;
107      }
108    }
109
110    0
111  }
112
113  fn next_socket(
114    &self,
115    addr: &A::ResolvedAddress,
116  ) -> Option<Arc<<<A::Runtime as Runtime>::Net as Net>::UdpSocket>> {
117    enum Kind {
118      V4(usize),
119      V6(usize),
120    }
121
122    let kind = if addr.is_ipv4() {
123      // if there's no v4 sockets, we assume remote addr can accept both v4 and v6
124      // give a try on v6
125      if self.num_v4_sockets == 0 {
126        let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v6_sockets;
127        Kind::V6(idx)
128      } else {
129        let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v4_sockets;
130        Kind::V4(idx)
131      }
132    } else if self.num_v6_sockets == 0 {
133      let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v4_sockets;
134      Kind::V4(idx)
135    } else {
136      let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v6_sockets;
137      Kind::V6(idx)
138    };
139
140    // if we failed to borrow, it means that this transport is being shut down.
141
142    match kind {
143      Kind::V4(idx) => {
144        if let Ok(sockets) = self.v4_sockets.try_borrow() {
145          Some(sockets[idx].clone())
146        } else {
147          None
148        }
149      }
150      Kind::V6(idx) => {
151        if let Ok(sockets) = self.v6_sockets.try_borrow() {
152          Some(sockets[idx].clone())
153        } else {
154          None
155        }
156      }
157    }
158  }
159}
160
161impl<I, A, S, R> Transport for NetTransport<I, A, S, R>
162where
163  I: Id + Data + Send + Sync + 'static,
164  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
165  A::Address: Data + Send + Sync + 'static,
166  S: StreamLayer<Runtime = R>,
167  R: Runtime,
168{
169  type Error = NetTransportError<Self::Resolver>;
170
171  type Id = I;
172  type Address = A::Address;
173  type ResolvedAddress = SocketAddr;
174  type Resolver = A;
175
176  type Connection = S::Stream;
177
178  type Runtime = <Self::Resolver as AddressResolver>::Runtime;
179
180  type Options = NetTransportOptions<Self::Id, Self::Resolver, S>;
181
182  async fn new(transport_opts: Self::Options) -> Result<Self, Self::Error> {
183    let (resolver_opts, stream_layer_opts, opts) = transport_opts.into();
184    let resolver = Arc::new(
185      <A as AddressResolver>::new(resolver_opts)
186        .await
187        .map_err(NetTransportError::Resolver)?,
188    );
189
190    let stream_layer = Arc::new(<S as StreamLayer>::new(stream_layer_opts).await?);
191    let opts = Arc::new(opts);
192
193    // If we reject the empty list outright we can assume that there's at
194    // least one listener of each type later during operation.
195    if opts.bind_addresses.is_empty() {
196      return Err(NetTransportError::EmptyBindAddresses);
197    }
198
199    let (stream_tx, stream_rx) = promised_stream::<Self>();
200    let (packet_tx, packet_rx) = packet_stream::<Self>();
201    let (shutdown_tx, shutdown_rx) = async_channel::bounded(1);
202
203    let mut v4_promised_listeners = Vec::with_capacity(opts.bind_addresses.len());
204    let mut v4_sockets = Vec::with_capacity(opts.bind_addresses.len());
205    let mut v6_promised_listeners = Vec::with_capacity(opts.bind_addresses.len());
206    let mut v6_sockets = Vec::with_capacity(opts.bind_addresses.len());
207    let mut resolved_bind_address = SmallVec::new();
208
209    for addr in opts.bind_addresses.iter() {
210      let addr = resolver
211        .resolve(addr)
212        .await
213        .map_err(|e| NetTransportError::Resolve {
214          addr: addr.clone(),
215          err: e,
216        })?;
217      let bind_port = addr.port();
218
219      let (local_addr, ln) = if bind_port == 0 {
220        let mut retries = 0;
221        loop {
222          match stream_layer.bind(addr).await {
223            Ok(ln) => break (ln.local_addr(), ln),
224            Err(e) => {
225              if retries < 9 {
226                retries += 1;
227                continue;
228              }
229              return Err(NetTransportError::ListenPromised(addr, e));
230            }
231          }
232        }
233      } else {
234        match stream_layer.bind(addr).await {
235          Ok(ln) => (ln.local_addr(), ln),
236          Err(e) => return Err(NetTransportError::ListenPromised(addr, e)),
237        }
238      };
239
240      if local_addr.is_ipv4() {
241        v4_promised_listeners.push((Arc::new(ln), local_addr));
242      } else {
243        v6_promised_listeners.push((Arc::new(ln), local_addr));
244      }
245      // If the config port given was zero, use the first TCP listener
246      // to pick an available port and then apply that to everything
247      // else.
248      let addr = if bind_port == 0 { local_addr } else { addr };
249      resolved_bind_address.push(addr);
250
251      let (local_addr, packet_socket) =
252        <<<A::Runtime as Runtime>::Net as Net>::UdpSocket as UdpSocket>::bind(addr)
253          .await
254          .map(|ln| (addr, ln))
255          .map_err(|e| NetTransportError::ListenPacket(addr, e))?;
256
257      set_udp_recv_buffer(&packet_socket, opts.recv_buffer_size)?;
258
259      if local_addr.is_ipv4() {
260        v4_sockets.push((Arc::new(packet_socket), local_addr));
261      } else {
262        v6_sockets.push((Arc::new(packet_socket), local_addr))
263      }
264    }
265
266    let expose_addr_index = Self::find_advertise_addr_index(&resolved_bind_address);
267    let advertise_addr = resolved_bind_address[expose_addr_index];
268    let self_addr = opts.bind_addresses[expose_addr_index].cheap_clone();
269    let shutdown = Arc::new(AtomicBool::new(false));
270    let handles = FuturesUnordered::new();
271    // Fire them up start that we've been able to create them all.
272    // keep the first tcp and udp listener, gossip protocol, we made sure there's at least one
273    // udp and tcp listener can
274    for ((promised_ln, promised_addr), (socket, socket_addr)) in v4_promised_listeners
275      .iter()
276      .zip(v4_sockets.iter())
277      .chain(v6_promised_listeners.iter().zip(v6_sockets.iter()))
278    {
279      let processor = PromisedProcessor::<A, Self, S> {
280        stream_tx: stream_tx.clone(),
281        ln: promised_ln.clone(),
282        shutdown_rx: shutdown_rx.clone(),
283        local_addr: *promised_addr,
284      };
285      handles.push(R::spawn(processor.run()));
286
287      let processor = PacketProcessor::<A, Self> {
288        packet_tx: packet_tx.clone(),
289        socket: socket.clone(),
290        local_addr: *socket_addr,
291        shutdown: shutdown.clone(),
292        #[cfg(feature = "metrics")]
293        metric_labels: opts.metric_labels.clone().unwrap_or_default(),
294        shutdown_rx: shutdown_rx.clone(),
295      };
296
297      handles.push(R::spawn(processor.run()));
298    }
299
300    // find final advertise address
301    let final_advertise_addr = if let Some(addr) = opts.advertise_address {
302      addr
303    } else if advertise_addr.ip().is_unspecified() {
304      let ip = getifs::private_addrs()
305        .map_err(|_| NetTransportError::NoPrivateIP)
306        .and_then(|ips| {
307          if let Some(ip) = ips.into_iter().next().map(|ip| ip.addr()) {
308            Ok(ip)
309          } else {
310            Err(NetTransportError::NoPrivateIP)
311          }
312        })?;
313      SocketAddr::new(ip, advertise_addr.port())
314    } else {
315      advertise_addr
316    };
317
318    // if final_advertise_addr.is_global_ip() {
319    //   #[cfg(feature = "encryption")]
320    //   if S::is_secure()
321    //     && (encryptor.is_none() || opts.encryption_algo.is_none() || !opts.gossip_verify_outgoing)
322    //   {
323    //     tracing::warn!(advertise_addr=%final_advertise_addr, "memberlist_net: binding to public address without enabling encryption for packet stream layer!");
324    //   }
325
326    //   #[cfg(feature = "encryption")]
327    //   if !S::is_secure()
328    //     && (encryptor.is_none() || opts.encryption_algo.is_none() || !opts.gossip_verify_outgoing)
329    //   {
330    //     tracing::warn!(advertise_addr=%final_advertise_addr, "memberlist_net: binding to public address without enabling encryption for stream layer!");
331    //   }
332
333    //   #[cfg(not(feature = "encryption"))]
334    //   tracing::warn!(advertise_addr=%final_advertise_addr, "memberlist_net: binding to public address without enabling encryption for stream layer!");
335    // }
336
337    Ok(Self {
338      advertise_addr: final_advertise_addr,
339      local_addr: self_addr,
340      opts,
341      packet_rx,
342      stream_rx,
343      handles: AtomicRefCell::new(handles),
344      num_v4_sockets: v4_sockets.len(),
345      v4_sockets: AtomicRefCell::new(v4_sockets.into_iter().map(|(ln, _)| ln).collect()),
346      v4_round_robin: AtomicUsize::new(0),
347      num_v6_sockets: v6_sockets.len(),
348      v6_sockets: AtomicRefCell::new(v6_sockets.into_iter().map(|(ln, _)| ln).collect()),
349      v6_round_robin: AtomicUsize::new(0),
350      stream_layer,
351      resolver,
352      shutdown_tx,
353    })
354  }
355
356  async fn resolve(
357    &self,
358    addr: &<Self::Resolver as AddressResolver>::Address,
359  ) -> Result<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Error> {
360    self
361      .resolver
362      .resolve(addr)
363      .await
364      .map_err(|e| Self::Error::Resolve {
365        addr: addr.cheap_clone(),
366        err: e,
367      })
368  }
369
370  #[inline]
371  fn local_id(&self) -> &Self::Id {
372    &self.opts.id
373  }
374
375  #[inline]
376  fn local_address(&self) -> &<Self::Resolver as AddressResolver>::Address {
377    &self.local_addr
378  }
379
380  #[inline]
381  fn advertise_address(&self) -> &<Self::Resolver as AddressResolver>::ResolvedAddress {
382    &self.advertise_addr
383  }
384
385  #[inline]
386  fn max_packet_size(&self) -> usize {
387    self.opts.max_packet_size
388  }
389
390  #[inline]
391  fn header_overhead(&self) -> usize {
392    0
393  }
394
395  fn blocked_address(
396    &self,
397    addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
398  ) -> Result<(), Self::Error> {
399    let ip = addr.ip();
400    if self.opts.cidrs_policy.is_blocked(&ip) {
401      Err(Self::Error::BlockedIp(ip))
402    } else {
403      Ok(())
404    }
405  }
406
407  async fn send_to(
408    &self,
409    addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
410    packets: Payload,
411  ) -> Result<(usize, <Self::Runtime as RuntimeLite>::Instant), Self::Error> {
412    let start = <Self::Runtime as RuntimeLite>::now();
413
414    let src = packets.as_slice();
415    match self.next_socket(addr) {
416      Some(skt) => skt
417        .send_to(src, addr)
418        .await
419        .map(|num| {
420          tracing::trace!(remote=%addr, total_bytes = %num, sent=?src, "memberlist_net.packet");
421          (num, start)
422        })
423        .map_err(Into::into),
424      None => {
425        tracing::error!("memberlist_net.packet: transport is being shutdown");
426        Err(
427          std::io::Error::new(
428            std::io::ErrorKind::ConnectionAborted,
429            "transport is being shutdown",
430          )
431          .into(),
432        )
433      }
434    }
435  }
436
437  async fn open(
438    &self,
439    addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
440    deadline: <Self::Runtime as RuntimeLite>::Instant,
441  ) -> Result<Self::Connection, Self::Error> {
442    let connector =
443      <Self::Runtime as RuntimeLite>::timeout_at(deadline, self.stream_layer.connect(*addr));
444    match connector.await {
445      Ok(Ok(conn)) => Ok(conn),
446      Ok(Err(e)) => Err(e.into()),
447      Err(e) => Err(Self::Error::Io(e.into())),
448    }
449  }
450
451  fn packet(
452    &self,
453  ) -> PacketSubscriber<
454    <Self::Resolver as AddressResolver>::ResolvedAddress,
455    <Self::Runtime as RuntimeLite>::Instant,
456  > {
457    self.packet_rx.clone()
458  }
459
460  fn stream(
461    &self,
462  ) -> StreamSubscriber<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Connection> {
463    self.stream_rx.clone()
464  }
465
466  fn packet_reliable(&self) -> bool {
467    false
468  }
469
470  fn packet_secure(&self) -> bool {
471    false
472  }
473
474  fn stream_secure(&self) -> bool {
475    S::is_secure()
476  }
477
478  async fn shutdown(&self) -> Result<(), Self::Error> {
479    if !self.shutdown_tx.close() {
480      return Ok(());
481    }
482
483    // clear all udp sockets
484    loop {
485      if let Ok(mut s) = self.v4_sockets.try_borrow_mut() {
486        s.clear();
487        break;
488      }
489    }
490
491    loop {
492      if let Ok(mut s) = self.v6_sockets.try_borrow_mut() {
493        s.clear();
494        break;
495      }
496    }
497
498    let mut handles = core::mem::take(&mut *self.handles.borrow_mut());
499    while handles.next().await.is_some() {}
500    Ok(())
501  }
502}
503
504impl<I, A, S, R> Drop for NetTransport<I, A, S, R>
505where
506  I: Id,
507  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
508  S: StreamLayer<Runtime = R>,
509  R: Runtime,
510{
511  fn drop(&mut self) {
512    self.shutdown_tx.close();
513  }
514}
515
516// Resize the UDP receive window. The function
517// attempts to set the read buffer to `udpRecvBuf` but backs off until
518// the read buffer can be set.
519fn set_udp_recv_buffer<U>(udp: &U, mut size: usize) -> std::io::Result<()>
520where
521  U: agnostic::net::UdpSocket,
522{
523  let mut err = None;
524  while size > 0 {
525    match udp.set_recv_buffer_size(size) {
526      Ok(_) => return Ok(()),
527      Err(e) => err = Some(e),
528    }
529    size /= 2;
530  }
531
532  Err(
533    err.unwrap_or_else(|| std::io::Error::other("fail to set receive buffer size for UDP socket")),
534  )
535}