use std::collections::HashMap;
use knafeh::codec::JsonCodec;
use knafeh::error::RpcStatusCode;
use knafeh::rpc::message::{RpcRequest, RpcResponse};
use knafeh::rpc::middleware::{LoggingInterceptor, MiddlewareStack};
use knafeh::rpc::router::MethodRouter;
use knafeh::rpc::service::{MethodDescriptor, MethodKind, Service};
use knafeh::rpc::stream::{RpcStreamRequest, RpcStreamResponse};
use knafeh::transport::connection::{
extract_method_path, headers_to_metadata, metadata_to_headers, validate_metadata_key,
};
use knafeh::transport::tls::TlsConfig;
use knafeh::Server;
use async_trait::async_trait;
struct EchoService;
#[async_trait]
impl Service for EchoService {
fn name(&self) -> &str {
"echo"
}
fn methods(&self) -> Vec<MethodDescriptor> {
vec![MethodDescriptor {
name: "echo".to_string(),
kind: MethodKind::Unary,
}]
}
async fn call_unary(
&self,
_method: &str,
request: RpcRequest,
) -> Result<RpcResponse, knafeh::KnafehError> {
Ok(RpcResponse::ok(request.body))
}
async fn call_server_stream(
&self,
_method: &str,
_request: RpcRequest,
) -> Result<RpcStreamResponse, knafeh::KnafehError> {
unimplemented!()
}
async fn call_client_stream(
&self,
_method: &str,
_stream: RpcStreamRequest,
) -> Result<RpcResponse, knafeh::KnafehError> {
unimplemented!()
}
async fn call_bidi_stream(
&self,
_method: &str,
_stream: RpcStreamRequest,
) -> Result<RpcStreamResponse, knafeh::KnafehError> {
unimplemented!()
}
}
#[test]
fn test_codec_json_roundtrip() {
use knafeh::codec::Codec;
let codec = JsonCodec::new();
let input = br#"{"hello":"world"}"#;
let encoded = codec.encode(input).unwrap();
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(decoded, input);
}
#[test]
fn test_codec_json_invalid() {
use knafeh::codec::Codec;
let codec = JsonCodec::new();
assert!(codec.encode(b"not json").is_err());
assert!(codec.decode(b"not json").is_err());
}
#[tokio::test]
async fn test_router_unary() {
let mut router = MethodRouter::new();
router.add_service(std::sync::Arc::new(EchoService));
let request = RpcRequest::new("echo/echo", b"hello".to_vec());
let response = router.route_unary(request).await.unwrap();
assert_eq!(response.body, b"hello");
assert_eq!(response.status.code, RpcStatusCode::Ok);
}
#[tokio::test]
async fn test_router_not_found() {
let router = MethodRouter::new();
let request = RpcRequest::new("unknown/method", b"hello".to_vec());
let result = router.route_unary(request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_middleware_stack() {
let mut stack = MiddlewareStack::new();
stack.add(std::sync::Arc::new(LoggingInterceptor));
let mut request = RpcRequest::new("test/method", b"{}".to_vec());
stack.apply_request(&mut request).await.unwrap();
let mut response = RpcResponse::ok(b"{}".to_vec());
stack.apply_response(&mut response).await.unwrap();
}
#[test]
fn test_metadata_header_roundtrip() {
let mut metadata = HashMap::new();
metadata.insert("user-key".to_string(), "value1".to_string());
let headers = metadata_to_headers(&metadata).unwrap();
let recovered = headers_to_metadata(
&headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<Vec<_>>(),
);
assert_eq!(recovered.get("user-key").unwrap(), "value1");
}
#[test]
fn test_metadata_header_conversion_rejects_invalid_keys() {
let mut metadata = HashMap::new();
metadata.insert("x-rpc-status".to_string(), "0".to_string());
assert!(metadata_to_headers(&metadata).is_err());
}
#[test]
fn test_metadata_rejects_reserved_headers() {
assert!(validate_metadata_key(":path").is_err());
assert!(validate_metadata_key("x-rpc-method-kind").is_err());
assert!(validate_metadata_key("x-rpc-status").is_err());
assert!(validate_metadata_key("method-kind").is_err());
assert!(validate_metadata_key("status").is_err());
assert!(validate_metadata_key("status-message").is_err());
assert!(validate_metadata_key("trace-id").is_ok());
}
#[test]
fn test_extract_method_path() {
assert_eq!(
extract_method_path(b"/greeter/say_hello").unwrap(),
"greeter/say_hello"
);
assert!(extract_method_path(b"/invalid").is_err());
}
#[test]
fn test_request_builder() {
let req = RpcRequest::new("svc/method", b"body".to_vec()).with_metadata("key", "value");
assert_eq!(req.service_name(), Some("svc"));
assert_eq!(req.method_name(), Some("method"));
assert_eq!(req.metadata.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_response_error() {
let resp = RpcResponse::error(RpcStatusCode::NotFound, "not found");
assert_eq!(resp.status.code, RpcStatusCode::NotFound);
assert_eq!(resp.status.message, "not found");
assert!(resp.body.is_empty());
}
#[test]
fn test_rpc_status_code_roundtrip() {
for code in 0..=16 {
let status = RpcStatusCode::from_u8(code);
let _ = format!("{status:?}");
}
assert_eq!(RpcStatusCode::from_u8(255), RpcStatusCode::Unknown);
}
#[test]
fn test_server_rejects_client_tls_config() {
let result = Server::builder()
.tls(TlsConfig::client_insecure())
.add_service(EchoService)
.build();
match result {
Ok(_) => panic!("expected server TLS validation to fail"),
Err(err) => assert!(err.to_string().contains("server TLS requires cert + key")),
}
}