prometheus_hyper/
lib.rs

1//! # Example coding
2//! ```
3//! use prometheus::{IntCounter, Opts, Registry};
4//! use prometheus_hyper::{RegistryFn, Server};
5//! use std::{error::Error, net::SocketAddr, sync::Arc, time::Duration};
6//! use tokio::sync::Notify;
7//!
8//! pub struct CustomMetrics {
9//!     pub foo: IntCounter,
10//! }
11//!
12//! impl CustomMetrics {
13//!     pub fn new() -> Result<(Self, RegistryFn), Box<dyn Error>> {
14//!         let foo = IntCounter::with_opts(Opts::new("foo", "description"))?;
15//!         let foo_clone = foo.clone();
16//!         let f = |r: &Registry| r.register(Box::new(foo_clone));
17//!         Ok((Self { foo }, Box::new(f)))
18//!     }
19//! }
20//!
21//! #[tokio::main(flavor = "current_thread")]
22//! async fn main() -> std::result::Result<(), std::io::Error> {
23//!     let registry = Arc::new(Registry::new());
24//!     let shutdown = Arc::new(Notify::new());
25//!     let shutdown_clone = Arc::clone(&shutdown);
26//!     let (metrics, f) = CustomMetrics::new().expect("failed prometheus");
27//!     f(&registry).expect("problem registering");
28//!
29//!     // Startup Server
30//!     let jh = tokio::spawn(async move {
31//!         Server::run(
32//!             Arc::clone(&registry),
33//!             SocketAddr::from(([0; 4], 8080)),
34//!             shutdown_clone.notified(),
35//!         )
36//!         .await
37//!     });
38//!
39//!     // Change Metrics
40//!     metrics.foo.inc();
41//!
42//!     // Shutdown
43//!     tokio::time::sleep(Duration::from_secs(5)).await;
44//!     shutdown.notify_one();
45//!     jh.await.unwrap()
46//! }
47//! ```
48use bytes::Bytes;
49use http_body_util::Full;
50use hyper::{header, service::Service, Request, Response, StatusCode};
51use hyper_util::rt::TokioIo;
52use prometheus::{Encoder, Registry, TextEncoder};
53use std::{convert::Infallible, future::Future, net::SocketAddr, ops::Deref, pin::Pin};
54use tokio::net::TcpListener;
55use tracing::{info, trace};
56
57#[cfg(feature = "internal_metrics")]
58use prometheus::{
59    register_histogram_with_registry, register_int_counter_with_registry, register_int_gauge_with_registry, Histogram,
60    IntCounter, IntGauge,
61};
62
63#[cfg(feature = "internal_metrics")]
64use std::convert::TryInto;
65
66/// Helper fn to register metrics
67pub type RegistryFn = Box<dyn FnOnce(&Registry) -> Result<(), prometheus::Error>>;
68
69/// Metrics Server based on [`tokio`] and [`hyper`]
70///
71/// [`tokio`]: tokio
72/// [`hyper`]: hyper
73pub struct Server {}
74
75impl Server {
76    /// Create and run the metrics Server
77    ///
78    /// # Arguments
79    /// * `registry` - provide the [`Registry`] you are also registering your
80    ///   metric types to.
81    /// * `addr` - `host:ip` to tcp listen on.
82    /// * `shutdown` - a [`Future`], once this completes the server will start
83    ///   to shut down. You can use a [`signal`] or [`Notify`] for clean
84    ///   shutdown or [`pending`] to newer shutdown.
85    /// # Result
86    /// * [`std::io::Error`] is thrown when listening on addr fails. All other
87    ///   causes are handled internally, logged and ignored
88    ///
89    /// # Examples
90    /// ```
91    /// use prometheus::Registry;
92    /// use prometheus_hyper::Server;
93    /// use std::{net::SocketAddr, sync::Arc};
94    /// # #[tokio::main(flavor = "current_thread")]
95    /// # async fn main() {
96    ///
97    /// let registry = Arc::new(Registry::new());
98    ///
99    /// // Start Server endlessly
100    /// tokio::spawn(async move {
101    ///     Server::run(
102    ///         Arc::clone(&registry),
103    ///         SocketAddr::from(([0; 4], 8080)),
104    ///         futures_util::future::pending(),
105    ///     )
106    ///     .await
107    /// });
108    /// # }
109    /// ```
110    /// [`Registry`]: prometheus::Registry
111    /// [`Future`]: std::future::Future
112    /// [`pending`]: https://docs.rs/futures-util/latest/futures_util/future/fn.pending.html
113    /// [`hyper::Error`]: hyper::Error
114    /// [`signal`]: tokio::signal
115    /// [`Notify`]: tokio::sync::Notify
116    /// [`tokio`]: tokio
117    /// [`hyper`]: hyper
118    pub async fn run<S, F, R>(registry: R, addr: S, shutdown: F) -> Result<(), std::io::Error>
119    where
120        S: Into<SocketAddr>,
121        F: Future<Output = ()>,
122        R: Deref<Target = Registry> + Clone + Send + 'static,
123    {
124        let addr = addr.into();
125
126        #[cfg(feature = "internal_metrics")]
127        let durations = register_histogram_with_registry!(
128            "prometheus_exporter_request_duration_seconds",
129            "HTTP request durations in seconds",
130            registry
131        )
132        .unwrap();
133        #[cfg(feature = "internal_metrics")]
134        let requests = register_int_counter_with_registry!(
135            "prometheus_exporter_requests_total",
136            "HTTP requests received in metrics endpoint",
137            registry
138        )
139        .unwrap();
140        #[cfg(feature = "internal_metrics")]
141        let sizes = register_int_gauge_with_registry!(
142            "prometheus_exporter_response_size_bytes",
143            "HTTP response sizes in bytes",
144            registry
145        )
146        .unwrap();
147
148        info!("starting hyper server to serve metrics");
149
150        let service = MetricsService {
151            registry: registry.clone(),
152            #[cfg(feature = "internal_metrics")]
153            durations: durations.clone(),
154            #[cfg(feature = "internal_metrics")]
155            requests: requests.clone(),
156            #[cfg(feature = "internal_metrics")]
157            sizes: sizes.clone(),
158        };
159
160        let listener = TcpListener::bind(addr).await?;
161        let mut shutdown = core::pin::pin!(shutdown);
162        while let Some(conn) = tokio::select! {
163            _ = shutdown.as_mut() => None,
164            conn = listener.accept() => Some(conn),
165        } {
166            match conn {
167                Ok((tcp, _)) => {
168                    let io = TokioIo::new(tcp);
169                    let service_clone = service.clone();
170
171                    tokio::task::spawn(async move {
172                        use hyper::server::conn::http1;
173                        let conn = http1::Builder::new().serve_connection(io, service_clone);
174
175                        if let Err(e) = conn.await {
176                            tracing::error!(?e, "error serving connection")
177                        }
178                    });
179                },
180                Err(e) => tracing::error!(?e, "error accepting new connection"),
181            }
182        }
183
184        #[cfg(feature = "internal_metrics")]
185        {
186            if let Err(e) = registry.unregister(Box::new(durations)) {
187                tracing::error!(?e, "could not unregister 'durations'");
188            };
189            if let Err(e) = registry.unregister(Box::new(requests)) {
190                tracing::error!(?e, "could not unregister 'requests'");
191            };
192            if let Err(e) = registry.unregister(Box::new(sizes)) {
193                tracing::error!(?e, "could not unregister 'sizes'");
194            };
195        }
196
197        Ok(())
198    }
199}
200
201#[cfg(feature = "internal_metrics")]
202#[derive(Debug, Clone)]
203struct MetricsService<R> {
204    registry:  R,
205    durations: Histogram,
206    requests:  IntCounter,
207    sizes:     IntGauge,
208}
209
210#[cfg(not(feature = "internal_metrics"))]
211#[derive(Debug, Clone)]
212struct MetricsService<R> {
213    registry: R,
214}
215
216impl<R> Service<Request<hyper::body::Incoming>> for MetricsService<R>
217where
218    R: Deref<Target = Registry> + Clone + Send + 'static,
219{
220    type Error = Infallible;
221    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
222    type Response = Response<Full<Bytes>>;
223
224    fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
225        #[cfg(feature = "internal_metrics")]
226        let timer = self.durations.start_timer();
227
228        let (code, body) = if req.uri().path() == "/metrics" {
229            #[cfg(feature = "internal_metrics")]
230            self.requests.inc();
231
232            trace!("request");
233
234            let mf = self.registry.deref().gather();
235            let mut buffer = vec![];
236
237            let encoder = TextEncoder::new();
238            encoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");
239
240            #[cfg(feature = "internal_metrics")]
241            if let Ok(size) = buffer.len().try_into() {
242                self.sizes.set(size);
243            }
244
245            (StatusCode::OK, Full::new(Bytes::from(buffer)))
246        } else {
247            trace!("wrong uri, return 404");
248            (StatusCode::NOT_FOUND, Full::new(Bytes::from("404 not found")))
249        };
250
251        let response = Response::builder()
252            .status(code)
253            .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
254            .body(body)
255            .unwrap();
256
257        #[cfg(feature = "internal_metrics")]
258        timer.observe_duration();
259
260        Box::pin(async { Ok::<Response<http_body_util::Full<bytes::Bytes>>, Infallible>(response) })
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use http_body_util::Empty;
268    use hyper::Request;
269    use std::{sync::Arc, time::Duration};
270    use tokio::{net::TcpStream, sync::Notify};
271
272    #[tokio::test]
273    async fn test_create() {
274        let shutdown = Arc::new(Notify::new());
275        let registry = Arc::new(Registry::new());
276
277        let shutdown_clone = Arc::clone(&shutdown);
278
279        let r = tokio::spawn(async move {
280            Server::run(
281                Arc::clone(&registry),
282                SocketAddr::from(([0; 4], 6001)),
283                shutdown_clone.notified(),
284            )
285            .await
286        });
287
288        shutdown.notify_one();
289        r.await.expect("tokio error").expect("prometheus_hyper server error");
290    }
291
292    #[tokio::test]
293    async fn test_default() {
294        let shutdown = Arc::new(Notify::new());
295        let registry = prometheus::default_registry();
296
297        let shutdown_clone = Arc::clone(&shutdown);
298
299        let r = tokio::spawn(async move {
300            Server::run(registry, SocketAddr::from(([0; 4], 6002)), shutdown_clone.notified()).await
301        });
302
303        shutdown.notify_one();
304        r.await.expect("tokio error").expect("prometheus_hyper server error");
305    }
306
307    #[tokio::test]
308    async fn test_sample() {
309        let shutdown = Arc::new(Notify::new());
310        let registry = Arc::new(Registry::new());
311
312        let shutdown_clone = Arc::clone(&shutdown);
313
314        let r = tokio::spawn(async move {
315            Server::run(
316                Arc::clone(&registry),
317                SocketAddr::from(([0; 4], 6003)),
318                shutdown_clone.notified(),
319            )
320            .await
321        });
322
323        tokio::time::sleep(Duration::from_millis(500)).await;
324
325        let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6003))).await.unwrap();
326        let io = TokioIo::new(stream);
327        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
328        tokio::task::spawn(async move {
329            if let Err(err) = conn.await {
330                println!("Connection failed: {:?}", err);
331            }
332        });
333
334        let req = Request::builder()
335            .method("GET")
336            .uri("http://localhost:6003/metrics")
337            .body(Empty::<Bytes>::new())
338            .expect("request builder");
339
340        let res = sender.send_request(req).await.expect("couldn't reach server");
341
342        assert_eq!(res.status(), StatusCode::OK);
343
344        shutdown.notify_one();
345        r.await.expect("tokio error").expect("prometheus_hyper server error");
346    }
347
348    #[tokio::test]
349    async fn test_wrong_endpoint_sample() {
350        let shutdown = Arc::new(Notify::new());
351        let registry = Arc::new(Registry::new());
352
353        let shutdown_clone = Arc::clone(&shutdown);
354
355        let r = tokio::spawn(async move {
356            Server::run(
357                Arc::clone(&registry),
358                SocketAddr::from(([0; 4], 6004)),
359                shutdown_clone.notified(),
360            )
361            .await
362        });
363
364        tokio::time::sleep(Duration::from_millis(500)).await;
365
366        let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6004))).await.unwrap();
367        let io = TokioIo::new(stream);
368        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
369        tokio::task::spawn(async move {
370            if let Err(err) = conn.await {
371                println!("Connection failed: {:?}", err);
372            }
373        });
374
375        let req = Request::builder()
376            .method("GET")
377            .uri("http://localhost:6004/foobar")
378            .body(Empty::<Bytes>::new())
379            .expect("request builder");
380
381        let res = sender.send_request(req).await.expect("couldn't reach server");
382        assert_eq!(res.status(), StatusCode::NOT_FOUND);
383
384        shutdown.notify_one();
385        r.await.expect("tokio error").expect("prometheus_hyper server error");
386    }
387}