use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as RustlsError, SignatureScheme};
use tokio_tungstenite::{
Connector,
tungstenite::{
client::IntoClientRequest,
http::{Request, header::AUTHORIZATION},
},
};
use url::Url;
use crate::types::errors::ClientError;
#[derive(Clone, Debug)]
pub struct ConnectionOptions {
pub host: String,
pub port: i64,
pub tls: bool,
pub allow_insecure_tls: bool,
pub pat_token: Option<String>,
}
impl ConnectionOptions {
pub fn new(host: impl Into<String>, port: i64, tls: bool, pat_token: Option<String>) -> Self {
Self {
host: host.into(),
port,
tls,
allow_insecure_tls: false,
pat_token,
}
}
pub fn with_insecure_tls(mut self, allow_insecure_tls: bool) -> Self {
self.allow_insecure_tls = allow_insecure_tls;
self
}
pub fn ws_url(&self) -> Result<String, ClientError> {
build_ws_url(&self.host, self.port, self.tls)
}
}
pub(crate) fn build_ws_request(options: &ConnectionOptions) -> Result<Request<()>, ClientError> {
if options.tls {
ensure_rustls_crypto_provider();
}
let url = options.ws_url()?;
let mut request = url
.into_client_request()
.map_err(|_| ClientError::InvalidInput("invalid host/port"))?;
if let Some(token) = options.pat_token.as_deref() {
let value = format!("Bearer {token}");
let header_value = value
.parse()
.map_err(|_| ClientError::InvalidInput("invalid PAT token format"))?;
request.headers_mut().insert(AUTHORIZATION, header_value);
}
Ok(request)
}
pub(crate) fn build_tls_connector(options: &ConnectionOptions) -> Option<Connector> {
if options.tls && options.allow_insecure_tls {
ensure_rustls_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureServerCertVerifier))
.with_no_client_auth();
return Some(Connector::Rustls(Arc::new(config)));
}
None
}
fn build_ws_url(host: &str, port: i64, tls: bool) -> Result<String, ClientError> {
let scheme = if tls { "wss" } else { "ws" };
let base = format!("{scheme}://{host}:{port}/ws");
let url = Url::parse(&base).map_err(|_| ClientError::InvalidInput("invalid host/port"))?;
Ok(url.into())
}
fn ensure_rustls_crypto_provider() {
use rustls::crypto::CryptoProvider;
if CryptoProvider::get_default().is_none() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
}
#[derive(Debug)]
struct InsecureServerCertVerifier;
impl ServerCertVerifier for InsecureServerCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, RustlsError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, RustlsError> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, RustlsError> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
if let Some(provider) = rustls::crypto::CryptoProvider::get_default() {
return provider
.signature_verification_algorithms
.supported_schemes();
}
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_includes_authorization_header_when_pat_is_set() {
let options =
ConnectionOptions::new("127.0.0.1", 8245, false, Some("pat_test.token".to_string()));
let request = build_ws_request(&options).expect("request should build");
let auth = request
.headers()
.get(AUTHORIZATION)
.expect("authorization header missing")
.to_str()
.expect("header should be valid utf8");
assert_eq!(auth, "Bearer pat_test.token");
}
#[test]
fn request_omits_authorization_header_when_pat_is_not_set() {
let options = ConnectionOptions::new("127.0.0.1", 8245, false, None);
let request = build_ws_request(&options).expect("request should build");
assert!(request.headers().get(AUTHORIZATION).is_none());
}
#[test]
fn tls_request_installs_crypto_provider() {
let options = ConnectionOptions::new("127.0.0.1", 8245, true, None);
let _ = build_ws_request(&options).expect("tls request should build");
assert!(rustls::crypto::CryptoProvider::get_default().is_some());
}
#[test]
fn connection_options_default_to_secure_tls_verification() {
let options = ConnectionOptions::new("127.0.0.1", 8245, true, None);
assert!(!options.allow_insecure_tls);
}
#[test]
fn build_tls_connector_returns_none_by_default() {
let options = ConnectionOptions::new("127.0.0.1", 8245, true, None);
assert!(build_tls_connector(&options).is_none());
}
#[test]
fn build_tls_connector_returns_rustls_when_insecure_tls_is_enabled() {
let options = ConnectionOptions::new("127.0.0.1", 8245, true, None).with_insecure_tls(true);
assert!(matches!(
build_tls_connector(&options),
Some(Connector::Rustls(_))
));
}
}