rbdc/net/tls/
mod.rs

1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use crate::rt::{AsyncRead, AsyncWrite, TlsStream};
10
11use crate::Error;
12use std::mem::replace;
13use rbs::err_protocol;
14
15/// X.509 Certificate input, either a file path or a PEM encoded inline certificate(s).
16#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
17pub enum CertificateInput {
18    /// PEM encoded certificate(s)
19    Inline(Vec<u8>),
20    /// Path to a file containing PEM encoded certificate(s)
21    File(PathBuf),
22}
23
24impl From<String> for CertificateInput {
25    fn from(value: String) -> Self {
26        let trimmed = value.trim();
27        // Some heuristics according to https://tools.ietf.org/html/rfc7468
28        if trimmed.starts_with("-----BEGIN CERTIFICATE-----")
29            && trimmed.contains("-----END CERTIFICATE-----")
30        {
31            CertificateInput::Inline(value.as_bytes().to_vec())
32        } else {
33            CertificateInput::File(PathBuf::from(value))
34        }
35    }
36}
37
38impl CertificateInput {
39    async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
40        use crate::rt::fs;
41        match self {
42            CertificateInput::Inline(v) => Ok(v.clone()),
43            CertificateInput::File(path) => fs::read(path).await,
44        }
45    }
46}
47
48impl std::fmt::Display for CertificateInput {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
52            CertificateInput::File(path) => write!(f, "file: {}", path.display()),
53        }
54    }
55}
56
57#[cfg(feature = "tls-rustls")]
58mod rustls;
59
60pub enum MaybeTlsStream<S>
61where
62    S: AsyncRead + AsyncWrite + Unpin,
63{
64    Raw(S),
65    Tls(TlsStream<S>),
66    Upgrading,
67}
68
69impl<S> MaybeTlsStream<S>
70where
71    S: AsyncRead + AsyncWrite + Unpin,
72{
73    #[inline]
74    pub fn is_tls(&self) -> bool {
75        matches!(self, Self::Tls(_))
76    }
77
78    pub async fn upgrade(
79        &mut self,
80        host: &str,
81        accept_invalid_certs: bool,
82        accept_invalid_hostnames: bool,
83        root_cert_path: Option<&CertificateInput>,
84    ) -> Result<(), Error> {
85        let connector = configure_tls_connector(
86            accept_invalid_certs,
87            accept_invalid_hostnames,
88            root_cert_path,
89        )
90        .await?;
91
92        let stream = match replace(self, MaybeTlsStream::Upgrading) {
93            MaybeTlsStream::Raw(stream) => stream,
94
95            MaybeTlsStream::Tls(_) => {
96                // ignore upgrade, we are already a TLS connection
97                return Ok(());
98            }
99
100            MaybeTlsStream::Upgrading => {
101                // we previously failed to upgrade and now hold no connection
102                // this should only happen from an internal misuse of this method
103                return Err(Error::from(io::ErrorKind::ConnectionAborted.to_string()));
104            }
105        };
106
107        #[cfg(feature = "tls-rustls")]
108        let host = tokio_rustls::rustls::pki_types::ServerName::try_from(host.to_string())
109            .map_err(|err| Error::from(err.to_string()))?;
110
111        *self = MaybeTlsStream::Tls(connector.connect(host, stream).await.map_err(|err| err_protocol!("{}", err))?);
112
113        Ok(())
114    }
115}
116
117#[cfg(feature = "tls-native-tls")]
118async fn configure_tls_connector(
119    accept_invalid_certs: bool,
120    accept_invalid_hostnames: bool,
121    root_cert_path: Option<&CertificateInput>,
122) -> Result<crate::rt::TlsConnector, Error> {
123    use crate::rt::native_tls::{Certificate, TlsConnector};
124
125    let mut builder = TlsConnector::builder();
126    builder
127        .danger_accept_invalid_certs(accept_invalid_certs)
128        .danger_accept_invalid_hostnames(accept_invalid_hostnames);
129
130    if !accept_invalid_certs {
131        if let Some(ca) = root_cert_path {
132            let data = ca.data().await.map_err(|err| err_protocol!("{}", err))?;
133            let cert = Certificate::from_pem(&data).map_err(|err| err_protocol!("{}", err))?;
134
135            builder.add_root_certificate(cert);
136        }
137    }
138    let connector = builder.build().map_err(|err| err_protocol!("{}", err))?.into();
139
140    Ok(connector)
141}
142
143#[cfg(feature = "tls-rustls")]
144use self::rustls::configure_tls_connector;
145
146impl<S> AsyncRead for MaybeTlsStream<S>
147where
148    S: Unpin + AsyncWrite + AsyncRead,
149{
150    fn poll_read(
151        mut self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153        buf: &mut super::PollReadBuf<'_>,
154    ) -> Poll<io::Result<super::PollReadOut>> {
155        match &mut *self {
156            MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
157            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
158
159            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
160        }
161    }
162}
163
164impl<S> AsyncWrite for MaybeTlsStream<S>
165where
166    S: Unpin + AsyncWrite + AsyncRead,
167{
168    fn poll_write(
169        mut self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171        buf: &[u8],
172    ) -> Poll<io::Result<usize>> {
173        match &mut *self {
174            MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
175            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
176
177            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
178        }
179    }
180
181    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182        match &mut *self {
183            MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
184            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
185
186            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
187        }
188    }
189
190    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
191        match &mut *self {
192            MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
193            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
194
195            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
196        }
197    }
198}
199
200impl<S> Deref for MaybeTlsStream<S>
201where
202    S: Unpin + AsyncWrite + AsyncRead,
203{
204    type Target = S;
205
206    fn deref(&self) -> &Self::Target {
207        match self {
208            MaybeTlsStream::Raw(s) => s,
209
210            #[cfg(feature = "tls-rustls")]
211            MaybeTlsStream::Tls(s) => s.get_ref().0,
212
213            #[cfg(feature = "tls-native-tls")]
214            MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
215
216            MaybeTlsStream::Upgrading => {
217                panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
218            }
219        }
220    }
221}
222
223impl<S> DerefMut for MaybeTlsStream<S>
224where
225    S: Unpin + AsyncWrite + AsyncRead,
226{
227    fn deref_mut(&mut self) -> &mut Self::Target {
228        match self {
229            MaybeTlsStream::Raw(s) => s,
230
231            #[cfg(feature = "tls-rustls")]
232            MaybeTlsStream::Tls(s) => s.get_mut().0,
233
234            #[cfg(feature = "tls-native-tls")]
235            MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
236
237            MaybeTlsStream::Upgrading => {
238                panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
239            }
240        }
241    }
242}