use compio_io::{AsyncRead, AsyncWrite};
use compio_net::{SocketOpts, TcpStream};
use compio_tls::{MaybeTlsStream, TlsConnector};
use tungstenite::{
Error,
client::{IntoClientRequest, uri_mode},
handshake::client::{Request, Response},
stream::Mode,
};
use crate::{Config, WebSocketStream, client_async_with_config};
mod encryption {
#[cfg(feature = "native-tls")]
pub mod native_tls {
use compio_tls::{TlsConnector, native_tls};
use tungstenite::{Error, error::TlsError};
pub fn new_connector() -> Result<TlsConnector, Error> {
let native_connector = native_tls::TlsConnector::new().map_err(TlsError::from)?;
Ok(TlsConnector::from(native_connector))
}
}
#[cfg(feature = "rustls")]
pub mod rustls {
use std::sync::Arc;
use compio_tls::{
TlsConnector,
rustls::{ClientConfig, RootCertStore},
};
use tungstenite::Error;
fn config_with_certs() -> Result<Arc<ClientConfig>, Error> {
#[allow(unused_mut)]
let mut root_store = RootCertStore::empty();
#[cfg(feature = "rustls-native-certs")]
{
let rustls_native_certs::CertificateResult { certs, errors, .. } =
rustls_native_certs::load_native_certs();
if !errors.is_empty() {
compio_log::warn!("native root CA certificate loading errors: {errors:?}");
}
#[cfg(not(feature = "webpki-roots"))]
if certs.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("no native root CA certificates found (errors: {errors:?})"),
)
.into());
}
let total_number = certs.len();
let (number_added, number_ignored) = root_store.add_parsable_certificates(certs);
compio_log::debug!(
"Added {number_added}/{total_number} native root certificates (ignored \
{number_ignored})"
);
}
#[cfg(feature = "webpki-roots")]
{
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
Ok(Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
))
}
#[cfg(feature = "rustls-platform-verifier")]
fn config_with_platform_verifier() -> Result<Arc<ClientConfig>, Error> {
use rustls_platform_verifier::BuilderVerifierExt;
let config_result = ClientConfig::builder()
.with_platform_verifier()
.map_err(tungstenite::error::TlsError::from)?;
Ok(Arc::new(config_result.with_no_client_auth()))
}
pub fn new_connector() -> Result<TlsConnector, Error> {
#[cfg(feature = "rustls-platform-verifier")]
{
let config = match config_with_platform_verifier() {
Ok(config_builder) => config_builder,
Err(_e) => {
compio_log::warn!("Error creating platform verifier: {_e}");
config_with_certs()?
}
};
Ok(TlsConnector::from(config))
}
#[cfg(not(feature = "rustls-platform-verifier"))]
{
let config = config_with_certs()?;
Ok(TlsConnector::from(config))
}
}
}
}
async fn wrap_stream<S>(
socket: S,
domain: &str,
connector: Option<TlsConnector>,
mode: Mode,
) -> Result<MaybeTlsStream<S>, Error>
where
S: AsyncRead + AsyncWrite + 'static,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::new_plain(socket)),
Mode::Tls => {
let stream = {
let connector = if let Some(connector) = connector {
connector
} else {
#[cfg(feature = "native-tls")]
{
match encryption::native_tls::new_connector() {
Ok(c) => c,
Err(_e) => {
compio_log::warn!(
"Falling back to rustls TLS connector due to native-tls \
error: {}",
_e
);
#[cfg(feature = "rustls")]
{
encryption::rustls::new_connector()?
}
#[cfg(not(feature = "rustls"))]
{
return Err(_e);
}
}
}
}
#[cfg(all(feature = "rustls", not(feature = "native-tls")))]
{
encryption::rustls::new_connector()?
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
{
return Err(Error::Url(
tungstenite::error::UrlError::TlsFeatureNotEnabled,
));
}
};
connector.connect(domain, socket).await.map_err(Error::Io)?
};
Ok(MaybeTlsStream::new_tls(stream))
}
}
}
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
where
R: IntoClientRequest,
S: AsyncRead + AsyncWrite + Unpin + 'static,
{
client_async_tls_with_config(request, stream, None, None).await
}
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
connector: Option<TlsConnector>,
config: impl Into<Config>,
) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
where
R: IntoClientRequest,
S: AsyncRead + AsyncWrite + Unpin + 'static,
{
let request: Request = request.into_client_request()?;
let domain = domain(&request)?;
let mode = uri_mode(request.uri())?;
let stream = wrap_stream(stream, domain, connector, mode).await?;
client_async_with_config(request, stream, config).await
}
type ConnectStream = MaybeTlsStream<TcpStream>;
pub async fn connect_async<R>(
request: R,
) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
where
R: IntoClientRequest,
{
connect_async_with_config(request, None).await
}
pub async fn connect_async_with_config<R>(
request: R,
config: impl Into<Config>,
) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
where
R: IntoClientRequest,
{
connect_async_tls_with_config(request, config, None).await
}
pub async fn connect_async_tls_with_config<R>(
request: R,
config: impl Into<Config>,
connector: Option<TlsConnector>,
) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
where
R: IntoClientRequest,
{
let config = config.into();
let request: Request = request.into_client_request()?;
let domain = request
.uri()
.host()
.ok_or(Error::Url(tungstenite::error::UrlError::NoHostName))?;
let port = port(&request)?;
let opts = SocketOpts::new().nodelay(config.disable_nagle);
let socket = TcpStream::connect_with_options((domain, port), &opts)
.await
.map_err(Error::Io)?;
client_async_tls_with_config(request, socket, connector, config).await
}
#[inline]
fn port(request: &Request) -> Result<u16, Error> {
request
.uri()
.port_u16()
.or_else(|| match uri_mode(request.uri()).ok()? {
Mode::Plain => Some(80),
Mode::Tls => Some(443),
})
.ok_or(Error::Url(
tungstenite::error::UrlError::UnsupportedUrlScheme,
))
}
#[inline]
fn domain(request: &Request) -> Result<&str, Error> {
request
.uri()
.host()
.map(|host| {
if host.starts_with('[') && host.ends_with(']') {
&host[1..host.len() - 1]
} else {
host
}
})
.ok_or(tungstenite::Error::Url(
tungstenite::error::UrlError::NoHostName,
))
}