use std::{convert::Infallible, future::Future, net::SocketAddr};
use tonic::server::NamedService;
use crate::core::{CoreError, Service, ServiceFuture, ShutdownToken};
pub async fn serve_health_with_shutdown<F>(
addr: SocketAddr,
shutdown: F,
) -> Result<(), tonic::transport::Error>
where
F: Future<Output = ()> + Send + 'static,
{
let (_reporter, health_service) = tonic_health::server::health_reporter();
tonic::transport::Server::builder()
.add_service(health_service)
.serve_with_shutdown(addr, shutdown)
.await
}
pub struct TonicService<S> {
name: String,
addr: SocketAddr,
service: std::sync::Mutex<Option<S>>,
}
impl<S> TonicService<S> {
pub fn new(name: impl Into<String>, addr: SocketAddr, service: S) -> Self {
Self {
name: name.into(),
addr,
service: std::sync::Mutex::new(Some(service)),
}
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
}
impl<S> Service for TonicService<S>
where
S: tower::Service<
http::Request<tonic::body::Body>,
Response = http::Response<tonic::body::Body>,
Error = Infallible,
> + NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
fn name(&self) -> &str {
&self.name
}
fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
Box::pin(async move {
let service = self
.service
.lock()
.expect("tonic service mutex")
.take()
.ok_or_else(|| {
CoreError::Service(format!("service {} already started", self.name))
})?;
tonic::transport::Server::builder()
.add_service(service)
.serve_with_shutdown(self.addr, async move {
shutdown.cancelled().await;
})
.await
.map_err(|error| {
CoreError::Service(format!("RPC service {} failed: {error}", self.name))
})
})
}
}
pub struct TonicHealthService {
name: String,
addr: SocketAddr,
}
impl TonicHealthService {
pub fn new(name: impl Into<String>, addr: SocketAddr) -> Self {
Self {
name: name.into(),
addr,
}
}
}
impl Service for TonicHealthService {
fn name(&self) -> &str {
&self.name
}
fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
Box::pin(async move {
serve_health_with_shutdown(self.addr, async move {
shutdown.cancelled().await;
})
.await
.map_err(|error| {
CoreError::Service(format!("RPC service {} failed: {error}", self.name))
})
})
}
}
#[cfg(feature = "resil")]
#[derive(Debug, Clone)]
pub struct RpcServerLayerStack {
config: crate::rpc::RpcServerConfig,
#[cfg(feature = "observability")]
metrics: Option<crate::observability::MetricsRegistry>,
}
#[cfg(feature = "resil")]
impl RpcServerLayerStack {
pub fn new(config: crate::rpc::RpcServerConfig) -> Self {
Self {
config,
#[cfg(feature = "observability")]
metrics: None,
}
}
pub fn production_defaults(name: impl Into<String>, addr: SocketAddr) -> Self {
Self::new(crate::rpc::RpcServerConfig::production_defaults(name, addr))
}
#[allow(deprecated)]
#[deprecated(note = "use production_defaults instead")]
pub fn go_zero_defaults(name: impl Into<String>, addr: SocketAddr) -> Self {
Self::production_defaults(name, addr)
}
#[cfg(feature = "observability")]
pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
self.metrics = Some(metrics);
self
}
pub fn into_layer(self) -> crate::rpc::RpcUnaryResilienceLayer {
#[cfg(feature = "observability")]
{
let mut resilience = crate::rpc::RpcResilienceLayer::new(
self.config.name.clone(),
self.config.resilience.clone(),
);
if let Some(metrics) = self.metrics {
resilience = resilience.with_metrics(metrics);
}
crate::rpc::RpcUnaryResilienceLayer::new(resilience)
}
#[cfg(not(feature = "observability"))]
{
crate::rpc::RpcUnaryResilienceLayer::new(crate::rpc::RpcResilienceLayer::new(
self.config.name.clone(),
self.config.resilience.clone(),
))
}
}
}