1use http;
2use http::Uri;
3use crate::{errors::WsError, protocol::Mode};
4
5pub fn get_scheme(uri: &http::Uri) -> Result<Mode, WsError> {
7 match uri.scheme_str().unwrap_or("ws").to_lowercase().as_str() {
8 "ws" => Ok(Mode::WS),
9 "wss" => Ok(Mode::WSS),
10 s => Err(WsError::InvalidUri(format!("unknown scheme {s}"))),
11 }
12}
13
14pub fn get_host(uri: &Uri) -> Result<&str, WsError> {
16 uri.host()
17 .ok_or_else(|| WsError::InvalidUri(format!("can not find host {}", uri)))
18}
19
20#[cfg(feature = "sync")]
21mod blocking {
22 use crate::errors::WsError;
23 use http;
24 use std::net::TcpStream;
25
26 use super::{get_host, get_scheme};
27
28 pub fn tcp_connect(uri: &http::Uri) -> Result<TcpStream, WsError> {
30 let mode = get_scheme(uri)?;
31 let host = get_host(uri)?;
32 let port = uri.port_u16().unwrap_or_else(|| mode.default_port());
33 let stream = TcpStream::connect((host, port)).map_err(|e| {
34 WsError::ConnectionFailed(format!("failed to create tcp connection {e}"))
35 })?;
36 Ok(stream)
37 }
38
39 #[cfg(feature = "sync_tls_rustls")]
49 pub fn wrap_rustls<
51 S: std::io::Read + std::io::Write + Sync + Send + std::fmt::Debug + 'static,
52 >(
53 stream: S,
54 host: &str,
55 certs: Vec<std::path::PathBuf>,
56 ) -> Result<rustls_connector::TlsStream<S>, WsError> {
57 use std::io::BufReader;
58
59 let mut config = rustls_connector::RustlsConnectorConfig::new_with_webpki_roots_certs();
60 let mut cert_data = vec![];
61 for cert_path in certs.iter() {
62 let mut pem = std::fs::File::open(cert_path).map_err(|_| {
63 WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
64 })?;
65 let mut cert = BufReader::new(&mut pem);
66 let certs = rustls_pemfile::certs(&mut cert)
67 .map_err(|e| WsError::LoadCertFailed(e.to_string()))?;
68 cert_data.extend_from_slice(&certs);
69 }
70 config.add_parsable_certificates(&cert_data);
71 let connector = config.connector_with_no_client_auth();
72 let tls_stream = connector
73 .connect(host, stream)
74 .map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
75 tracing::debug!("tls connection established");
76 Ok(tls_stream)
77 }
78
79 #[cfg(feature = "sync_tls_native")]
89 pub fn wrap_native_tls<S: std::io::Read + std::io::Write>(
91 stream: S,
92 host: &str,
93 certs: Vec<std::path::PathBuf>,
94 ) -> Result<native_tls::TlsStream<S>, WsError> {
95 let mut builder = native_tls::TlsConnector::builder();
96 for cert_path in certs.iter() {
97 let mut pem = std::fs::File::open(cert_path).map_err(|_| {
98 WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
99 })?;
100 let mut data = vec![];
101 if let Err(e) = std::io::Read::read_to_end(&mut pem, &mut data) {
102 tracing::error!(
103 "failed to read cert file {} {}",
104 cert_path.display(),
105 e.to_string()
106 );
107 continue;
108 }
109 match native_tls::Certificate::from_der(&data) {
110 Ok(cert) => {
111 builder.add_root_certificate(cert);
112 }
113 Err(e) => {
114 tracing::error!(
115 "invalid cert file {} {}",
116 cert_path.display(),
117 e.to_string()
118 );
119 continue;
120 }
121 }
122 }
123 let connector = builder.build().unwrap();
124 let tls_stream = connector
125 .connect(host, stream)
126 .map_err(|_| WsError::ConnectionFailed("tls connect failed".into()))?;
127 tracing::debug!("tls connection established");
128 Ok(tls_stream)
129 }
130}
131
132#[cfg(feature = "sync")]
133pub use blocking::*;
134
135#[cfg(feature = "async")]
136mod non_blocking {
137 use http::Uri;
138 use tokio::net::TcpStream;
139
140 use crate::errors::WsError;
141
142 use super::{get_host, get_scheme};
143
144 pub async fn async_tcp_connect(uri: &Uri) -> Result<TcpStream, WsError> {
146 let mode = get_scheme(uri)?;
147 let host = get_host(uri)?;
148 let port = uri.port_u16().unwrap_or_else(|| mode.default_port());
149
150 TcpStream::connect((host, port))
151 .await
152 .map_err(|e| WsError::ConnectionFailed(format!("failed to create tcp connection {e}")))
153 }
154
155 #[cfg(feature = "async_tls_rustls")]
156 impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> crate::codec::Split
157 for tokio_rustls::client::TlsStream<S>
158 {
159 type R = tokio::io::ReadHalf<tokio_rustls::client::TlsStream<S>>;
160 type W = tokio::io::WriteHalf<tokio_rustls::client::TlsStream<S>>;
161 fn split(self) -> (Self::R, Self::W) {
162 tokio::io::split(self)
163 }
164 }
165
166 #[cfg(feature = "async_tls_rustls")]
167 impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> crate::codec::Split
168 for tokio_rustls::server::TlsStream<S>
169 {
170 type R = tokio::io::ReadHalf<tokio_rustls::server::TlsStream<S>>;
171 type W = tokio::io::WriteHalf<tokio_rustls::server::TlsStream<S>>;
172 fn split(self) -> (Self::R, Self::W) {
173 tokio::io::split(self)
174 }
175 }
176
177 #[cfg(feature = "async_tls_rustls")]
178 pub async fn async_wrap_rustls<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>(
180 stream: S,
181 host: &str,
182 certs: Vec<std::path::PathBuf>,
183 ) -> Result<tokio_rustls::client::TlsStream<S>, WsError> {
184 use std::io::BufReader;
185
186 let mut root_store = rustls_connector::rustls::RootCertStore::empty();
187 root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
188 rustls_connector::rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
189 ta.subject,
190 ta.spki,
191 ta.name_constraints,
192 )
193 }));
194 let mut trust_anchors = vec![];
195 for cert_path in certs.iter() {
196 let mut pem = std::fs::File::open(cert_path).map_err(|_| {
197 WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
198 })?;
199 let mut cert = BufReader::new(&mut pem);
200 let certs = rustls_pemfile::certs(&mut cert)
201 .map_err(|e| WsError::LoadCertFailed(e.to_string()))?;
202 for item in certs {
203 let ta = webpki::TrustAnchor::try_from_cert_der(&item[..])
204 .map_err(|e| WsError::LoadCertFailed(e.to_string()))?;
205 let anchor =
206 rustls_connector::rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
207 ta.subject,
208 ta.spki,
209 ta.name_constraints,
210 );
211 trust_anchors.push(anchor);
212 }
213 }
214 root_store.add_server_trust_anchors(trust_anchors.into_iter());
215 let config = rustls_connector::rustls::ClientConfig::builder()
216 .with_safe_defaults()
217 .with_root_certificates(root_store)
218 .with_no_client_auth();
219 let domain = tokio_rustls::rustls::ServerName::try_from(host)
220 .map_err(|e| WsError::TlsDnsFailed(e.to_string()))?;
221 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
222 let tls_stream = connector
223 .connect(domain, stream)
224 .await
225 .map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
226 tracing::debug!("tls connection established");
227 Ok(tls_stream)
228 }
229
230 #[cfg(feature = "async_tls_native")]
231 impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> crate::codec::Split
232 for tokio_native_tls::TlsStream<S>
233 {
234 type R = tokio::io::ReadHalf<tokio_native_tls::TlsStream<S>>;
235 type W = tokio::io::WriteHalf<tokio_native_tls::TlsStream<S>>;
236 fn split(self) -> (Self::R, Self::W) {
237 tokio::io::split(self)
238 }
239 }
240
241 #[cfg(feature = "async_tls_native")]
242 pub async fn async_wrap_native_tls<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin>(
244 stream: S,
245 host: &str,
246 certs: Vec<std::path::PathBuf>,
247 ) -> Result<tokio_native_tls::TlsStream<S>, WsError> {
248 let mut builder = native_tls::TlsConnector::builder();
249 for cert_path in certs.iter() {
250 let mut pem = std::fs::File::open(cert_path).map_err(|_| {
251 WsError::CertFileNotFound(cert_path.to_str().unwrap_or_default().to_string())
252 })?;
253 let mut data = vec![];
254 if let Err(e) = std::io::Read::read_to_end(&mut pem, &mut data) {
255 tracing::error!(
256 "failed to read cert file {} {}",
257 cert_path.display(),
258 e.to_string()
259 );
260 continue;
261 }
262 match native_tls::Certificate::from_der(&data) {
263 Ok(cert) => {
264 builder.add_root_certificate(cert);
265 }
266 Err(e) => {
267 tracing::error!(
268 "invalid cert file {} {}",
269 cert_path.display(),
270 e.to_string()
271 );
272 continue;
273 }
274 }
275 }
276 let connector = builder.build().unwrap();
277 let connector = tokio_native_tls::TlsConnector::from(connector);
278 let tls_stream = connector
279 .connect(host, stream)
280 .await
281 .map_err(|e| WsError::ConnectionFailed(e.to_string()))?;
282 tracing::debug!("tls connection established");
283 Ok(tls_stream)
284 }
285}
286
287#[cfg(feature = "async")]
288pub use non_blocking::*;