prometheus-hyper 0.2.1

small Tokio/Hyper server to run Prometheus metrics
Documentation
//! # Example coding
//! ```
//! use prometheus::{IntCounter, Opts, Registry};
//! use prometheus_hyper::{RegistryFn, Server};
//! use std::{error::Error, net::SocketAddr, sync::Arc, time::Duration};
//! use tokio::sync::Notify;
//!
//! pub struct CustomMetrics {
//!     pub foo: IntCounter,
//! }
//!
//! impl CustomMetrics {
//!     pub fn new() -> Result<(Self, RegistryFn), Box<dyn Error>> {
//!         let foo = IntCounter::with_opts(Opts::new("foo", "description"))?;
//!         let foo_clone = foo.clone();
//!         let f = |r: &Registry| r.register(Box::new(foo_clone));
//!         Ok((Self { foo }, Box::new(f)))
//!     }
//! }
//!
//! #[tokio::main(flavor = "current_thread")]
//! async fn main() -> std::result::Result<(), std::io::Error> {
//!     let registry = Arc::new(Registry::new());
//!     let shutdown = Arc::new(Notify::new());
//!     let shutdown_clone = Arc::clone(&shutdown);
//!     let (metrics, f) = CustomMetrics::new().expect("failed prometheus");
//!     f(&registry).expect("problem registering");
//!
//!     // Startup Server
//!     let jh = tokio::spawn(async move {
//!         Server::run(
//!             Arc::clone(&registry),
//!             SocketAddr::from(([0; 4], 8080)),
//!             shutdown_clone.notified(),
//!         )
//!         .await
//!     });
//!
//!     // Change Metrics
//!     metrics.foo.inc();
//!
//!     // Shutdown
//!     tokio::time::sleep(Duration::from_secs(5)).await;
//!     shutdown.notify_one();
//!     jh.await.unwrap()
//! }
//! ```
use bytes::Bytes;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode, header, service::Service};
use hyper_util::rt::TokioIo;
use prometheus::{Encoder, Registry, TextEncoder};
use std::{convert::Infallible, future::Future, net::SocketAddr, ops::Deref, pin::Pin};
use tokio::net::TcpListener;
use tracing::{info, trace};

#[cfg(feature = "internal_metrics")]
use prometheus::{
    Histogram, IntCounter, IntGauge, register_histogram_with_registry, register_int_counter_with_registry,
    register_int_gauge_with_registry,
};

#[cfg(feature = "internal_metrics")]
use std::convert::TryInto;

/// Helper fn to register metrics
pub type RegistryFn = Box<dyn FnOnce(&Registry) -> Result<(), prometheus::Error>>;

/// Metrics Server based on [`tokio`] and [`hyper`]
///
/// [`tokio`]: tokio
/// [`hyper`]: hyper
pub struct Server {}

