pub mod framing;
pub mod handler;
pub mod service;
pub mod streaming;
pub use framing::parse_grpc_client_stream;
pub use handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData, GrpcResponseData, RpcMode};
pub use service::{GenericGrpcService, copy_metadata, is_grpc_request, parse_grpc_path};
pub use streaming::{MessageStream, StreamingRequest, StreamingResponse};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrpcConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_max_message_size")]
pub max_message_size: usize,
#[serde(default = "default_true")]
pub enable_compression: bool,
#[serde(default)]
pub request_timeout: Option<u64>,
#[serde(default = "default_max_concurrent_streams")]
pub max_concurrent_streams: u32,
#[serde(default = "default_true")]
pub enable_keepalive: bool,
#[serde(default = "default_keepalive_interval")]
pub keepalive_interval: u64,
#[serde(default = "default_keepalive_timeout")]
pub keepalive_timeout: u64,
}
impl Default for GrpcConfig {
fn default() -> Self {
Self {
enabled: true,
max_message_size: default_max_message_size(),
enable_compression: true,
request_timeout: None,
max_concurrent_streams: default_max_concurrent_streams(),
enable_keepalive: true,
keepalive_interval: default_keepalive_interval(),
keepalive_timeout: default_keepalive_timeout(),
}
}
}
const fn default_true() -> bool {
true
}
const fn default_max_message_size() -> usize {
4 * 1024 * 1024 }
const fn default_max_concurrent_streams() -> u32 {
100
}
const fn default_keepalive_interval() -> u64 {
75 }
const fn default_keepalive_timeout() -> u64 {
20 }
type GrpcHandlerEntry = (Arc<dyn GrpcHandler>, RpcMode);
const WILDCARD_METHOD: &str = "*";
#[derive(Clone)]
pub struct GrpcRegistry {
handlers: Arc<HashMap<(String, String), GrpcHandlerEntry>>,
}
impl GrpcRegistry {
pub fn new() -> Self {
Self {
handlers: Arc::new(HashMap::new()),
}
}
pub fn register(
&mut self,
service_name: impl Into<String>,
method_name: impl Into<String>,
handler: Arc<dyn GrpcHandler>,
rpc_mode: RpcMode,
) {
let handlers = Arc::make_mut(&mut self.handlers);
handlers.insert((service_name.into(), method_name.into()), (handler, rpc_mode));
}
pub fn register_service(
&mut self,
service_name: impl Into<String>,
handler: Arc<dyn GrpcHandler>,
rpc_mode: RpcMode,
) {
self.register(service_name, WILDCARD_METHOD, handler, rpc_mode);
}
pub fn get(&self, service_name: &str, method_name: &str) -> Option<(Arc<dyn GrpcHandler>, RpcMode)> {
self.handlers
.get(&(service_name.to_owned(), method_name.to_owned()))
.or_else(|| {
self.handlers
.get(&(service_name.to_owned(), WILDCARD_METHOD.to_owned()))
})
.cloned()
}
pub fn service_names(&self) -> Vec<String> {
self.handlers
.keys()
.map(|(service_name, _)| service_name.clone())
.collect::<HashSet<_>>()
.into_iter()
.collect()
}
pub fn method_names(&self, service_name: &str) -> Vec<String> {
self.handlers
.keys()
.filter(|(registered_service, method_name)| {
registered_service == service_name && method_name.as_str() != WILDCARD_METHOD
})
.map(|(_, method_name)| method_name.clone())
.collect()
}
pub fn contains(&self, service_name: &str, method_name: &str) -> bool {
self.handlers
.contains_key(&(service_name.to_owned(), method_name.to_owned()))
}
pub fn contains_service(&self, service_name: &str) -> bool {
self.handlers
.keys()
.any(|(registered_service, _)| registered_service == service_name)
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
}
impl Default for GrpcRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grpc::handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData};
use std::future::Future;
use std::pin::Pin;
struct TestHandler;
impl GrpcHandler for TestHandler {
fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
Box::pin(async {
Ok(GrpcResponseData {
payload: bytes::Bytes::new(),
metadata: tonic::metadata::MetadataMap::new(),
})
})
}
fn service_name(&self) -> &'static str {
"test.Service"
}
}
#[test]
fn test_grpc_config_default() {
let config = GrpcConfig::default();
assert!(config.enabled);
assert_eq!(config.max_message_size, 4 * 1024 * 1024);
assert!(config.enable_compression);
assert!(config.request_timeout.is_none());
assert_eq!(config.max_concurrent_streams, 100);
assert!(config.enable_keepalive);
assert_eq!(config.keepalive_interval, 75);
assert_eq!(config.keepalive_timeout, 20);
}
#[test]
fn test_grpc_config_serialization() {
let config = GrpcConfig::default();
let json = serde_json::to_string(&config).unwrap();
let deserialized: GrpcConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.enabled, deserialized.enabled);
assert_eq!(config.max_message_size, deserialized.max_message_size);
assert_eq!(config.enable_compression, deserialized.enable_compression);
}
#[test]
fn test_grpc_registry_new() {
let registry = GrpcRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_grpc_registry_register() {
let mut registry = GrpcRegistry::new();
let handler = Arc::new(TestHandler);
registry.register("test.Service", "TestMethod", handler, RpcMode::Unary);
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
assert!(registry.contains("test.Service", "TestMethod"));
}
#[test]
fn test_grpc_registry_get() {
let mut registry = GrpcRegistry::new();
let handler = Arc::new(TestHandler);
registry.register("test.Service", "TestMethod", handler, RpcMode::Unary);
let retrieved = registry.get("test.Service", "TestMethod");
assert!(retrieved.is_some());
let (handler, rpc_mode) = retrieved.unwrap();
assert_eq!(handler.service_name(), "test.Service");
assert_eq!(rpc_mode, RpcMode::Unary);
}
#[test]
fn test_grpc_registry_get_nonexistent() {
let registry = GrpcRegistry::new();
let result = registry.get("nonexistent.Service", "MissingMethod");
assert!(result.is_none());
}
#[test]
fn test_grpc_registry_service_names() {
let mut registry = GrpcRegistry::new();
registry.register("service1", "Method1", Arc::new(TestHandler), RpcMode::Unary);
registry.register("service2", "Method2", Arc::new(TestHandler), RpcMode::ServerStreaming);
registry.register("service3", "Method3", Arc::new(TestHandler), RpcMode::Unary);
let mut names = registry.service_names();
names.sort();
assert_eq!(names, vec!["service1", "service2", "service3"]);
}
#[test]
fn test_grpc_registry_contains() {
let mut registry = GrpcRegistry::new();
registry.register("test.Service", "TestMethod", Arc::new(TestHandler), RpcMode::Unary);
assert!(registry.contains("test.Service", "TestMethod"));
assert!(!registry.contains("other.Service", "TestMethod"));
}
#[test]
fn test_grpc_registry_multiple_services() {
let mut registry = GrpcRegistry::new();
registry.register("user.Service", "GetUser", Arc::new(TestHandler), RpcMode::Unary);
registry.register(
"post.Service",
"ListPosts",
Arc::new(TestHandler),
RpcMode::ServerStreaming,
);
assert_eq!(registry.len(), 2);
assert!(registry.contains("user.Service", "GetUser"));
assert!(registry.contains("post.Service", "ListPosts"));
}
#[test]
fn test_grpc_registry_clone() {
let mut registry = GrpcRegistry::new();
registry.register("test.Service", "TestMethod", Arc::new(TestHandler), RpcMode::Unary);
let cloned = registry.clone();
assert_eq!(cloned.len(), 1);
assert!(cloned.contains("test.Service", "TestMethod"));
}
#[test]
fn test_grpc_registry_default() {
let registry = GrpcRegistry::default();
assert!(registry.is_empty());
}
#[test]
fn test_grpc_registry_rpc_mode_storage() {
let mut registry = GrpcRegistry::new();
registry.register("unary.Service", "UnaryMethod", Arc::new(TestHandler), RpcMode::Unary);
registry.register(
"server_stream.Service",
"StreamMethod",
Arc::new(TestHandler),
RpcMode::ServerStreaming,
);
registry.register(
"client_stream.Service",
"UploadMethod",
Arc::new(TestHandler),
RpcMode::ClientStreaming,
);
registry.register(
"bidi.Service",
"ChatMethod",
Arc::new(TestHandler),
RpcMode::BidirectionalStreaming,
);
let (_, mode) = registry.get("unary.Service", "UnaryMethod").unwrap();
assert_eq!(mode, RpcMode::Unary);
let (_, mode) = registry.get("server_stream.Service", "StreamMethod").unwrap();
assert_eq!(mode, RpcMode::ServerStreaming);
let (_, mode) = registry.get("client_stream.Service", "UploadMethod").unwrap();
assert_eq!(mode, RpcMode::ClientStreaming);
let (_, mode) = registry.get("bidi.Service", "ChatMethod").unwrap();
assert_eq!(mode, RpcMode::BidirectionalStreaming);
}
#[test]
fn test_grpc_registry_service_fallback() {
let mut registry = GrpcRegistry::new();
registry.register_service("test.Service", Arc::new(TestHandler), RpcMode::Unary);
assert!(registry.contains_service("test.Service"));
assert!(registry.get("test.Service", "AnyMethod").is_some());
assert!(registry.method_names("test.Service").is_empty());
}
#[test]
fn test_grpc_registry_prefers_method_specific_handler() {
struct MethodSpecificHandler;
impl GrpcHandler for MethodSpecificHandler {
fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
Box::pin(async {
Ok(GrpcResponseData {
payload: bytes::Bytes::from("method-specific"),
metadata: tonic::metadata::MetadataMap::new(),
})
})
}
fn service_name(&self) -> &str {
"test.Service"
}
}
let mut registry = GrpcRegistry::new();
registry.register_service("test.Service", Arc::new(TestHandler), RpcMode::Unary);
registry.register(
"test.Service",
"GetThing",
Arc::new(MethodSpecificHandler),
RpcMode::ServerStreaming,
);
let (_, mode) = registry.get("test.Service", "GetThing").unwrap();
assert_eq!(mode, RpcMode::ServerStreaming);
let (_, fallback_mode) = registry.get("test.Service", "OtherThing").unwrap();
assert_eq!(fallback_mode, RpcMode::Unary);
}
}