Skip to main content

rumqttc_core/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::future::Future;
5use std::io;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9
10#[cfg(feature = "use-rustls-no-provider")]
11use rustls_native_certs::load_native_certs;
12use tokio::net::{TcpSocket, TcpStream, lookup_host};
13#[cfg(feature = "use-native-tls")]
14use tokio_native_tls::native_tls::TlsConnector;
15#[cfg(feature = "use-rustls-no-provider")]
16use tokio_rustls::rustls::{ClientConfig, RootCertStore};
17
18#[cfg(feature = "proxy")]
19mod proxy;
20#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
21mod tls;
22#[cfg(feature = "websocket")]
23mod websockets;
24
25#[cfg(feature = "proxy")]
26pub use proxy::{Proxy, ProxyAuth, ProxyError, ProxyType};
27#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
28pub use tls::Error as TlsError;
29#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
30pub use tls::tls_connect;
31#[cfg(all(
32    feature = "websocket",
33    feature = "use-native-tls",
34    not(feature = "use-rustls-no-provider")
35))]
36pub use tls::websocket_tls_connector;
37#[cfg(all(
38    feature = "websocket",
39    feature = "use-rustls-no-provider",
40    not(feature = "use-native-tls")
41))]
42pub use tls::websocket_tls_connector;
43#[cfg(feature = "websocket")]
44pub use websockets::{UrlError, ValidationError, WsAdapter, split_url, validate_response_headers};
45
46#[cfg(not(feature = "websocket"))]
47pub trait AsyncReadWrite:
48    tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
49{
50}
51#[cfg(not(feature = "websocket"))]
52impl<T> AsyncReadWrite for T where
53    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin
54{
55}
56
57#[cfg(feature = "websocket")]
58pub trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
59#[cfg(feature = "websocket")]
60impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
61
62pub type DynAsyncReadWrite = Box<dyn AsyncReadWrite>;
63
64/// Custom socket connector used to establish the underlying stream before optional proxy/TLS layers.
65pub type SocketConnector = Arc<
66    dyn Fn(
67            String,
68            NetworkOptions,
69        ) -> Pin<Box<dyn Future<Output = Result<DynAsyncReadWrite, io::Error>> + Send>>
70        + Send
71        + Sync,
72>;
73
74/// TLS configuration method
75#[derive(Clone, Debug)]
76#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
77pub enum TlsConfiguration {
78    #[cfg(feature = "use-rustls-no-provider")]
79    Simple {
80        /// ca certificate
81        ca: Vec<u8>,
82        /// alpn settings
83        alpn: Option<Vec<Vec<u8>>>,
84        /// tls `client_authentication`
85        client_auth: Option<(Vec<u8>, Vec<u8>)>,
86    },
87    #[cfg(feature = "use-native-tls")]
88    SimpleNative {
89        /// ca certificate
90        ca: Vec<u8>,
91        /// pkcs12 binary der and
92        /// password for use with der
93        client_auth: Option<(Vec<u8>, String)>,
94    },
95    #[cfg(feature = "use-rustls-no-provider")]
96    /// Injected rustls `ClientConfig` for TLS, to allow more customisation.
97    Rustls(Arc<ClientConfig>),
98    #[cfg(feature = "use-native-tls")]
99    /// Use default native-tls configuration
100    Native,
101    #[cfg(feature = "use-native-tls")]
102    /// Injected native-tls TlsConnector for TLS, to allow more customisation.
103    NativeConnector(TlsConnector),
104}
105
106#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
107impl TlsConfiguration {
108    #[cfg(feature = "use-rustls-no-provider")]
109    #[must_use]
110    /// Builds a rustls client configuration backed by the platform root store.
111    ///
112    /// # Panics
113    ///
114    /// Panics if loading native certificates fails or a certificate cannot be
115    /// inserted into the root store.
116    pub fn default_rustls() -> Self {
117        let mut root_cert_store = RootCertStore::empty();
118        for cert in load_native_certs().expect("could not load platform certs") {
119            root_cert_store.add(cert).unwrap();
120        }
121
122        let tls_config = ClientConfig::builder()
123            .with_root_certificates(root_cert_store)
124            .with_no_client_auth();
125
126        Self::Rustls(Arc::new(tls_config))
127    }
128
129    #[cfg(feature = "use-native-tls")]
130    #[must_use]
131    pub fn default_native() -> Self {
132        Self::Native
133    }
134}
135
136#[cfg(all(feature = "use-rustls-no-provider", not(feature = "use-native-tls")))]
137impl Default for TlsConfiguration {
138    fn default() -> Self {
139        Self::default_rustls()
140    }
141}
142
143#[cfg(all(feature = "use-native-tls", not(feature = "use-rustls-no-provider")))]
144impl Default for TlsConfiguration {
145    fn default() -> Self {
146        Self::default_native()
147    }
148}
149
150#[cfg(feature = "use-rustls-no-provider")]
151impl From<ClientConfig> for TlsConfiguration {
152    fn from(config: ClientConfig) -> Self {
153        Self::Rustls(Arc::new(config))
154    }
155}
156
157#[cfg(feature = "use-native-tls")]
158impl From<TlsConnector> for TlsConfiguration {
159    fn from(connector: TlsConnector) -> Self {
160        TlsConfiguration::NativeConnector(connector)
161    }
162}
163
164/// Provides a way to configure low level network connection configurations
165#[derive(Clone, Debug, Default)]
166pub struct NetworkOptions {
167    tcp_send_buffer_size: Option<u32>,
168    tcp_recv_buffer_size: Option<u32>,
169    tcp_nodelay: bool,
170    conn_timeout: u64,
171    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
172    bind_device: Option<String>,
173}
174
175impl NetworkOptions {
176    #[must_use]
177    pub const fn new() -> Self {
178        Self {
179            tcp_send_buffer_size: None,
180            tcp_recv_buffer_size: None,
181            tcp_nodelay: false,
182            conn_timeout: 5,
183            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
184            bind_device: None,
185        }
186    }
187
188    pub const fn set_tcp_nodelay(&mut self, nodelay: bool) {
189        self.tcp_nodelay = nodelay;
190    }
191
192    pub const fn set_tcp_send_buffer_size(&mut self, size: u32) {
193        self.tcp_send_buffer_size = Some(size);
194    }
195
196    pub const fn set_tcp_recv_buffer_size(&mut self, size: u32) {
197        self.tcp_recv_buffer_size = Some(size);
198    }
199
200    /// set connection timeout in secs
201    pub const fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self {
202        self.conn_timeout = timeout;
203        self
204    }
205
206    /// get timeout in secs
207    #[must_use]
208    pub const fn connection_timeout(&self) -> u64 {
209        self.conn_timeout
210    }
211
212    /// bind connection to a specific network device by name
213    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
214    #[cfg_attr(
215        docsrs,
216        doc(cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))
217    )]
218    pub fn set_bind_device(&mut self, bind_device: &str) -> &mut Self {
219        self.bind_device = Some(bind_device.to_string());
220        self
221    }
222}
223
224/// Default TCP socket connection logic used by the MQTT event loop.
225///
226/// This resolves the host, applies [`NetworkOptions`] on each candidate socket,
227/// and returns the first successful connection.
228///
229/// # Errors
230///
231/// Returns any DNS lookup, socket configuration, or connect error encountered.
232/// When multiple address candidates are available, the last connect error is
233/// returned if they all fail.
234pub async fn default_socket_connect(
235    host: String,
236    network_options: NetworkOptions,
237) -> io::Result<TcpStream> {
238    let addrs = lookup_host(host).await?;
239    let mut last_err = None;
240
241    for addr in addrs {
242        let socket = match addr {
243            SocketAddr::V4(_) => TcpSocket::new_v4()?,
244            SocketAddr::V6(_) => TcpSocket::new_v6()?,
245        };
246
247        socket.set_nodelay(network_options.tcp_nodelay)?;
248
249        if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
250            socket.set_send_buffer_size(send_buff_size)?;
251        }
252        if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
253            socket.set_recv_buffer_size(recv_buffer_size)?;
254        }
255
256        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
257        {
258            if let Some(bind_device) = &network_options.bind_device {
259                socket.bind_device(Some(bind_device.as_bytes()))?;
260            }
261        }
262
263        match socket.connect(addr).await {
264            Ok(s) => return Ok(s),
265            Err(e) => {
266                last_err = Some(e);
267            }
268        }
269    }
270
271    Err(last_err.unwrap_or_else(|| {
272        io::Error::new(
273            io::ErrorKind::InvalidInput,
274            "could not resolve to any address",
275        )
276    }))
277}
278
279#[cfg(test)]
280mod tests {
281    use super::TlsConfiguration;
282
283    #[cfg(all(
284        feature = "use-rustls-no-provider",
285        any(feature = "use-rustls-aws-lc", feature = "use-rustls-ring")
286    ))]
287    #[test]
288    fn default_rustls_returns_rustls_variant() {
289        assert!(matches!(
290            TlsConfiguration::default_rustls(),
291            TlsConfiguration::Rustls(_)
292        ));
293    }
294
295    #[cfg(feature = "use-native-tls")]
296    #[test]
297    fn default_native_returns_native_variant() {
298        assert!(matches!(
299            TlsConfiguration::default_native(),
300            TlsConfiguration::Native
301        ));
302    }
303}