1#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/memberlist/main/art/logo_72x72.png")]
3#![allow(clippy::type_complexity)]
4#![forbid(unsafe_code)]
5#![deny(warnings, missing_docs)]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7#![cfg_attr(docsrs, allow(unused_attributes))]
8
9use std::{
10 net::{IpAddr, SocketAddr},
11 sync::{
12 Arc,
13 atomic::{AtomicUsize, Ordering},
14 },
15 time::Duration,
16};
17
18use agnostic_lite::{AsyncSpawner, RuntimeLite, time::Instant};
19use atomic_refcell::AtomicRefCell;
20use crossbeam_skiplist::SkipMap;
21use futures::{StreamExt, stream::FuturesUnordered};
22use memberlist_core::proto::{Data, Payload, SmallVec};
23pub use memberlist_core::{
24 proto::{CIDRsPolicy, Label, LabelError, ProtoReader},
25 transport::*,
26};
27
28mod processor;
29use processor::*;
30
31#[cfg(any(test, feature = "test"))]
33#[cfg_attr(docsrs, doc(cfg(feature = "test")))]
34pub mod tests;
35
36mod error;
37pub use error::*;
38mod options;
39pub use options::*;
40pub mod stream_layer;
42use stream_layer::*;
43
44const MAX_MESSAGE_SIZE: usize = u32::MAX as usize;
45
46const PACKET_TAG: u8 = 254;
47const STREAM_TAG: u8 = 255;
48
49#[derive(Copy, Clone)]
50#[repr(u8)]
51enum StreamType {
52 Stream = STREAM_TAG,
53 Packet = PACKET_TAG,
54}
55
56impl TryFrom<u8> for StreamType {
57 type Error = u8;
58
59 fn try_from(value: u8) -> Result<Self, Self::Error> {
60 Ok(match value {
61 STREAM_TAG => Self::Stream,
62 PACKET_TAG => Self::Packet,
63 _ => return Err(value),
64 })
65 }
66}
67
68#[cfg(feature = "tokio")]
69pub type TokioQuicTransport<I, A, S> = QuicTransport<I, A, S, agnostic_lite::tokio::TokioRuntime>;
71
72#[cfg(feature = "smol")]
73pub type SmolQuicTransport<I, A, S> = QuicTransport<I, A, S, agnostic_lite::smol::SmolRuntime>;
75
76pub struct QuicTransport<I, A, S, R>
78where
79 I: Id + Send + Sync + 'static,
80 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
81 S: StreamLayer<Runtime = R>,
82 R: RuntimeLite,
83{
84 opts: Options<I, A>,
85 advertise_addr: A::ResolvedAddress,
86 local_addr: A::Address,
87 packet_rx: PacketSubscriber<A::ResolvedAddress, R::Instant>,
88 stream_rx: StreamSubscriber<A::ResolvedAddress, S::Stream>,
89 #[allow(dead_code)]
90 stream_layer: S,
91 connection_pool: Arc<SkipMap<SocketAddr, (R::Instant, S::Connection)>>,
92 v4_round_robin: AtomicUsize,
93 v4_connectors: SmallVec<S::Connector>,
94 v6_round_robin: AtomicUsize,
95 v6_connectors: SmallVec<S::Connector>,
96 handles: AtomicRefCell<FuturesUnordered<<R::Spawner as AsyncSpawner>::JoinHandle<()>>>,
97 resolver: A,
98 shutdown_tx: async_channel::Sender<()>,
99 max_packet_size: usize,
100}
101
102impl<I, A, S, R> QuicTransport<I, A, S, R>
103where
104 I: Id + Data + Send + Sync + 'static,
105 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
106 A::Address: Send + Sync + 'static,
107 A::ResolvedAddress: Data,
108 S: StreamLayer<Runtime = R>,
109 R: RuntimeLite,
110{
111 async fn new_in(
112 resolver: A,
113 stream_layer: S,
114 opts: Options<I, A>,
115 ) -> Result<Self, QuicTransportError<A>> {
116 if opts.bind_addresses.is_empty() {
119 return Err(QuicTransportError::EmptyBindAddresses);
120 }
121
122 let (stream_tx, stream_rx) = promised_stream::<Self>();
123 let (packet_tx, packet_rx) = packet_stream::<Self>();
124 let (shutdown_tx, shutdown_rx) = async_channel::bounded(1);
125
126 let mut v4_connectors = SmallVec::with_capacity(opts.bind_addresses.len());
127 let mut v6_connectors = SmallVec::with_capacity(opts.bind_addresses.len());
128 let mut v4_acceptors = SmallVec::with_capacity(opts.bind_addresses.len());
129 let mut v6_acceptors = SmallVec::with_capacity(opts.bind_addresses.len());
130 let mut resolved_bind_address = SmallVec::new();
131
132 for addr in opts.bind_addresses.iter() {
133 let addr = resolver
134 .resolve(addr)
135 .await
136 .map_err(|e| QuicTransportError::Resolve {
137 addr: addr.cheap_clone(),
138 err: e,
139 })?;
140
141 let bind_port = addr.port();
142
143 let (local_addr, acceptor, connector) = if bind_port == 0 {
144 let mut retries = 0;
145 loop {
146 match stream_layer.bind(addr).await {
147 Ok(res) => break res,
148 Err(e) => {
149 if retries < 9 {
150 retries += 1;
151 continue;
152 }
153 return Err(QuicTransportError::Listen(addr, e));
154 }
155 }
156 }
157 } else {
158 match stream_layer.bind(addr).await {
159 Ok(res) => res,
160 Err(e) => return Err(QuicTransportError::Listen(addr, e)),
161 }
162 };
163
164 if local_addr.is_ipv4() {
165 v4_acceptors.push((local_addr, acceptor));
166 v4_connectors.push(connector);
167 } else {
168 v6_acceptors.push((local_addr, acceptor));
169 v6_connectors.push(connector);
170 }
171 let addr = if bind_port == 0 { local_addr } else { addr };
175 resolved_bind_address.push(addr);
176 }
177
178 let expose_addr_index = Self::find_advertise_addr_index(&resolved_bind_address);
179 let advertise_addr = resolved_bind_address[expose_addr_index];
180 let self_addr = opts.bind_addresses[expose_addr_index].cheap_clone();
181 let handles = FuturesUnordered::new();
182
183 for (local_addr, acceptor) in v4_acceptors.into_iter().chain(v6_acceptors.into_iter()) {
187 let processor = Processor::<A, Self, S> {
188 acceptor,
189 packet_tx: packet_tx.clone(),
190 stream_tx: stream_tx.clone(),
191 local_addr,
192 timeout: opts.timeout,
193 shutdown_rx: shutdown_rx.clone(),
194 #[cfg(feature = "metrics")]
195 metric_labels: opts.metric_labels.clone().unwrap_or_default(),
196 };
197
198 handles.push(R::spawn(processor.run()));
199 }
200
201 let final_advertise_addr = if let Some(addr) = opts.advertise_address {
203 addr
204 } else if advertise_addr.ip().is_unspecified() {
205 let ip = getifs::private_addrs()
206 .map_err(|_| QuicTransportError::NoPrivateIP)
207 .and_then(|ips| {
208 if let Some(ip) = ips.into_iter().next().map(|ip| ip.addr()) {
209 Ok(ip)
210 } else {
211 Err(QuicTransportError::NoPrivateIP)
212 }
213 })?;
214 SocketAddr::new(ip, advertise_addr.port())
215 } else {
216 advertise_addr
217 };
218
219 let connection_pool = Arc::new(SkipMap::new());
220 let interval = <A::Runtime as RuntimeLite>::interval(opts.connection_pool_cleanup_period);
221 let pool = connection_pool.clone();
222 let shutdown_rx = shutdown_rx.clone();
223 handles.push(R::spawn(Self::connection_pool_cleaner(
224 pool,
225 interval,
226 shutdown_rx,
227 opts.connection_ttl.unwrap_or(Duration::ZERO),
228 )));
229
230 Ok(Self {
231 advertise_addr: final_advertise_addr,
232 connection_pool,
233 local_addr: self_addr,
234 max_packet_size: MAX_MESSAGE_SIZE.min(stream_layer.max_stream_data()),
235 opts,
236 packet_rx,
237 stream_rx,
238 handles: AtomicRefCell::new(handles),
239 v4_connectors,
240 v6_connectors,
241 v4_round_robin: AtomicUsize::new(0),
242 v6_round_robin: AtomicUsize::new(0),
243 stream_layer,
244 resolver,
245 shutdown_tx,
246 })
247 }
248
249 fn find_advertise_addr_index(addrs: &[SocketAddr]) -> usize {
250 for (i, addr) in addrs.iter().enumerate() {
251 if !addr.ip().is_unspecified() {
252 return i;
253 }
254 }
255
256 0
257 }
258
259 async fn connection_pool_cleaner(
260 pool: Arc<SkipMap<SocketAddr, (R::Instant, S::Connection)>>,
261 mut interval: impl agnostic_lite::time::AsyncInterval,
262 shutdown_rx: async_channel::Receiver<()>,
263 max_conn_idle: Duration,
264 ) {
265 loop {
266 let fut1 = shutdown_rx.recv();
267 let fut2 = async {
268 interval.next().await;
269
270 for ent in pool.iter() {
271 let (deadline, conn) = ent.value();
272 if max_conn_idle == Duration::ZERO {
273 if conn.is_closed().await {
274 let _ = conn.close().await;
275 ent.remove();
276 }
277 continue;
278 }
279
280 if deadline.elapsed() >= max_conn_idle || conn.is_closed().await {
281 let _ = conn.close().await;
282 ent.remove();
283 }
284 }
285 };
286
287 futures::pin_mut!(fut1, fut2);
288 match futures::future::select(fut1, fut2).await {
289 futures::future::Either::Left(_) => break,
290 futures::future::Either::Right(_) => {}
291 }
292 }
293 }
294}
295
296impl<I, A, S, R> QuicTransport<I, A, S, R>
297where
298 I: Id + Send + Sync + 'static,
299 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
300 S: StreamLayer<Runtime = R>,
301 R: RuntimeLite,
302{
303 fn next_connector(&self, addr: &A::ResolvedAddress) -> &S::Connector {
304 if addr.is_ipv4() {
305 if self.v4_connectors.is_empty() {
308 let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.v6_connectors.len();
309 &self.v6_connectors[idx]
310 } else {
311 let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.v4_connectors.len();
312 &self.v4_connectors[idx]
313 }
314 } else if self.v6_connectors.is_empty() {
315 let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.v4_connectors.len();
316 &self.v4_connectors[idx]
317 } else {
318 let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.v6_connectors.len();
319 &self.v6_connectors[idx]
320 }
321 }
322
323 async fn fetch_stream(
324 &self,
325 addr: SocketAddr,
326 timeout: Option<R::Instant>,
327 ) -> Result<S::Stream, QuicTransportError<A>> {
328 if let Some(ent) = self.connection_pool.get(&addr) {
329 let (_, connection) = ent.value();
330 if !connection.is_closed().await {
331 if let Some(timeout) = timeout {
332 return R::timeout_at(timeout, connection.open_bi())
333 .await
334 .map_err(|e| QuicTransportError::Io(e.into()))?
335 .map(|(stream, _)| stream)
336 .map_err(Into::into);
337 } else {
338 return connection
339 .open_bi()
340 .await
341 .map(|(s, _)| s)
342 .map_err(Into::into);
343 }
344 }
345 }
346
347 let connector = self.next_connector(&addr);
348 let connection = connector.connect(addr).await?;
349 connection
350 .open_bi()
351 .await
352 .map(|(s, _)| {
353 self
354 .connection_pool
355 .insert(addr, (Instant::now(), connection));
356 s
357 })
358 .map_err(Into::into)
359 }
360}
361
362impl<I, A, S, R> Transport for QuicTransport<I, A, S, R>
363where
364 I: Id + Data + Send + Sync + 'static,
365 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
366 A::Address: Send + Sync + 'static,
367 A::ResolvedAddress: Data,
368 S: StreamLayer<Runtime = R>,
369 R: RuntimeLite,
370{
371 type Error = QuicTransportError<A>;
372
373 type Id = I;
374
375 type Address = A::Address;
376 type ResolvedAddress = SocketAddr;
377 type Resolver = A;
378
379 type Connection = S::Stream;
380
381 type Runtime = A::Runtime;
382
383 type Options = QuicTransportOptions<I, A, S>;
384
385 async fn new(transport_opts: Self::Options) -> Result<Self, Self::Error> {
386 let (resolver_options, stream_layer_options, opts) = transport_opts.into();
387 let resolver = <A as AddressResolver>::new(resolver_options)
388 .await
389 .map_err(Self::Error::Resolver)?;
390
391 let stream_layer = S::new(stream_layer_options).await?;
392 Self::new_in(resolver, stream_layer, opts).await
393 }
394
395 async fn resolve(
396 &self,
397 addr: &<Self::Resolver as AddressResolver>::Address,
398 ) -> Result<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Error> {
399 self
400 .resolver
401 .resolve(addr)
402 .await
403 .map_err(|e| Self::Error::Resolve {
404 addr: addr.cheap_clone(),
405 err: e,
406 })
407 }
408
409 #[inline]
410 fn local_id(&self) -> &Self::Id {
411 &self.opts.id
412 }
413
414 #[inline]
415 fn local_address(&self) -> &<Self::Resolver as AddressResolver>::Address {
416 &self.local_addr
417 }
418
419 #[inline]
420 fn advertise_address(&self) -> &<Self::Resolver as AddressResolver>::ResolvedAddress {
421 &self.advertise_addr
422 }
423
424 #[inline]
425 fn max_packet_size(&self) -> usize {
426 self.max_packet_size
427 }
428
429 #[inline]
430 fn header_overhead(&self) -> usize {
431 1
432 }
433
434 fn blocked_address(
435 &self,
436 addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
437 ) -> Result<(), Self::Error> {
438 let ip = addr.ip();
439 if self.opts.cidrs_policy.is_blocked(&ip) {
440 Err(Self::Error::BlockedIp(ip))
441 } else {
442 Ok(())
443 }
444 }
445
446 async fn read(
447 &self,
448 from: &Self::ResolvedAddress,
449 conn: &mut <Self::Connection as Connection>::Reader,
450 ) -> Result<usize, Self::Error> {
451 let mut buf = [0; 1];
452 conn.read_exact(&mut buf).await?;
453 match StreamType::try_from(buf[0]) {
454 Ok(StreamType::Stream) => Ok(1),
455 Ok(StreamType::Packet) => Err(QuicTransportError::Io(std::io::Error::new(
456 std::io::ErrorKind::InvalidData,
457 format!("receive an unexpected packet stream from {from}"),
458 ))),
459 Err(tag) => Err(QuicTransportError::Io(std::io::Error::new(
460 std::io::ErrorKind::InvalidData,
461 format!(
462 "receive a stream from {from} with invalid type value: {}",
463 tag
464 ),
465 ))),
466 }
467 }
468
469 async fn write(
470 &self,
471 conn: &mut <Self::Connection as Connection>::Writer,
472 mut src: Payload,
473 ) -> Result<usize, Self::Error> {
474 use memberlist_core::proto::ProtoWriter;
475
476 let header = src.header_mut();
477 if header.is_empty() {
478 return Err(QuicTransportError::custom(
479 "not enough space for header".into(),
480 ));
481 }
482 header[0] = StreamType::Stream as u8;
483 let ttl = self
484 .opts
485 .timeout
486 .map(|ttl| <Self::Runtime as RuntimeLite>::now() + ttl);
487
488 let src = src.as_slice();
489 tracing::trace!(
490 total_bytes = %src.len(),
491 sent = ?src,
492 "memberlist_quic.stream"
493 );
494
495 match ttl {
496 None => {
497 conn.write_all(src).await?;
498 conn.flush().await.map_err(Into::into).map(|_| src.len())
499 }
500 Some(ttl) => R::timeout_at(ttl, async {
501 conn.write_all(src).await?;
502 conn.flush().await.map(|_| src.len())
503 })
504 .await
505 .map_err(std::io::Error::from)?
506 .map_err(Into::into),
507 }
508 }
509
510 async fn send_to(
511 &self,
512 addr: &Self::ResolvedAddress,
513 mut src: Payload,
514 ) -> Result<(usize, <Self::Runtime as RuntimeLite>::Instant), Self::Error> {
515 let start = <Self::Runtime as RuntimeLite>::now();
516 let ttl = self.opts.timeout.map(|ttl| start + ttl);
517 let mut stream = self.fetch_stream(*addr, ttl).await?;
518 let header = src.header_mut();
519 if header.is_empty() {
520 return Err(QuicTransportError::custom(
521 "not enough space for header".into(),
522 ));
523 }
524 header[0] = StreamType::Packet as u8;
525
526 let src = src.as_slice();
527 tracing::trace!(
528 total_bytes = %src.len(),
529 sent = ?src,
530 "memberlist_quic.packet"
531 );
532
533 match ttl {
534 None => {
535 stream.write_all(src).await?;
536 stream.flush().await?;
537
538 Ok((src.len(), start))
539 }
540 Some(ttl) => R::timeout_at(ttl, async {
541 stream.write_all(src).await?;
542 stream.flush().await.map(|_| (src.len(), start))
543 })
544 .await
545 .map_err(std::io::Error::from)?
546 .map_err(Into::into),
547 }
548 }
549
550 async fn open(
551 &self,
552 addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
553 deadline: R::Instant,
554 ) -> Result<Self::Connection, Self::Error> {
555 self.fetch_stream(*addr, Some(deadline)).await
556 }
557
558 fn packet(
559 &self,
560 ) -> PacketSubscriber<<Self::Resolver as AddressResolver>::ResolvedAddress, R::Instant> {
561 self.packet_rx.clone()
562 }
563
564 fn stream(
565 &self,
566 ) -> StreamSubscriber<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Connection> {
567 self.stream_rx.clone()
568 }
569
570 #[inline]
571 fn packet_reliable(&self) -> bool {
572 true
573 }
574
575 #[inline]
576 fn packet_secure(&self) -> bool {
577 true
578 }
579
580 #[inline]
581 fn stream_secure(&self) -> bool {
582 true
583 }
584
585 async fn shutdown(&self) -> Result<(), Self::Error> {
586 if !self.shutdown_tx.close() {
587 return Ok(());
588 }
589
590 for conn in self.connection_pool.iter() {
591 let (_, conn) = conn.value();
592 let addr = conn.local_addr();
593 if let Err(e) = conn.close().await {
594 tracing::error!(err = %e, local_addr=%addr, "memberlist.transport.quic: failed to close connection");
595 }
596 }
597
598 for connector in self.v4_connectors.iter().chain(self.v6_connectors.iter()) {
599 let addr = connector.local_addr();
600 if let Err(e) = connector.close().await.map_err(Self::Error::from) {
601 tracing::error!(err = %e, local_addr=%addr, "memberlist.transport.quic: failed to close connector");
602 }
603 }
604
605 let mut handles = core::mem::take(&mut *self.handles.borrow_mut());
607 while let Some(res) = handles.next().await {
608 match res {
609 Ok(()) => {}
610 Err(e) => {
611 tracing::error!(err = %e, "memberlist.transport.quic: failed to wait listener task finish");
612 }
613 }
614 }
615 Ok(())
616 }
617}
618
619impl<I, A, S, R> Drop for QuicTransport<I, A, S, R>
620where
621 I: Id + Send + Sync + 'static,
622 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
623 S: StreamLayer<Runtime = R>,
624 R: RuntimeLite,
625{
626 fn drop(&mut self) {
627 self.shutdown_tx.close();
628 }
629}