use std::{collections::HashMap, sync::Mutex};
use futures_util::StreamExt;
use poem::{endpoint::BoxEndpoint, IntoEndpoint};
use tokio::sync::watch::{Receiver, Sender};
use crate::{Code, Request, Response, Service, Status, Streaming};
#[allow(private_in_public, unreachable_pub)]
#[allow(clippy::derive_partial_eq_without_eq)]
mod proto {
include!(concat!(env!("OUT_DIR"), "/grpc.health.v1.rs"));
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServingStatus {
Serving,
NotServing,
}
impl ServingStatus {
fn to_proto(self) -> proto::health_check_response::ServingStatus {
use proto::health_check_response::ServingStatus::*;
match self {
ServingStatus::Serving => Serving,
ServingStatus::NotServing => NotServing,
}
}
}
type ServiceStatusMap = HashMap<String, ServingStatus>;
struct HealthService {
receiver: Receiver<ServiceStatusMap>,
}
pub struct HealthReporter {
state: Mutex<(ServiceStatusMap, Sender<ServiceStatusMap>)>,
}
impl HealthReporter {
fn set_status<S: Service>(&self, status: ServingStatus) {
let mut state = self.state.lock().unwrap();
state.0.insert(S::NAME.to_string(), status);
let _ = state.1.send(state.0.clone());
}
pub fn set_serving<S: Service>(&self) {
self.set_status::<S>(ServingStatus::Serving);
}
pub fn set_not_serving<S: Service>(&self) {
self.set_status::<S>(ServingStatus::NotServing);
}
pub fn clear_service_status<S: Service>(&self) {
let mut state = self.state.lock().unwrap();
state.0.remove(S::NAME);
let _ = state.1.send(state.0.clone());
}
}
#[poem::async_trait]
impl proto::Health for HealthService {
async fn check(
&self,
request: Request<proto::HealthCheckRequest>,
) -> Result<Response<proto::HealthCheckResponse>, Status> {
let service_status = self.receiver.borrow();
match service_status.get(&request.service) {
Some(status) => Ok(Response::new(proto::HealthCheckResponse {
status: status.to_proto().into(),
})),
None => Err(Status::new(Code::NotFound)
.with_message(format!("service `{}` not found", request.service))),
}
}
async fn watch(
&self,
request: Request<proto::HealthCheckRequest>,
) -> Result<Response<Streaming<proto::HealthCheckResponse>>, Status> {
let mut stream = tokio_stream::wrappers::WatchStream::new(self.receiver.clone());
let service_name = request.into_inner().service;
Ok(Response::new(Streaming::new(async_stream::try_stream! {
while let Some(service_status) = stream.next().await {
let res = service_status.get(&service_name);
let status = res.ok_or_else(|| Status::new(Code::NotFound).with_message(format!("service `{}` not found", service_name)))?
.to_proto()
.into();
yield proto::HealthCheckResponse { status };
}
})))
}
}
pub fn health_service() -> (
impl IntoEndpoint<Endpoint = BoxEndpoint<'static, poem::Response>> + Service,
HealthReporter,
) {
let (sender, receiver) = tokio::sync::watch::channel(Default::default());
(
proto::HealthServer::new(HealthService { receiver }),
HealthReporter {
state: Mutex::new((Default::default(), sender)),
},
)
}
#[cfg(test)]
mod tests {
use futures_util::StreamExt;
use super::*;
use crate::health::proto::Health;
fn create_service() -> (HealthService, HealthReporter) {
let (sender, receiver) = tokio::sync::watch::channel(Default::default());
(
HealthService { receiver },
HealthReporter {
state: Mutex::new((Default::default(), sender)),
},
)
}
#[tokio::test]
async fn check() {
let (service, reporter) = create_service();
let res = service
.check(Request::new(proto::HealthCheckRequest {
service: <proto::HealthServer<HealthService>>::NAME.to_string(),
}))
.await;
assert_eq!(res.unwrap_err().code(), Code::NotFound);
reporter.set_serving::<proto::HealthServer<HealthService>>();
let res = service
.check(Request::new(proto::HealthCheckRequest {
service: <proto::HealthServer<HealthService>>::NAME.to_string(),
}))
.await;
assert_eq!(
res.unwrap().into_inner(),
proto::HealthCheckResponse {
status: proto::health_check_response::ServingStatus::Serving.into()
}
);
reporter.set_not_serving::<proto::HealthServer<HealthService>>();
let res = service
.check(Request::new(proto::HealthCheckRequest {
service: <proto::HealthServer<HealthService>>::NAME.to_string(),
}))
.await;
assert_eq!(
res.unwrap().into_inner(),
proto::HealthCheckResponse {
status: proto::health_check_response::ServingStatus::NotServing.into()
}
);
reporter.clear_service_status::<proto::HealthServer<HealthService>>();
let res = service
.check(Request::new(proto::HealthCheckRequest {
service: <proto::HealthServer<HealthService>>::NAME.to_string(),
}))
.await;
assert_eq!(res.unwrap_err().code(), Code::NotFound);
}
#[tokio::test]
async fn watch() {
let (service, reporter) = create_service();
let mut stream = service
.watch(Request::new(proto::HealthCheckRequest {
service: <proto::HealthServer<HealthService>>::NAME.to_string(),
}))
.await
.unwrap();
assert_eq!(
stream.next().await.unwrap().unwrap_err().code(),
Code::NotFound
);
reporter.set_serving::<proto::HealthServer<HealthService>>();
let mut stream = service
.watch(Request::new(proto::HealthCheckRequest {
service: <proto::HealthServer<HealthService>>::NAME.to_string(),
}))
.await
.unwrap();
assert_eq!(
stream.next().await.unwrap().unwrap(),
proto::HealthCheckResponse {
status: proto::health_check_response::ServingStatus::Serving.into()
}
);
reporter.set_not_serving::<proto::HealthServer<HealthService>>();
assert_eq!(
stream.next().await.unwrap().unwrap(),
proto::HealthCheckResponse {
status: proto::health_check_response::ServingStatus::NotServing.into()
}
);
reporter.clear_service_status::<proto::HealthServer<HealthService>>();
assert_eq!(
stream.next().await.unwrap().unwrap_err().code(),
Code::NotFound
);
}
}