use async_trait::async_trait;
use futures_util::{Sink, Stream};
use std::fmt::Debug;
use tokio::net::TcpStream;
use tokio_tungstenite::{
client_async_tls_with_config, Connector as TlsConnector, MaybeTlsStream, WebSocketStream,
};
use tungstenite::client::IntoClientRequest;
use tungstenite::handshake::client::Response;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Message;
use crate::error::ConnectError;
use crate::transport::Transport;
pub type DefaultWsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub trait WsStream:
Stream<Item = Result<Message, tungstenite::Error>>
+ Sink<Message, Error = tungstenite::Error>
+ Unpin
+ Send
+ 'static
{
}
impl<T> WsStream for T where
T: Stream<Item = Result<Message, tungstenite::Error>>
+ Sink<Message, Error = tungstenite::Error>
+ Unpin
+ Send
+ 'static
{
}
#[async_trait]
pub trait Connector: Send + Sync + 'static {
type Stream: WsStream;
async fn connect(&self, uri: &str) -> Result<(Self::Stream, Response), ConnectError>;
fn name(&self) -> &'static str {
"connector"
}
}
#[derive(Clone)]
pub struct DefaultConnector {
ws_config: Option<WebSocketConfig>,
disable_nagle: bool,
tls_connector: Option<TlsConnector>,
}
impl Debug for DefaultConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DefaultConnector")
.field("ws_config", &self.ws_config)
.field("disable_nagle", &self.disable_nagle)
.field("tls_connector", &self.tls_connector.is_some())
.finish()
}
}
impl DefaultConnector {
#[must_use]
pub const fn new() -> Self {
Self {
ws_config: None,
disable_nagle: false,
tls_connector: None,
}
}
#[must_use]
pub const fn with_ws_config(mut self, config: WebSocketConfig) -> Self {
self.ws_config = Some(config);
self
}
#[must_use]
pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
self.disable_nagle = nodelay;
self
}
#[must_use]
pub fn with_tls_connector(mut self, connector: TlsConnector) -> Self {
self.tls_connector = Some(connector);
self
}
#[must_use]
pub fn low_latency() -> Self {
Self {
ws_config: Some(
WebSocketConfig::default()
.max_message_size(Some(64 << 20)) .max_frame_size(Some(16 << 20)), ),
disable_nagle: true,
tls_connector: None,
}
}
#[cfg(all(feature = "native-tls", not(feature = "__rustls-tls")))]
pub fn danger_accept_invalid_certs(mut self) -> Result<Self, ConnectError> {
use native_tls::TlsConnector as NativeTlsConnector;
let tls = NativeTlsConnector::builder()
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.build()
.map_err(|e| {
ConnectError::Tls(format!("Failed to build insecure TLS connector: {e}"))
})?;
self.tls_connector = Some(TlsConnector::NativeTls(tls));
Ok(self)
}
#[cfg(feature = "__rustls-tls")]
pub fn danger_accept_invalid_certs(mut self) -> Result<Self, ConnectError> {
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme};
use std::sync::Arc;
#[derive(Debug)]
struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
]
}
}
let mut config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
config.alpn_protocols = vec![b"http/1.1".to_vec()];
self.tls_connector = Some(TlsConnector::Rustls(Arc::new(config)));
Ok(self)
}
#[cfg(feature = "native-tls")]
pub fn with_custom_ca_cert(
mut self,
cert: native_tls::Certificate,
) -> Result<Self, ConnectError> {
use native_tls::TlsConnector as NativeTlsConnector;
let tls = NativeTlsConnector::builder()
.add_root_certificate(cert)
.build()
.map_err(|e| {
ConnectError::Tls(format!("Failed to build TLS connector with custom CA: {e}"))
})?;
self.tls_connector = Some(TlsConnector::NativeTls(tls));
Ok(self)
}
#[cfg(feature = "native-tls")]
pub fn with_client_identity(
mut self,
identity: native_tls::Identity,
) -> Result<Self, ConnectError> {
use native_tls::TlsConnector as NativeTlsConnector;
let tls = NativeTlsConnector::builder()
.identity(identity)
.build()
.map_err(|e| {
ConnectError::Tls(format!(
"Failed to build TLS connector with client identity: {e}"
))
})?;
self.tls_connector = Some(TlsConnector::NativeTls(tls));
Ok(self)
}
}
impl Default for DefaultConnector {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Connector for DefaultConnector {
type Stream = DefaultWsStream;
async fn connect(&self, uri: &str) -> Result<(Self::Stream, Response), ConnectError> {
let request = uri
.into_client_request()
.map_err(|e| ConnectError::InvalidUri(format!("Failed to parse URI '{uri}': {e}")))?;
let host = request
.uri()
.host()
.ok_or_else(|| ConnectError::InvalidUri("No host in URI".into()))?;
let port = request
.uri()
.port_u16()
.unwrap_or_else(|| match request.uri().scheme_str() {
Some("wss") => 443,
_ => 80,
});
let addr = format!("{host}:{port}");
let socket = TcpStream::connect(&addr).await.map_err(|e| {
tracing::debug!(addr = %addr, error = ?e, "TCP connection failed");
ConnectError::TcpConnect(e.to_string())
})?;
if self.disable_nagle {
socket
.set_nodelay(true)
.map_err(|e| ConnectError::Io(format!("Failed to set TCP_NODELAY: {e}")))?;
}
let (ws_stream, response) = client_async_tls_with_config(
request,
socket,
self.ws_config,
self.tls_connector.clone(),
)
.await
.map_err(|e| {
tracing::debug!(uri = %uri, error = ?e, "WebSocket connection failed");
ConnectError::WebSocketUpgrade(e.to_string())
})?;
tracing::debug!(uri = %uri, "WebSocket connection established");
Ok((ws_stream, response))
}
fn name(&self) -> &'static str {
"default"
}
}
#[derive(Debug, Clone)]
pub struct TransportConnector<T: Transport> {
#[allow(dead_code)]
transport: T,
ws_config: Option<WebSocketConfig>,
}
impl<T: Transport> TransportConnector<T> {
#[must_use]
pub const fn new(transport: T) -> Self {
Self {
transport,
ws_config: None,
}
}
#[must_use]
pub const fn with_ws_config(mut self, config: WebSocketConfig) -> Self {
self.ws_config = Some(config);
self
}
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub uri: String,
pub response_status: u16,
pub response_headers: Vec<(String, String)>,
}
impl ConnectionInfo {
#[must_use]
pub fn from_response(uri: &str, response: &Response) -> Self {
let headers = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
Self {
uri: uri.to_string(),
response_status: response.status().as_u16(),
response_headers: headers,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_connector_creation() {
let connector = DefaultConnector::new();
assert_eq!(connector.name(), "default");
}
#[test]
fn test_low_latency_connector() {
let connector = DefaultConnector::low_latency();
assert!(connector.disable_nagle);
assert!(connector.ws_config.is_some());
}
}