1#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/memberlist/main/art/logo_72x72.png")]
3#![allow(clippy::type_complexity)]
4#![deny(missing_docs, warnings)]
5#![forbid(unsafe_code)]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7#![cfg_attr(docsrs, allow(unused_attributes))]
8
9use std::{
10 net::SocketAddr,
11 sync::{
12 Arc,
13 atomic::{AtomicBool, AtomicUsize, Ordering},
14 },
15};
16
17use agnostic::{
18 AsyncSpawner, Runtime, RuntimeLite,
19 net::{Net, UdpSocket},
20};
21use atomic_refcell::AtomicRefCell;
22use futures::{StreamExt, stream::FuturesUnordered};
23use memberlist_core::proto::{Data, Payload, SmallVec};
24pub use memberlist_core::{
25 proto::{CIDRsPolicy, Label, LabelError},
26 transport::*,
27};
28
29mod options;
30pub use options::*;
31
32mod promised_processor;
33use promised_processor::*;
34mod packet_processor;
35use packet_processor::*;
36
37pub mod error;
39use error::*;
40
41pub mod stream_layer;
43use stream_layer::*;
44
45pub mod resolver {
47 #[cfg(feature = "dns")]
48 pub use nodecraft::resolver::dns;
49 pub use nodecraft::resolver::{address, socket_addr};
50}
51
52#[cfg(any(test, feature = "test"))]
54#[cfg_attr(docsrs, doc(cfg(feature = "test")))]
55pub mod tests;
56
57const DEFAULT_UDP_RECV_BUF_SIZE: usize = 2 * 1024 * 1024;
60
61#[cfg(feature = "tokio")]
62pub type TokioNetTransport<I, A, S> = NetTransport<I, A, S, agnostic::tokio::TokioRuntime>;
64
65#[cfg(feature = "smol")]
66pub type SmolNetTransport<I, A, S> = NetTransport<I, A, S, agnostic::smol::SmolRuntime>;
68
69pub struct NetTransport<I, A, S, R>
71where
72 I: Id,
73 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
74 S: StreamLayer<Runtime = R>,
75 R: Runtime,
76{
77 opts: Arc<Options<I, A>>,
78 advertise_addr: A::ResolvedAddress,
79 local_addr: A::Address,
80 packet_rx: PacketSubscriber<A::ResolvedAddress, R::Instant>,
81 stream_rx: StreamSubscriber<A::ResolvedAddress, S::Stream>,
82 num_v4_sockets: usize,
83 v4_round_robin: AtomicUsize,
84 v4_sockets: AtomicRefCell<SmallVec<Arc<<<A::Runtime as Runtime>::Net as Net>::UdpSocket>>>,
85 num_v6_sockets: usize,
86 v6_round_robin: AtomicUsize,
87 v6_sockets: AtomicRefCell<SmallVec<Arc<<<A::Runtime as Runtime>::Net as Net>::UdpSocket>>>,
88 stream_layer: Arc<S>,
89 handles: AtomicRefCell<FuturesUnordered<<R::Spawner as AsyncSpawner>::JoinHandle<()>>>,
90 resolver: Arc<A>,
91 shutdown_tx: async_channel::Sender<()>,
92}
93
94impl<I, A, S, R> NetTransport<I, A, S, R>
95where
96 I: Id + Data + Send + Sync + 'static,
97 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
98 A::Address: Send + Sync + 'static,
99 A::ResolvedAddress: Data,
100 S: StreamLayer<Runtime = R>,
101 R: Runtime,
102{
103 fn find_advertise_addr_index(addrs: &[SocketAddr]) -> usize {
104 for (i, addr) in addrs.iter().enumerate() {
105 if !addr.ip().is_unspecified() {
106 return i;
107 }
108 }
109
110 0
111 }
112
113 fn next_socket(
114 &self,
115 addr: &A::ResolvedAddress,
116 ) -> Option<Arc<<<A::Runtime as Runtime>::Net as Net>::UdpSocket>> {
117 enum Kind {
118 V4(usize),
119 V6(usize),
120 }
121
122 let kind = if addr.is_ipv4() {
123 if self.num_v4_sockets == 0 {
126 let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v6_sockets;
127 Kind::V6(idx)
128 } else {
129 let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v4_sockets;
130 Kind::V4(idx)
131 }
132 } else if self.num_v6_sockets == 0 {
133 let idx = self.v4_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v4_sockets;
134 Kind::V4(idx)
135 } else {
136 let idx = self.v6_round_robin.fetch_add(1, Ordering::AcqRel) % self.num_v6_sockets;
137 Kind::V6(idx)
138 };
139
140 match kind {
143 Kind::V4(idx) => {
144 if let Ok(sockets) = self.v4_sockets.try_borrow() {
145 Some(sockets[idx].clone())
146 } else {
147 None
148 }
149 }
150 Kind::V6(idx) => {
151 if let Ok(sockets) = self.v6_sockets.try_borrow() {
152 Some(sockets[idx].clone())
153 } else {
154 None
155 }
156 }
157 }
158 }
159}
160
161impl<I, A, S, R> Transport for NetTransport<I, A, S, R>
162where
163 I: Id + Data + Send + Sync + 'static,
164 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
165 A::Address: Data + Send + Sync + 'static,
166 S: StreamLayer<Runtime = R>,
167 R: Runtime,
168{
169 type Error = NetTransportError<Self::Resolver>;
170
171 type Id = I;
172 type Address = A::Address;
173 type ResolvedAddress = SocketAddr;
174 type Resolver = A;
175
176 type Connection = S::Stream;
177
178 type Runtime = <Self::Resolver as AddressResolver>::Runtime;
179
180 type Options = NetTransportOptions<Self::Id, Self::Resolver, S>;
181
182 async fn new(transport_opts: Self::Options) -> Result<Self, Self::Error> {
183 let (resolver_opts, stream_layer_opts, opts) = transport_opts.into();
184 let resolver = Arc::new(
185 <A as AddressResolver>::new(resolver_opts)
186 .await
187 .map_err(NetTransportError::Resolver)?,
188 );
189
190 let stream_layer = Arc::new(<S as StreamLayer>::new(stream_layer_opts).await?);
191 let opts = Arc::new(opts);
192
193 if opts.bind_addresses.is_empty() {
196 return Err(NetTransportError::EmptyBindAddresses);
197 }
198
199 let (stream_tx, stream_rx) = promised_stream::<Self>();
200 let (packet_tx, packet_rx) = packet_stream::<Self>();
201 let (shutdown_tx, shutdown_rx) = async_channel::bounded(1);
202
203 let mut v4_promised_listeners = Vec::with_capacity(opts.bind_addresses.len());
204 let mut v4_sockets = Vec::with_capacity(opts.bind_addresses.len());
205 let mut v6_promised_listeners = Vec::with_capacity(opts.bind_addresses.len());
206 let mut v6_sockets = Vec::with_capacity(opts.bind_addresses.len());
207 let mut resolved_bind_address = SmallVec::new();
208
209 for addr in opts.bind_addresses.iter() {
210 let addr = resolver
211 .resolve(addr)
212 .await
213 .map_err(|e| NetTransportError::Resolve {
214 addr: addr.clone(),
215 err: e,
216 })?;
217 let bind_port = addr.port();
218
219 let (local_addr, ln) = if bind_port == 0 {
220 let mut retries = 0;
221 loop {
222 match stream_layer.bind(addr).await {
223 Ok(ln) => break (ln.local_addr(), ln),
224 Err(e) => {
225 if retries < 9 {
226 retries += 1;
227 continue;
228 }
229 return Err(NetTransportError::ListenPromised(addr, e));
230 }
231 }
232 }
233 } else {
234 match stream_layer.bind(addr).await {
235 Ok(ln) => (ln.local_addr(), ln),
236 Err(e) => return Err(NetTransportError::ListenPromised(addr, e)),
237 }
238 };
239
240 if local_addr.is_ipv4() {
241 v4_promised_listeners.push((Arc::new(ln), local_addr));
242 } else {
243 v6_promised_listeners.push((Arc::new(ln), local_addr));
244 }
245 let addr = if bind_port == 0 { local_addr } else { addr };
249 resolved_bind_address.push(addr);
250
251 let (local_addr, packet_socket) =
252 <<<A::Runtime as Runtime>::Net as Net>::UdpSocket as UdpSocket>::bind(addr)
253 .await
254 .map(|ln| (addr, ln))
255 .map_err(|e| NetTransportError::ListenPacket(addr, e))?;
256
257 set_udp_recv_buffer(&packet_socket, opts.recv_buffer_size)?;
258
259 if local_addr.is_ipv4() {
260 v4_sockets.push((Arc::new(packet_socket), local_addr));
261 } else {
262 v6_sockets.push((Arc::new(packet_socket), local_addr))
263 }
264 }
265
266 let expose_addr_index = Self::find_advertise_addr_index(&resolved_bind_address);
267 let advertise_addr = resolved_bind_address[expose_addr_index];
268 let self_addr = opts.bind_addresses[expose_addr_index].cheap_clone();
269 let shutdown = Arc::new(AtomicBool::new(false));
270 let handles = FuturesUnordered::new();
271 for ((promised_ln, promised_addr), (socket, socket_addr)) in v4_promised_listeners
275 .iter()
276 .zip(v4_sockets.iter())
277 .chain(v6_promised_listeners.iter().zip(v6_sockets.iter()))
278 {
279 let processor = PromisedProcessor::<A, Self, S> {
280 stream_tx: stream_tx.clone(),
281 ln: promised_ln.clone(),
282 shutdown_rx: shutdown_rx.clone(),
283 local_addr: *promised_addr,
284 };
285 handles.push(R::spawn(processor.run()));
286
287 let processor = PacketProcessor::<A, Self> {
288 packet_tx: packet_tx.clone(),
289 socket: socket.clone(),
290 local_addr: *socket_addr,
291 shutdown: shutdown.clone(),
292 #[cfg(feature = "metrics")]
293 metric_labels: opts.metric_labels.clone().unwrap_or_default(),
294 shutdown_rx: shutdown_rx.clone(),
295 };
296
297 handles.push(R::spawn(processor.run()));
298 }
299
300 let final_advertise_addr = if let Some(addr) = opts.advertise_address {
302 addr
303 } else if advertise_addr.ip().is_unspecified() {
304 let ip = getifs::private_addrs()
305 .map_err(|_| NetTransportError::NoPrivateIP)
306 .and_then(|ips| {
307 if let Some(ip) = ips.into_iter().next().map(|ip| ip.addr()) {
308 Ok(ip)
309 } else {
310 Err(NetTransportError::NoPrivateIP)
311 }
312 })?;
313 SocketAddr::new(ip, advertise_addr.port())
314 } else {
315 advertise_addr
316 };
317
318 Ok(Self {
338 advertise_addr: final_advertise_addr,
339 local_addr: self_addr,
340 opts,
341 packet_rx,
342 stream_rx,
343 handles: AtomicRefCell::new(handles),
344 num_v4_sockets: v4_sockets.len(),
345 v4_sockets: AtomicRefCell::new(v4_sockets.into_iter().map(|(ln, _)| ln).collect()),
346 v4_round_robin: AtomicUsize::new(0),
347 num_v6_sockets: v6_sockets.len(),
348 v6_sockets: AtomicRefCell::new(v6_sockets.into_iter().map(|(ln, _)| ln).collect()),
349 v6_round_robin: AtomicUsize::new(0),
350 stream_layer,
351 resolver,
352 shutdown_tx,
353 })
354 }
355
356 async fn resolve(
357 &self,
358 addr: &<Self::Resolver as AddressResolver>::Address,
359 ) -> Result<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Error> {
360 self
361 .resolver
362 .resolve(addr)
363 .await
364 .map_err(|e| Self::Error::Resolve {
365 addr: addr.cheap_clone(),
366 err: e,
367 })
368 }
369
370 #[inline]
371 fn local_id(&self) -> &Self::Id {
372 &self.opts.id
373 }
374
375 #[inline]
376 fn local_address(&self) -> &<Self::Resolver as AddressResolver>::Address {
377 &self.local_addr
378 }
379
380 #[inline]
381 fn advertise_address(&self) -> &<Self::Resolver as AddressResolver>::ResolvedAddress {
382 &self.advertise_addr
383 }
384
385 #[inline]
386 fn max_packet_size(&self) -> usize {
387 self.opts.max_packet_size
388 }
389
390 #[inline]
391 fn header_overhead(&self) -> usize {
392 0
393 }
394
395 fn blocked_address(
396 &self,
397 addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
398 ) -> Result<(), Self::Error> {
399 let ip = addr.ip();
400 if self.opts.cidrs_policy.is_blocked(&ip) {
401 Err(Self::Error::BlockedIp(ip))
402 } else {
403 Ok(())
404 }
405 }
406
407 async fn send_to(
408 &self,
409 addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
410 packets: Payload,
411 ) -> Result<(usize, <Self::Runtime as RuntimeLite>::Instant), Self::Error> {
412 let start = <Self::Runtime as RuntimeLite>::now();
413
414 let src = packets.as_slice();
415 match self.next_socket(addr) {
416 Some(skt) => skt
417 .send_to(src, addr)
418 .await
419 .map(|num| {
420 tracing::trace!(remote=%addr, total_bytes = %num, sent=?src, "memberlist_net.packet");
421 (num, start)
422 })
423 .map_err(Into::into),
424 None => {
425 tracing::error!("memberlist_net.packet: transport is being shutdown");
426 Err(
427 std::io::Error::new(
428 std::io::ErrorKind::ConnectionAborted,
429 "transport is being shutdown",
430 )
431 .into(),
432 )
433 }
434 }
435 }
436
437 async fn open(
438 &self,
439 addr: &<Self::Resolver as AddressResolver>::ResolvedAddress,
440 deadline: <Self::Runtime as RuntimeLite>::Instant,
441 ) -> Result<Self::Connection, Self::Error> {
442 let connector =
443 <Self::Runtime as RuntimeLite>::timeout_at(deadline, self.stream_layer.connect(*addr));
444 match connector.await {
445 Ok(Ok(conn)) => Ok(conn),
446 Ok(Err(e)) => Err(e.into()),
447 Err(e) => Err(Self::Error::Io(e.into())),
448 }
449 }
450
451 fn packet(
452 &self,
453 ) -> PacketSubscriber<
454 <Self::Resolver as AddressResolver>::ResolvedAddress,
455 <Self::Runtime as RuntimeLite>::Instant,
456 > {
457 self.packet_rx.clone()
458 }
459
460 fn stream(
461 &self,
462 ) -> StreamSubscriber<<Self::Resolver as AddressResolver>::ResolvedAddress, Self::Connection> {
463 self.stream_rx.clone()
464 }
465
466 fn packet_reliable(&self) -> bool {
467 false
468 }
469
470 fn packet_secure(&self) -> bool {
471 false
472 }
473
474 fn stream_secure(&self) -> bool {
475 S::is_secure()
476 }
477
478 async fn shutdown(&self) -> Result<(), Self::Error> {
479 if !self.shutdown_tx.close() {
480 return Ok(());
481 }
482
483 loop {
485 if let Ok(mut s) = self.v4_sockets.try_borrow_mut() {
486 s.clear();
487 break;
488 }
489 }
490
491 loop {
492 if let Ok(mut s) = self.v6_sockets.try_borrow_mut() {
493 s.clear();
494 break;
495 }
496 }
497
498 let mut handles = core::mem::take(&mut *self.handles.borrow_mut());
499 while handles.next().await.is_some() {}
500 Ok(())
501 }
502}
503
504impl<I, A, S, R> Drop for NetTransport<I, A, S, R>
505where
506 I: Id,
507 A: AddressResolver<ResolvedAddress = SocketAddr, Runtime = R>,
508 S: StreamLayer<Runtime = R>,
509 R: Runtime,
510{
511 fn drop(&mut self) {
512 self.shutdown_tx.close();
513 }
514}
515
516fn set_udp_recv_buffer<U>(udp: &U, mut size: usize) -> std::io::Result<()>
520where
521 U: agnostic::net::UdpSocket,
522{
523 let mut err = None;
524 while size > 0 {
525 match udp.set_recv_buffer_size(size) {
526 Ok(_) => return Ok(()),
527 Err(e) => err = Some(e),
528 }
529 size /= 2;
530 }
531
532 Err(
533 err.unwrap_or_else(|| std::io::Error::other("fail to set receive buffer size for UDP socket")),
534 )
535}