knafeh 1.0.0

QUIC-based RPC library with Python bindings
Documentation
//! Raw QUIC transport using quinn — bypasses HTTP/3 entirely.
//!
//! Each RPC call opens a bidirectional QUIC stream, writes the request
//! frame, and reads the response frame. No HTTP headers, no pseudo-headers,
//! no content-type negotiation.
//!
//! Wire format: see [`super::quic_wire`].

use std::net::SocketAddr;
use std::sync::Arc;

use quinn::{ClientConfig, Endpoint, RecvStream, SendStream, ServerConfig};

use crate::codec::Codec;
use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::{Metadata, RpcRequest, RpcResponse, RpcStatus};
use crate::rpc::middleware::MiddlewareStack;
use crate::rpc::router::MethodRouter;
use crate::transport::quic_wire;

// ---------------------------------------------------------------------------
// Client
// ---------------------------------------------------------------------------

/// A raw-QUIC RPC client. Opens bidirectional streams per call.
pub struct QuicClient {
    connection: quinn::Connection,
    codec: Arc<dyn Codec>,
    middleware: Arc<MiddlewareStack>,
}

impl QuicClient {
    /// Connect to a server at `addr` without certificate verification.
    ///
    /// **Warning:** Skips TLS verification — use only for testing or
    /// trusted networks. For production, use a proper certificate verifier.
    pub async fn connect_insecure(
        addr: SocketAddr,
        codec: Arc<dyn Codec>,
        middleware: Arc<MiddlewareStack>,
    ) -> Result<Self, KnafehError> {
        let crypto = rustls::ClientConfig::builder()
            .dangerous()
            .with_custom_certificate_verifier(Arc::new(SkipVerification))
            .with_no_client_auth();

        let client_cfg = ClientConfig::new(Arc::new(
            quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
                .map_err(|e| KnafehError::Tls(e.to_string()))?,
        ));

        let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())
            .map_err(|e| KnafehError::Transport(e.to_string()))?;
        endpoint.set_default_client_config(client_cfg);

        let connection = endpoint
            .connect(addr, "localhost")
            .map_err(|e| KnafehError::Transport(e.to_string()))?
            .await
            .map_err(|e| KnafehError::Transport(e.to_string()))?;

        Ok(Self {
            connection,
            codec,
            middleware,
        })
    }

    /// Make a unary RPC call over a fresh bidirectional QUIC stream.
    pub async fn call(&self, method: &str, body: Vec<u8>) -> Result<RpcResponse, KnafehError> {
        let encoded_body = self.codec.encode(&body)?;

        let mut request = RpcRequest {
            method: method.to_string(),
            metadata: Metadata::new(),
            body: encoded_body,
        };
        self.middleware.apply_request(&mut request).await?;

        let (mut send, mut recv) = self
            .connection
            .open_bi()
            .await
            .map_err(|e| KnafehError::Transport(format!("open_bi failed: {e}")))?;

        let req_bytes =
            quic_wire::encode_request(&request.method, &request.body, &request.metadata);
        send.write_all(&req_bytes)
            .await
            .map_err(|e| KnafehError::Transport(format!("write failed: {e}")))?;
        send.finish()
            .map_err(|e| KnafehError::Transport(format!("finish failed: {e}")))?;

        let resp_bytes = recv
            .read_to_end(super::quic_wire::MAX_MESSAGE_SIZE)
            .await
            .map_err(|e| KnafehError::Transport(format!("read failed: {e}")))?;

        let (status_code, status_message, resp_body, metadata) =
            quic_wire::decode_response(&resp_bytes)?;

        let mut response = RpcResponse {
            status: RpcStatus {
                code: status_code,
                message: status_message,
            },
            metadata,
            body: resp_body,
        };

        self.middleware.apply_response(&mut response).await?;

        if response.status.is_ok() {
            response.body = self.codec.decode(&response.body)?;
        }

        Ok(response)
    }
}

// ---------------------------------------------------------------------------
// Server
// ---------------------------------------------------------------------------

/// A raw-QUIC RPC server. Accepts bidirectional streams.
pub struct QuicServer {
    endpoint: Endpoint,
    router: Arc<MethodRouter>,
    codec: Arc<dyn Codec>,
    middleware: Arc<MiddlewareStack>,
}

impl QuicServer {
    /// Build and bind a QUIC server with the given TLS cert/key.
    pub fn bind(
        addr: SocketAddr,
        cert_pem: &str,
        key_pem: &str,
        router: MethodRouter,
        codec: Arc<dyn Codec>,
        middleware: MiddlewareStack,
    ) -> Result<Self, KnafehError> {
        let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
            .collect::<Result<Vec<_>, _>>()
            .map_err(|e| KnafehError::Tls(format!("cert parse error: {e}")))?;

        let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())
            .map_err(|e| KnafehError::Tls(format!("key parse error: {e}")))?
            .ok_or_else(|| KnafehError::Tls("no private key found".into()))?;

