1use std::io::{Read, Write};
4
5#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
6use crate::error::{Error, UrlError};
7use crate::{
8 client::{client_with_config, uri_mode, IntoClientRequest},
9 error::Result,
10 handshake::{
11 client::{ClientHandshake, Response},
12 core::HandshakeError,
13 },
14 protocol::{config::WebSocketConfig, websocket::WebSocket},
15 stream::SimplifiedStream,
16};
17
18#[non_exhaustive]
22#[allow(missing_debug_implementations)]
23pub enum Connector {
24 Plain,
26
27 #[cfg(feature = "native-tls")]
29 NativeTls(native_tls_crate::TlsConnector),
30
31 #[cfg(feature = "__rustls-tls")]
33 Rustls(std::sync::Arc<rustls::ClientConfig>),
34}
35
36mod encryption {
37 #[cfg(feature = "native-tls")]
38 pub mod native_tls {
39 use crate::{
40 error::{Error, Result, TlsError},
41 stream::{Mode, SimplifiedStream},
42 };
43 use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
44 use std::io::{Read, Write};
45
46 pub fn wrap_stream<S>(
47 socket: S,
48 domain: &str,
49 mode: Mode,
50 tls_connection: Option<TlsConnector>,
51 ) -> Result<SimplifiedStream<S>>
52 where
53 S: Read + Write,
54 {
55 match mode {
56 Mode::Plain => Ok(SimplifiedStream::Plain(socket)),
57 Mode::Tls => {
58 let try_connector = tls_connection.map_or_else(TlsConnector::new, Ok);
59 let connector = try_connector.map_err(TlsError::Native)?;
60 let connected = connector.connect(domain, socket);
61
62 match connected {
63 Err(e) => match e {
64 TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
65 TlsHandshakeError::WouldBlock(_) => {
66 panic!("Bug: TLS handshake not blocked")
67 }
68 },
69 Ok(s) => Ok(SimplifiedStream::NativeTls(s)),
70 }
71 }
72 }
73 }
74 }
75
76 #[cfg(feature = "__rustls-tls")]
77 pub mod rustls {
78 use crate::{
79 error::{Result, TlsError},
80 stream::{Mode, SimplifiedStream},
81 };
82 use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
83 use rustls_pki_types::ServerName;
84 use std::{
85 io::{Read, Write},
86 sync::Arc,
87 };
88
89 pub fn wrap_stream<S>(
90 socket: S,
91 domain: &str,
92 mode: Mode,
93 tls_connector: Option<Arc<ClientConfig>>,
94 ) -> Result<SimplifiedStream<S>>
95 where
96 S: Read + Write,
97 {
98 match mode {
99 Mode::Plain => Ok(SimplifiedStream::Plain(socket)),
100 Mode::Tls => {
101 let config = match tls_connector {
102 Some(config) => config,
103 None => {
104 #[allow(unused_mut)]
105 let mut root_store = RootCertStore::empty();
106
107 #[cfg(feature = "rustls-tls-native-roots")]
108 {
109 let rustls_native_certs::CertificateResult {
110 certs, errors, ..
111 } = rustls_native_certs::load_native_certs();
112
113 if certs.is_empty() {
115 return Err(std::io::Error::new(
116 std::io::ErrorKind::NotFound,
117 format!("No native root CA certificates found (errors: {errors:?})")
118 ).into());
119 }
120
121 }
124
125 #[cfg(feature = "rustls-tls-webpki-roots")]
126 {
127 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
128 }
129
130 Arc::new(
131 ClientConfig::builder()
132 .with_root_certificates(root_store)
133 .with_no_client_auth(),
134 )
135 }
136 };
137
138 let domain = ServerName::try_from(domain)
139 .map_err(|_| TlsError::InvalidDnsName)?
140 .to_owned();
141
142 let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
143 let stream = StreamOwned::new(client, socket);
144
145 Ok(SimplifiedStream::Rustls(stream))
146 }
147 }
148 }
149 }
150
151 pub mod plain {
152 use crate::{
153 error::{Error, Result, UrlError},
154 stream::{Mode, SimplifiedStream},
155 };
156 use std::io::{Read, Write};
157
158 pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<SimplifiedStream<S>>
159 where
160 S: Read + Write,
161 {
162 match mode {
163 Mode::Plain => Ok(SimplifiedStream::Plain(socket)),
164 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
165 }
166 }
167 }
168}
169
170type TlsErrorHandshake<S> = HandshakeError<ClientHandshake<SimplifiedStream<S>>>;
171
172pub fn client_tls<R, S>(
175 request: R,
176 stream: S,
177) -> Result<(WebSocket<SimplifiedStream<S>>, Response), TlsErrorHandshake<S>>
178where
179 R: IntoClientRequest,
180 S: Read + Write,
181{
182 client_tls_with_config(request, stream, None, None)
183}
184
185pub fn client_tls_with_config<R, S>(
191 request: R,
192 stream: S,
193 config: Option<WebSocketConfig>,
194 connector: Option<Connector>,
195) -> Result<(WebSocket<SimplifiedStream<S>>, Response), TlsErrorHandshake<S>>
196where
197 R: IntoClientRequest,
198 S: Read + Write,
199{
200 let request = request.into_client_request()?;
201
202 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
203 let domain = match request.uri().host() {
204 Some(d) => Ok(d.to_string()),
205 None => Err(Error::Url(UrlError::MissingHost)),
206 }?;
207
208 let mode = uri_mode(request.uri())?;
209
210 let stream = match connector {
211 Some(conn) => match conn {
212 #[cfg(feature = "native-tls")]
213 Connector::NativeTls(conn) => {
214 self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
215 }
216
217 #[cfg(feature = "__rustls-tls")]
218 Connector::Rustls(conn) => {
219 self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
220 }
221
222 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
223 },
224 None => {
225 #[cfg(feature = "native-tls")]
226 {
227 self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
228 }
229 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
230 {
231 self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
232 }
233 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
234 {
235 self::encryption::plain::wrap_stream(stream, mode)
236 }
237 }
238 }?;
239
240 client_with_config(request, stream, config)
241}