use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::service::Service;
use hyper::{Request, Response, StatusCode};
use metrics_exporter_prometheus::PrometheusHandle;
pub type ReadyCheck =
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<(), String>> + Send>> + Send + Sync>;
#[derive(Clone)]
pub struct HealthRouter<S> {
pub inner: S,
pub ready_check: Option<ReadyCheck>,
pub metrics_handle: Option<PrometheusHandle>,
}
impl<S> HealthRouter<S> {
pub fn new(inner: S, ready_check: Option<ReadyCheck>) -> Self {
Self {
inner,
ready_check,
metrics_handle: None,
}
}
#[must_use]
pub fn with_metrics(mut self, handle: PrometheusHandle) -> Self {
self.metrics_handle = Some(handle);
self
}
}
type RespBody = s3s::Body;
fn make_text_response(status: StatusCode, body: &'static str) -> Response<RespBody> {
let bytes = Bytes::from_static(body.as_bytes());
Response::builder()
.status(status)
.header("content-type", "text/plain; charset=utf-8")
.header("content-length", bytes.len().to_string())
.body(s3s::Body::http_body(
Full::new(bytes).map_err(|never| match never {}),
))
.expect("static response")
}
fn make_owned_text_response(
status: StatusCode,
content_type: &'static str,
body: String,
) -> Response<RespBody> {
let bytes = Bytes::from(body.into_bytes());
Response::builder()
.status(status)
.header("content-type", content_type)
.header("content-length", bytes.len().to_string())
.body(s3s::Body::http_body(
Full::new(bytes).map_err(|never| match never {}),
))
.expect("owned response")
}
impl<S> Service<Request<Incoming>> for HealthRouter<S>
where
S: Service<Request<Incoming>, Response = Response<s3s::Body>, Error = s3s::HttpError>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<RespBody>;
type Error = s3s::HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let path = req.uri().path();
match (req.method(), path) {
(&hyper::Method::GET, "/health") | (&hyper::Method::HEAD, "/health") => {
Box::pin(async { Ok(make_text_response(StatusCode::OK, "ok\n")) })
}
(&hyper::Method::GET, "/metrics") | (&hyper::Method::HEAD, "/metrics") => {
let handle = self.metrics_handle.clone();
Box::pin(async move {
match handle {
Some(h) => {
let body = h.render();
Ok(make_owned_text_response(
StatusCode::OK,
"text/plain; version=0.0.4; charset=utf-8",
body,
))
}
None => Ok(make_text_response(
StatusCode::SERVICE_UNAVAILABLE,
"metrics not configured\n",
)),
}
})
}
(&hyper::Method::GET, "/ready") | (&hyper::Method::HEAD, "/ready") => {
let check = self.ready_check.clone();
Box::pin(async move {
match check {
Some(f) => match f().await {
Ok(()) => Ok(make_text_response(StatusCode::OK, "ready\n")),
Err(reason) => {
tracing::warn!(%reason, "readiness check failed");
Ok(make_text_response(
StatusCode::SERVICE_UNAVAILABLE,
"not ready\n",
))
}
},
None => Ok(make_text_response(StatusCode::OK, "ready (no check)\n")),
}
})
}
_ => {
let inner = self.inner.clone();
Box::pin(async move { inner.call(req).await })
}
}
}
}
trait FullExt<B> {
fn map_err<E, F: FnMut(Infallible) -> E>(
self,
f: F,
) -> http_body_util::combinators::MapErr<Self, F>
where
Self: Sized;
}
impl<B> FullExt<B> for Full<B>
where
B: bytes::Buf,
{
fn map_err<E, F: FnMut(Infallible) -> E>(
self,
f: F,
) -> http_body_util::combinators::MapErr<Self, F>
where
Self: Sized,
{
http_body_util::BodyExt::map_err(self, f)
}
}