        let server_crypto = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(certs, key)
            .map_err(|e| KnafehError::Tls(e.to_string()))?;

        let server_config = ServerConfig::with_crypto(Arc::new(
            quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
                .map_err(|e| KnafehError::Tls(e.to_string()))?,
        ));

        let endpoint = Endpoint::server(server_config, addr)
            .map_err(|e| KnafehError::Transport(e.to_string()))?;

        Ok(Self {
            endpoint,
            router: Arc::new(router),
            codec,
            middleware: Arc::new(middleware),
        })
    }

    /// The local address the server is bound to.
    pub fn local_addr(&self) -> Result<SocketAddr, KnafehError> {
        self.endpoint
            .local_addr()
            .map_err(|e| KnafehError::Transport(e.to_string()))
    }

    /// Accept connections and handle RPC calls.
    pub async fn serve(&self) -> Result<(), KnafehError> {
        while let Some(incoming) = self.endpoint.accept().await {
            let conn = incoming
                .await
                .map_err(|e| KnafehError::Transport(format!("accept failed: {e}")))?;

            let router = Arc::clone(&self.router);
            let codec = Arc::clone(&self.codec);
            let middleware = Arc::clone(&self.middleware);

            tokio::spawn(async move {
                loop {
                    let stream = conn.accept_bi().await;
                    match stream {
                        Ok((send, recv)) => {
                            let router = Arc::clone(&router);
                            let codec = Arc::clone(&codec);
                            let middleware = Arc::clone(&middleware);
                            tokio::spawn(async move {
                                if let Err(e) =
                                    handle_stream(send, recv, &router, codec.as_ref(), &middleware)
                                        .await
                                {
                                    tracing::warn!(error = %e, "stream handler error");
                                }
                            });
                        }
                        Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
                        Err(e) => {
                            tracing::warn!(error = %e, "accept_bi error");
                            break;
                        }
                    }
                }
            });
        }
        Ok(())
    }
}

async fn handle_stream(
    mut send: SendStream,
    mut recv: RecvStream,
    router: &MethodRouter,
    codec: &dyn Codec,
    middleware: &MiddlewareStack,
) -> Result<(), KnafehError> {
    let req_bytes = recv
        .read_to_end(super::quic_wire::MAX_MESSAGE_SIZE)
        .await
        .map_err(|e| KnafehError::Transport(format!("read failed: {e}")))?;

    let (method, raw_body, metadata) = quic_wire::decode_request(&req_bytes)?;

    let body = codec.decode(&raw_body)?;

    let mut request = RpcRequest {
        method,
        metadata,
        body,
    };
    middleware.apply_request(&mut request).await?;

    let mut response = match router.route_unary(request).await {
        Ok(resp) => resp,
        Err(KnafehError::Service { code, message }) => RpcResponse::error(code, message),
        Err(e) => RpcResponse::error(RpcStatusCode::Internal, e.to_string()),
    };

    middleware.apply_response(&mut response).await?;

    let encoded_body = codec.encode(&response.body)?;

    let resp_bytes = quic_wire::encode_response(
        response.status.code,
        &response.status.message,
        &encoded_body,
        &response.metadata,
    );
    send.write_all(&resp_bytes)
        .await
        .map_err(|e| KnafehError::Transport(format!("write failed: {e}")))?;
    send.finish()
        .map_err(|e| KnafehError::Transport(format!("finish failed: {e}")))?;

    Ok(())
}

// ---------------------------------------------------------------------------
// TLS skip verification (for testing)
// ---------------------------------------------------------------------------

#[derive(Debug)]
struct SkipVerification;

impl rustls::client::danger::ServerCertVerifier for SkipVerification {
    fn verify_server_cert(
        &self,
        _: &rustls::pki_types::CertificateDer,
        _: &[rustls::pki_types::CertificateDer],
        _: &rustls::pki_types::ServerName,
        _: &[u8],
        _: rustls::pki_types::UnixTime,
    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
        Ok(rustls::client::danger::ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _: &[u8],
        _: &rustls::pki_types::CertificateDer,
        _: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _: &[u8],
        _: &rustls::pki_types::CertificateDer,
        _: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
        vec![
            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
            rustls::SignatureScheme::RSA_PSS_SHA256,
            rustls::SignatureScheme::RSA_PSS_SHA384,
            rustls::SignatureScheme::RSA_PSS_SHA512,
            rustls::SignatureScheme::ED25519,
        ]
    }
}