lirays 0.1.2

Rust client for LiRAYS SCADA over WebSocket + Protobuf
Documentation
use std::sync::Arc;

use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as RustlsError, SignatureScheme};
use tokio_tungstenite::{
    Connector,
    tungstenite::{
        client::IntoClientRequest,
        http::{Request, header::AUTHORIZATION},
    },
};
use url::Url;

use crate::types::errors::ClientError;

/// Connection settings used to build the websocket request.
///
/// # Example
/// ```rust
/// use lirays::ConnectionOptions;
///
/// let opts = ConnectionOptions::new("127.0.0.1", 8245, false, None);
/// assert_eq!(opts.ws_url().unwrap(), "ws://127.0.0.1:8245/ws");
/// ```
#[derive(Clone, Debug)]
pub struct ConnectionOptions {
    /// Server hostname or IP address.
    pub host: String,
    /// Server TCP port.
    pub port: i64,
    /// Enables `wss://` when `true`, otherwise uses `ws://`.
    pub tls: bool,
    /// Skips certificate validation when `true`.
    ///
    /// This is intended only for local development with self-signed certs.
    pub allow_insecure_tls: bool,
    /// Optional PAT token sent as `Authorization: Bearer <token>`.
    pub pat_token: Option<String>,
}

impl ConnectionOptions {
    /// Creates a new set of connection options.
    pub fn new(host: impl Into<String>, port: i64, tls: bool, pat_token: Option<String>) -> Self {
        Self {
            host: host.into(),
            port,
            tls,
            allow_insecure_tls: false,
            pat_token,
        }
    }

    /// Enables or disables insecure TLS certificate verification bypass.
    pub fn with_insecure_tls(mut self, allow_insecure_tls: bool) -> Self {
        self.allow_insecure_tls = allow_insecure_tls;
        self
    }

    /// Returns the final websocket URL in the form `ws(s)://host:port/ws`.
    pub fn ws_url(&self) -> Result<String, ClientError> {
        build_ws_url(&self.host, self.port, self.tls)
    }
}

/// Builds an HTTP upgrade request for websocket connection.
///
/// If `pat_token` is present, an `Authorization` header is attached.
pub(crate) fn build_ws_request(options: &ConnectionOptions) -> Result<Request<()>, ClientError> {
    if options.tls {
        ensure_rustls_crypto_provider();
    }

    let url = options.ws_url()?;
    let mut request = url
        .into_client_request()
        .map_err(|_| ClientError::InvalidInput("invalid host/port"))?;

    if let Some(token) = options.pat_token.as_deref() {
        let value = format!("Bearer {token}");
        let header_value = value
            .parse()
            .map_err(|_| ClientError::InvalidInput("invalid PAT token format"))?;
        request.headers_mut().insert(AUTHORIZATION, header_value);
    }

    Ok(request)
}

/// Builds an optional custom TLS connector based on connection options.
pub(crate) fn build_tls_connector(options: &ConnectionOptions) -> Option<Connector> {
    if options.tls && options.allow_insecure_tls {
        ensure_rustls_crypto_provider();
        let config = rustls::ClientConfig::builder()
            .dangerous()
            .with_custom_certificate_verifier(Arc::new(InsecureServerCertVerifier))
            .with_no_client_auth();
        return Some(Connector::Rustls(Arc::new(config)));
    }

    None
}

/// Composes and validates websocket URL from host/port/tls parts.
fn build_ws_url(host: &str, port: i64, tls: bool) -> Result<String, ClientError> {
    let scheme = if tls { "wss" } else { "ws" };
    let base = format!("{scheme}://{host}:{port}/ws");
    let url = Url::parse(&base).map_err(|_| ClientError::InvalidInput("invalid host/port"))?;
    Ok(url.into())
}

/// Ensures rustls has a process-level crypto provider installed.
///
/// rustls 0.23 requires a provider before any TLS operation.
fn ensure_rustls_crypto_provider() {
    use rustls::crypto::CryptoProvider;

    if CryptoProvider::get_default().is_none() {
        let _ = rustls::crypto::ring::default_provider().install_default();
    }
}

#[derive(Debug)]
struct InsecureServerCertVerifier;

impl ServerCertVerifier for InsecureServerCertVerifier {
    fn verify_server_cert(
        &self,
        _end_entity: &CertificateDer<'_>,
        _intermediates: &[CertificateDer<'_>],
        _server_name: &ServerName<'_>,
        _ocsp_response: &[u8],
        _now: UnixTime,
    ) -> Result<ServerCertVerified, RustlsError> {
        Ok(ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, RustlsError> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, RustlsError> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
        if let Some(provider) = rustls::crypto::CryptoProvider::get_default() {
            return provider
                .signature_verification_algorithms
                .supported_schemes();
        }

        rustls::crypto::ring::default_provider()
            .signature_verification_algorithms
            .supported_schemes()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn request_includes_authorization_header_when_pat_is_set() {
        let options =
            ConnectionOptions::new("127.0.0.1", 8245, false, Some("pat_test.token".to_string()));
        let request = build_ws_request(&options).expect("request should build");
        let auth = request
            .headers()
            .get(AUTHORIZATION)
            .expect("authorization header missing")
            .to_str()
            .expect("header should be valid utf8");
        assert_eq!(auth, "Bearer pat_test.token");
    }

    #[test]
    fn request_omits_authorization_header_when_pat_is_not_set() {
        let options = ConnectionOptions::new("127.0.0.1", 8245, false, None);
        let request = build_ws_request(&options).expect("request should build");
        assert!(request.headers().get(AUTHORIZATION).is_none());
    }

    #[test]
    fn tls_request_installs_crypto_provider() {
        let options = ConnectionOptions::new("127.0.0.1", 8245, true, None);
        let _ = build_ws_request(&options).expect("tls request should build");
        assert!(rustls::crypto::CryptoProvider::get_default().is_some());
    }

    #[test]
    fn connection_options_default_to_secure_tls_verification() {
        let options = ConnectionOptions::new("127.0.0.1", 8245, true, None);
        assert!(!options.allow_insecure_tls);
    }

    #[test]
    fn build_tls_connector_returns_none_by_default() {
        let options = ConnectionOptions::new("127.0.0.1", 8245, true, None);
        assert!(build_tls_connector(&options).is_none());
    }

    #[test]
    fn build_tls_connector_returns_rustls_when_insecure_tls_is_enabled() {
        let options = ConnectionOptions::new("127.0.0.1", 8245, true, None).with_insecure_tls(true);
        assert!(matches!(
            build_tls_connector(&options),
            Some(Connector::Rustls(_))
        ));
    }
}