memberlist_quic/
lib.rs

1//! [`memberlist`](https://crates.io/crates/memberlist)'s [`Transport`] layer based on QUIC.
2#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/memberlist/main/art/logo_72x72.png")]
3#![allow(clippy::type_complexity)]
4#![forbid(unsafe_code)]
5#![deny(warnings, missing_docs)]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7#![cfg_attr(docsrs, allow(unused_attributes))]
8
9use std::{
10  net::{IpAddr, SocketAddr},
11  sync::{
12    Arc,
13    atomic::{AtomicUsize, Ordering},
14  },
15  time::Duration,
16};
17
18use agnostic_lite::{AsyncSpawner, RuntimeLite, time::Instant};
19use atomic_refcell::AtomicRefCell;
20use crossbeam_skiplist::SkipMap;
21use futures::{StreamExt, stream::FuturesUnordered};
22use memberlist_core::proto::{Data, Payload, SmallVec};
23pub use memberlist_core::{
24  proto::{CIDRsPolicy, Label, LabelError, ProtoReader},
25  transport::*,
26};
27
28mod processor;
29use processor::*;
30
31/// Exports unit tests.
32#[cfg(any(test, feature = "test"))]
33#[cfg_attr(docsrs, doc(cfg(feature = "test")))]
34pub mod tests;
35
36mod error;
37pub use error::*;
38mod options;
39pub use options::*;
40/// Abstract the [`StremLayer`](crate::stream_layer::StreamLayer) for [`QuicTransport`].
41pub mod stream_layer;
42use stream_layer::*;
43
44const MAX_MESSAGE_SIZE: usize = u32::MAX as usize;
45
46const PACKET_TAG: u8 = 254;
47const STREAM_TAG: u8 = 255;
48
49#[derive(Copy, Clone)]
50#[repr(u8)]
51enum StreamType {
52  Stream = STREAM_TAG,
53  Packet = PACKET_TAG,
54}
55
56impl TryFrom<u8> for StreamType {
57  type Error = u8;
58
59  fn try_from(value: u8) -> Result<Self, Self::Error> {
60    Ok(match value {
61      STREAM_TAG => Self::Stream,
62      PACKET_TAG => Self::Packet,
63      _ => return Err(value),
64    })
65  }
66}
67
68#[cfg(feature = "tokio")]
69/// [`QuicTransport`] based on [`tokio`](https://crates.io/crates/tokio).
70pub type TokioQuicTransport<I, A, S> = QuicTransport<I, A, S, agnostic_lite::tokio::TokioRuntime>;
71
72#[cfg(feature = "smol")]
73/// [`QuicTransport`] based on [`smol`](https://crates.io/crates/smol).
74pub type SmolQuicTransport<I, A, S> = QuicTransport<I, A, S, agnostic_lite::smol::SmolRuntime>;
75
76/// A [`Transport`] implementation based on QUIC
77pub struct QuicTransport<I, A, S, R>
78where
79  I: Id + Send + Sync + 'static,
80  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
81  S: StreamLayer<Runtime = R>,
82  R: RuntimeLite,
83{
84  opts: Options<I, A>,
85  advertise_addr: A::ResolvedAddress,
86  local_addr: A::Address,
87  packet_rx: PacketSubscriber<A::ResolvedAddress, R::Instant>,
88  stream_rx: StreamSubscriber<A::ResolvedAddress, S::Stream>,
89  #[allow(dead_code)]
90  stream_layer: S,
91  connection_pool: Arc<SkipMap<SocketAddr, (R::Instant, S::Connection)>>,
92  v4_round_robin: AtomicUsize,
93  v4_connectors: SmallVec<S::Connector>,
94  v6_round_robin: AtomicUsize,
95  v6_connectors: SmallVec<S::Connector>,
96  handles: AtomicRefCell<FuturesUnordered<<R::Spawner as AsyncSpawner>::JoinHandle<()>>>,
97  resolver: A,
98  shutdown_tx: async_channel::Sender<()>,
99  max_packet_size: usize,
100}
101
102impl<I, A, S, R> QuicTransport<I, A, S, R>
103where
104  I: Id + Data + Send + Sync + 'static,
105  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
106  A::Address: Send + Sync + 'static,
107  A::ResolvedAddress: Data,
108  S: StreamLayer<Runtime = R>,
109  R: RuntimeLite,
110{
111  async fn new_in(
112    resolver: A,
113    stream_layer: S,
114    opts: Options<I, A>,
115  ) -> Result<Self, QuicTransportError<A>> {
116    // If we reject the empty list outright we can assume that there's at
117    // least one listener of each type later during operation.
118    if opts.bind_addresses.is_empty() {
119      return Err(QuicTransportError::EmptyBindAddresses);
120    }
121
122    let (stream_tx, stream_rx) = promised_stream::<Self>();
123    let (packet_tx, packet_rx) = packet_stream::<Self>();
124    let (shutdown_tx, shutdown_rx) = async_channel::bounded(1);
125
126    let mut v4_connectors = SmallVec::with_capacity(opts.bind_addresses.len());
127    let mut v6_connectors = SmallVec::with_capacity(opts.bind_addresses.len());
128    let mut v4_acceptors = SmallVec::with_capacity(opts.bind_addresses.len());
129    let mut v6_acceptors = SmallVec::with_capacity(opts.bind_addresses.len());
130    let mut resolved_bind_address = SmallVec::new();
131
132    for addr in opts.bind_addresses.iter() {
133      let addr = resolver
134        .resolve(addr)
135        .await
136        .map_err(|e| QuicTransportError::Resolve {
137          addr: addr.cheap_clone(),
138          err: e,
139        })?;
140
141      let bind_port = addr.port();
142
143      let (local_addr, acceptor, connector) = if bind_port == 0 {
144        let mut retries = 0;
145        loop {
146          match stream_layer.bind(addr).await {
147            Ok(res) => break res,
148            Err(e) => {
149              if retries < 9 {
150                retries += 1;
151                continue;
152              }
153              return Err(QuicTransportError::Listen(addr, e));
154            }
155          }
156        }
157      } else {
158        match stream_layer.bind(addr).await {
159          Ok(res) => res,
160          Err(e) => return Err(QuicTransportError::Listen(addr, e)),
161        }
162      };
163
164      if local_addr.is_ipv4() {
165        v4_acceptors.push((local_addr, acceptor));
166        v4_connectors.push(connector);
167      } else {
168        v6_acceptors.push((local_addr, acceptor));
169        v6_connectors.push(connector);
170      }
171      // If the config port given was zero, use the first TCP listener
172      // to pick an available port and then apply that to everything
173      // else.
174      let addr = if bind_port == 0 { local_addr } else { addr };
175      resolved_bind_address.push(addr);
176    }
177
178    let expose_addr_index = Self::find_advertise_addr_index(&resolved_bind_address);
179    let advertise_addr = resolved_bind_address[expose_addr_index];
180    let self_addr = opts.bind_addresses[expose_addr_index].cheap_clone();
181    let handles = FuturesUnordered::new();
182
183    // Fire them up start that we've been able to create them all.
184    // keep the first tcp and udp listener, gossip protocol, we made sure there's at least one
185    // udp and tcp listener can
186    for (local_addr, acceptor) in v4_acceptors.into_iter().chain(v6_acceptors.into_iter()) {
187      let processor = Processor::<A, Self, S> {
188        acceptor,
189        packet_tx: packet_tx.clone(),
190        stream_tx: stream_tx.clone(),
191        local_addr,
192        timeout: opts.timeout,
193        shutdown_rx: shutdown_rx.clone(),
194        #[cfg(feature = "metrics")]
195        metric_labels: opts.metric_labels.clone().unwrap_or_default(),
196      };
197
198      handles.push(R::spawn(processor.run()));
199    }
200
201    // find final advertise address
202    let final_advertise_addr = if let Some(addr) = opts.advertise_address {
203      addr
204    } else if advertise_addr.ip().is_unspecified() {
205      let ip = getifs::private_addrs()
206        .map_err(|_| QuicTransportError::NoPrivateIP)
207        .and_then(|ips| {
208          if let Some(ip) = ips.into_iter().next().map(|ip| ip.addr()) {
209            Ok(ip)
210          } else {
211            Err(QuicTransportError::NoPrivateIP)
212          }
213        })?;
214      SocketAddr::new(ip, advertise_addr.port())
215    } else {
216      advertise_addr
217    };
218
219    let connection_pool = Arc::new(SkipMap::new());
220    let interval = <A::Runtime as RuntimeLite>::interval(opts.connection_pool_cleanup_period);
221    let pool = connection_pool.clone();
222    let shutdown_rx = shutdown_rx.clone();
223    handles.push(R::spawn(Self::connection_pool_cleaner(
224      pool,
225      interval,
226      shutdown_rx,
227      opts.connection_ttl.unwrap_or(Duration::ZERO),
228    )));
229
230    Ok(Self {
231      advertise_addr: final_advertise_addr,
232      connection_pool,
233      local_addr: self_addr,
234      max_packet_size: MAX_MESSAGE_SIZE.min(stream_layer.max_stream_data()),
235      opts,
236      packet_rx,
237      stream_rx,
238      handles: AtomicRefCell::new(handles),
239      v4_connectors,
240      v6_connectors,
241      v4_round_robin: AtomicUsize::new(0),
242      v6_round_robin: AtomicUsize::new(0),
243      stream_layer,
244      resolver,
245      shutdown_tx,
246    })
247  }
248
249  fn find_advertise_addr_index(addrs: &[SocketAddr]) -> usize {
250    for (i, addr) in addrs.iter().enumerate() {
251      if !addr.ip().is_unspecified() {
252        return i;
253      }
254    }
255
256    0
257  }
258
259  async fn connection_pool_cleaner(
260    pool: Arc<SkipMap<SocketAddr, (R::Instant, S::Connection)>>,
261    mut interval: impl agnostic_lite::time::AsyncInterval,
262    shutdown_rx: async_channel::Receiver<()>,
263    max_conn_idle: Duration,
264  ) {
265    loop {
266      let fut1 = shutdown_rx.recv();
267      let fut2 = async {
268        interval.next().await;
269
270        for ent in pool.iter() {
271          let (deadline, conn) = ent.value();
272          if max_conn_idle == Duration::ZERO {
273            if conn.is_closed().await {
274              let _ = conn.close().await;
275              ent.remove();
276            }
277            continue;
278          }
279
280          if deadline.elapsed() >= max_conn_idle || conn.is_closed().await {
281            let _ = conn.close().await;
282            ent.remove();
283          }
284        }
285      };
286
287      futures::pin_mut!(fut1, fut2);
288      match futures::future::select(fut1, fut2).await {
289        futures::future::Either::Left(_) => break,
290        futures::future::Either::Right(_) => {}
291      }
292    }
293  }
294}
295
296impl<I, A, S, R> QuicTransport<I, A, S, R>
297where
298  I: Id + Send + Sync + 'static,
299  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
300  S: StreamLayer<Runtime = R>,
301  R: RuntimeLite,
302{
303  fn next_connector(&self, addr: &A::ResolvedAddress) -> &S::Connector {
304    if addr.is_ipv4() {
305      // if there's no v4 sockets, we assume remote addr can accept both v4 and v6
306      // give a try on v6
307      if self.v4_connectors.is_empty() {
308        let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.v6_connectors.len();
309        &self.v6_connectors[idx]
310      } else {
311        let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.v4_connectors.len();
312        &self.v4_connectors[idx]
313      }
314    } else if self.v6_connectors.is_empty() {
315      let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.v4_connectors.len();
316      &self.v4_connectors[idx]
317    } else {
318      let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.v6_connectors.len();
319      &self.v6_connectors[idx]
320    }
321  }
322
323  async fn fetch_stream(
324    &self,
325    addr: SocketAddr,
326    timeout: Option<R::Instant>,
327  ) -> Result<S::Stream, QuicTransportError<A>> {
328    if let Some(ent) = self.connection_pool.get(&addr) {
329      let (_, connection) = ent.value();
330      if !connection.is_closed().await {
331        if let Some(timeout) = timeout {
332          return R::timeout_at(timeout, connection.open_bi())
333            .await
334            .map_err(|e| QuicTransportError::Io(e.into()))?
335            .map(|(stream, _)| stream)
336            .map_err(Into::into);
337        } else {
338          return connection
339            .open_bi()
340            .await
341            .map(|(s, _)| s)
342            .map_err(Into::into);
343        }
344      }
345    }
346
347    let connector = self.next_connector(&addr);
348    let connection = connector.connect(addr).await?;
349    connection
350      .open_bi()
351      .await
352      .map(|(s, _)| {
353        self
354          .connection_pool
355          .insert(addr, (Instant::now(), connection));
356        s
357      })
358      .map_err(Into::into)
359  }
360}
361
362impl<I, A, S, R> Transport for QuicTransport<I, A, S, R>
363where
364  I: Id + Data + Send + Sync + 'static,
365  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
366  A::Address: Send + Sync + 'static,
367  A::ResolvedAddress: Data,
368  S: StreamLayer<Runtime = R>,
369  R: RuntimeLite,
370{
371  type Error = QuicTransportError<A>;
372
373  type Id = I;
374
375  type Address = A::Address;
376  type ResolvedAddress = SocketAddr;
377  type Resolver = A;
378
379  type Connection = S::Stream;
380
381  type Runtime = A::Runtime;
382
383  type Options = QuicTransportOptions<I, A, S>;
384
385  async fn new(transport_opts: Self::Options) -> Result<Self, Self::Error> {
386    let (resolver_options, stream_layer_options, opts) = transport_opts.into();
387    let resolver = <A as AddressResolver>::new(resolver_options)
388      .await
389      .map_err(Self::Error::Resolver)?;
390
391    let stream_layer = S::new(stream_layer_options).await?;
392    Self::new_in(resolver, stream_layer, opts).await
393  }
394
395  async fn resolve(
396    &self,
397    addr: &<Self::Resolver as AddressResolver>::Address,
398  ) -> Result<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Error> {
399    self
400      .resolver
401      .resolve(addr)
402      .await
403      .map_err(|e| Self::Error::Resolve {
404        addr: addr.cheap_clone(),
405        err: e,
406      })
407  }
408
409  #[inline]
410  fn local_id(&self) -> &Self::Id {
411    &self.opts.id
412  }
413
414  #[inline]
415  fn local_address(&self) -> &<Self::Resolver as AddressResolver>::Address {
416    &self.local_addr
417  }
418
419  #[inline]
420  fn advertise_address(&self) -> &<Self::Resolver as AddressResolver>::ResolvedAddress {
421    &self.advertise_addr
422  }
423
424  #[inline]
425  fn max_packet_size(&self) -> usize {
426    self.max_packet_size
427  }
428
429  #[inline]
430  fn header_overhead(&self) -> usize {
431    1
432  }
433
434  fn blocked_address(
435    &self,
436    addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
437  ) -> Result<(), Self::Error> {
438    let ip = addr.ip();
439    if self.opts.cidrs_policy.is_blocked(&ip) {
440      Err(Self::Error::BlockedIp(ip))
441    } else {
442      Ok(())
443    }
444  }
445
446  async fn read(
447    &self,
448    from: &Self::ResolvedAddress,
449    conn: &mut <Self::Connection as Connection>::Reader,
450  ) -> Result<usize, Self::Error> {
451    let mut buf = [0; 1];
452    conn.read_exact(&mut buf).await?;
453    match StreamType::try_from(buf[0]) {
454      Ok(StreamType::Stream) => Ok(1),
455      Ok(StreamType::Packet) => Err(QuicTransportError::Io(std::io::Error::new(
456        std::io::ErrorKind::InvalidData,
457        format!("receive an unexpected packet stream from {from}"),
458      ))),
459      Err(tag) => Err(QuicTransportError::Io(std::io::Error::new(
460        std::io::ErrorKind::InvalidData,
461        format!(
462          "receive a stream from {from} with invalid type value: {}",
463          tag
464        ),
465      ))),
466    }
467  }
468
469  async fn write(
470    &self,
471    conn: &mut <Self::Connection as Connection>::Writer,
472    mut src: Payload,
473  ) -> Result<usize, Self::Error> {
474    use memberlist_core::proto::ProtoWriter;
475
476    let header = src.header_mut();
477    if header.is_empty() {
478      return Err(QuicTransportError::custom(
479        "not enough space for header".into(),
480      ));
481    }
482    header[0] = StreamType::Stream as u8;
483    let ttl = self
484      .opts
485      .timeout
486      .map(|ttl| <Self::Runtime as RuntimeLite>::now() + ttl);
487
488    let src = src.as_slice();
489    tracing::trace!(
490      total_bytes = %src.len(),
491      sent = ?src,
492      "memberlist_quic.stream"
493    );
494
495    match ttl {
496      None => {
497        conn.write_all(src).await?;
498        conn.flush().await.map_err(Into::into).map(|_| src.len())
499      }
500      Some(ttl) => R::timeout_at(ttl, async {
501        conn.write_all(src).await?;
502        conn.flush().await.map(|_| src.len())
503      })
504      .await
505      .map_err(std::io::Error::from)?
506      .map_err(Into::into),
507    }
508  }
509
510  async fn send_to(
511    &self,
512    addr: &Self::ResolvedAddress,
513    mut src: Payload,
514  ) -> Result<(usize, <Self::Runtime as RuntimeLite>::Instant), Self::Error> {
515    let start = <Self::Runtime as RuntimeLite>::now();
516    let ttl = self.opts.timeout.map(|ttl| start + ttl);
517    let mut stream = self.fetch_stream(*addr, ttl).await?;
518    let header = src.header_mut();
519    if header.is_empty() {
520      return Err(QuicTransportError::custom(
521        "not enough space for header".into(),
522      ));
523    }
524    header[0] = StreamType::Packet as u8;
525
526    let src = src.as_slice();
527    tracing::trace!(
528      total_bytes = %src.len(),
529      sent = ?src,
530      "memberlist_quic.packet"
531    );
532
533    match ttl {
534      None => {
535        stream.write_all(src).await?;
536        stream.flush().await?;
537
538        Ok((src.len(), start))
539      }
540      Some(ttl) => R::timeout_at(ttl, async {
541        stream.write_all(src).await?;
542        stream.flush().await.map(|_| (src.len(), start))
543      })
544      .await
545      .map_err(std::io::Error::from)?
546      .map_err(Into::into),
547    }
548  }
549
550  async fn open(
551    &self,
552    addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
553    deadline: R::Instant,
554  ) -> Result<Self::Connection, Self::Error> {
555    self.fetch_stream(*addr, Some(deadline)).await
556  }
557
558  fn packet(
559    &self,
560  ) -> PacketSubscriber<<Self::Resolver as AddressResolver>::ResolvedAddress, R::Instant> {
561    self.packet_rx.clone()
562  }
563
564  fn stream(
565    &self,
566  ) -> StreamSubscriber<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Connection> {
567    self.stream_rx.clone()
568  }
569
570  #[inline]
571  fn packet_reliable(&self) -> bool {
572    true
573  }
574
575  #[inline]
576  fn packet_secure(&self) -> bool {
577    true
578  }
579
580  #[inline]
581  fn stream_secure(&self) -> bool {
582    true
583  }
584
585  async fn shutdown(&self) -> Result<(), Self::Error> {
586    if !self.shutdown_tx.close() {
587      return Ok(());
588    }
589
590    for conn in self.connection_pool.iter() {
591      let (_, conn) = conn.value();
592      let addr = conn.local_addr();
593      if let Err(e) = conn.close().await {
594        tracing::error!(err = %e, local_addr=%addr, "memberlist.transport.quic: failed to close connection");
595      }
596    }
597
598    for connector in self.v4_connectors.iter().chain(self.v6_connectors.iter()) {
599      let addr = connector.local_addr();
600      if let Err(e) = connector.close().await.map_err(Self::Error::from) {
601        tracing::error!(err = %e, local_addr=%addr, "memberlist.transport.quic: failed to close connector");
602      }
603    }
604
605    // Block until all the listener threads have died.
606    let mut handles = core::mem::take(&mut *self.handles.borrow_mut());
607    while let Some(res) = handles.next().await {
608      match res {
609        Ok(()) => {}
610        Err(e) => {
611          tracing::error!(err = %e, "memberlist.transport.quic: failed to wait listener task finish");
612        }
613      }
614    }
615    Ok(())
616  }
617}
618
619impl<I, A, S, R> Drop for QuicTransport<I, A, S, R>
620where
621  I: Id + Send + Sync + 'static,
622  A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
623  S: StreamLayer<Runtime = R>,
624  R: RuntimeLite,
625{
626  fn drop(&mut self) {
627    self.shutdown_tx.close();
628  }
629}