use std::{convert::Infallible, future::Future, net::SocketAddr};
use tonic::server::NamedService;
use crate::core::{CoreError, Service, ServiceFuture, ShutdownToken};
pub async fn serve_health_with_shutdown<F>(
addr: SocketAddr,
shutdown: F,
) -> Result<(), tonic::transport::Error>
where
F: Future<Output = ()> + Send + 'static,
{
let (_reporter, health_service) = tonic_health::server::health_reporter();
tonic::transport::Server::builder()
.add_service(health_service)
.serve_with_shutdown(addr, shutdown)
.await
}
pub struct TonicService<S> {
name: String,
addr: SocketAddr,
service: std::sync::Mutex<Option<S>>,
}
impl<S> TonicService<S> {
pub fn new(name: impl Into<String>, addr: SocketAddr, service: S) -> Self {
Self {
name: name.into(),
addr,
service: std::sync::Mutex::new(Some(service)),
}
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
}
impl<S> Service for TonicService<S>
where
S: tower::Service<
http::Request<tonic::body::Body>,
Response = http::Response<tonic::body::Body>,
Error = Infallible,
> + NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
fn name(&self) -> &str {
&self.name
}
fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
Box::pin(async move {
let service = self
.service
.lock()
.expect("tonic service mutex")
.take()
.ok_or_else(|| {
CoreError::Service(format!("service {} already started", self.name))
})?;
tonic::transport::Server::builder()
.add_service(service)
.serve_with_shutdown(self.addr, async move {
shutdown.cancelled().await;
})
.await
.map_err(|error| {
CoreError::Service(format!("RPC service {} failed: {error}", self.name))
})
})
}
}
pub struct TonicHealthService {
name: String,
addr: SocketAddr,
}
impl TonicHealthService {
pub fn new(name: impl Into<String>, addr: SocketAddr) -> Self {
Self {
name: name.into(),
addr,
}
}
}
impl Service for TonicHealthService {
fn name(&self) -> &str {
&self.name
}
fn start(&self, shutdown: ShutdownToken) -> ServiceFuture<'_> {
Box::pin(async move {
serve_health_with_shutdown(self.addr, async move {
shutdown.cancelled().await;
})
.await
.map_err(|error| {
CoreError::Service(format!("RPC service {} failed: {error}", self.name))
})
})
}
}