impl Server {
    /// Create and run the metrics Server
    ///
    /// # Arguments
    /// * `registry` - provide the [`Registry`] you are also registering your
    ///   metric types to.
    /// * `addr` - `host:ip` to tcp listen on.
    /// * `shutdown` - a [`Future`], once this completes the server will start
    ///   to shut down. You can use a [`signal`] or [`Notify`] for clean
    ///   shutdown or [`pending`] to newer shutdown.
    /// # Result
    /// * [`std::io::Error`] is thrown when listening on addr fails. All other
    ///   causes are handled internally, logged and ignored
    ///
    /// # Examples
    /// ```
    /// use prometheus::Registry;
    /// use prometheus_hyper::Server;
    /// use std::{net::SocketAddr, sync::Arc};
    /// # #[tokio::main(flavor = "current_thread")]
    /// # async fn main() {
    ///
    /// let registry = Arc::new(Registry::new());
    ///
    /// // Start Server endlessly
    /// tokio::spawn(async move {
    ///     Server::run(
    ///         Arc::clone(&registry),
    ///         SocketAddr::from(([0; 4], 8080)),
    ///         futures_util::future::pending(),
    ///     )
    ///     .await
    /// });
    /// # }
    /// ```
    /// [`Registry`]: prometheus::Registry
    /// [`Future`]: std::future::Future
    /// [`pending`]: https://docs.rs/futures-util/latest/futures_util/future/fn.pending.html
    /// [`hyper::Error`]: hyper::Error
    /// [`signal`]: tokio::signal
    /// [`Notify`]: tokio::sync::Notify
    /// [`tokio`]: tokio
    /// [`hyper`]: hyper
    pub async fn run<S, F, R>(registry: R, addr: S, shutdown: F) -> Result<(), std::io::Error>
    where
        S: Into<SocketAddr>,
        F: Future<Output = ()>,
        R: Deref<Target = Registry> + Clone + Send + 'static,
    {
        let addr = addr.into();

        #[cfg(feature = "internal_metrics")]
        let durations = register_histogram_with_registry!(
            "prometheus_exporter_request_duration_seconds",
            "HTTP request durations in seconds",
            registry
        )
        .unwrap();
        #[cfg(feature = "internal_metrics")]
        let requests = register_int_counter_with_registry!(
            "prometheus_exporter_requests_total",
            "HTTP requests received in metrics endpoint",
            registry
        )
        .unwrap();
        #[cfg(feature = "internal_metrics")]
        let sizes = register_int_gauge_with_registry!(
            "prometheus_exporter_response_size_bytes",
            "HTTP response sizes in bytes",
            registry
        )
        .unwrap();

        info!("starting hyper server to serve metrics");

        let service = MetricsService {
            registry: registry.clone(),
            #[cfg(feature = "internal_metrics")]
            durations: durations.clone(),
            #[cfg(feature = "internal_metrics")]
            requests: requests.clone(),
            #[cfg(feature = "internal_metrics")]
            sizes: sizes.clone(),
        };

        let listener = TcpListener::bind(addr).await?;
        let mut shutdown = core::pin::pin!(shutdown);
        while let Some(conn) = tokio::select! {
            _ = shutdown.as_mut() => None,
            conn = listener.accept() => Some(conn),
        } {
            match conn {
                Ok((tcp, _)) => {
                    let io = TokioIo::new(tcp);
                    let service_clone = service.clone();

                    tokio::task::spawn(async move {
                        use hyper::server::conn::http1;
                        let conn = http1::Builder::new().serve_connection(io, service_clone);

                        if let Err(e) = conn.await {
                            tracing::error!(?e, "error serving connection")
                        }
                    });
                },
                Err(e) => tracing::error!(?e, "error accepting new connection"),
            }
        }

        #[cfg(feature = "internal_metrics")]
        {
            if let Err(e) = registry.unregister(Box::new(durations)) {
                tracing::error!(?e, "could not unregister 'durations'");
            };
            if let Err(e) = registry.unregister(Box::new(requests)) {
                tracing::error!(?e, "could not unregister 'requests'");
            };
            if let Err(e) = registry.unregister(Box::new(sizes)) {
                tracing::error!(?e, "could not unregister 'sizes'");
            };
        }

        Ok(())
    }
}

#[cfg(feature = "internal_metrics")]
#[derive(Debug, Clone)]
struct MetricsService<R> {
    registry:  R,
    durations: Histogram,
    requests:  IntCounter,
    sizes:     IntGauge,
}

#[cfg(not(feature = "internal_metrics"))]
#[derive(Debug, Clone)]
struct MetricsService<R> {
    registry: R,
}

