1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use crate::rt::{AsyncRead, AsyncWrite, TlsStream};
10
11use crate::Error;
12use std::mem::replace;
13use rbs::err_protocol;
14
15#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
17pub enum CertificateInput {
18 Inline(Vec<u8>),
20 File(PathBuf),
22}
23
24impl From<String> for CertificateInput {
25 fn from(value: String) -> Self {
26 let trimmed = value.trim();
27 if trimmed.starts_with("-----BEGIN CERTIFICATE-----")
29 && trimmed.contains("-----END CERTIFICATE-----")
30 {
31 CertificateInput::Inline(value.as_bytes().to_vec())
32 } else {
33 CertificateInput::File(PathBuf::from(value))
34 }
35 }
36}
37
38impl CertificateInput {
39 async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
40 use crate::rt::fs;
41 match self {
42 CertificateInput::Inline(v) => Ok(v.clone()),
43 CertificateInput::File(path) => fs::read(path).await,
44 }
45 }
46}
47
48impl std::fmt::Display for CertificateInput {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
52 CertificateInput::File(path) => write!(f, "file: {}", path.display()),
53 }
54 }
55}
56
57#[cfg(feature = "tls-rustls")]
58mod rustls;
59
60pub enum MaybeTlsStream<S>
61where
62 S: AsyncRead + AsyncWrite + Unpin,
63{
64 Raw(S),
65 Tls(TlsStream<S>),
66 Upgrading,
67}
68
69impl<S> MaybeTlsStream<S>
70where
71 S: AsyncRead + AsyncWrite + Unpin,
72{
73 #[inline]
74 pub fn is_tls(&self) -> bool {
75 matches!(self, Self::Tls(_))
76 }
77
78 pub async fn upgrade(
79 &mut self,
80 host: &str,
81 accept_invalid_certs: bool,
82 accept_invalid_hostnames: bool,
83 root_cert_path: Option<&CertificateInput>,
84 ) -> Result<(), Error> {
85 let connector = configure_tls_connector(
86 accept_invalid_certs,
87 accept_invalid_hostnames,
88 root_cert_path,
89 )
90 .await?;
91
92 let stream = match replace(self, MaybeTlsStream::Upgrading) {
93 MaybeTlsStream::Raw(stream) => stream,
94
95 MaybeTlsStream::Tls(_) => {
96 return Ok(());
98 }
99
100 MaybeTlsStream::Upgrading => {
101 return Err(Error::from(io::ErrorKind::ConnectionAborted.to_string()));
104 }
105 };
106
107 #[cfg(feature = "tls-rustls")]
108 let host = tokio_rustls::rustls::pki_types::ServerName::try_from(host.to_string())
109 .map_err(|err| Error::from(err.to_string()))?;
110
111 *self = MaybeTlsStream::Tls(connector.connect(host, stream).await.map_err(|err| err_protocol!("{}", err))?);
112
113 Ok(())
114 }
115}
116
117#[cfg(feature = "tls-native-tls")]
118async fn configure_tls_connector(
119 accept_invalid_certs: bool,
120 accept_invalid_hostnames: bool,
121 root_cert_path: Option<&CertificateInput>,
122) -> Result<crate::rt::TlsConnector, Error> {
123 use crate::rt::native_tls::{Certificate, TlsConnector};
124
125 let mut builder = TlsConnector::builder();
126 builder
127 .danger_accept_invalid_certs(accept_invalid_certs)
128 .danger_accept_invalid_hostnames(accept_invalid_hostnames);
129
130 if !accept_invalid_certs {
131 if let Some(ca) = root_cert_path {
132 let data = ca.data().await.map_err(|err| err_protocol!("{}", err))?;
133 let cert = Certificate::from_pem(&data).map_err(|err| err_protocol!("{}", err))?;
134
135 builder.add_root_certificate(cert);
136 }
137 }
138 let connector = builder.build().map_err(|err| err_protocol!("{}", err))?.into();
139
140 Ok(connector)
141}
142
143#[cfg(feature = "tls-rustls")]
144use self::rustls::configure_tls_connector;
145
146impl<S> AsyncRead for MaybeTlsStream<S>
147where
148 S: Unpin + AsyncWrite + AsyncRead,
149{
150 fn poll_read(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 buf: &mut super::PollReadBuf<'_>,
154 ) -> Poll<io::Result<super::PollReadOut>> {
155 match &mut *self {
156 MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
157 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
158
159 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
160 }
161 }
162}
163
164impl<S> AsyncWrite for MaybeTlsStream<S>
165where
166 S: Unpin + AsyncWrite + AsyncRead,
167{
168 fn poll_write(
169 mut self: Pin<&mut Self>,
170 cx: &mut Context<'_>,
171 buf: &[u8],
172 ) -> Poll<io::Result<usize>> {
173 match &mut *self {
174 MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
175 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
176
177 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
178 }
179 }
180
181 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182 match &mut *self {
183 MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
184 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
185
186 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
187 }
188 }
189
190 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
191 match &mut *self {
192 MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
193 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
194
195 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
196 }
197 }
198}
199
200impl<S> Deref for MaybeTlsStream<S>
201where
202 S: Unpin + AsyncWrite + AsyncRead,
203{
204 type Target = S;
205
206 fn deref(&self) -> &Self::Target {
207 match self {
208 MaybeTlsStream::Raw(s) => s,
209
210 #[cfg(feature = "tls-rustls")]
211 MaybeTlsStream::Tls(s) => s.get_ref().0,
212
213 #[cfg(feature = "tls-native-tls")]
214 MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
215
216 MaybeTlsStream::Upgrading => {
217 panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
218 }
219 }
220 }
221}
222
223impl<S> DerefMut for MaybeTlsStream<S>
224where
225 S: Unpin + AsyncWrite + AsyncRead,
226{
227 fn deref_mut(&mut self) -> &mut Self::Target {
228 match self {
229 MaybeTlsStream::Raw(s) => s,
230
231 #[cfg(feature = "tls-rustls")]
232 MaybeTlsStream::Tls(s) => s.get_mut().0,
233
234 #[cfg(feature = "tls-native-tls")]
235 MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
236
237 MaybeTlsStream::Upgrading => {
238 panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
239 }
240 }
241 }
242}