Skip to main content

amaters_net/
metrics_layer.rs

1//! Metrics middleware for the AmateRS network layer.
2//!
3//! Provides lock-free, per-method request counters and latency histograms
4//! using `AtomicU64` — following the same hand-rolled pattern as
5//! [`amaters_core::metrics::CoreMetrics`].  No external `metrics-rs` crate is
6//! required.
7//!
8//! # Structure
9//!
10//! - [`MethodMetrics`]: per-method counters and histogram bucket counters.
11//! - [`NetMetrics`]: registry keyed by gRPC method name plus global totals.
12//! - [`MetricsLayer`]: Tower [`Layer`] factory.
13//! - [`MetricsService<S>`]: Tower [`Service`] wrapping `S`; records timing on
14//!   every call.
15//! - [`ActiveRequestGuard`]: RAII guard that decrements `active_requests` on
16//!   drop — ensures the gauge is decremented even when the inner service panics.
17//!
18//! # New metrics (v0.2.1)
19//!
20//! - `active_requests` (gauge) — incremented on request entry, decremented via
21//!   [`ActiveRequestGuard`] on exit.
22//! - `bytes_sent_total` (counter) — accumulated from response body size hints.
23//! - `bytes_received_total` (counter) — accumulated from request body size hints.
24//! - `rtt_histogram` (histogram) — same 7-bucket + `+Inf` scheme as
25//!   `latency_buckets`, keyed by end-to-end round-trip time.
26//!
27//! # Prometheus text format
28//!
29//! Call [`NetMetrics::to_prometheus`] to obtain an OpenMetrics/Prometheus text
30//! snapshot, e.g.
31//!
32//! ```text
33//! # HELP amaters_net_requests_total Total gRPC requests
34//! # TYPE amaters_net_requests_total counter
35//! amaters_net_requests_total 42
36//! amaters_net_errors_total 3
37//! # HELP amaters_net_active_requests Currently active requests
38//! # TYPE amaters_net_active_requests gauge
39//! amaters_net_active_requests 0
40//! ```
41
42use std::collections::HashMap;
43use std::future::Future;
44use std::net::SocketAddr;
45use std::pin::Pin;
46use std::sync::atomic::{AtomicU64, Ordering};
47use std::sync::{Arc, Mutex};
48use std::task::{Context, Poll};
49use std::time::Instant;
50
51use tower_layer::Layer;
52use tower_service::Service;
53
54// ─── Constants ───────────────────────────────────────────────────────────────
55
56/// Latency histogram bucket upper bounds in milliseconds.
57///
58/// The array defines seven finite boundaries; the eighth bucket is the implicit
59/// `+Inf` catch-all that accumulates all observations.
60const LATENCY_BUCKETS_MS: [u64; 7] = [1, 5, 10, 50, 100, 500, 1_000];
61
62// ─── MethodMetrics ────────────────────────────────────────────────────────────
63
64/// Per-method atomic counters and a fixed-size latency histogram.
65pub struct MethodMetrics {
66    requests_total: AtomicU64,
67    errors_total: AtomicU64,
68    /// 8 buckets: 7 finite upper bounds (`LATENCY_BUCKETS_MS`) + `+Inf`.
69    latency_buckets: [AtomicU64; 8],
70}
71
72impl MethodMetrics {
73    fn new() -> Self {
74        Self {
75            requests_total: AtomicU64::new(0),
76            errors_total: AtomicU64::new(0),
77            // Array-of-AtomicU64 cannot derive Default; initialise manually.
78            latency_buckets: [
79                AtomicU64::new(0),
80                AtomicU64::new(0),
81                AtomicU64::new(0),
82                AtomicU64::new(0),
83                AtomicU64::new(0),
84                AtomicU64::new(0),
85                AtomicU64::new(0),
86                AtomicU64::new(0),
87            ],
88        }
89    }
90
91    /// Record a single observation of `duration_ms`.
92    fn record(&self, duration_ms: u64, is_error: bool) {
93        self.requests_total.fetch_add(1, Ordering::Relaxed);
94        if is_error {
95            self.errors_total.fetch_add(1, Ordering::Relaxed);
96        }
97        // Cumulative histogram: increment every bucket whose upper bound is
98        // >= the observed value, plus the +Inf bucket.
99        for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
100            if duration_ms <= bound {
101                self.latency_buckets[idx].fetch_add(1, Ordering::Relaxed);
102            }
103        }
104        // The +Inf bucket always counts every observation.
105        self.latency_buckets[7].fetch_add(1, Ordering::Relaxed);
106    }
107
108    fn requests(&self) -> u64 {
109        self.requests_total.load(Ordering::Relaxed)
110    }
111
112    fn errors(&self) -> u64 {
113        self.errors_total.load(Ordering::Relaxed)
114    }
115
116    fn bucket(&self, idx: usize) -> u64 {
117        self.latency_buckets[idx].load(Ordering::Relaxed)
118    }
119}
120
121// ─── ActiveRequestGuard ───────────────────────────────────────────────────────
122
123/// RAII guard that decrements `active_requests` when dropped.
124///
125/// Guarantees the gauge is decremented even if the inner service panics.
126pub struct ActiveRequestGuard<'a>(&'a AtomicU64);
127
128impl Drop for ActiveRequestGuard<'_> {
129    fn drop(&mut self) {
130        self.0.fetch_sub(1, Ordering::Relaxed);
131    }
132}
133
134// ─── NetMetrics ───────────────────────────────────────────────────────────────
135
136/// Network metrics registry.
137///
138/// Tracks per-method counters and global totals.  All atomic fields are
139/// updated with `Ordering::Relaxed` which is sufficient for monotonically
140/// increasing counters observed by a single scraper.
141pub struct NetMetrics {
142    /// Per-method metrics, keyed by gRPC method path.
143    methods: Mutex<HashMap<String, Arc<MethodMetrics>>>,
144    /// Global request counter across all methods.
145    total_requests: AtomicU64,
146    /// Global error counter across all methods.
147    total_errors: AtomicU64,
148    /// Currently-in-flight request gauge.
149    pub active_requests: AtomicU64,
150    /// Total bytes sent (response bodies).
151    pub bytes_sent_total: AtomicU64,
152    /// Total bytes received (request bodies).
153    pub bytes_received_total: AtomicU64,
154    /// RTT histogram — 8 buckets: 7 finite bounds + `+Inf`.
155    pub rtt_histogram: [AtomicU64; 8],
156}
157
158impl NetMetrics {
159    /// Create a new empty registry wrapped in an `Arc`.
160    pub fn new() -> Arc<Self> {
161        Arc::new(Self {
162            methods: Mutex::new(HashMap::new()),
163            total_requests: AtomicU64::new(0),
164            total_errors: AtomicU64::new(0),
165            active_requests: AtomicU64::new(0),
166            bytes_sent_total: AtomicU64::new(0),
167            bytes_received_total: AtomicU64::new(0),
168            rtt_histogram: [
169                AtomicU64::new(0),
170                AtomicU64::new(0),
171                AtomicU64::new(0),
172                AtomicU64::new(0),
173                AtomicU64::new(0),
174                AtomicU64::new(0),
175                AtomicU64::new(0),
176                AtomicU64::new(0),
177            ],
178        })
179    }
180
181    /// Increment the active-requests gauge and return a guard that decrements it.
182    ///
183    /// The guard must be kept alive until the request completes.
184    pub fn enter_request(&self) -> ActiveRequestGuard<'_> {
185        self.active_requests.fetch_add(1, Ordering::Relaxed);
186        ActiveRequestGuard(&self.active_requests)
187    }
188
189    /// Record bytes received (request body size).
190    pub fn add_bytes_received(&self, bytes: u64) {
191        self.bytes_received_total
192            .fetch_add(bytes, Ordering::Relaxed);
193    }
194
195    /// Record bytes sent (response body size).
196    pub fn add_bytes_sent(&self, bytes: u64) {
197        self.bytes_sent_total.fetch_add(bytes, Ordering::Relaxed);
198    }
199
200    /// Record a single RTT observation (milliseconds) in the histogram.
201    pub fn record_rtt(&self, rtt_ms: u64) {
202        for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
203            if rtt_ms <= bound {
204                self.rtt_histogram[idx].fetch_add(1, Ordering::Relaxed);
205            }
206        }
207        // +Inf always increments.
208        self.rtt_histogram[7].fetch_add(1, Ordering::Relaxed);
209    }
210
211    /// Record a request for `method` with the given `duration_ms`.
212    ///
213    /// Creates a per-method entry the first time a method name is seen.
214    pub fn record_request(&self, method: &str, duration_ms: u64, is_error: bool) {
215        self.total_requests.fetch_add(1, Ordering::Relaxed);
216        if is_error {
217            self.total_errors.fetch_add(1, Ordering::Relaxed);
218        }
219
220        let method_metrics = {
221            let mut map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
222            Arc::clone(
223                map.entry(method.to_owned())
224                    .or_insert_with(|| Arc::new(MethodMetrics::new())),
225            )
226        };
227        method_metrics.record(duration_ms, is_error);
228    }
229
230    /// Render all metrics in Prometheus text format.
231    ///
232    /// The output format follows the OpenMetrics / Prometheus exposition
233    /// format, consistent with [`amaters_core::metrics::CoreMetrics::to_prometheus`].
234    pub fn to_prometheus(&self) -> String {
235        let mut out = String::with_capacity(8192);
236
237        // ── Global counters ──────────────────────────────────────────────────
238        let total_req = self.total_requests.load(Ordering::Relaxed);
239        let total_err = self.total_errors.load(Ordering::Relaxed);
240
241        out.push_str("# HELP amaters_net_requests_total Total gRPC requests\n");
242        out.push_str("# TYPE amaters_net_requests_total counter\n");
243        out.push_str(&format!("amaters_net_requests_total {total_req}\n"));
244
245        out.push_str("# HELP amaters_net_errors_total Total gRPC errors\n");
246        out.push_str("# TYPE amaters_net_errors_total counter\n");
247        out.push_str(&format!("amaters_net_errors_total {total_err}\n"));
248
249        // ── Active requests gauge ────────────────────────────────────────────
250        let active = self.active_requests.load(Ordering::Relaxed);
251        out.push_str("# HELP amaters_net_active_requests Currently active requests\n");
252        out.push_str("# TYPE amaters_net_active_requests gauge\n");
253        out.push_str(&format!("amaters_net_active_requests {active}\n"));
254
255        // ── Byte counters ────────────────────────────────────────────────────
256        let bytes_sent = self.bytes_sent_total.load(Ordering::Relaxed);
257        out.push_str("# HELP amaters_net_bytes_sent_total Total bytes sent\n");
258        out.push_str("# TYPE amaters_net_bytes_sent_total counter\n");
259        out.push_str(&format!("amaters_net_bytes_sent_total {bytes_sent}\n"));
260
261        let bytes_recv = self.bytes_received_total.load(Ordering::Relaxed);
262        out.push_str("# HELP amaters_net_bytes_received_total Total bytes received\n");
263        out.push_str("# TYPE amaters_net_bytes_received_total counter\n");
264        out.push_str(&format!("amaters_net_bytes_received_total {bytes_recv}\n"));
265
266        // ── RTT histogram ────────────────────────────────────────────────────
267        out.push_str("# HELP amaters_net_rtt_bucket RTT histogram\n");
268        out.push_str("# TYPE amaters_net_rtt_bucket histogram\n");
269        for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
270            let count = self.rtt_histogram[idx].load(Ordering::Relaxed);
271            out.push_str(&format!(
272                "amaters_net_rtt_bucket{{le=\"{bound}\"}} {count}\n"
273            ));
274        }
275        let inf_count = self.rtt_histogram[7].load(Ordering::Relaxed);
276        out.push_str(&format!(
277            "amaters_net_rtt_bucket{{le=\"+Inf\"}} {inf_count}\n"
278        ));
279
280        // ── Per-method metrics ───────────────────────────────────────────────
281        let map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
282
283        let mut methods: Vec<(&String, &Arc<MethodMetrics>)> = map.iter().collect();
284        // Sort for deterministic output in tests.
285        methods.sort_by_key(|(k, _)| k.as_str());
286
287        for (method, m) in &methods {
288            let label = format!("{{method=\"{method}\"}}");
289
290            out.push_str(&format!(
291                "amaters_net_method_requests_total{label} {}\n",
292                m.requests()
293            ));
294            out.push_str(&format!(
295                "amaters_net_method_errors_total{label} {}\n",
296                m.errors()
297            ));
298
299            // Histogram buckets
300            out.push_str(&format!(
301                "# HELP amaters_net_request_duration_ms{label} Request latency histogram\n"
302            ));
303            out.push_str("# TYPE amaters_net_request_duration_ms histogram\n");
304            for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
305                out.push_str(&format!(
306                    "amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"{bound}\"}} {}\n",
307                    m.bucket(idx)
308                ));
309            }
310            out.push_str(&format!(
311                "amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"+Inf\"}} {}\n",
312                m.bucket(7)
313            ));
314        }
315
316        out
317    }
318}
319
320// ─── Prometheus HTTP server ───────────────────────────────────────────────────
321
322/// axum handler: serialise the current metrics snapshot as Prometheus text.
323async fn metrics_handler(
324    axum::extract::State(metrics): axum::extract::State<Arc<NetMetrics>>,
325) -> (
326    axum::http::StatusCode,
327    [(axum::http::HeaderName, &'static str); 1],
328    String,
329) {
330    let body = metrics.to_prometheus();
331    (
332        axum::http::StatusCode::OK,
333        [(
334            axum::http::header::CONTENT_TYPE,
335            "text/plain; version=0.0.4; charset=utf-8",
336        )],
337        body,
338    )
339}
340
341/// Spawn a background task that serves Prometheus-format metrics at `GET /metrics`.
342///
343/// The server binds to `addr` and runs until the tokio runtime shuts down, or the
344/// returned [`tokio::task::JoinHandle`] is explicitly aborted via
345/// [`JoinHandle::abort`](tokio::task::JoinHandle::abort).  Dropping the handle does
346/// **not** stop the task.
347///
348/// # Example
349///
350/// ```rust,no_run
351/// use std::net::SocketAddr;
352/// use amaters_net::metrics_layer::{NetMetrics, spawn_metrics_server};
353///
354/// # #[tokio::main]
355/// # async fn main() {
356/// let metrics = NetMetrics::new();
357/// let addr: SocketAddr = "127.0.0.1:9090".parse().expect("valid addr");
358/// let _handle = spawn_metrics_server(addr, metrics);
359/// // handle.abort() stops the server
360/// # }
361/// ```
362pub fn spawn_metrics_server(
363    addr: SocketAddr,
364    metrics: Arc<NetMetrics>,
365) -> tokio::task::JoinHandle<()> {
366    let app = axum::Router::new()
367        .route("/metrics", axum::routing::get(metrics_handler))
368        .with_state(metrics);
369
370    tokio::spawn(async move {
371        match tokio::net::TcpListener::bind(addr).await {
372            Ok(listener) => {
373                tracing::info!("Metrics server listening on {}", addr);
374                if let Err(e) = axum::serve(listener, app).await {
375                    tracing::warn!("Metrics server error: {}", e);
376                }
377            }
378            Err(e) => {
379                tracing::error!("Failed to bind metrics server to {}: {}", addr, e);
380            }
381        }
382    })
383}
384
385// ─── MetricsLayer ─────────────────────────────────────────────────────────────
386
387/// Tower [`Layer`] that wraps a service with metrics recording.
388#[derive(Clone)]
389pub struct MetricsLayer {
390    metrics: Arc<NetMetrics>,
391}
392
393impl MetricsLayer {
394    /// Create a new layer backed by the given [`NetMetrics`] registry.
395    pub fn new(metrics: Arc<NetMetrics>) -> Self {
396        Self { metrics }
397    }
398}
399
400impl<S> Layer<S> for MetricsLayer {
401    type Service = MetricsService<S>;
402
403    fn layer(&self, inner: S) -> Self::Service {
404        MetricsService {
405            inner,
406            metrics: Arc::clone(&self.metrics),
407        }
408    }
409}
410
411// ─── MetricsService ───────────────────────────────────────────────────────────
412
413/// Tower [`Service`] that records per-request timing metrics.
414#[derive(Clone)]
415pub struct MetricsService<S> {
416    inner: S,
417    metrics: Arc<NetMetrics>,
418}
419
420impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for MetricsService<S>
421where
422    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
423    S::Future: Send + 'static,
424    S::Error: Send + 'static,
425    ReqBody: http_body::Body + Send + 'static,
426    ResBody: http_body::Body + Send + 'static,
427{
428    type Response = http::Response<ResBody>;
429    type Error = S::Error;
430    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
431
432    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
433        self.inner.poll_ready(cx)
434    }
435
436    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
437        // Extract method name from the URI path (gRPC convention: /package.Service/Method).
438        let method = req.uri().path().to_owned();
439
440        // Capture request body size hint before moving the request.
441        let req_bytes = req
442            .body()
443            .size_hint()
444            .exact()
445            .unwrap_or_else(|| req.body().size_hint().lower());
446
447        let mut inner = self.inner.clone();
448        std::mem::swap(&mut self.inner, &mut inner);
449
450        let metrics = Arc::clone(&self.metrics);
451        let start = Instant::now();
452
453        Box::pin(async move {
454            // Increment active-requests gauge; guard decrements on drop.
455            metrics.add_bytes_received(req_bytes);
456            let _guard = metrics.enter_request();
457
458            let result = inner.call(req).await;
459            let elapsed_ms = start.elapsed().as_millis() as u64;
460            let is_error = result.is_err();
461
462            // Record response bytes from size hint.
463            if let Ok(ref resp) = result {
464                let resp_bytes = resp
465                    .body()
466                    .size_hint()
467                    .exact()
468                    .unwrap_or_else(|| resp.body().size_hint().lower());
469                metrics.add_bytes_sent(resp_bytes);
470            }
471
472            metrics.record_request(&method, elapsed_ms, is_error);
473            metrics.record_rtt(elapsed_ms);
474            // _guard drops here, decrementing active_requests.
475            result
476        })
477    }
478}
479
480// ─── Tests ────────────────────────────────────────────────────────────────────
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use std::convert::Infallible;
486    use tower_service::Service as _;
487
488    // ── Simple inner service ──────────────────────────────────────────────────
489
490    #[derive(Clone)]
491    struct OkService;
492
493    impl Service<http::Request<String>> for OkService {
494        type Response = http::Response<String>;
495        type Error = Infallible;
496        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
497
498        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
499            Poll::Ready(Ok(()))
500        }
501
502        fn call(&mut self, _req: http::Request<String>) -> Self::Future {
503            Box::pin(async { Ok(http::Response::new(String::new())) })
504        }
505    }
506
507    fn make_req(path: &str) -> http::Request<String> {
508        http::Request::builder()
509            .uri(path)
510            .body(String::new())
511            .expect("request builder should not fail")
512    }
513
514    // ─────────────────────────────────────────────────────────────────────────
515
516    #[tokio::test]
517    async fn test_metrics_counter_increments() {
518        let metrics = NetMetrics::new();
519        let layer = MetricsLayer::new(Arc::clone(&metrics));
520        let mut svc = layer.layer(OkService);
521
522        for _ in 0..3 {
523            svc.call(make_req("/amaters.AqlService/ExecuteQuery"))
524                .await
525                .expect("service call should not error");
526        }
527
528        assert_eq!(
529            metrics.total_requests.load(Ordering::Relaxed),
530            3,
531            "total_requests should be 3 after 3 calls"
532        );
533    }
534
535    #[tokio::test]
536    async fn test_metrics_latency_histogram_records() {
537        let metrics = NetMetrics::new();
538
539        // Directly exercise record_request with a known 10 ms duration.
540        metrics.record_request("/test/Method", 10, false);
541
542        let map = metrics
543            .methods
544            .lock()
545            .expect("mutex should not be poisoned");
546        let m = map
547            .get("/test/Method")
548            .expect("method entry should exist after recording");
549
550        // Bucket index 2 = le=10ms; 10ms should fall exactly on the boundary.
551        assert_eq!(
552            m.bucket(2),
553            1,
554            "le=10ms bucket should be 1 for a 10ms observation"
555        );
556        // The +Inf bucket (index 7) must always be 1.
557        assert_eq!(m.bucket(7), 1, "+Inf bucket should be 1");
558        // Bucket index 0 = le=1ms should be 0.
559        assert_eq!(
560            m.bucket(0),
561            0,
562            "le=1ms bucket should be 0 for a 10ms observation"
563        );
564    }
565
566    #[tokio::test]
567    async fn test_metrics_prometheus_text_format() {
568        let metrics = NetMetrics::new();
569        metrics.record_request("/amaters.AqlService/ExecuteQuery", 5, false);
570        metrics.record_request("/amaters.AqlService/ExecuteQuery", 50, false);
571        metrics.record_request("/amaters.AqlService/ExecuteQuery", 200, true);
572
573        let prom = metrics.to_prometheus();
574
575        assert!(
576            prom.contains("amaters_net_requests_total"),
577            "output must contain amaters_net_requests_total"
578        );
579        assert!(
580            prom.contains("amaters_net_errors_total"),
581            "output must contain amaters_net_errors_total"
582        );
583        assert!(
584            prom.contains("amaters_net_method_requests_total"),
585            "output must contain per-method counter"
586        );
587        // Validate a specific counter value.
588        assert!(
589            prom.contains("amaters_net_requests_total 3"),
590            "total requests should be 3"
591        );
592        assert!(
593            prom.contains("amaters_net_errors_total 1"),
594            "total errors should be 1"
595        );
596    }
597
598    /// Verify that `MetricsService` correctly wraps a service via the layer.
599    #[tokio::test]
600    async fn test_metrics_layer_wraps_service() {
601        let metrics = NetMetrics::new();
602        let layer = MetricsLayer::new(Arc::clone(&metrics));
603        let mut svc = layer.layer(OkService);
604
605        svc.call(make_req("/pkg.Svc/Method"))
606            .await
607            .expect("should succeed");
608
609        let prom = metrics.to_prometheus();
610        assert!(
611            prom.contains("/pkg.Svc/Method"),
612            "method should appear in Prometheus output"
613        );
614    }
615
616    /// Verify that latency bucket boundaries are correct.
617    #[test]
618    fn test_latency_bucket_boundaries() {
619        let m = MethodMetrics::new();
620
621        // A 1ms observation should land in bucket 0 (le=1) and above.
622        m.record(1, false);
623        assert_eq!(m.bucket(0), 1, "le=1 should catch 1ms");
624        assert_eq!(m.bucket(7), 1, "+Inf must always count");
625
626        // A 0ms observation should land in all buckets.
627        m.record(0, false);
628        for i in 0..8 {
629            let expected = 2u64;
630            assert_eq!(
631                m.bucket(i),
632                expected,
633                "all buckets should be 2 after recording 0ms and 1ms (bucket={i})"
634            );
635        }
636    }
637
638    /// Verify that error tracking works correctly.
639    #[test]
640    fn test_metrics_error_counting() {
641        let m = MethodMetrics::new();
642        m.record(10, true);
643        m.record(20, false);
644        m.record(30, true);
645        assert_eq!(m.requests(), 3);
646        assert_eq!(m.errors(), 2);
647    }
648
649    // ── New metric tests (Item 3) ─────────────────────────────────────────────
650
651    /// Active-requests gauge increments when a request enters the service.
652    #[tokio::test]
653    async fn test_active_requests_gauge_increments_during_request() {
654        let metrics = NetMetrics::new();
655        // Enter a request manually to inspect the gauge.
656        let guard = metrics.enter_request();
657        assert_eq!(
658            metrics.active_requests.load(Ordering::Relaxed),
659            1,
660            "active_requests should be 1 after entering"
661        );
662        drop(guard);
663    }
664
665    /// Active-requests gauge decrements when the request completes.
666    #[tokio::test]
667    async fn test_active_requests_gauge_decrements_on_completion() {
668        let metrics = NetMetrics::new();
669        {
670            let _guard = metrics.enter_request();
671            assert_eq!(metrics.active_requests.load(Ordering::Relaxed), 1);
672        }
673        // Guard dropped.
674        assert_eq!(
675            metrics.active_requests.load(Ordering::Relaxed),
676            0,
677            "active_requests should be 0 after guard is dropped"
678        );
679    }
680
681    /// bytes_sent_total counter accumulates correctly.
682    #[test]
683    fn test_bytes_sent_counter_records() {
684        let metrics = NetMetrics::new();
685        metrics.add_bytes_sent(100);
686        metrics.add_bytes_sent(200);
687        assert_eq!(
688            metrics.bytes_sent_total.load(Ordering::Relaxed),
689            300,
690            "bytes_sent_total should be 300"
691        );
692    }
693
694    /// bytes_received_total counter accumulates correctly.
695    #[test]
696    fn test_bytes_received_counter_records() {
697        let metrics = NetMetrics::new();
698        metrics.add_bytes_received(512);
699        metrics.add_bytes_received(512);
700        assert_eq!(
701            metrics.bytes_received_total.load(Ordering::Relaxed),
702            1024,
703            "bytes_received_total should be 1024"
704        );
705    }
706
707    /// RTT histogram records observations into the right buckets.
708    #[test]
709    fn test_rtt_histogram_records() {
710        let metrics = NetMetrics::new();
711        // 5 ms → lands in buckets for le=5, le=10, ..., le=1000, +Inf
712        metrics.record_rtt(5);
713
714        // Bucket 0 = le=1ms; 5ms should NOT be in it.
715        assert_eq!(
716            metrics.rtt_histogram[0].load(Ordering::Relaxed),
717            0,
718            "le=1 bucket should be 0 for 5ms observation"
719        );
720        // Bucket 1 = le=5ms; 5ms should land here.
721        assert_eq!(
722            metrics.rtt_histogram[1].load(Ordering::Relaxed),
723            1,
724            "le=5 bucket should be 1 for 5ms observation"
725        );
726        // +Inf (index 7) always increments.
727        assert_eq!(
728            metrics.rtt_histogram[7].load(Ordering::Relaxed),
729            1,
730            "+Inf bucket should be 1"
731        );
732    }
733
734    /// Prometheus output includes all four new metric families.
735    #[test]
736    fn test_prometheus_output_includes_new_metrics() {
737        let metrics = NetMetrics::new();
738        metrics.add_bytes_sent(42);
739        metrics.add_bytes_received(24);
740        metrics.record_rtt(10);
741        let _ = metrics.enter_request(); // don't drop — keep gauge at 1 momentarily
742
743        let prom = metrics.to_prometheus();
744
745        assert!(
746            prom.contains("amaters_net_active_requests"),
747            "output must contain active_requests"
748        );
749        assert!(
750            prom.contains("amaters_net_bytes_sent_total"),
751            "output must contain bytes_sent_total"
752        );
753        assert!(
754            prom.contains("amaters_net_bytes_received_total"),
755            "output must contain bytes_received_total"
756        );
757        assert!(
758            prom.contains("amaters_net_rtt_bucket"),
759            "output must contain rtt_bucket"
760        );
761        assert!(
762            prom.contains("amaters_net_bytes_sent_total 42"),
763            "bytes_sent_total should be 42"
764        );
765        assert!(
766            prom.contains("amaters_net_bytes_received_total 24"),
767            "bytes_received_total should be 24"
768        );
769    }
770
771    /// Drop guard decrements even when the guard is not explicitly dropped
772    /// (simulates a panic path via std::mem::drop).
773    #[test]
774    fn test_active_requests_exception_safe() {
775        let metrics = NetMetrics::new();
776        {
777            let guard = metrics.enter_request();
778            assert_eq!(metrics.active_requests.load(Ordering::Relaxed), 1);
779            // Simulate early exit (panic / error propagation) by dropping early.
780            drop(guard);
781        }
782        assert_eq!(
783            metrics.active_requests.load(Ordering::Relaxed),
784            0,
785            "active_requests must be 0 after drop, even on early exit"
786        );
787    }
788
789    // ── Prometheus HTTP endpoint tests ────────────────────────────────────────
790
791    /// Spawn the metrics server on an ephemeral port and verify `GET /metrics`
792    /// returns HTTP 200 with `Content-Type: text/plain`.
793    ///
794    /// This test uses a raw `TcpStream` so no extra HTTP client dependency is
795    /// needed.
796    #[tokio::test]
797    async fn test_prometheus_endpoint_returns_200() {
798        use tokio::io::{AsyncReadExt, AsyncWriteExt};
799
800        let metrics = NetMetrics::new();
801        // Port 0 → OS assigns an ephemeral port.
802        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
803            .await
804            .expect("should bind to ephemeral port");
805        let addr = listener
806            .local_addr()
807            .expect("should have local addr after bind");
808
809        // Hand the already-bound listener to a custom task so we control the addr.
810        let app = axum::Router::new()
811            .route("/metrics", axum::routing::get(metrics_handler))
812            .with_state(Arc::clone(&metrics));
813        let _handle = tokio::spawn(async move {
814            if let Err(e) = axum::serve(listener, app).await {
815                tracing::warn!("test metrics server error: {}", e);
816            }
817        });
818
819        // Give the task a moment to accept connections.
820        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
821
822        let mut stream = tokio::net::TcpStream::connect(addr)
823            .await
824            .expect("should connect to metrics server");
825
826        stream
827            .write_all(b"GET /metrics HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
828            .await
829            .expect("should write request");
830
831        let mut response = Vec::new();
832        stream
833            .read_to_end(&mut response)
834            .await
835            .expect("should read response");
836
837        let response_str = String::from_utf8_lossy(&response);
838        assert!(
839            response_str.starts_with("HTTP/1.1 200"),
840            "expected HTTP 200, got: {}",
841            &response_str[..response_str.find('\r').unwrap_or(response_str.len())]
842        );
843        assert!(
844            response_str.contains("text/plain"),
845            "expected text/plain Content-Type"
846        );
847    }
848
849    /// Unit test (no network): verify that `to_prometheus()` output contains
850    /// the mandatory metric families expected by Prometheus scrapers.
851    #[test]
852    fn test_prometheus_metrics_format_contains_required_families() {
853        let metrics = NetMetrics::new();
854        metrics.record_request("/amaters.AqlService/Query", 10, false);
855        metrics.add_bytes_sent(1024);
856        let _ = metrics.enter_request(); // bumps active_requests to 1
857
858        let prom = metrics.to_prometheus();
859
860        assert!(
861            prom.contains("amaters_net_requests_total"),
862            "must contain amaters_net_requests_total"
863        );
864        assert!(
865            prom.contains("amaters_net_active_requests"),
866            "must contain amaters_net_active_requests"
867        );
868        assert!(
869            prom.contains("amaters_net_requests_total 1"),
870            "must report exactly 1 request after one recording"
871        );
872    }
873}