use std::convert::Infallible;
use std::future::Future;
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::sync::Arc;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::TcpListener;
use typeway_core::ApiSpec;
use typeway_grpc::health::HealthService;
use typeway_grpc::reflection::ReflectionService;
use typeway_grpc::service::{ApiToServiceDescriptor, GrpcServiceDescriptor};
use typeway_grpc::CollectRpcs;
use crate::body::BoxBody;
use crate::router::{Router, RouterService};
pub struct GrpcServer<A: ApiSpec> {
router: Arc<Router>,
service_name: String,
package: String,
reflection: ReflectionService,
health: HealthService,
reflection_enabled: bool,
grpc_spec_json: Option<Arc<String>>,
grpc_docs_html: Option<Arc<String>>,
#[cfg(feature = "grpc-proto-binary")]
transcoder: Option<Arc<typeway_grpc::ProtoTranscoder>>,
_api: PhantomData<A>,
}
impl<A: ApiSpec + CollectRpcs> GrpcServer<A> {
pub(crate) fn new(router: Arc<Router>, service_name: String, package: String) -> Self {
let reflection = ReflectionService::from_api::<A>(&service_name, &package);
let health = HealthService::new();
GrpcServer {
router,
service_name,
package,
reflection,
health,
reflection_enabled: true,
grpc_spec_json: None,
grpc_docs_html: None,
#[cfg(feature = "grpc-proto-binary")]
transcoder: None,
_api: PhantomData,
}
}
pub fn with_state<T: Clone + Send + Sync + 'static>(self, state: T) -> Self {
self.router.set_state_injector(Arc::new(move |ext| {
ext.insert(state.clone());
}));
self
}
pub fn with_reflection(mut self, enabled: bool) -> Self {
self.reflection_enabled = enabled;
self
}
pub fn health_service(&self) -> HealthService {
self.health.clone()
}
pub fn nest(self, prefix: &str) -> Self {
self.router.set_prefix(prefix);
self
}
pub fn max_body_size(self, max: usize) -> Self {
self.router.set_max_body_size(max);
self
}
pub fn with_grpc_docs(mut self) -> Self {
use typeway_grpc::spec::ApiToGrpcSpec;
let spec = A::grpc_spec(&self.service_name, &self.package);
let json = serde_json::to_string_pretty(&spec).expect("spec serialization");
let html = typeway_grpc::docs_page::generate_docs_html(&spec);
self.grpc_spec_json = Some(Arc::new(json));
self.grpc_docs_html = Some(Arc::new(html));
self
}
pub fn with_grpc_docs_with_handler_docs(mut self, docs: &[typeway_core::HandlerDoc]) -> Self {
use typeway_grpc::spec::ApiToGrpcSpec;
let spec = A::grpc_spec_with_docs(&self.service_name, &self.package, docs);
let json = serde_json::to_string_pretty(&spec).expect("spec serialization");
let html = typeway_grpc::docs_page::generate_docs_html(&spec);
self.grpc_spec_json = Some(Arc::new(json));
self.grpc_docs_html = Some(Arc::new(html));
self
}
#[cfg(feature = "grpc-proto-binary")]
pub fn with_proto_binary(mut self) -> Self {
use typeway_grpc::spec::ApiToGrpcSpec;
let spec = A::grpc_spec(&self.service_name, &self.package);
self.transcoder = Some(Arc::new(typeway_grpc::ProtoTranscoder::new(spec)));
self
}
pub async fn serve(
self,
addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
tracing::info!("Listening on http://{addr} (REST + gRPC)");
tracing::info!(" gRPC service: {}.{}", self.package, self.service_name);
if self.reflection_enabled {
tracing::info!(" gRPC reflection: enabled");
}
tracing::info!(" gRPC health check: enabled");
self.serve_with_shutdown(listener, std::future::pending())
.await
}
pub async fn serve_with_shutdown(
self,
listener: TcpListener,
shutdown: impl Future<Output = ()> + Send,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let multiplexer = self.build_multiplexer();
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let io = TokioIo::new(stream);
let svc = multiplexer.clone();
let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
tokio::task::spawn(async move {
if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection(io, hyper_svc)
.await
{
tracing::debug!("Connection closed: {e}");
}
});
}
() = &mut shutdown => {
tracing::info!("Shutting down gracefully...");
return Ok(());
}
}
}
}
#[cfg(feature = "protobuf")]
pub async fn serve_with_direct_handlers(
self,
addr: SocketAddr,
direct_handlers: Vec<(String, crate::grpc_direct::DirectHandler)>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
tracing::info!("Listening on http://{addr} (REST + gRPC)");
let multiplexer = self.build_multiplexer_with_directs(direct_handlers);
let shutdown = std::future::pending::<()>();
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let io = TokioIo::new(stream);
let svc = multiplexer.clone();
let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
tokio::task::spawn(async move {
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection(io, hyper_svc)
.await;
});
}
() = &mut shutdown => { return Ok(()); }
}
}
}
pub fn service_descriptor(&self) -> GrpcServiceDescriptor {
A::service_descriptor(&self.service_name, &self.package)
}
pub fn layer<L>(self, layer: L) -> LayeredGrpcServer<A, L::Service>
where
L: tower_layer::Layer<crate::grpc_dispatch::GrpcMultiplexer>,
L::Service: tower_service::Service<
http::Request<hyper::body::Incoming>,
Response = http::Response<BoxBody>,
Error = Infallible,
> + Clone
+ Send
+ 'static,
<L::Service as tower_service::Service<http::Request<hyper::body::Incoming>>>::Future:
Send + 'static,
{
let multiplexer = self.build_multiplexer();
LayeredGrpcServer {
service: layer.layer(multiplexer),
_api: PhantomData,
}
}
#[cfg(feature = "protobuf")]
fn build_multiplexer_with_directs(
self,
direct_handlers: Vec<(String, crate::grpc_direct::DirectHandler)>,
) -> crate::grpc_dispatch::GrpcMultiplexer {
let descriptor = A::service_descriptor(&self.service_name, &self.package);
let mut grpc_router =
crate::grpc_dispatch::GrpcRouter::from_router(&self.router, &descriptor);
for (path, handler) in direct_handlers {
grpc_router.add_direct_handler(path, handler);
}
crate::grpc_dispatch::GrpcMultiplexer {
rest: RouterService::new(self.router),
grpc_router: Arc::new(grpc_router),
reflection: Arc::new(self.reflection),
health: self.health,
reflection_enabled: self.reflection_enabled,
grpc_spec_json: self.grpc_spec_json,
grpc_docs_html: self.grpc_docs_html,
#[cfg(feature = "grpc-proto-binary")]
transcoder: self.transcoder,
}
}
fn build_multiplexer(self) -> crate::grpc_dispatch::GrpcMultiplexer {
let descriptor = A::service_descriptor(&self.service_name, &self.package);
let grpc_router = crate::grpc_dispatch::GrpcRouter::from_router(&self.router, &descriptor);
crate::grpc_dispatch::GrpcMultiplexer {
rest: RouterService::new(self.router),
grpc_router: Arc::new(grpc_router),
reflection: Arc::new(self.reflection),
health: self.health,
reflection_enabled: self.reflection_enabled,
grpc_spec_json: self.grpc_spec_json,
grpc_docs_html: self.grpc_docs_html,
#[cfg(feature = "grpc-proto-binary")]
transcoder: self.transcoder,
}
}
}
pub struct LayeredGrpcServer<A: ApiSpec, S> {
service: S,
_api: PhantomData<A>,
}
impl<A, S> LayeredGrpcServer<A, S>
where
A: ApiSpec + CollectRpcs,
S: tower_service::Service<
http::Request<hyper::body::Incoming>,
Response = http::Response<BoxBody>,
Error = Infallible,
> + Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
pub async fn serve(
self,
addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
tracing::info!("Listening on http://{addr} (REST + gRPC, layered)");
self.serve_with_shutdown(listener, std::future::pending())
.await
}
pub async fn serve_with_shutdown(
self,
listener: TcpListener,
shutdown: impl Future<Output = ()> + Send,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let svc = self.service;
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let io = TokioIo::new(stream);
let svc = svc.clone();
let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
tokio::task::spawn(async move {
if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection(io, hyper_svc)
.await
{
tracing::debug!("Connection closed: {e}");
}
});
}
() = &mut shutdown => {
tracing::info!("Shutting down gracefully...");
return Ok(());
}
}
}
}
pub fn layer<L>(self, layer: L) -> LayeredGrpcServer<A, L::Service>
where
L: tower_layer::Layer<S>,
L::Service: tower_service::Service<
http::Request<hyper::body::Incoming>,
Response = http::Response<BoxBody>,
Error = Infallible,
> + Clone
+ Send
+ 'static,
<L::Service as tower_service::Service<http::Request<hyper::body::Incoming>>>::Future:
Send + 'static,
{
LayeredGrpcServer {
service: layer.layer(self.service),
_api: PhantomData,
}
}
}
pub(crate) fn make_grpc_server<A: ApiSpec + CollectRpcs>(
router: Arc<Router>,
service_name: &str,
package: &str,
) -> GrpcServer<A> {
GrpcServer::new(router, service_name.to_string(), package.to_string())
}
use typeway_grpc::{EndpointToRpc, RpcMethod};
impl<Auth, E: EndpointToRpc> EndpointToRpc for crate::auth::Protected<Auth, E> {
fn to_rpc() -> RpcMethod {
E::to_rpc()
}
}
impl<V: Send + Sync + 'static, E: EndpointToRpc> EndpointToRpc for crate::typed::Validated<V, E> {
fn to_rpc() -> RpcMethod {
E::to_rpc()
}
}
impl<Auth, E: typeway_grpc::GrpcReady> typeway_grpc::GrpcReady for crate::auth::Protected<Auth, E> {}
impl<V: Send + Sync + 'static, E: typeway_grpc::GrpcReady> typeway_grpc::GrpcReady
for crate::typed::Validated<V, E>
{
}
use crate::handler_for::BindableEndpoint;
impl<E: BindableEndpoint> BindableEndpoint for typeway_grpc::streaming::ServerStream<E> {
fn method() -> http::Method {
E::method()
}
fn pattern() -> String {
E::pattern()
}
fn match_fn() -> crate::router::MatchFn {
E::match_fn()
}
}
impl<E: BindableEndpoint> BindableEndpoint for typeway_grpc::streaming::ClientStream<E> {
fn method() -> http::Method {
E::method()
}
fn pattern() -> String {
E::pattern()
}
fn match_fn() -> crate::router::MatchFn {
E::match_fn()
}
}
impl<E: BindableEndpoint> BindableEndpoint for typeway_grpc::streaming::BidirectionalStream<E> {
fn method() -> http::Method {
E::method()
}
fn pattern() -> String {
E::pattern()
}
fn match_fn() -> crate::router::MatchFn {
E::match_fn()
}
}