Skip to main content

rumqttc_core/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::future::Future;
5use std::io;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10#[cfg(feature = "use-rustls-no-provider")]
11use rustls_native_certs::load_native_certs;
12use tokio::net::{TcpSocket, TcpStream, lookup_host};
13#[cfg(feature = "use-native-tls")]
14use tokio_native_tls::native_tls::TlsConnector;
15#[cfg(feature = "use-rustls-no-provider")]
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17
18#[cfg(feature = "proxy")]
19mod proxy;
20mod scheduler;
21#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
22mod tls;
23#[cfg(feature = "websocket")]
24mod websockets;
25
26#[cfg(feature = "proxy")]
27pub use proxy::{Proxy, ProxyAuth, ProxyError, ProxyType};
28pub use scheduler::{OutboundScheduler, RequestClass, RequestReadiness, ScheduledRequest};
29#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
30pub use tls::Error as TlsError;
31#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
32pub use tls::tls_connect;
33#[cfg(all(
34    feature = "websocket",
35    feature = "use-native-tls",
36    not(feature = "use-rustls-no-provider")
37))]
38pub use tls::websocket_tls_connector;
39#[cfg(all(
40    feature = "websocket",
41    feature = "use-rustls-no-provider",
42    not(feature = "use-native-tls")
43))]
44pub use tls::websocket_tls_connector;
45#[cfg(feature = "websocket")]
46pub use websockets::{UrlError, ValidationError, WsAdapter, split_url, validate_response_headers};
47
48#[cfg(not(feature = "websocket"))]
49pub trait AsyncReadWrite:
50    tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
51{
52}
53#[cfg(not(feature = "websocket"))]
54impl<T> AsyncReadWrite for T where
55    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
56{
57}
58
59#[cfg(feature = "websocket")]
60pub trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
61#[cfg(feature = "websocket")]
62impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
63
64pub type DynAsyncReadWrite = Box<dyn AsyncReadWrite>;
65
66/// Custom socket connector used to establish the underlying stream before optional proxy/TLS layers.
67pub type SocketConnector = Arc<
68    dyn Fn(
69            String,
70            NetworkOptions,
71        ) -> Pin<Box<dyn Future<Output = Result<DynAsyncReadWrite, io::Error>> + Send>>
72        + Send
73        + Sync,
74>;
75
76/// TLS configuration method
77#[derive(Clone, Debug)]
78#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
79pub enum TlsConfiguration {
80    #[cfg(feature = "use-rustls-no-provider")]
81    Simple {
82        /// ca certificate
83        ca: Vec<u8>,
84        /// alpn settings
85        alpn: Option<Vec<Vec<u8>>>,
86        /// tls `client_authentication`
87        client_auth: Option<(Vec<u8>, Vec<u8>)>,
88    },
89    #[cfg(feature = "use-native-tls")]
90    SimpleNative {
91        /// ca certificate
92        ca: Vec<u8>,
93        /// pkcs12 binary der and
94        /// password for use with der
95        client_auth: Option<(Vec<u8>, String)>,
96    },
97    #[cfg(feature = "use-rustls-no-provider")]
98    /// Injected rustls `ClientConfig` for TLS, to allow more customisation.
99    Rustls(Arc<ClientConfig>),
100    #[cfg(feature = "use-native-tls")]
101    /// Use default native-tls configuration
102    Native,
103    #[cfg(feature = "use-native-tls")]
104    /// Injected native-tls TlsConnector for TLS, to allow more customisation.
105    NativeConnector(TlsConnector),
106}
107
108#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
109impl TlsConfiguration {
110    #[cfg(feature = "use-rustls-no-provider")]
111    #[must_use]
112    /// Builds a rustls client configuration backed by the platform root store.
113    ///
114    /// # Panics
115    ///
116    /// Panics if loading native certificates fails or a certificate cannot be
117    /// inserted into the root store.
118    pub fn default_rustls() -> Self {
119        let mut root_cert_store = RootCertStore::empty();
120        for cert in load_native_certs().expect("could not load platform certs") {
121            root_cert_store.add(cert).unwrap();
122        }
123
124        let tls_config = ClientConfig::builder()
125            .with_root_certificates(root_cert_store)
126            .with_no_client_auth();
127
128        Self::Rustls(Arc::new(tls_config))
129    }
130
131    #[cfg(feature = "use-native-tls")]
132    #[must_use]
133    /// Builds a native-tls configuration from PEM CA bytes and optional
134    /// PKCS#12 client identity data.
135    pub fn simple_native(ca: Vec<u8>, client_auth: Option<(Vec<u8>, String)>) -> Self {
136        Self::SimpleNative { ca, client_auth }
137    }
138
139    #[cfg(feature = "use-native-tls")]
140    #[must_use]
141    pub fn default_native() -> Self {
142        Self::Native
143    }
144}
145
146#[cfg(all(feature = "use-rustls-no-provider", not(feature = "use-native-tls")))]
147impl Default for TlsConfiguration {
148    fn default() -> Self {
149        Self::default_rustls()
150    }
151}
152
153#[cfg(all(feature = "use-native-tls", not(feature = "use-rustls-no-provider")))]
154impl Default for TlsConfiguration {
155    fn default() -> Self {
156        Self::default_native()
157    }
158}
159
160#[cfg(feature = "use-rustls-no-provider")]
161impl From<ClientConfig> for TlsConfiguration {
162    fn from(config: ClientConfig) -> Self {
163        Self::Rustls(Arc::new(config))
164    }
165}
166
167#[cfg(feature = "use-native-tls")]
168impl From<TlsConnector> for TlsConfiguration {
169    fn from(connector: TlsConnector) -> Self {
170        TlsConfiguration::NativeConnector(connector)
171    }
172}
173
174/// Provides a way to configure low level network connection configurations
175#[derive(Clone, Debug, Default)]
176pub struct NetworkOptions {
177    tcp_send_buffer_size: Option<u32>,
178    tcp_recv_buffer_size: Option<u32>,
179    tcp_nodelay: bool,
180    conn_timeout: u64,
181    bind_addr: Option<SocketAddr>,
182    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
183    bind_device: Option<String>,
184}
185
186impl NetworkOptions {
187    #[must_use]
188    pub const fn new() -> Self {
189        Self {
190            tcp_send_buffer_size: None,
191            tcp_recv_buffer_size: None,
192            tcp_nodelay: false,
193            conn_timeout: 5,
194            bind_addr: None,
195            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
196            bind_device: None,
197        }
198    }
199
200    pub const fn set_tcp_nodelay(&mut self, nodelay: bool) {
201        self.tcp_nodelay = nodelay;
202    }
203
204    pub const fn set_tcp_send_buffer_size(&mut self, size: u32) {
205        self.tcp_send_buffer_size = Some(size);
206    }
207
208    pub const fn set_tcp_recv_buffer_size(&mut self, size: u32) {
209        self.tcp_recv_buffer_size = Some(size);
210    }
211
212    /// set connection timeout in secs
213    pub const fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self {
214        self.conn_timeout = timeout;
215        self
216    }
217
218    /// get timeout in secs
219    #[must_use]
220    pub const fn connection_timeout(&self) -> u64 {
221        self.conn_timeout
222    }
223
224    /// Bind a connection to a specific local socket address.
225    ///
226    /// When the address uses a fixed nonzero port, the default multi-address
227    /// dialer avoids overlapping attempts to prevent `AddrInUse`, which means
228    /// same-family fallback attempts are no longer staggered in parallel.
229    ///
230    /// In that mode, an earlier candidate keeps the fixed port until it
231    /// completes or the overall connect timeout expires. This preserves source
232    /// port stability, but gives up happy-eyeballs-style fallback across
233    /// same-family addresses.
234    pub const fn set_bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
235        self.bind_addr = Some(bind_addr);
236        self
237    }
238
239    #[must_use]
240    pub const fn bind_addr(&self) -> Option<SocketAddr> {
241        self.bind_addr
242    }
243
244    /// bind connection to a specific network device by name
245    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
246    #[cfg_attr(
247        docsrs,
248        doc(cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))
249    )]
250    pub fn set_bind_device(&mut self, bind_device: &str) -> &mut Self {
251        self.bind_device = Some(bind_device.to_string());
252        self
253    }
254}
255
256fn configure_tcp_socket(socket: &TcpSocket, network_options: &NetworkOptions) -> io::Result<()> {
257    socket.set_nodelay(network_options.tcp_nodelay)?;
258
259    if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
260        socket.set_send_buffer_size(send_buff_size)?;
261    }
262    if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
263        socket.set_recv_buffer_size(recv_buffer_size)?;
264    }
265
266    if let Some(bind_addr) = network_options.bind_addr {
267        socket.bind(bind_addr)?;
268    }
269
270    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
271    {
272        if let Some(bind_device) = &network_options.bind_device {
273            socket.bind_device(Some(bind_device.as_bytes()))?;
274        }
275    }
276
277    Ok(())
278}
279
280/// Connects a single resolved socket address using the provided [`NetworkOptions`].
281///
282/// This is the per-address building block used by the default sequential dialer and by callers
283/// that want to apply a custom scheduling policy across multiple resolved addresses.
284///
285/// # Errors
286///
287/// Returns any socket construction, socket configuration, or connect error encountered.
288pub async fn connect_socket_addr(
289    addr: SocketAddr,
290    network_options: NetworkOptions,
291) -> io::Result<TcpStream> {
292    let socket = match addr {
293        SocketAddr::V4(_) => TcpSocket::new_v4()?,
294        SocketAddr::V6(_) => TcpSocket::new_v6()?,
295    };
296
297    configure_tcp_socket(&socket, &network_options)?;
298    socket.connect(addr).await
299}
300
301/// Default TCP socket connection logic used by the MQTT event loop.
302///
303/// This resolves the host, applies [`NetworkOptions`] on each candidate socket,
304/// and returns the first successful connection.
305///
306/// # Errors
307///
308/// Returns any DNS lookup, socket configuration, or connect error encountered.
309/// When multiple address candidates are available, the last connect error is
310/// returned if they all fail.
311pub async fn default_socket_connect(
312    host: String,
313    network_options: NetworkOptions,
314) -> io::Result<TcpStream> {
315    let addrs = lookup_host(host).await?;
316    let mut last_err = None;
317
318    for addr in addrs {
319        match connect_socket_addr(addr, network_options.clone()).await {
320            Ok(stream) => return Ok(stream),
321            Err(err) => {
322                last_err = Some(err);
323            }
324        }
325    }
326
327    Err(last_err.unwrap_or_else(|| {
328        io::Error::new(
329            io::ErrorKind::InvalidInput,
330            "could not resolve to any address",
331        )
332    }))
333}
334
335#[cfg(test)]
336mod tests {
337    #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
338    use super::TlsConfiguration;
339    use super::{NetworkOptions, connect_socket_addr, default_socket_connect};
340    use std::io;
341    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
342    use tokio::net::TcpListener;
343
344    #[cfg(all(
345        feature = "use-rustls-no-provider",
346        any(feature = "use-rustls-aws-lc", feature = "use-rustls-ring")
347    ))]
348    #[test]
349    fn default_rustls_returns_rustls_variant() {
350        assert!(matches!(
351            TlsConfiguration::default_rustls(),
352            TlsConfiguration::Rustls(_)
353        ));
354    }
355
356    #[cfg(feature = "use-native-tls")]
357    #[test]
358    fn default_native_returns_native_variant() {
359        assert!(matches!(
360            TlsConfiguration::default_native(),
361            TlsConfiguration::Native
362        ));
363    }
364
365    #[cfg(feature = "use-native-tls")]
366    #[test]
367    fn simple_native_returns_simple_native_variant() {
368        let config = TlsConfiguration::simple_native(
369            Vec::from("Test CA"),
370            Some((vec![1, 2, 3], String::from("secret"))),
371        );
372
373        assert!(matches!(
374            config,
375            TlsConfiguration::SimpleNative {
376                ca,
377                client_auth: Some((identity, password))
378            } if ca == b"Test CA" && identity == vec![1, 2, 3] && password == "secret"
379        ));
380    }
381
382    #[tokio::test]
383    async fn connect_socket_addr_succeeds_with_ipv4_bind_addr() {
384        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
385        let listener_addr = listener.local_addr().unwrap();
386
387        let accept = tokio::spawn(async move {
388            let (stream, peer_addr) = listener.accept().await.unwrap();
389            drop(stream);
390            peer_addr
391        });
392
393        let mut network_options = NetworkOptions::new();
394        network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)));
395
396        let stream = connect_socket_addr(listener_addr, network_options)
397            .await
398            .unwrap();
399        let local_addr = stream.local_addr().unwrap();
400        assert!(local_addr.ip().is_loopback());
401        drop(stream);
402
403        let peer_addr = accept.await.unwrap();
404        assert_eq!(peer_addr.ip(), local_addr.ip());
405    }
406
407    #[tokio::test]
408    async fn connect_socket_addr_returns_error_for_mismatched_bind_addr_family() {
409        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
410        let listener_addr = listener.local_addr().unwrap();
411
412        let mut network_options = NetworkOptions::new();
413        network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
414            Ipv6Addr::LOCALHOST,
415            0,
416            0,
417            0,
418        )));
419
420        let err = connect_socket_addr(listener_addr, network_options)
421            .await
422            .unwrap_err();
423        assert_ne!(err.kind(), io::ErrorKind::WouldBlock);
424    }
425
426    #[tokio::test]
427    async fn connect_socket_addr_succeeds_with_ipv6_bind_addr() {
428        let listener = match TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await {
429            Ok(listener) => listener,
430            Err(_) => return,
431        };
432        let listener_addr = listener.local_addr().unwrap();
433
434        let accept = tokio::spawn(async move {
435            let (stream, peer_addr) = listener.accept().await.unwrap();
436            drop(stream);
437            peer_addr
438        });
439
440        let mut network_options = NetworkOptions::new();
441        network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
442            Ipv6Addr::LOCALHOST,
443            0,
444            0,
445            0,
446        )));
447
448        let stream = connect_socket_addr(listener_addr, network_options)
449            .await
450            .unwrap();
451        let local_addr = stream.local_addr().unwrap();
452        assert_eq!(local_addr.ip(), IpAddr::V6(Ipv6Addr::LOCALHOST));
453        drop(stream);
454
455        let peer_addr = accept.await.unwrap();
456        assert_eq!(peer_addr.ip(), local_addr.ip());
457    }
458
459    #[tokio::test]
460    async fn default_socket_connect_still_connects_without_bind_addr() {
461        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
462        let addr = listener.local_addr().unwrap();
463
464        let accept = tokio::spawn(async move {
465            let (stream, _) = listener.accept().await.unwrap();
466            drop(stream);
467        });
468
469        let stream = default_socket_connect(addr.to_string(), NetworkOptions::new())
470            .await
471            .unwrap();
472        assert!(stream.local_addr().unwrap().ip().is_loopback());
473        drop(stream);
474        accept.await.unwrap();
475    }
476
477    #[test]
478    fn bind_addr_returns_configured_socket_addr() {
479        let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1883));
480        let mut network_options = NetworkOptions::new();
481        network_options.set_bind_addr(bind_addr);
482
483        assert_eq!(network_options.bind_addr(), Some(bind_addr));
484    }
485}