impl<R> Service<Request<hyper::body::Incoming>> for MetricsService<R>
where
    R: Deref<Target = Registry> + Clone + Send + 'static,
{
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
    type Response = Response<Full<Bytes>>;

    fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
        #[cfg(feature = "internal_metrics")]
        let timer = self.durations.start_timer();

        let (code, body) = if req.uri().path() == "/metrics" {
            #[cfg(feature = "internal_metrics")]
            self.requests.inc();

            trace!("request");

            let mf = self.registry.deref().gather();
            let mut buffer = vec![];

            let encoder = TextEncoder::new();
            encoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");

            #[cfg(feature = "internal_metrics")]
            if let Ok(size) = buffer.len().try_into() {
                self.sizes.set(size);
            }

            (StatusCode::OK, Full::new(Bytes::from(buffer)))
        } else {
            trace!("wrong uri, return 404");
            (StatusCode::NOT_FOUND, Full::new(Bytes::from("404 not found")))
        };

        let response = Response::builder()
            .status(code)
            .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
            .body(body)
            .unwrap();

        #[cfg(feature = "internal_metrics")]
        timer.observe_duration();

        Box::pin(async { Ok::<Response<http_body_util::Full<bytes::Bytes>>, Infallible>(response) })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use http_body_util::Empty;
    use hyper::Request;
    use std::{sync::Arc, time::Duration};
    use tokio::{net::TcpStream, sync::Notify};

    #[tokio::test]
    async fn test_create() {
        let shutdown = Arc::new(Notify::new());
        let registry = Arc::new(Registry::new());

        let shutdown_clone = Arc::clone(&shutdown);

        let r = tokio::spawn(async move {
            Server::run(
                Arc::clone(&registry),
                SocketAddr::from(([0; 4], 6001)),
                shutdown_clone.notified(),
            )
            .await
        });

        shutdown.notify_one();
        r.await.expect("tokio error").expect("prometheus_hyper server error");
    }

    #[tokio::test]
    async fn test_default() {
        let shutdown = Arc::new(Notify::new());
        let registry = prometheus::default_registry();

        let shutdown_clone = Arc::clone(&shutdown);

        let r = tokio::spawn(async move {
            Server::run(registry, SocketAddr::from(([0; 4], 6002)), shutdown_clone.notified()).await
        });

        shutdown.notify_one();
        r.await.expect("tokio error").expect("prometheus_hyper server error");
    }

    #[tokio::test]
    async fn test_sample() {
        let shutdown = Arc::new(Notify::new());
        let registry = Arc::new(Registry::new());

        let shutdown_clone = Arc::clone(&shutdown);

        let r = tokio::spawn(async move {
            Server::run(
                Arc::clone(&registry),
                SocketAddr::from(([0; 4], 6003)),
                shutdown_clone.notified(),
            )
            .await
        });

        tokio::time::sleep(Duration::from_millis(500)).await;

        let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6003))).await.unwrap();
        let io = TokioIo::new(stream);
        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
        tokio::task::spawn(async move {
            if let Err(err) = conn.await {
                println!("Connection failed: {:?}", err);
            }
        });

        let req = Request::builder()
            .method("GET")
            .uri("http://localhost:6003/metrics")
            .body(Empty::<Bytes>::new())
            .expect("request builder");

        let res = sender.send_request(req).await.expect("couldn't reach server");

        assert_eq!(res.status(), StatusCode::OK);

        shutdown.notify_one();
        r.await.expect("tokio error").expect("prometheus_hyper server error");
    }

    #[tokio::test]
    async fn test_wrong_endpoint_sample() {
        let shutdown = Arc::new(Notify::new());
        let registry = Arc::new(Registry::new());

        let shutdown_clone = Arc::clone(&shutdown);

        let r = tokio::spawn(async move {
            Server::run(
                Arc::clone(&registry),
                SocketAddr::from(([0; 4], 6004)),
                shutdown_clone.notified(),
            )
            .await
        });

        tokio::time::sleep(Duration::from_millis(500)).await;

        let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6004))).await.unwrap();
        let io = TokioIo::new(stream);
        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
        tokio::task::spawn(async move {
            if let Err(err) = conn.await {
                println!("Connection failed: {:?}", err);
            }
        });

        let req = Request::builder()
            .method("GET")
            .uri("http://localhost:6004/foobar")
            .body(Empty::<Bytes>::new())
            .expect("request builder");

        let res = sender.send_request(req).await.expect("couldn't reach server");
        assert_eq!(res.status(), StatusCode::NOT_FOUND);

        shutdown.notify_one();
        r.await.expect("tokio error").expect("prometheus_hyper server error");
    }
}