knafeh 1.1.0

QUIC-based RPC library with Python bindings
Documentation
//! Integration tests for the Knafeh RPC framework.

use std::collections::HashMap;

use knafeh::codec::JsonCodec;
use knafeh::error::RpcStatusCode;
use knafeh::rpc::message::{RpcRequest, RpcResponse};
use knafeh::rpc::middleware::{LoggingInterceptor, MiddlewareStack};
use knafeh::rpc::router::MethodRouter;
use knafeh::rpc::service::{MethodDescriptor, MethodKind, Service};
use knafeh::rpc::stream::{RpcStreamRequest, RpcStreamResponse};
use knafeh::transport::connection::{
    extract_method_path, headers_to_metadata, metadata_to_headers, validate_metadata_key,
};
use knafeh::transport::tls::TlsConfig;
use knafeh::Server;

use async_trait::async_trait;

// ---------------------------------------------------------------------------
// Test service
// ---------------------------------------------------------------------------

struct EchoService;

#[async_trait]
impl Service for EchoService {
    fn name(&self) -> &str {
        "echo"
    }

    fn methods(&self) -> Vec<MethodDescriptor> {
        vec![MethodDescriptor {
            name: "echo".to_string(),
            kind: MethodKind::Unary,
        }]
    }

    async fn call_unary(
        &self,
        _method: &str,
        request: RpcRequest,
    ) -> Result<RpcResponse, knafeh::KnafehError> {
        Ok(RpcResponse::ok(request.body))
    }

    async fn call_server_stream(
        &self,
        _method: &str,
        _request: RpcRequest,
    ) -> Result<RpcStreamResponse, knafeh::KnafehError> {
        unimplemented!()
    }

    async fn call_client_stream(
        &self,
        _method: &str,
        _stream: RpcStreamRequest,
    ) -> Result<RpcResponse, knafeh::KnafehError> {
        unimplemented!()
    }

    async fn call_bidi_stream(
        &self,
        _method: &str,
        _stream: RpcStreamRequest,
    ) -> Result<RpcStreamResponse, knafeh::KnafehError> {
        unimplemented!()
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[test]
fn test_codec_json_roundtrip() {
    use knafeh::codec::Codec;

    let codec = JsonCodec::new();
    let input = br#"{"hello":"world"}"#;
    let encoded = codec.encode(input).unwrap();
    let decoded = codec.decode(&encoded).unwrap();
    assert_eq!(decoded, input);
}

#[test]
fn test_codec_json_invalid() {
    use knafeh::codec::Codec;

    let codec = JsonCodec::new();
    assert!(codec.encode(b"not json").is_err());
    assert!(codec.decode(b"not json").is_err());
}

#[tokio::test]
async fn test_router_unary() {
    let mut router = MethodRouter::new();
    router.add_service(std::sync::Arc::new(EchoService));

    let request = RpcRequest::new("echo/echo", b"hello".to_vec());
    let response = router.route_unary(request).await.unwrap();

    assert_eq!(response.body, b"hello");
    assert_eq!(response.status.code, RpcStatusCode::Ok);
}

#[tokio::test]
async fn test_router_not_found() {
    let router = MethodRouter::new();
    let request = RpcRequest::new("unknown/method", b"hello".to_vec());
    let result = router.route_unary(request).await;

    assert!(result.is_err());
}

#[tokio::test]
async fn test_middleware_stack() {
    let mut stack = MiddlewareStack::new();
    stack.add(std::sync::Arc::new(LoggingInterceptor));

    let mut request = RpcRequest::new("test/method", b"{}".to_vec());
    stack.apply_request(&mut request).await.unwrap();

    let mut response = RpcResponse::ok(b"{}".to_vec());
    stack.apply_response(&mut response).await.unwrap();
}

#[test]
fn test_metadata_header_roundtrip() {
    let mut metadata = HashMap::new();
    metadata.insert("user-key".to_string(), "value1".to_string());

    let headers = metadata_to_headers(&metadata).unwrap();
    let recovered = headers_to_metadata(
        &headers
            .iter()
            .map(|(k, v)| (k.clone(), v.clone()))
            .collect::<Vec<_>>(),
    );

    assert_eq!(recovered.get("user-key").unwrap(), "value1");
}

#[test]
fn test_metadata_header_conversion_rejects_invalid_keys() {
    let mut metadata = HashMap::new();
    metadata.insert("x-rpc-status".to_string(), "0".to_string());

    assert!(metadata_to_headers(&metadata).is_err());
}

#[test]
fn test_metadata_rejects_reserved_headers() {
    assert!(validate_metadata_key(":path").is_err());
    assert!(validate_metadata_key("x-rpc-method-kind").is_err());
    assert!(validate_metadata_key("x-rpc-status").is_err());
    assert!(validate_metadata_key("method-kind").is_err());
    assert!(validate_metadata_key("status").is_err());
    assert!(validate_metadata_key("status-message").is_err());
    assert!(validate_metadata_key("trace-id").is_ok());
}

#[test]
fn test_extract_method_path() {
    assert_eq!(
        extract_method_path(b"/greeter/say_hello").unwrap(),
        "greeter/say_hello"
    );
    assert!(extract_method_path(b"/invalid").is_err());
}

#[test]
fn test_request_builder() {
    let req = RpcRequest::new("svc/method", b"body".to_vec()).with_metadata("key", "value");
    assert_eq!(req.service_name(), Some("svc"));
    assert_eq!(req.method_name(), Some("method"));
    assert_eq!(req.metadata.get("key"), Some(&"value".to_string()));
}

#[test]
fn test_response_error() {
    let resp = RpcResponse::error(RpcStatusCode::NotFound, "not found");
    assert_eq!(resp.status.code, RpcStatusCode::NotFound);
    assert_eq!(resp.status.message, "not found");
    assert!(resp.body.is_empty());
}

#[test]
fn test_rpc_status_code_roundtrip() {
    for code in 0..=16 {
        let status = RpcStatusCode::from_u8(code);
        // Ensure it doesn't panic.
        let _ = format!("{status:?}");
    }
    // Unknown codes map to Unknown.
    assert_eq!(RpcStatusCode::from_u8(255), RpcStatusCode::Unknown);
}

#[test]
fn test_server_rejects_client_tls_config() {
    let result = Server::builder()
        .tls(TlsConfig::client_insecure())
        .add_service(EchoService)
        .build();

    match result {
        Ok(_) => panic!("expected server TLS validation to fail"),
        Err(err) => assert!(err.to_string().contains("server TLS requires cert + key")),
    }
}