knafeh 1.0.0

QUIC-based RPC library with Python bindings
Documentation
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;

use crate::error::KnafehError;
use crate::rpc::message::{RpcRequest, RpcResponse};

/// An interceptor that can inspect and modify RPC requests and responses.
///
/// Interceptors are applied in order: `on_request` runs before the handler
/// (first interceptor first), and `on_response` runs after (last interceptor
/// first — reverse order).
#[async_trait]
pub trait Interceptor: Send + Sync + 'static {
    /// Called before the request is dispatched to the service handler.
    /// May modify the request or short-circuit by returning an error.
    async fn on_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
        let _ = request;
        Ok(())
    }

    /// Called after the service handler produces a response.
    /// May modify the response.
    async fn on_response(&self, response: &mut RpcResponse) -> Result<(), KnafehError> {
        let _ = response;
        Ok(())
    }
}

/// An ordered stack of interceptors applied to every RPC call.
pub struct MiddlewareStack {
    interceptors: Vec<Arc<dyn Interceptor>>,
}

impl MiddlewareStack {
    pub fn new() -> Self {
        Self {
            interceptors: Vec::new(),
        }
    }

    /// Add an interceptor to the end of the stack.
    pub fn add(&mut self, interceptor: Arc<dyn Interceptor>) {
        self.interceptors.push(interceptor);
    }

    /// Run all `on_request` interceptors in order.
    #[inline]
    pub async fn apply_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
        if self.interceptors.is_empty() {
            return Ok(());
        }
        for interceptor in &self.interceptors {
            interceptor.on_request(request).await?;
        }
        Ok(())
    }

    /// Run all `on_response` interceptors in reverse order.
    #[inline]
    pub async fn apply_response(&self, response: &mut RpcResponse) -> Result<(), KnafehError> {
        if self.interceptors.is_empty() {
            return Ok(());
        }
        for interceptor in self.interceptors.iter().rev() {
            interceptor.on_response(response).await?;
        }
        Ok(())
    }

    pub fn is_empty(&self) -> bool {
        self.interceptors.is_empty()
    }
}

impl Default for MiddlewareStack {
    fn default() -> Self {
        Self::new()
    }
}

// ---------------------------------------------------------------------------
// Built-in interceptors
// ---------------------------------------------------------------------------

/// Interceptor that logs RPC calls via the `tracing` crate.
pub struct LoggingInterceptor;

#[async_trait]
impl Interceptor for LoggingInterceptor {
    async fn on_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
        tracing::info!(method = %request.method, body_len = request.body.len(), "RPC request");
        Ok(())
    }

    async fn on_response(&self, response: &mut RpcResponse) -> Result<(), KnafehError> {
        tracing::info!(
            status = ?response.status.code,
            body_len = response.body.len(),
            "RPC response"
        );
        Ok(())
    }
}

/// Interceptor that enforces a per-call timeout.
pub struct TimeoutInterceptor {
    pub duration: Duration,
}

impl TimeoutInterceptor {
    pub fn new(duration: Duration) -> Self {
        Self { duration }
    }
}

#[async_trait]
impl Interceptor for TimeoutInterceptor {
    async fn on_request(&self, request: &mut RpcRequest) -> Result<(), KnafehError> {
        // Store the deadline in metadata so the handler/transport can enforce it.
        request.metadata.insert(
            "x-rpc-timeout-ms".to_string(),
            self.duration.as_millis().to_string(),
        );
        Ok(())
    }
}