xidl-jsonrpc 0.9.0

A IDL codegen.
Documentation
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::{Arc, LazyLock};

use dashmap::DashMap;
use s2n_quic::client::Connect;
use tokio::sync::{Mutex, mpsc};

use super::{Listener, Stream};

type DynStream = Box<dyn Stream + Unpin + Send + 'static>;

const DEFAULT_CERT: &str = "cert.pem";
const DEFAULT_KEY: &str = "key.pem";
const DEFAULT_CA: &str = "cert.pem";
const DEFAULT_SERVER_NAME: &str = "localhost";

struct QuicEndpoint {
    addr: String,
    params: HashMap<String, String>,
}

impl QuicEndpoint {
    fn parse(endpoint: &str) -> std::io::Result<Self> {
        let addr = endpoint
            .strip_prefix("quic://")
            .unwrap_or(endpoint)
            .to_string();
        let mut parts = addr.splitn(2, '?');
        let host = parts.next().unwrap_or_default().to_string();
        if host.is_empty() {
            return Err(std::io::Error::new(
                std::io::ErrorKind::InvalidInput,
                "missing quic address",
            ));
        }
        let mut params = HashMap::new();
        if let Some(query) = parts.next() {
            for pair in query.split('&') {
                if pair.is_empty() {
                    continue;
                }
                let mut kv = pair.splitn(2, '=');
                let key = kv.next().unwrap_or_default();
                if key.is_empty() {
                    continue;
                }
                let value = kv.next().unwrap_or_default();
                params.insert(key.to_string(), value.to_string());
            }
        }
        Ok(Self { addr: host, params })
    }

    fn cert_path(&self) -> String {
        self.params
            .get("cert")
            .cloned()
            .or_else(|| std::env::var("XIDL_QUIC_CERT").ok())
            .unwrap_or_else(|| DEFAULT_CERT.to_string())
    }

    fn key_path(&self) -> String {
        self.params
            .get("key")
            .cloned()
            .or_else(|| std::env::var("XIDL_QUIC_KEY").ok())
            .unwrap_or_else(|| DEFAULT_KEY.to_string())
    }

    fn ca_path(&self) -> String {
        self.params
            .get("ca")
            .cloned()
            .or_else(|| std::env::var("XIDL_QUIC_CA").ok())
            .unwrap_or_else(|| DEFAULT_CA.to_string())
    }

    fn server_name(&self) -> String {
        self.params
            .get("server_name")
            .cloned()
            .or_else(|| std::env::var("XIDL_QUIC_SERVER_NAME").ok())
            .unwrap_or_else(|| DEFAULT_SERVER_NAME.to_string())
    }
}

fn io_other<E: std::fmt::Display>(err: E) -> std::io::Error {
    std::io::Error::other(err.to_string())
}

pub struct QuicListener {
    rx: Mutex<mpsc::UnboundedReceiver<(DynStream, SocketAddr)>>,
    _accept_task: tokio::task::JoinHandle<()>,
}

impl QuicListener {
    pub fn bind(endpoint: &str) -> std::io::Result<Self> {
        let cfg = QuicEndpoint::parse(endpoint)?;
        let cert = cfg.cert_path();
        let key = cfg.key_path();
        let mut server = s2n_quic::Server::builder()
            .with_tls((cert.as_str(), key.as_str()))
            .map_err(io_other)?
            .with_io(cfg.addr.as_str())
            .map_err(io_other)?
            .start()
            .map_err(io_other)?;

        let (tx, rx) = mpsc::unbounded_channel::<(DynStream, SocketAddr)>();
        let task = tokio::spawn(async move {
            while let Some(mut connection) = server.accept().await {
                let tx = tx.clone();
                tokio::spawn(async move {
                    let peer = connection
                        .remote_addr()
                        .unwrap_or_else(|_| SocketAddr::from(([127, 0, 0, 1], 0)));
                    while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await {
                        if tx.send((Box::new(stream), peer)).is_err() {
                            break;
                        }
                    }
                });
            }
        });
        Ok(Self {
            rx: Mutex::new(rx),
            _accept_task: task,
        })
    }
}

#[async_trait::async_trait]
impl Listener for QuicListener {
    async fn accept(&self) -> std::io::Result<(DynStream, SocketAddr)> {
        let mut rx = self.rx.lock().await;
        rx.recv().await.ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "quic listener closed")
        })
    }
}

struct QuicClientConnection {
    _client: s2n_quic::Client,
    connection: Mutex<s2n_quic::connection::Connection>,
}

type ConnectionCache = DashMap<String, Arc<QuicClientConnection>>;
static CONNECTIONS: LazyLock<ConnectionCache> = LazyLock::new(DashMap::new);

pub async fn connect_quic(endpoint: &str) -> std::io::Result<DynStream> {
    let key = endpoint.to_string();
    if let Some(entry) = CONNECTIONS.get(&key) {
        let mut connection = entry.connection.lock().await;
        let stream = connection
            .open_bidirectional_stream()
            .await
            .map_err(io_other)?;
        return Ok(Box::new(stream));
    }

    let cfg = QuicEndpoint::parse(endpoint)?;
    let client = s2n_quic::Client::builder()
        .with_tls(cfg.ca_path().as_str())
        .map_err(io_other)?
        .with_io(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
        .map_err(io_other)?
        .start()
        .map_err(io_other)?;
    let connect = Connect::new(cfg.addr.parse::<SocketAddr>().map_err(io_other)?)
        .with_server_name(cfg.server_name());
    let connection = client.connect(connect).await.map_err(io_other)?;

    let shared = Arc::new(QuicClientConnection {
        _client: client,
        connection: Mutex::new(connection),
    });
    let mut guard = shared.connection.lock().await;
    let stream = guard.open_bidirectional_stream().await.map_err(io_other)?;
    drop(guard);
    CONNECTIONS.insert(key, shared);
    Ok(Box::new(stream))
}