Skip to main content

rumqttc_core/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::future::Future;
5use std::io;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10#[cfg(feature = "use-rustls-no-provider")]
11use rustls_native_certs::load_native_certs;
12use tokio::net::{TcpSocket, TcpStream, lookup_host};
13#[cfg(feature = "use-native-tls")]
14use tokio_native_tls::native_tls::TlsConnector;
15#[cfg(feature = "use-rustls-no-provider")]
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17
18#[cfg(feature = "proxy")]
19mod proxy;
20#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
21mod tls;
22#[cfg(feature = "websocket")]
23mod websockets;
24
25#[cfg(feature = "proxy")]
26pub use proxy::{Proxy, ProxyAuth, ProxyError, ProxyType};
27#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
28pub use tls::Error as TlsError;
29#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
30pub use tls::tls_connect;
31#[cfg(all(
32    feature = "websocket",
33    feature = "use-native-tls",
34    not(feature = "use-rustls-no-provider")
35))]
36pub use tls::websocket_tls_connector;
37#[cfg(all(
38    feature = "websocket",
39    feature = "use-rustls-no-provider",
40    not(feature = "use-native-tls")
41))]
42pub use tls::websocket_tls_connector;
43#[cfg(feature = "websocket")]
44pub use websockets::{UrlError, ValidationError, WsAdapter, split_url, validate_response_headers};
45
46#[cfg(not(feature = "websocket"))]
47pub trait AsyncReadWrite:
48    tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
49{
50}
51#[cfg(not(feature = "websocket"))]
52impl<T> AsyncReadWrite for T where
53    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
54{
55}
56
57#[cfg(feature = "websocket")]
58pub trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
59#[cfg(feature = "websocket")]
60impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
61
62pub type DynAsyncReadWrite = Box<dyn AsyncReadWrite>;
63
64/// Custom socket connector used to establish the underlying stream before optional proxy/TLS layers.
65pub type SocketConnector = Arc<
66    dyn Fn(
67            String,
68            NetworkOptions,
69        ) -> Pin<Box<dyn Future<Output = Result<DynAsyncReadWrite, io::Error>> + Send>>
70        + Send
71        + Sync,
72>;
73
74/// TLS configuration method
75#[derive(Clone, Debug)]
76#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
77pub enum TlsConfiguration {
78    #[cfg(feature = "use-rustls-no-provider")]
79    Simple {
80        /// ca certificate
81        ca: Vec<u8>,
82        /// alpn settings
83        alpn: Option<Vec<Vec<u8>>>,
84        /// tls `client_authentication`
85        client_auth: Option<(Vec<u8>, Vec<u8>)>,
86    },
87    #[cfg(feature = "use-native-tls")]
88    SimpleNative {
89        /// ca certificate
90        ca: Vec<u8>,
91        /// pkcs12 binary der and
92        /// password for use with der
93        client_auth: Option<(Vec<u8>, String)>,
94    },
95    #[cfg(feature = "use-rustls-no-provider")]
96    /// Injected rustls `ClientConfig` for TLS, to allow more customisation.
97    Rustls(Arc<ClientConfig>),
98    #[cfg(feature = "use-native-tls")]
99    /// Use default native-tls configuration
100    Native,
101    #[cfg(feature = "use-native-tls")]
102    /// Injected native-tls TlsConnector for TLS, to allow more customisation.
103    NativeConnector(TlsConnector),
104}
105
106#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
107impl TlsConfiguration {
108    #[cfg(feature = "use-rustls-no-provider")]
109    #[must_use]
110    /// Builds a rustls client configuration backed by the platform root store.
111    ///
112    /// # Panics
113    ///
114    /// Panics if loading native certificates fails or a certificate cannot be
115    /// inserted into the root store.
116    pub fn default_rustls() -> Self {
117        let mut root_cert_store = RootCertStore::empty();
118        for cert in load_native_certs().expect("could not load platform certs") {
119            root_cert_store.add(cert).unwrap();
120        }
121
122        let tls_config = ClientConfig::builder()
123            .with_root_certificates(root_cert_store)
124            .with_no_client_auth();
125
126        Self::Rustls(Arc::new(tls_config))
127    }
128
129    #[cfg(feature = "use-native-tls")]
130    #[must_use]
131    /// Builds a native-tls configuration from PEM CA bytes and optional
132    /// PKCS#12 client identity data.
133    pub fn simple_native(ca: Vec<u8>, client_auth: Option<(Vec<u8>, String)>) -> Self {
134        Self::SimpleNative { ca, client_auth }
135    }
136
137    #[cfg(feature = "use-native-tls")]
138    #[must_use]
139    pub fn default_native() -> Self {
140        Self::Native
141    }
142}
143
144#[cfg(all(feature = "use-rustls-no-provider", not(feature = "use-native-tls")))]
145impl Default for TlsConfiguration {
146    fn default() -> Self {
147        Self::default_rustls()
148    }
149}
150
151#[cfg(all(feature = "use-native-tls", not(feature = "use-rustls-no-provider")))]
152impl Default for TlsConfiguration {
153    fn default() -> Self {
154        Self::default_native()
155    }
156}
157
158#[cfg(feature = "use-rustls-no-provider")]
159impl From<ClientConfig> for TlsConfiguration {
160    fn from(config: ClientConfig) -> Self {
161        Self::Rustls(Arc::new(config))
162    }
163}
164
165#[cfg(feature = "use-native-tls")]
166impl From<TlsConnector> for TlsConfiguration {
167    fn from(connector: TlsConnector) -> Self {
168        TlsConfiguration::NativeConnector(connector)
169    }
170}
171
172/// Provides a way to configure low level network connection configurations
173#[derive(Clone, Debug, Default)]
174pub struct NetworkOptions {
175    tcp_send_buffer_size: Option<u32>,
176    tcp_recv_buffer_size: Option<u32>,
177    tcp_nodelay: bool,
178    conn_timeout: u64,
179    bind_addr: Option<SocketAddr>,
180    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181    bind_device: Option<String>,
182}
183
184impl NetworkOptions {
185    #[must_use]
186    pub const fn new() -> Self {
187        Self {
188            tcp_send_buffer_size: None,
189            tcp_recv_buffer_size: None,
190            tcp_nodelay: false,
191            conn_timeout: 5,
192            bind_addr: None,
193            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
194            bind_device: None,
195        }
196    }
197
198    pub const fn set_tcp_nodelay(&mut self, nodelay: bool) {
199        self.tcp_nodelay = nodelay;
200    }
201
202    pub const fn set_tcp_send_buffer_size(&mut self, size: u32) {
203        self.tcp_send_buffer_size = Some(size);
204    }
205
206    pub const fn set_tcp_recv_buffer_size(&mut self, size: u32) {
207        self.tcp_recv_buffer_size = Some(size);
208    }
209
210    /// set connection timeout in secs
211    pub const fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self {
212        self.conn_timeout = timeout;
213        self
214    }
215
216    /// get timeout in secs
217    #[must_use]
218    pub const fn connection_timeout(&self) -> u64 {
219        self.conn_timeout
220    }
221
222    /// Bind a connection to a specific local socket address.
223    ///
224    /// When the address uses a fixed nonzero port, the default multi-address
225    /// dialer avoids overlapping attempts to prevent `AddrInUse`, which means
226    /// same-family fallback attempts are no longer staggered in parallel.
227    ///
228    /// In that mode, an earlier candidate keeps the fixed port until it
229    /// completes or the overall connect timeout expires. This preserves source
230    /// port stability, but gives up happy-eyeballs-style fallback across
231    /// same-family addresses.
232    pub const fn set_bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
233        self.bind_addr = Some(bind_addr);
234        self
235    }
236
237    #[must_use]
238    pub const fn bind_addr(&self) -> Option<SocketAddr> {
239        self.bind_addr
240    }
241
242    /// bind connection to a specific network device by name
243    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
244    #[cfg_attr(
245        docsrs,
246        doc(cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))
247    )]
248    pub fn set_bind_device(&mut self, bind_device: &str) -> &mut Self {
249        self.bind_device = Some(bind_device.to_string());
250        self
251    }
252}
253
254fn configure_tcp_socket(socket: &TcpSocket, network_options: &NetworkOptions) -> io::Result<()> {
255    socket.set_nodelay(network_options.tcp_nodelay)?;
256
257    if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
258        socket.set_send_buffer_size(send_buff_size)?;
259    }
260    if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
261        socket.set_recv_buffer_size(recv_buffer_size)?;
262    }
263
264    if let Some(bind_addr) = network_options.bind_addr {
265        socket.bind(bind_addr)?;
266    }
267
268    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
269    {
270        if let Some(bind_device) = &network_options.bind_device {
271            socket.bind_device(Some(bind_device.as_bytes()))?;
272        }
273    }
274
275    Ok(())
276}
277
278/// Connects a single resolved socket address using the provided [`NetworkOptions`].
279///
280/// This is the per-address building block used by the default sequential dialer and by callers
281/// that want to apply a custom scheduling policy across multiple resolved addresses.
282///
283/// # Errors
284///
285/// Returns any socket construction, socket configuration, or connect error encountered.
286pub async fn connect_socket_addr(
287    addr: SocketAddr,
288    network_options: NetworkOptions,
289) -> io::Result<TcpStream> {
290    let socket = match addr {
291        SocketAddr::V4(_) => TcpSocket::new_v4()?,
292        SocketAddr::V6(_) => TcpSocket::new_v6()?,
293    };
294
295    configure_tcp_socket(&socket, &network_options)?;
296    socket.connect(addr).await
297}
298
299/// Default TCP socket connection logic used by the MQTT event loop.
300///
301/// This resolves the host, applies [`NetworkOptions`] on each candidate socket,
302/// and returns the first successful connection.
303///
304/// # Errors
305///
306/// Returns any DNS lookup, socket configuration, or connect error encountered.
307/// When multiple address candidates are available, the last connect error is
308/// returned if they all fail.
309pub async fn default_socket_connect(
310    host: String,
311    network_options: NetworkOptions,
312) -> io::Result<TcpStream> {
313    let addrs = lookup_host(host).await?;
314    let mut last_err = None;
315
316    for addr in addrs {
317        match connect_socket_addr(addr, network_options.clone()).await {
318            Ok(stream) => return Ok(stream),
319            Err(err) => {
320                last_err = Some(err);
321            }
322        }
323    }
324
325    Err(last_err.unwrap_or_else(|| {
326        io::Error::new(
327            io::ErrorKind::InvalidInput,
328            "could not resolve to any address",
329        )
330    }))
331}
332
333#[cfg(test)]
334mod tests {
335    #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
336    use super::TlsConfiguration;
337    use super::{NetworkOptions, connect_socket_addr, default_socket_connect};
338    use std::io;
339    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
340    use tokio::net::TcpListener;
341
342    #[cfg(all(
343        feature = "use-rustls-no-provider",
344        any(feature = "use-rustls-aws-lc", feature = "use-rustls-ring")
345    ))]
346    #[test]
347    fn default_rustls_returns_rustls_variant() {
348        assert!(matches!(
349            TlsConfiguration::default_rustls(),
350            TlsConfiguration::Rustls(_)
351        ));
352    }
353
354    #[cfg(feature = "use-native-tls")]
355    #[test]
356    fn default_native_returns_native_variant() {
357        assert!(matches!(
358            TlsConfiguration::default_native(),
359            TlsConfiguration::Native
360        ));
361    }
362
363    #[cfg(feature = "use-native-tls")]
364    #[test]
365    fn simple_native_returns_simple_native_variant() {
366        let config = TlsConfiguration::simple_native(
367            Vec::from("Test CA"),
368            Some((vec![1, 2, 3], String::from("secret"))),
369        );
370
371        assert!(matches!(
372            config,
373            TlsConfiguration::SimpleNative {
374                ca,
375                client_auth: Some((identity, password))
376            } if ca == b"Test CA" && identity == vec![1, 2, 3] && password == "secret"
377        ));
378    }
379
380    #[tokio::test]
381    async fn connect_socket_addr_succeeds_with_ipv4_bind_addr() {
382        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
383        let listener_addr = listener.local_addr().unwrap();
384
385        let accept = tokio::spawn(async move {
386            let (stream, peer_addr) = listener.accept().await.unwrap();
387            drop(stream);
388            peer_addr
389        });
390
391        let mut network_options = NetworkOptions::new();
392        network_options.set_bind_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)));
393
394        let stream = connect_socket_addr(listener_addr, network_options)
395            .await
396            .unwrap();
397        let local_addr = stream.local_addr().unwrap();
398        assert!(local_addr.ip().is_loopback());
399        drop(stream);
400
401        let peer_addr = accept.await.unwrap();
402        assert_eq!(peer_addr.ip(), local_addr.ip());
403    }
404
405    #[tokio::test]
406    async fn connect_socket_addr_returns_error_for_mismatched_bind_addr_family() {
407        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
408        let listener_addr = listener.local_addr().unwrap();
409
410        let mut network_options = NetworkOptions::new();
411        network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
412            Ipv6Addr::LOCALHOST,
413            0,
414            0,
415            0,
416        )));
417
418        let err = connect_socket_addr(listener_addr, network_options)
419            .await
420            .unwrap_err();
421        assert_ne!(err.kind(), io::ErrorKind::WouldBlock);
422    }
423
424    #[tokio::test]
425    async fn connect_socket_addr_succeeds_with_ipv6_bind_addr() {
426        let listener = match TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await {
427            Ok(listener) => listener,
428            Err(_) => return,
429        };
430        let listener_addr = listener.local_addr().unwrap();
431
432        let accept = tokio::spawn(async move {
433            let (stream, peer_addr) = listener.accept().await.unwrap();
434            drop(stream);
435            peer_addr
436        });
437
438        let mut network_options = NetworkOptions::new();
439        network_options.set_bind_addr(SocketAddr::V6(SocketAddrV6::new(
440            Ipv6Addr::LOCALHOST,
441            0,
442            0,
443            0,
444        )));
445
446        let stream = connect_socket_addr(listener_addr, network_options)
447            .await
448            .unwrap();
449        let local_addr = stream.local_addr().unwrap();
450        assert_eq!(local_addr.ip(), IpAddr::V6(Ipv6Addr::LOCALHOST));
451        drop(stream);
452
453        let peer_addr = accept.await.unwrap();
454        assert_eq!(peer_addr.ip(), local_addr.ip());
455    }
456
457    #[tokio::test]
458    async fn default_socket_connect_still_connects_without_bind_addr() {
459        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
460        let addr = listener.local_addr().unwrap();
461
462        let accept = tokio::spawn(async move {
463            let (stream, _) = listener.accept().await.unwrap();
464            drop(stream);
465        });
466
467        let stream = default_socket_connect(addr.to_string(), NetworkOptions::new())
468            .await
469            .unwrap();
470        assert!(stream.local_addr().unwrap().ip().is_loopback());
471        drop(stream);
472        accept.await.unwrap();
473    }
474
475    #[test]
476    fn bind_addr_returns_configured_socket_addr() {
477        let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1883));
478        let mut network_options = NetworkOptions::new();
479        network_options.set_bind_addr(bind_addr);
480
481        assert_eq!(network_options.bind_addr(), Some(bind_addr));
482    }
483}