use crate::pools::{Connection, Pool};
use futures_util::{pin_mut, stream, StreamExt};
use std::{hash::Hash, sync::Arc, time::Duration};
use tokio::sync::mpsc::error::SendError;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{Request, Response, Status};
use tonic_health::proto::{
health_check_response::ServingStatus,
health_server::{Health as GrpcService, HealthServer},
};
pub use tonic_health::proto::{HealthCheckRequest, HealthCheckResponse};
pub struct Health<P>
where
P: Pool,
P::Key: Hash + Eq + Default + Clone,
{
pool: Arc<P>,
#[cfg(feature = "transaction")]
transactions: crate::pools::transaction::Pool<P>,
}
impl<P> Clone for Health<P>
where
P: Pool,
P::Key: Hash + Eq + Default + Clone,
{
fn clone(&self) -> Self {
Self {
pool: Arc::clone(&self.pool),
#[cfg(feature = "transaction")]
transactions: self.transactions.clone(),
}
}
}
impl<P> Health<P>
where
P: Pool + 'static,
P::Key: Hash + Eq + Default + Clone + Send + Sync,
<P::Connection as Connection>::Error: Send + Sync,
{
#[tracing::instrument(skip(pool))]
fn new(pool: Arc<P>) -> Self {
Self {
#[cfg(feature = "transaction")]
transactions: crate::pools::transaction::Pool::new(Arc::clone(&pool)),
pool,
}
}
async fn check_postgres_service(&self, key: P::Key) -> Result<(), Status> {
let connection = self
.pool
.get_connection(key)
.await
.map_err(|error| Status::unavailable(error.to_string()))?;
connection
.query("SELECT 1", &[])
.await
.map_err(|error| Status::unavailable(error.to_string()))?;
Ok(())
}
#[cfg(feature = "transaction")]
async fn check_transaction_service(&self, key: P::Key) -> Result<(), Status> {
let id = self
.transactions
.begin(key.clone())
.await
.map_err(|error| Status::unavailable(error.to_string()))?;
let transaction_key = crate::pools::transaction::Key::new(key.clone(), id);
let transaction = self
.transactions
.get_connection(transaction_key)
.await
.map_err(|error| Status::unavailable(error.to_string()))?;
transaction
.query("SELECT 1", &[])
.await
.map_err(|error| Status::unavailable(error.to_string()))?;
self.transactions
.rollback(id, key)
.await
.map_err(|error| Status::unavailable(error.to_string()))?;
Ok(())
}
}
#[tonic::async_trait]
impl<P> GrpcService for Health<P>
where
P: Pool + 'static,
P::Key: Hash + Eq + Default + Clone + Send + Sync,
<P::Connection as Connection>::Error: Send + Sync,
{
#[tracing::instrument(
skip(self, request),
fields(service = request.get_ref().service),
err
)]
async fn check(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
tracing::debug!("Performing health check");
let key = P::Key::default();
match request.into_inner().service.to_lowercase().as_str() {
"" => {
#[cfg(feature = "transaction")]
self.check_transaction_service(key.clone()).await?;
self.check_postgres_service(key).await?;
}
"postgres" => self.check_postgres_service(key).await?,
#[cfg(feature = "transaction")]
"transaction" => self.check_transaction_service(key).await?,
service => {
return Err(Status::not_found(format!(
"Service '{}' does not exist",
service
)))
}
};
Ok(Response::new(HealthCheckResponse {
status: ServingStatus::Serving.into(),
}))
}
type WatchStream = UnboundedReceiverStream<Result<HealthCheckResponse, Status>>;
#[tracing::instrument(
skip(self, request),
fields(service = request.get_ref().service),
err
)]
async fn watch(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<Self::WatchStream>, Status> {
tracing::debug!("Streaming health checks");
let health_service = self.clone();
let request = request.into_inner();
let count = 1;
let watch_stream = stream::unfold(count, move |count| {
let health_service = health_service.clone();
let request = Request::new(request.clone());
async move {
let response = health_service
.check(request)
.await
.map(|response| response.into_inner());
Some((response, count + 1))
}
});
let (transmitter, receiver) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
pin_mut!(watch_stream);
while let Some(response) = watch_stream.next().await {
transmitter.send(response)?;
tokio::time::sleep(Duration::from_secs(1)).await;
}
Ok::<_, SendError<_>>(())
});
Ok(Response::new(UnboundedReceiverStream::new(receiver)))
}
}
pub fn new<P>(pool: Arc<P>) -> HealthServer<Health<P>>
where
P: Pool + 'static,
P::Key: Hash + Eq + Default + Clone,
{
HealthServer::new(Health::new(pool))
}