use crate::grpc::GrpcHandler;
use crate::prelude::HandlerGetter;
use crate::prelude::Route;
use http::Method;
use std::sync::Arc;
use tonic::body::Body;
use tonic::codegen::Service;
use tonic::server::NamedService;
pub trait GrpcRegister<S> {
fn get_handler(self) -> GrpcHandler<S>;
fn service(self) -> Route;
fn register(self, route: &mut Route);
}
impl<S> GrpcRegister<S> for S
where
S: Service<http::Request<Body>, Response = http::Response<Body>> + NamedService,
S: Clone + Send + 'static,
S: Sync + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
fn get_handler(self) -> GrpcHandler<S> {
GrpcHandler::new(self)
}
fn service(self) -> Route {
let handler = self.get_handler();
let path = handler.path().to_string();
let handler = Arc::new(handler);
Route::new(path.as_str()).append(
Route::new("<path:**>")
.insert_handler(Method::POST, handler.clone())
.insert_handler(Method::GET, handler),
)
}
fn register(self, route: &mut Route) {
route.push(self.service());
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
#[test]
fn test_grpc_register_get_handler() {
let mock_service = MockGreeterService::new();
let handler = mock_service.get_handler();
assert_eq!(handler.path(), "/mock.greeter/MockGreeter");
}
#[test]
fn test_grpc_register_service() {
let mock_service = MockGreeterService::new();
let route = mock_service.service();
assert_eq!(route.path, "mock.greeter");
}
#[test]
fn test_grpc_register_service_structure() {
let mock_service = MockGreeterService::new();
let route = mock_service.service();
assert_eq!(route.path, "mock.greeter");
assert!(!route.children.is_empty());
}
#[test]
fn test_grpc_register_service_path_wildcard() {
let mock_service = MockGreeterService::new();
let _route = mock_service.service();
}
#[test]
fn test_grpc_register_service_handlers() {
let mock_service = MockGreeterService::new();
let _route = mock_service.service();
}
#[test]
fn test_grpc_register_service_arc_handler() {
let mock_service = MockGreeterService::new();
let handler = mock_service.get_handler();
let _arc_handler = std::sync::Arc::new(handler);
let handler2 = MockGreeterService::new().get_handler();
let arc = std::sync::Arc::new(handler2);
assert_eq!(std::sync::Arc::strong_count(&arc), 1);
}
#[test]
fn test_grpc_register_service_path_conversion() {
let mock_service = MockGreeterService::new();
let handler = mock_service.get_handler();
let path_string = handler.path().to_string();
assert_eq!(path_string, "/mock.greeter/MockGreeter");
let _route = Route::new(path_string.as_str());
}
#[test]
fn test_grpc_register_service_chaining() {
let service = MockGreeterService::new();
let handler = service.get_handler();
let path = handler.path().to_string();
let _route = Route::new(path.as_str()).append(Route::new("<path:**>"));
}
#[test]
fn test_grpc_register_register_to_route() {
let mock_service = MockGreeterService::new();
let mut base_route = Route::new("/api");
mock_service.register(&mut base_route);
}
#[test]
fn test_grpc_register_multiple_services() {
let service1 = MockGreeterService::new();
let service2 = MockUserService::new();
let route1 = service1.service();
let route2 = service2.service();
assert_ne!(route1.path, route2.path);
assert_eq!(route1.path, "mock.greeter");
assert_eq!(route2.path, "mock.user.UserService");
}
#[test]
fn test_grpc_register_combine_routes() {
let greeter_service = MockGreeterService::new();
let user_service = MockUserService::new();
let combined_route = Route::new("/api")
.append(greeter_service.service())
.append(user_service.service());
assert_eq!(combined_route.path, "api");
}
#[test]
fn test_grpc_register_trait_bound() {
fn assert_grpc_register<S: GrpcRegister<S>>() {}
assert_grpc_register::<MockGreeterService>();
}
#[test]
fn test_grpc_register_clone() {
let service = MockGreeterService::new();
let _handler1 = service.clone().get_handler();
let _handler2 = service.get_handler();
}
#[test]
fn test_grpc_register_named_service() {
let _service = MockGreeterService::new();
assert_eq!(MockGreeterService::NAME, "/mock.greeter/MockGreeter");
}
#[test]
fn test_grpc_register_different_names() {
let _greeter = MockGreeterService::new();
let _user = MockUserService::new();
assert_ne!(MockGreeterService::NAME, MockUserService::NAME);
}
#[test]
fn test_grpc_register_empty_route() {
let service = MockGreeterService::new();
let mut empty_route = Route::new("");
service.register(&mut empty_route);
}
#[test]
fn test_grpc_register_nested_route() {
let service = MockGreeterService::new();
let mut nested_route = Route::new("/api/v1/grpc");
service.register(&mut nested_route);
}
#[derive(Clone)]
struct MockGreeterService {
_private: (),
}
impl MockGreeterService {
fn new() -> Self {
Self { _private: () }
}
}
impl NamedService for MockGreeterService {
const NAME: &'static str = "/mock.greeter/MockGreeter";
}
impl Service<http::Request<Body>> for MockGreeterService {
type Response = http::Response<Body>;
type Error = MockError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<Body>) -> Self::Future {
Box::pin(async move {
Ok(http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap())
})
}
}
#[derive(Clone)]
struct MockUserService {
_private: (),
}
impl MockUserService {
fn new() -> Self {
Self { _private: () }
}
}
impl NamedService for MockUserService {
const NAME: &'static str = "/mock.user.UserService";
}
impl Service<http::Request<Body>> for MockUserService {
type Response = http::Response<Body>;
type Error = MockError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<Body>) -> Self::Future {
Box::pin(async move {
Ok(http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap())
})
}
}
#[derive(Debug)]
struct MockError;
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Mock error")
}
}
impl std::error::Error for MockError {}
}