xidl-jsonrpc 0.28.0

A IDL codegen.
Documentation
use std::collections::HashMap;
#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
use std::fs::File;
#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
use std::io::BufReader;
#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
use std::sync::Arc;

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
use rustls::{ClientConfig, RootCertStore, ServerConfig};
#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
use tokio_rustls::{TlsAcceptor, TlsConnector};
use url::Url;

pub(crate) fn io_other<E: std::fmt::Display>(err: E) -> std::io::Error {
    std::io::Error::other(err.to_string())
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
pub(crate) fn socket_bind_addr(host: &str, port: u16) -> String {
    let bind_host = if host.contains(':') {
        format!("[{host}]")
    } else {
        host.to_string()
    };
    format!("{bind_host}:{port}")
}

pub(crate) struct TransportUrl {
    url: Url,
    params: HashMap<String, String>,
}

impl TransportUrl {
    pub(crate) fn parse(endpoint: &str, expected_schemes: &[&str]) -> std::io::Result<Self> {
        let url = Url::parse(endpoint).map_err(io_other)?;
        if !expected_schemes
            .iter()
            .any(|scheme| *scheme == url.scheme())
        {
            return Err(std::io::Error::new(
                std::io::ErrorKind::InvalidInput,
                format!("unsupported scheme: {}", url.scheme()),
            ));
        }
        let params = url
            .query_pairs()
            .map(|(k, v)| (k.to_string(), v.to_string()))
            .collect();
        Ok(Self { url, params })
    }

    #[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
    pub(crate) fn scheme(&self) -> &str {
        self.url.scheme()
    }

    #[cfg(feature = "tokio-websocket")]
    pub(crate) fn as_str(&self) -> &str {
        self.url.as_str()
    }

    pub(crate) fn host_port(&self) -> std::io::Result<(String, u16)> {
        let host = self.url.host_str().ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::InvalidInput, "endpoint missing host")
        })?;
        let port = self.url.port_or_known_default().ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::InvalidInput, "endpoint missing port")
        })?;
        Ok((host.to_string(), port))
    }

    #[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
    pub(crate) fn bind_addr(&self) -> std::io::Result<String> {
        let (host, port) = self.host_port()?;
        Ok(socket_bind_addr(&host, port))
    }

    #[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
    pub(crate) fn required_param(&self, key: &str, env: &str) -> std::io::Result<String> {
        self.params
            .get(key)
            .cloned()
            .or_else(|| std::env::var(env).ok())
            .ok_or_else(|| {
                std::io::Error::new(
                    std::io::ErrorKind::InvalidInput,
                    format!("missing tls parameter `{key}` (or env `{env}`)"),
                )
            })
    }

    #[cfg(feature = "tokio-tls")]
    pub(crate) fn optional_param(&self, key: &str) -> Option<&str> {
        self.params.get(key).map(String::as_str)
    }

    #[cfg(feature = "quic-s2n")]
    pub(crate) fn param_or_env_or(&self, key: &str, env: &str, default: &str) -> String {
        self.params
            .get(key)
            .cloned()
            .or_else(|| std::env::var(env).ok())
            .unwrap_or_else(|| default.to_string())
    }
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
fn load_certs(path: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
    let file = File::open(path)?;
    let mut reader = BufReader::new(file);
    let certs = rustls_pemfile::certs(&mut reader)
        .collect::<Result<Vec<_>, _>>()
        .map_err(io_other)?;
    if certs.is_empty() {
        return Err(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("no certificates found in {path}"),
        ));
    }
    Ok(certs)
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
fn load_private_key(path: &str) -> std::io::Result<PrivateKeyDer<'static>> {
    let file = File::open(path)?;
    let mut reader = BufReader::new(file);
    if let Some(key) = rustls_pemfile::private_key(&mut reader).map_err(io_other)? {
        return Ok(key);
    }
    Err(std::io::Error::new(
        std::io::ErrorKind::InvalidInput,
        format!("no private key found in {path}"),
    ))
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
pub(crate) fn build_server_acceptor(
    cert_path: &str,
    key_path: &str,
) -> std::io::Result<TlsAcceptor> {
    let certs = load_certs(cert_path)?;
    let key = load_private_key(key_path)?;
    let config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(certs, key)
        .map_err(io_other)?;
    Ok(TlsAcceptor::from(Arc::new(config)))
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
pub(crate) fn build_client_config(ca_path: &str) -> std::io::Result<Arc<ClientConfig>> {
    let mut roots = RootCertStore::empty();
    for cert in load_certs(ca_path)? {
        roots.add(cert).map_err(io_other)?;
    }
    let config = ClientConfig::builder()
        .with_root_certificates(roots)
        .with_no_client_auth();
    Ok(Arc::new(config))
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
pub(crate) fn build_client_connector(ca_path: &str) -> std::io::Result<TlsConnector> {
    Ok(TlsConnector::from(build_client_config(ca_path)?))
}

#[cfg(any(feature = "tokio-tls", feature = "tokio-websocket"))]
pub(crate) fn server_name(
    host: &str,
    override_name: Option<&str>,
) -> std::io::Result<ServerName<'static>> {
    let name = override_name.unwrap_or(host).to_string();
    ServerName::try_from(name).map_err(io_other)
}