1use bytes::Bytes;
49use http_body_util::Full;
50use hyper::{header, service::Service, Request, Response, StatusCode};
51use hyper_util::rt::TokioIo;
52use prometheus::{Encoder, Registry, TextEncoder};
53use std::{convert::Infallible, future::Future, net::SocketAddr, ops::Deref, pin::Pin};
54use tokio::net::TcpListener;
55use tracing::{info, trace};
56
57#[cfg(feature = "internal_metrics")]
58use prometheus::{
59 register_histogram_with_registry, register_int_counter_with_registry, register_int_gauge_with_registry, Histogram,
60 IntCounter, IntGauge,
61};
62
63#[cfg(feature = "internal_metrics")]
64use std::convert::TryInto;
65
66pub type RegistryFn = Box<dyn FnOnce(&Registry) -> Result<(), prometheus::Error>>;
68
69pub struct Server {}
74
75impl Server {
76 pub async fn run<S, F, R>(registry: R, addr: S, shutdown: F) -> Result<(), std::io::Error>
119 where
120 S: Into<SocketAddr>,
121 F: Future<Output = ()>,
122 R: Deref<Target = Registry> + Clone + Send + 'static,
123 {
124 let addr = addr.into();
125
126 #[cfg(feature = "internal_metrics")]
127 let durations = register_histogram_with_registry!(
128 "prometheus_exporter_request_duration_seconds",
129 "HTTP request durations in seconds",
130 registry
131 )
132 .unwrap();
133 #[cfg(feature = "internal_metrics")]
134 let requests = register_int_counter_with_registry!(
135 "prometheus_exporter_requests_total",
136 "HTTP requests received in metrics endpoint",
137 registry
138 )
139 .unwrap();
140 #[cfg(feature = "internal_metrics")]
141 let sizes = register_int_gauge_with_registry!(
142 "prometheus_exporter_response_size_bytes",
143 "HTTP response sizes in bytes",
144 registry
145 )
146 .unwrap();
147
148 info!("starting hyper server to serve metrics");
149
150 let service = MetricsService {
151 registry: registry.clone(),
152 #[cfg(feature = "internal_metrics")]
153 durations: durations.clone(),
154 #[cfg(feature = "internal_metrics")]
155 requests: requests.clone(),
156 #[cfg(feature = "internal_metrics")]
157 sizes: sizes.clone(),
158 };
159
160 let listener = TcpListener::bind(addr).await?;
161 let mut shutdown = core::pin::pin!(shutdown);
162 while let Some(conn) = tokio::select! {
163 _ = shutdown.as_mut() => None,
164 conn = listener.accept() => Some(conn),
165 } {
166 match conn {
167 Ok((tcp, _)) => {
168 let io = TokioIo::new(tcp);
169 let service_clone = service.clone();
170
171 tokio::task::spawn(async move {
172 use hyper::server::conn::http1;
173 let conn = http1::Builder::new().serve_connection(io, service_clone);
174
175 if let Err(e) = conn.await {
176 tracing::error!(?e, "error serving connection")
177 }
178 });
179 },
180 Err(e) => tracing::error!(?e, "error accepting new connection"),
181 }
182 }
183
184 #[cfg(feature = "internal_metrics")]
185 {
186 if let Err(e) = registry.unregister(Box::new(durations)) {
187 tracing::error!(?e, "could not unregister 'durations'");
188 };
189 if let Err(e) = registry.unregister(Box::new(requests)) {
190 tracing::error!(?e, "could not unregister 'requests'");
191 };
192 if let Err(e) = registry.unregister(Box::new(sizes)) {
193 tracing::error!(?e, "could not unregister 'sizes'");
194 };
195 }
196
197 Ok(())
198 }
199}
200
201#[cfg(feature = "internal_metrics")]
202#[derive(Debug, Clone)]
203struct MetricsService<R> {
204 registry: R,
205 durations: Histogram,
206 requests: IntCounter,
207 sizes: IntGauge,
208}
209
210#[cfg(not(feature = "internal_metrics"))]
211#[derive(Debug, Clone)]
212struct MetricsService<R> {
213 registry: R,
214}
215
216impl<R> Service<Request<hyper::body::Incoming>> for MetricsService<R>
217where
218 R: Deref<Target = Registry> + Clone + Send + 'static,
219{
220 type Error = Infallible;
221 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
222 type Response = Response<Full<Bytes>>;
223
224 fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
225 #[cfg(feature = "internal_metrics")]
226 let timer = self.durations.start_timer();
227
228 let (code, body) = if req.uri().path() == "/metrics" {
229 #[cfg(feature = "internal_metrics")]
230 self.requests.inc();
231
232 trace!("request");
233
234 let mf = self.registry.deref().gather();
235 let mut buffer = vec![];
236
237 let encoder = TextEncoder::new();
238 encoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");
239
240 #[cfg(feature = "internal_metrics")]
241 if let Ok(size) = buffer.len().try_into() {
242 self.sizes.set(size);
243 }
244
245 (StatusCode::OK, Full::new(Bytes::from(buffer)))
246 } else {
247 trace!("wrong uri, return 404");
248 (StatusCode::NOT_FOUND, Full::new(Bytes::from("404 not found")))
249 };
250
251 let response = Response::builder()
252 .status(code)
253 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
254 .body(body)
255 .unwrap();
256
257 #[cfg(feature = "internal_metrics")]
258 timer.observe_duration();
259
260 Box::pin(async { Ok::<Response<http_body_util::Full<bytes::Bytes>>, Infallible>(response) })
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use http_body_util::Empty;
268 use hyper::Request;
269 use std::{sync::Arc, time::Duration};
270 use tokio::{net::TcpStream, sync::Notify};
271
272 #[tokio::test]
273 async fn test_create() {
274 let shutdown = Arc::new(Notify::new());
275 let registry = Arc::new(Registry::new());
276
277 let shutdown_clone = Arc::clone(&shutdown);
278
279 let r = tokio::spawn(async move {
280 Server::run(
281 Arc::clone(®istry),
282 SocketAddr::from(([0; 4], 6001)),
283 shutdown_clone.notified(),
284 )
285 .await
286 });
287
288 shutdown.notify_one();
289 r.await.expect("tokio error").expect("prometheus_hyper server error");
290 }
291
292 #[tokio::test]
293 async fn test_default() {
294 let shutdown = Arc::new(Notify::new());
295 let registry = prometheus::default_registry();
296
297 let shutdown_clone = Arc::clone(&shutdown);
298
299 let r = tokio::spawn(async move {
300 Server::run(registry, SocketAddr::from(([0; 4], 6002)), shutdown_clone.notified()).await
301 });
302
303 shutdown.notify_one();
304 r.await.expect("tokio error").expect("prometheus_hyper server error");
305 }
306
307 #[tokio::test]
308 async fn test_sample() {
309 let shutdown = Arc::new(Notify::new());
310 let registry = Arc::new(Registry::new());
311
312 let shutdown_clone = Arc::clone(&shutdown);
313
314 let r = tokio::spawn(async move {
315 Server::run(
316 Arc::clone(®istry),
317 SocketAddr::from(([0; 4], 6003)),
318 shutdown_clone.notified(),
319 )
320 .await
321 });
322
323 tokio::time::sleep(Duration::from_millis(500)).await;
324
325 let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6003))).await.unwrap();
326 let io = TokioIo::new(stream);
327 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
328 tokio::task::spawn(async move {
329 if let Err(err) = conn.await {
330 println!("Connection failed: {:?}", err);
331 }
332 });
333
334 let req = Request::builder()
335 .method("GET")
336 .uri("http://localhost:6003/metrics")
337 .body(Empty::<Bytes>::new())
338 .expect("request builder");
339
340 let res = sender.send_request(req).await.expect("couldn't reach server");
341
342 assert_eq!(res.status(), StatusCode::OK);
343
344 shutdown.notify_one();
345 r.await.expect("tokio error").expect("prometheus_hyper server error");
346 }
347
348 #[tokio::test]
349 async fn test_wrong_endpoint_sample() {
350 let shutdown = Arc::new(Notify::new());
351 let registry = Arc::new(Registry::new());
352
353 let shutdown_clone = Arc::clone(&shutdown);
354
355 let r = tokio::spawn(async move {
356 Server::run(
357 Arc::clone(®istry),
358 SocketAddr::from(([0; 4], 6004)),
359 shutdown_clone.notified(),
360 )
361 .await
362 });
363
364 tokio::time::sleep(Duration::from_millis(500)).await;
365
366 let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6004))).await.unwrap();
367 let io = TokioIo::new(stream);
368 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
369 tokio::task::spawn(async move {
370 if let Err(err) = conn.await {
371 println!("Connection failed: {:?}", err);
372 }
373 });
374
375 let req = Request::builder()
376 .method("GET")
377 .uri("http://localhost:6004/foobar")
378 .body(Empty::<Bytes>::new())
379 .expect("request builder");
380
381 let res = sender.send_request(req).await.expect("couldn't reach server");
382 assert_eq!(res.status(), StatusCode::NOT_FOUND);
383
384 shutdown.notify_one();
385 r.await.expect("tokio error").expect("prometheus_hyper server error");
386 }
387}