#[cfg(not(any(
feature = "native",
feature = "rustls-native-roots",
feature = "rustls-webpki-roots"
)))]
mod r#impl {
pub type TlsConnector = ();
use super::{TlsContainer, TlsError};
use crate::{
connection::Connection,
error::{ReceiveMessageError, ReceiveMessageErrorType},
};
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector};
pub fn new() -> Result<TlsContainer, TlsError> {
Ok(TlsContainer { tls: None })
}
pub async fn connect(
url: &str,
config: WebSocketConfig,
_tls: &TlsContainer,
) -> Result<Connection, ReceiveMessageError> {
let (stream, _) = tokio_tungstenite::connect_async_with_config(url, Some(config))
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;
Ok(stream)
}
pub fn connector(_: &TlsContainer) -> Option<Connector> {
None
}
}
#[cfg(all(
feature = "native",
not(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots"))
))]
mod r#impl {
pub use native_tls::TlsConnector;
use super::{TlsContainer, TlsError, TlsErrorType};
use crate::{
connection::Connection,
error::{ReceiveMessageError, ReceiveMessageErrorType},
};
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector};
pub fn new() -> Result<TlsContainer, TlsError> {
let native_connector = TlsConnector::new().map_err(|err| TlsError {
kind: TlsErrorType::Loading,
source: Some(Box::new(err)),
})?;
Ok(TlsContainer {
tls: Some(native_connector),
})
}
pub async fn connect(
url: &str,
config: WebSocketConfig,
tls: &TlsContainer,
) -> Result<Connection, ReceiveMessageError> {
let (stream, _) =
tokio_tungstenite::connect_async_tls_with_config(url, Some(config), tls.connector())
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;
Ok(stream)
}
pub fn connector(container: &TlsContainer) -> Option<Connector> {
container
.tls
.as_ref()
.map(|tls| Connector::NativeTls(tls.clone()))
}
}
#[cfg(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots"))]
mod r#impl {
use super::{TlsContainer, TlsError};
use crate::{
connection::Connection,
error::{ReceiveMessageError, ReceiveMessageErrorType},
};
use rustls_tls::ClientConfig;
use std::sync::Arc;
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector};
pub type TlsConnector = Arc<ClientConfig>;
#[cfg(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots"))]
pub fn new() -> Result<TlsContainer, TlsError> {
let mut roots = rustls_tls::RootCertStore::empty();
#[cfg(feature = "rustls-native-roots")]
{
let certs = rustls_native_certs::load_native_certs().map_err(|err| TlsError {
kind: super::TlsErrorType::Loading,
source: Some(Box::new(err)),
})?;
for cert in certs {
roots
.add(&rustls_tls::Certificate(cert.0))
.map_err(|err| TlsError {
kind: super::TlsErrorType::Loading,
source: Some(Box::new(err)),
})?;
}
}
#[cfg(feature = "rustls-webpki-roots")]
{
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls_tls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
};
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(TlsContainer {
tls: Some(Arc::new(config)),
})
}
pub async fn connect(
url: &str,
config: WebSocketConfig,
tls: &TlsContainer,
) -> Result<Connection, ReceiveMessageError> {
let (stream, _) =
tokio_tungstenite::connect_async_tls_with_config(url, Some(config), tls.connector())
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;
Ok(stream)
}
pub fn connector(container: &TlsContainer) -> Option<Connector> {
container
.tls
.as_ref()
.map(|tls| Connector::Rustls(Arc::clone(tls)))
}
}
use r#impl::TlsConnector;
use std::{
error::Error,
fmt::{Debug, Display, Formatter, Result as FmtResult},
};
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, Connector};
use crate::{connection::Connection, error::ReceiveMessageError};
#[derive(Debug)]
pub struct TlsError {
kind: TlsErrorType,
source: Option<Box<dyn Error + Send + Sync>>,
}
#[allow(dead_code)]
impl TlsError {
#[must_use = "retrieving the type has no effect if left unused"]
pub const fn kind(&self) -> &TlsErrorType {
&self.kind
}
#[must_use = "consuming the error and retrieving the source has no effect if left unused"]
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
self.source
}
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (TlsErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, self.source)
}
}
impl Display for TlsError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self.kind {
TlsErrorType::Loading => {
f.write_str("failed to load the tls connector or its certificates")
}
}
}
}
impl Error for TlsError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum TlsErrorType {
#[allow(unused)]
Loading,
}
#[derive(Clone)]
pub struct TlsContainer {
#[allow(unused)]
tls: Option<TlsConnector>,
}
impl Debug for TlsContainer {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
let mut debugger = f.debug_struct("TlsContainer");
#[cfg(all(
feature = "native",
not(any(feature = "rustls-native-roots", feature = "rustls-webpki-roots")),
))]
debugger.field("tls", &self.tls);
debugger.finish()
}
}
impl TlsContainer {
pub fn new() -> Result<Self, TlsError> {
r#impl::new()
}
pub async fn connect(
&self,
url: &str,
config: WebSocketConfig,
) -> Result<Connection, ReceiveMessageError> {
r#impl::connect(url, config, self).await
}
#[allow(unused)]
pub(crate) fn connector(&self) -> Option<Connector> {
r#impl::connector(self)
}
}
#[cfg(test)]
mod tests {
use super::TlsContainer;
use static_assertions::assert_impl_all;
use std::fmt::Debug;
assert_impl_all!(TlsContainer: Debug, Clone, Send, Sync);
}