knafeh 1.0.0

QUIC-based RPC library with Python bindings
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::{RpcRequest, RpcResponse};
use crate::rpc::service::{MethodDescriptor, MethodKind, Service};
use crate::rpc::stream::{RpcStreamRequest, RpcStreamResponse};

/// Routes incoming RPC requests to the appropriate service and method.
pub struct MethodRouter {
    /// Map from service name → service implementation.
    services: HashMap<String, Arc<dyn Service>>,
    /// Map from `"service/method"` → method descriptor (for fast lookup).
    methods: HashMap<String, (Arc<dyn Service>, MethodDescriptor)>,
}

impl MethodRouter {
    pub fn new() -> Self {
        Self {
            services: HashMap::new(),
            methods: HashMap::new(),
        }
    }

    /// Register a service with the router.
    pub fn add_service(&mut self, service: Arc<dyn Service>) {
        let service_name = service.name().to_string();
        for method in service.methods() {
            let key = format!("{}/{}", service_name, method.name);
            self.methods.insert(key, (Arc::clone(&service), method));
        }
        self.services.insert(service_name, service);
    }

    /// Look up the method kind for a given path.
    pub fn lookup(&self, path: &str) -> Option<(&dyn Service, &MethodDescriptor)> {
        self.methods
            .get(path)
            .map(|(svc, desc)| (svc.as_ref(), desc))
    }

    /// Route and execute a unary RPC call.
    pub async fn route_unary(&self, request: RpcRequest) -> Result<RpcResponse, KnafehError> {
        let method_path = &request.method;
        let (service, descriptor) =
            self.methods
                .get(method_path)
                .ok_or_else(|| KnafehError::Service {
                    code: RpcStatusCode::NotFound,
                    message: format!("method not found: {method_path}"),
                })?;

        if descriptor.kind != MethodKind::Unary {
            return Err(KnafehError::Service {
                code: RpcStatusCode::Unimplemented,
                message: format!("method {method_path} is {:?}, not Unary", descriptor.kind),
            });
        }

        let method_name = &descriptor.name;
        service.call_unary(method_name, request).await
    }

    /// Route and execute a server-streaming RPC call.
    pub async fn route_server_stream(
        &self,
        request: RpcRequest,
    ) -> Result<RpcStreamResponse, KnafehError> {
        let method_path = &request.method;
        let (service, descriptor) =
            self.methods
                .get(method_path)
                .ok_or_else(|| KnafehError::Service {
                    code: RpcStatusCode::NotFound,
                    message: format!("method not found: {method_path}"),
                })?;

        if descriptor.kind != MethodKind::ServerStreaming {
            return Err(KnafehError::Service {
                code: RpcStatusCode::Unimplemented,
                message: format!(
                    "method {method_path} is {:?}, not ServerStreaming",
                    descriptor.kind
                ),
            });
        }

        let method_name = &descriptor.name;
        service.call_server_stream(method_name, request).await
    }

    /// Route and execute a client-streaming RPC call.
    pub async fn route_client_stream(
        &self,
        method_path: &str,
        stream: RpcStreamRequest,
    ) -> Result<RpcResponse, KnafehError> {
        let (service, descriptor) =
            self.methods
                .get(method_path)
                .ok_or_else(|| KnafehError::Service {
                    code: RpcStatusCode::NotFound,
                    message: format!("method not found: {method_path}"),
                })?;

        if descriptor.kind != MethodKind::ClientStreaming {
            return Err(KnafehError::Service {
                code: RpcStatusCode::Unimplemented,
                message: format!(
                    "method {method_path} is {:?}, not ClientStreaming",
                    descriptor.kind
                ),
            });
        }

        let method_name = &descriptor.name;
        service.call_client_stream(method_name, stream).await
    }

    /// Route and execute a bidirectional-streaming RPC call.
    pub async fn route_bidi_stream(
        &self,
        method_path: &str,
        stream: RpcStreamRequest,
    ) -> Result<RpcStreamResponse, KnafehError> {
        let (service, descriptor) =
            self.methods
                .get(method_path)
                .ok_or_else(|| KnafehError::Service {
                    code: RpcStatusCode::NotFound,
                    message: format!("method not found: {method_path}"),
                })?;

        if descriptor.kind != MethodKind::BidiStreaming {
            return Err(KnafehError::Service {
                code: RpcStatusCode::Unimplemented,
                message: format!(
                    "method {method_path} is {:?}, not BidiStreaming",
                    descriptor.kind
                ),
            });
        }

        let method_name = &descriptor.name;
        service.call_bidi_stream(method_name, stream).await
    }

    /// List all registered service names.
    pub fn service_names(&self) -> Vec<&str> {
        self.services.keys().map(|s| s.as_str()).collect()
    }

    /// List all registered method paths.
    pub fn method_paths(&self) -> Vec<&str> {
        self.methods.keys().map(|s| s.as_str()).collect()
    }
}

impl Default for MethodRouter {
    fn default() -> Self {
        Self::new()
    }
}