1use std::collections::HashMap;
43use std::future::Future;
44use std::net::SocketAddr;
45use std::pin::Pin;
46use std::sync::atomic::{AtomicU64, Ordering};
47use std::sync::{Arc, Mutex};
48use std::task::{Context, Poll};
49use std::time::Instant;
50
51use tower_layer::Layer;
52use tower_service::Service;
53
54const LATENCY_BUCKETS_MS: [u64; 7] = [1, 5, 10, 50, 100, 500, 1_000];
61
62pub struct MethodMetrics {
66 requests_total: AtomicU64,
67 errors_total: AtomicU64,
68 latency_buckets: [AtomicU64; 8],
70}
71
72impl MethodMetrics {
73 fn new() -> Self {
74 Self {
75 requests_total: AtomicU64::new(0),
76 errors_total: AtomicU64::new(0),
77 latency_buckets: [
79 AtomicU64::new(0),
80 AtomicU64::new(0),
81 AtomicU64::new(0),
82 AtomicU64::new(0),
83 AtomicU64::new(0),
84 AtomicU64::new(0),
85 AtomicU64::new(0),
86 AtomicU64::new(0),
87 ],
88 }
89 }
90
91 fn record(&self, duration_ms: u64, is_error: bool) {
93 self.requests_total.fetch_add(1, Ordering::Relaxed);
94 if is_error {
95 self.errors_total.fetch_add(1, Ordering::Relaxed);
96 }
97 for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
100 if duration_ms <= bound {
101 self.latency_buckets[idx].fetch_add(1, Ordering::Relaxed);
102 }
103 }
104 self.latency_buckets[7].fetch_add(1, Ordering::Relaxed);
106 }
107
108 fn requests(&self) -> u64 {
109 self.requests_total.load(Ordering::Relaxed)
110 }
111
112 fn errors(&self) -> u64 {
113 self.errors_total.load(Ordering::Relaxed)
114 }
115
116 fn bucket(&self, idx: usize) -> u64 {
117 self.latency_buckets[idx].load(Ordering::Relaxed)
118 }
119}
120
121pub struct ActiveRequestGuard<'a>(&'a AtomicU64);
127
128impl Drop for ActiveRequestGuard<'_> {
129 fn drop(&mut self) {
130 self.0.fetch_sub(1, Ordering::Relaxed);
131 }
132}
133
134pub struct NetMetrics {
142 methods: Mutex<HashMap<String, Arc<MethodMetrics>>>,
144 total_requests: AtomicU64,
146 total_errors: AtomicU64,
148 pub active_requests: AtomicU64,
150 pub bytes_sent_total: AtomicU64,
152 pub bytes_received_total: AtomicU64,
154 pub rtt_histogram: [AtomicU64; 8],
156}
157
158impl NetMetrics {
159 pub fn new() -> Arc<Self> {
161 Arc::new(Self {
162 methods: Mutex::new(HashMap::new()),
163 total_requests: AtomicU64::new(0),
164 total_errors: AtomicU64::new(0),
165 active_requests: AtomicU64::new(0),
166 bytes_sent_total: AtomicU64::new(0),
167 bytes_received_total: AtomicU64::new(0),
168 rtt_histogram: [
169 AtomicU64::new(0),
170 AtomicU64::new(0),
171 AtomicU64::new(0),
172 AtomicU64::new(0),
173 AtomicU64::new(0),
174 AtomicU64::new(0),
175 AtomicU64::new(0),
176 AtomicU64::new(0),
177 ],
178 })
179 }
180
181 pub fn enter_request(&self) -> ActiveRequestGuard<'_> {
185 self.active_requests.fetch_add(1, Ordering::Relaxed);
186 ActiveRequestGuard(&self.active_requests)
187 }
188
189 pub fn add_bytes_received(&self, bytes: u64) {
191 self.bytes_received_total
192 .fetch_add(bytes, Ordering::Relaxed);
193 }
194
195 pub fn add_bytes_sent(&self, bytes: u64) {
197 self.bytes_sent_total.fetch_add(bytes, Ordering::Relaxed);
198 }
199
200 pub fn record_rtt(&self, rtt_ms: u64) {
202 for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
203 if rtt_ms <= bound {
204 self.rtt_histogram[idx].fetch_add(1, Ordering::Relaxed);
205 }
206 }
207 self.rtt_histogram[7].fetch_add(1, Ordering::Relaxed);
209 }
210
211 pub fn record_request(&self, method: &str, duration_ms: u64, is_error: bool) {
215 self.total_requests.fetch_add(1, Ordering::Relaxed);
216 if is_error {
217 self.total_errors.fetch_add(1, Ordering::Relaxed);
218 }
219
220 let method_metrics = {
221 let mut map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
222 Arc::clone(
223 map.entry(method.to_owned())
224 .or_insert_with(|| Arc::new(MethodMetrics::new())),
225 )
226 };
227 method_metrics.record(duration_ms, is_error);
228 }
229
230 pub fn to_prometheus(&self) -> String {
235 let mut out = String::with_capacity(8192);
236
237 let total_req = self.total_requests.load(Ordering::Relaxed);
239 let total_err = self.total_errors.load(Ordering::Relaxed);
240
241 out.push_str("# HELP amaters_net_requests_total Total gRPC requests\n");
242 out.push_str("# TYPE amaters_net_requests_total counter\n");
243 out.push_str(&format!("amaters_net_requests_total {total_req}\n"));
244
245 out.push_str("# HELP amaters_net_errors_total Total gRPC errors\n");
246 out.push_str("# TYPE amaters_net_errors_total counter\n");
247 out.push_str(&format!("amaters_net_errors_total {total_err}\n"));
248
249 let active = self.active_requests.load(Ordering::Relaxed);
251 out.push_str("# HELP amaters_net_active_requests Currently active requests\n");
252 out.push_str("# TYPE amaters_net_active_requests gauge\n");
253 out.push_str(&format!("amaters_net_active_requests {active}\n"));
254
255 let bytes_sent = self.bytes_sent_total.load(Ordering::Relaxed);
257 out.push_str("# HELP amaters_net_bytes_sent_total Total bytes sent\n");
258 out.push_str("# TYPE amaters_net_bytes_sent_total counter\n");
259 out.push_str(&format!("amaters_net_bytes_sent_total {bytes_sent}\n"));
260
261 let bytes_recv = self.bytes_received_total.load(Ordering::Relaxed);
262 out.push_str("# HELP amaters_net_bytes_received_total Total bytes received\n");
263 out.push_str("# TYPE amaters_net_bytes_received_total counter\n");
264 out.push_str(&format!("amaters_net_bytes_received_total {bytes_recv}\n"));
265
266 out.push_str("# HELP amaters_net_rtt_bucket RTT histogram\n");
268 out.push_str("# TYPE amaters_net_rtt_bucket histogram\n");
269 for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
270 let count = self.rtt_histogram[idx].load(Ordering::Relaxed);
271 out.push_str(&format!(
272 "amaters_net_rtt_bucket{{le=\"{bound}\"}} {count}\n"
273 ));
274 }
275 let inf_count = self.rtt_histogram[7].load(Ordering::Relaxed);
276 out.push_str(&format!(
277 "amaters_net_rtt_bucket{{le=\"+Inf\"}} {inf_count}\n"
278 ));
279
280 let map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
282
283 let mut methods: Vec<(&String, &Arc<MethodMetrics>)> = map.iter().collect();
284 methods.sort_by_key(|(k, _)| k.as_str());
286
287 for (method, m) in &methods {
288 let label = format!("{{method=\"{method}\"}}");
289
290 out.push_str(&format!(
291 "amaters_net_method_requests_total{label} {}\n",
292 m.requests()
293 ));
294 out.push_str(&format!(
295 "amaters_net_method_errors_total{label} {}\n",
296 m.errors()
297 ));
298
299 out.push_str(&format!(
301 "# HELP amaters_net_request_duration_ms{label} Request latency histogram\n"
302 ));
303 out.push_str("# TYPE amaters_net_request_duration_ms histogram\n");
304 for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
305 out.push_str(&format!(
306 "amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"{bound}\"}} {}\n",
307 m.bucket(idx)
308 ));
309 }
310 out.push_str(&format!(
311 "amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"+Inf\"}} {}\n",
312 m.bucket(7)
313 ));
314 }
315
316 out
317 }
318}
319
320async fn metrics_handler(
324 axum::extract::State(metrics): axum::extract::State<Arc<NetMetrics>>,
325) -> (
326 axum::http::StatusCode,
327 [(axum::http::HeaderName, &'static str); 1],
328 String,
329) {
330 let body = metrics.to_prometheus();
331 (
332 axum::http::StatusCode::OK,
333 [(
334 axum::http::header::CONTENT_TYPE,
335 "text/plain; version=0.0.4; charset=utf-8",
336 )],
337 body,
338 )
339}
340
341pub fn spawn_metrics_server(
363 addr: SocketAddr,
364 metrics: Arc<NetMetrics>,
365) -> tokio::task::JoinHandle<()> {
366 let app = axum::Router::new()
367 .route("/metrics", axum::routing::get(metrics_handler))
368 .with_state(metrics);
369
370 tokio::spawn(async move {
371 match tokio::net::TcpListener::bind(addr).await {
372 Ok(listener) => {
373 tracing::info!("Metrics server listening on {}", addr);
374 if let Err(e) = axum::serve(listener, app).await {
375 tracing::warn!("Metrics server error: {}", e);
376 }
377 }
378 Err(e) => {
379 tracing::error!("Failed to bind metrics server to {}: {}", addr, e);
380 }
381 }
382 })
383}
384
385#[derive(Clone)]
389pub struct MetricsLayer {
390 metrics: Arc<NetMetrics>,
391}
392
393impl MetricsLayer {
394 pub fn new(metrics: Arc<NetMetrics>) -> Self {
396 Self { metrics }
397 }
398}
399
400impl<S> Layer<S> for MetricsLayer {
401 type Service = MetricsService<S>;
402
403 fn layer(&self, inner: S) -> Self::Service {
404 MetricsService {
405 inner,
406 metrics: Arc::clone(&self.metrics),
407 }
408 }
409}
410
411#[derive(Clone)]
415pub struct MetricsService<S> {
416 inner: S,
417 metrics: Arc<NetMetrics>,
418}
419
420impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for MetricsService<S>
421where
422 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
423 S::Future: Send + 'static,
424 S::Error: Send + 'static,
425 ReqBody: http_body::Body + Send + 'static,
426 ResBody: http_body::Body + Send + 'static,
427{
428 type Response = http::Response<ResBody>;
429 type Error = S::Error;
430 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
431
432 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
433 self.inner.poll_ready(cx)
434 }
435
436 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
437 let method = req.uri().path().to_owned();
439
440 let req_bytes = req
442 .body()
443 .size_hint()
444 .exact()
445 .unwrap_or_else(|| req.body().size_hint().lower());
446
447 let mut inner = self.inner.clone();
448 std::mem::swap(&mut self.inner, &mut inner);
449
450 let metrics = Arc::clone(&self.metrics);
451 let start = Instant::now();
452
453 Box::pin(async move {
454 metrics.add_bytes_received(req_bytes);
456 let _guard = metrics.enter_request();
457
458 let result = inner.call(req).await;
459 let elapsed_ms = start.elapsed().as_millis() as u64;
460 let is_error = result.is_err();
461
462 if let Ok(ref resp) = result {
464 let resp_bytes = resp
465 .body()
466 .size_hint()
467 .exact()
468 .unwrap_or_else(|| resp.body().size_hint().lower());
469 metrics.add_bytes_sent(resp_bytes);
470 }
471
472 metrics.record_request(&method, elapsed_ms, is_error);
473 metrics.record_rtt(elapsed_ms);
474 result
476 })
477 }
478}
479
480#[cfg(test)]
483mod tests {
484 use super::*;
485 use std::convert::Infallible;
486 use tower_service::Service as _;
487
488 #[derive(Clone)]
491 struct OkService;
492
493 impl Service<http::Request<String>> for OkService {
494 type Response = http::Response<String>;
495 type Error = Infallible;
496 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
497
498 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
499 Poll::Ready(Ok(()))
500 }
501
502 fn call(&mut self, _req: http::Request<String>) -> Self::Future {
503 Box::pin(async { Ok(http::Response::new(String::new())) })
504 }
505 }
506
507 fn make_req(path: &str) -> http::Request<String> {
508 http::Request::builder()
509 .uri(path)
510 .body(String::new())
511 .expect("request builder should not fail")
512 }
513
514 #[tokio::test]
517 async fn test_metrics_counter_increments() {
518 let metrics = NetMetrics::new();
519 let layer = MetricsLayer::new(Arc::clone(&metrics));
520 let mut svc = layer.layer(OkService);
521
522 for _ in 0..3 {
523 svc.call(make_req("/amaters.AqlService/ExecuteQuery"))
524 .await
525 .expect("service call should not error");
526 }
527
528 assert_eq!(
529 metrics.total_requests.load(Ordering::Relaxed),
530 3,
531 "total_requests should be 3 after 3 calls"
532 );
533 }
534
535 #[tokio::test]
536 async fn test_metrics_latency_histogram_records() {
537 let metrics = NetMetrics::new();
538
539 metrics.record_request("/test/Method", 10, false);
541
542 let map = metrics
543 .methods
544 .lock()
545 .expect("mutex should not be poisoned");
546 let m = map
547 .get("/test/Method")
548 .expect("method entry should exist after recording");
549
550 assert_eq!(
552 m.bucket(2),
553 1,
554 "le=10ms bucket should be 1 for a 10ms observation"
555 );
556 assert_eq!(m.bucket(7), 1, "+Inf bucket should be 1");
558 assert_eq!(
560 m.bucket(0),
561 0,
562 "le=1ms bucket should be 0 for a 10ms observation"
563 );
564 }
565
566 #[tokio::test]
567 async fn test_metrics_prometheus_text_format() {
568 let metrics = NetMetrics::new();
569 metrics.record_request("/amaters.AqlService/ExecuteQuery", 5, false);
570 metrics.record_request("/amaters.AqlService/ExecuteQuery", 50, false);
571 metrics.record_request("/amaters.AqlService/ExecuteQuery", 200, true);
572
573 let prom = metrics.to_prometheus();
574
575 assert!(
576 prom.contains("amaters_net_requests_total"),
577 "output must contain amaters_net_requests_total"
578 );
579 assert!(
580 prom.contains("amaters_net_errors_total"),
581 "output must contain amaters_net_errors_total"
582 );
583 assert!(
584 prom.contains("amaters_net_method_requests_total"),
585 "output must contain per-method counter"
586 );
587 assert!(
589 prom.contains("amaters_net_requests_total 3"),
590 "total requests should be 3"
591 );
592 assert!(
593 prom.contains("amaters_net_errors_total 1"),
594 "total errors should be 1"
595 );
596 }
597
598 #[tokio::test]
600 async fn test_metrics_layer_wraps_service() {
601 let metrics = NetMetrics::new();
602 let layer = MetricsLayer::new(Arc::clone(&metrics));
603 let mut svc = layer.layer(OkService);
604
605 svc.call(make_req("/pkg.Svc/Method"))
606 .await
607 .expect("should succeed");
608
609 let prom = metrics.to_prometheus();
610 assert!(
611 prom.contains("/pkg.Svc/Method"),
612 "method should appear in Prometheus output"
613 );
614 }
615
616 #[test]
618 fn test_latency_bucket_boundaries() {
619 let m = MethodMetrics::new();
620
621 m.record(1, false);
623 assert_eq!(m.bucket(0), 1, "le=1 should catch 1ms");
624 assert_eq!(m.bucket(7), 1, "+Inf must always count");
625
626 m.record(0, false);
628 for i in 0..8 {
629 let expected = 2u64;
630 assert_eq!(
631 m.bucket(i),
632 expected,
633 "all buckets should be 2 after recording 0ms and 1ms (bucket={i})"
634 );
635 }
636 }
637
638 #[test]
640 fn test_metrics_error_counting() {
641 let m = MethodMetrics::new();
642 m.record(10, true);
643 m.record(20, false);
644 m.record(30, true);
645 assert_eq!(m.requests(), 3);
646 assert_eq!(m.errors(), 2);
647 }
648
649 #[tokio::test]
653 async fn test_active_requests_gauge_increments_during_request() {
654 let metrics = NetMetrics::new();
655 let guard = metrics.enter_request();
657 assert_eq!(
658 metrics.active_requests.load(Ordering::Relaxed),
659 1,
660 "active_requests should be 1 after entering"
661 );
662 drop(guard);
663 }
664
665 #[tokio::test]
667 async fn test_active_requests_gauge_decrements_on_completion() {
668 let metrics = NetMetrics::new();
669 {
670 let _guard = metrics.enter_request();
671 assert_eq!(metrics.active_requests.load(Ordering::Relaxed), 1);
672 }
673 assert_eq!(
675 metrics.active_requests.load(Ordering::Relaxed),
676 0,
677 "active_requests should be 0 after guard is dropped"
678 );
679 }
680
681 #[test]
683 fn test_bytes_sent_counter_records() {
684 let metrics = NetMetrics::new();
685 metrics.add_bytes_sent(100);
686 metrics.add_bytes_sent(200);
687 assert_eq!(
688 metrics.bytes_sent_total.load(Ordering::Relaxed),
689 300,
690 "bytes_sent_total should be 300"
691 );
692 }
693
694 #[test]
696 fn test_bytes_received_counter_records() {
697 let metrics = NetMetrics::new();
698 metrics.add_bytes_received(512);
699 metrics.add_bytes_received(512);
700 assert_eq!(
701 metrics.bytes_received_total.load(Ordering::Relaxed),
702 1024,
703 "bytes_received_total should be 1024"
704 );
705 }
706
707 #[test]
709 fn test_rtt_histogram_records() {
710 let metrics = NetMetrics::new();
711 metrics.record_rtt(5);
713
714 assert_eq!(
716 metrics.rtt_histogram[0].load(Ordering::Relaxed),
717 0,
718 "le=1 bucket should be 0 for 5ms observation"
719 );
720 assert_eq!(
722 metrics.rtt_histogram[1].load(Ordering::Relaxed),
723 1,
724 "le=5 bucket should be 1 for 5ms observation"
725 );
726 assert_eq!(
728 metrics.rtt_histogram[7].load(Ordering::Relaxed),
729 1,
730 "+Inf bucket should be 1"
731 );
732 }
733
734 #[test]
736 fn test_prometheus_output_includes_new_metrics() {
737 let metrics = NetMetrics::new();
738 metrics.add_bytes_sent(42);
739 metrics.add_bytes_received(24);
740 metrics.record_rtt(10);
741 let _ = metrics.enter_request(); let prom = metrics.to_prometheus();
744
745 assert!(
746 prom.contains("amaters_net_active_requests"),
747 "output must contain active_requests"
748 );
749 assert!(
750 prom.contains("amaters_net_bytes_sent_total"),
751 "output must contain bytes_sent_total"
752 );
753 assert!(
754 prom.contains("amaters_net_bytes_received_total"),
755 "output must contain bytes_received_total"
756 );
757 assert!(
758 prom.contains("amaters_net_rtt_bucket"),
759 "output must contain rtt_bucket"
760 );
761 assert!(
762 prom.contains("amaters_net_bytes_sent_total 42"),
763 "bytes_sent_total should be 42"
764 );
765 assert!(
766 prom.contains("amaters_net_bytes_received_total 24"),
767 "bytes_received_total should be 24"
768 );
769 }
770
771 #[test]
774 fn test_active_requests_exception_safe() {
775 let metrics = NetMetrics::new();
776 {
777 let guard = metrics.enter_request();
778 assert_eq!(metrics.active_requests.load(Ordering::Relaxed), 1);
779 drop(guard);
781 }
782 assert_eq!(
783 metrics.active_requests.load(Ordering::Relaxed),
784 0,
785 "active_requests must be 0 after drop, even on early exit"
786 );
787 }
788
789 #[tokio::test]
797 async fn test_prometheus_endpoint_returns_200() {
798 use tokio::io::{AsyncReadExt, AsyncWriteExt};
799
800 let metrics = NetMetrics::new();
801 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
803 .await
804 .expect("should bind to ephemeral port");
805 let addr = listener
806 .local_addr()
807 .expect("should have local addr after bind");
808
809 let app = axum::Router::new()
811 .route("/metrics", axum::routing::get(metrics_handler))
812 .with_state(Arc::clone(&metrics));
813 let _handle = tokio::spawn(async move {
814 if let Err(e) = axum::serve(listener, app).await {
815 tracing::warn!("test metrics server error: {}", e);
816 }
817 });
818
819 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
821
822 let mut stream = tokio::net::TcpStream::connect(addr)
823 .await
824 .expect("should connect to metrics server");
825
826 stream
827 .write_all(b"GET /metrics HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
828 .await
829 .expect("should write request");
830
831 let mut response = Vec::new();
832 stream
833 .read_to_end(&mut response)
834 .await
835 .expect("should read response");
836
837 let response_str = String::from_utf8_lossy(&response);
838 assert!(
839 response_str.starts_with("HTTP/1.1 200"),
840 "expected HTTP 200, got: {}",
841 &response_str[..response_str.find('\r').unwrap_or(response_str.len())]
842 );
843 assert!(
844 response_str.contains("text/plain"),
845 "expected text/plain Content-Type"
846 );
847 }
848
849 #[test]
852 fn test_prometheus_metrics_format_contains_required_families() {
853 let metrics = NetMetrics::new();
854 metrics.record_request("/amaters.AqlService/Query", 10, false);
855 metrics.add_bytes_sent(1024);
856 let _ = metrics.enter_request(); let prom = metrics.to_prometheus();
859
860 assert!(
861 prom.contains("amaters_net_requests_total"),
862 "must contain amaters_net_requests_total"
863 );
864 assert!(
865 prom.contains("amaters_net_active_requests"),
866 "must contain amaters_net_active_requests"
867 );
868 assert!(
869 prom.contains("amaters_net_requests_total 1"),
870 "must report exactly 1 request after one recording"
871 );
872 }
873}