use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
use tower_layer::Layer;
use tower_service::Service;
const LATENCY_BUCKETS_MS: [u64; 7] = [1, 5, 10, 50, 100, 500, 1_000];
pub struct MethodMetrics {
requests_total: AtomicU64,
errors_total: AtomicU64,
latency_buckets: [AtomicU64; 8],
}
impl MethodMetrics {
fn new() -> Self {
Self {
requests_total: AtomicU64::new(0),
errors_total: AtomicU64::new(0),
latency_buckets: [
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
],
}
}
fn record(&self, duration_ms: u64, is_error: bool) {
self.requests_total.fetch_add(1, Ordering::Relaxed);
if is_error {
self.errors_total.fetch_add(1, Ordering::Relaxed);
}
for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
if duration_ms <= bound {
self.latency_buckets[idx].fetch_add(1, Ordering::Relaxed);
}
}
self.latency_buckets[7].fetch_add(1, Ordering::Relaxed);
}
fn requests(&self) -> u64 {
self.requests_total.load(Ordering::Relaxed)
}
fn errors(&self) -> u64 {
self.errors_total.load(Ordering::Relaxed)
}
fn bucket(&self, idx: usize) -> u64 {
self.latency_buckets[idx].load(Ordering::Relaxed)
}
}
pub struct ActiveRequestGuard<'a>(&'a AtomicU64);
impl Drop for ActiveRequestGuard<'_> {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
pub struct NetMetrics {
methods: Mutex<HashMap<String, Arc<MethodMetrics>>>,
total_requests: AtomicU64,
total_errors: AtomicU64,
pub active_requests: AtomicU64,
pub bytes_sent_total: AtomicU64,
pub bytes_received_total: AtomicU64,
pub rtt_histogram: [AtomicU64; 8],
}
impl NetMetrics {
pub fn new() -> Arc<Self> {
Arc::new(Self {
methods: Mutex::new(HashMap::new()),
total_requests: AtomicU64::new(0),
total_errors: AtomicU64::new(0),
active_requests: AtomicU64::new(0),
bytes_sent_total: AtomicU64::new(0),
bytes_received_total: AtomicU64::new(0),
rtt_histogram: [
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
],
})
}
pub fn enter_request(&self) -> ActiveRequestGuard<'_> {
self.active_requests.fetch_add(1, Ordering::Relaxed);
ActiveRequestGuard(&self.active_requests)
}
pub fn add_bytes_received(&self, bytes: u64) {
self.bytes_received_total
.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_bytes_sent(&self, bytes: u64) {
self.bytes_sent_total.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_rtt(&self, rtt_ms: u64) {
for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
if rtt_ms <= bound {
self.rtt_histogram[idx].fetch_add(1, Ordering::Relaxed);
}
}
self.rtt_histogram[7].fetch_add(1, Ordering::Relaxed);
}
pub fn record_request(&self, method: &str, duration_ms: u64, is_error: bool) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if is_error {
self.total_errors.fetch_add(1, Ordering::Relaxed);
}
let method_metrics = {
let mut map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
Arc::clone(
map.entry(method.to_owned())
.or_insert_with(|| Arc::new(MethodMetrics::new())),
)
};
method_metrics.record(duration_ms, is_error);
}
pub fn to_prometheus(&self) -> String {
let mut out = String::with_capacity(8192);
let total_req = self.total_requests.load(Ordering::Relaxed);
let total_err = self.total_errors.load(Ordering::Relaxed);
out.push_str("# HELP amaters_net_requests_total Total gRPC requests\n");
out.push_str("# TYPE amaters_net_requests_total counter\n");
out.push_str(&format!("amaters_net_requests_total {total_req}\n"));
out.push_str("# HELP amaters_net_errors_total Total gRPC errors\n");
out.push_str("# TYPE amaters_net_errors_total counter\n");
out.push_str(&format!("amaters_net_errors_total {total_err}\n"));
let active = self.active_requests.load(Ordering::Relaxed);
out.push_str("# HELP amaters_net_active_requests Currently active requests\n");
out.push_str("# TYPE amaters_net_active_requests gauge\n");
out.push_str(&format!("amaters_net_active_requests {active}\n"));
let bytes_sent = self.bytes_sent_total.load(Ordering::Relaxed);
out.push_str("# HELP amaters_net_bytes_sent_total Total bytes sent\n");
out.push_str("# TYPE amaters_net_bytes_sent_total counter\n");
out.push_str(&format!("amaters_net_bytes_sent_total {bytes_sent}\n"));
let bytes_recv = self.bytes_received_total.load(Ordering::Relaxed);
out.push_str("# HELP amaters_net_bytes_received_total Total bytes received\n");
out.push_str("# TYPE amaters_net_bytes_received_total counter\n");
out.push_str(&format!("amaters_net_bytes_received_total {bytes_recv}\n"));
out.push_str("# HELP amaters_net_rtt_bucket RTT histogram\n");
out.push_str("# TYPE amaters_net_rtt_bucket histogram\n");
for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
let count = self.rtt_histogram[idx].load(Ordering::Relaxed);
out.push_str(&format!(
"amaters_net_rtt_bucket{{le=\"{bound}\"}} {count}\n"
));
}
let inf_count = self.rtt_histogram[7].load(Ordering::Relaxed);
out.push_str(&format!(
"amaters_net_rtt_bucket{{le=\"+Inf\"}} {inf_count}\n"
));
let map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
let mut methods: Vec<(&String, &Arc<MethodMetrics>)> = map.iter().collect();
methods.sort_by_key(|(k, _)| k.as_str());
for (method, m) in &methods {
let label = format!("{{method=\"{method}\"}}");
out.push_str(&format!(
"amaters_net_method_requests_total{label} {}\n",
m.requests()
));
out.push_str(&format!(
"amaters_net_method_errors_total{label} {}\n",
m.errors()
));
out.push_str(&format!(
"# HELP amaters_net_request_duration_ms{label} Request latency histogram\n"
));
out.push_str("# TYPE amaters_net_request_duration_ms histogram\n");
for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
out.push_str(&format!(
"amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"{bound}\"}} {}\n",
m.bucket(idx)
));
}
out.push_str(&format!(
"amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"+Inf\"}} {}\n",
m.bucket(7)
));
}
out
}
}
async fn metrics_handler(
axum::extract::State(metrics): axum::extract::State<Arc<NetMetrics>>,
) -> (
axum::http::StatusCode,
[(axum::http::HeaderName, &'static str); 1],
String,
) {
let body = metrics.to_prometheus();
(
axum::http::StatusCode::OK,
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
body,
)
}
pub fn spawn_metrics_server(
addr: SocketAddr,
metrics: Arc<NetMetrics>,
) -> tokio::task::JoinHandle<()> {
let app = axum::Router::new()
.route("/metrics", axum::routing::get(metrics_handler))
.with_state(metrics);
tokio::spawn(async move {
match tokio::net::TcpListener::bind(addr).await {
Ok(listener) => {
tracing::info!("Metrics server listening on {}", addr);
if let Err(e) = axum::serve(listener, app).await {
tracing::warn!("Metrics server error: {}", e);
}
}
Err(e) => {
tracing::error!("Failed to bind metrics server to {}: {}", addr, e);
}
}
})
}
#[derive(Clone)]
pub struct MetricsLayer {
metrics: Arc<NetMetrics>,
}
impl MetricsLayer {
pub fn new(metrics: Arc<NetMetrics>) -> Self {
Self { metrics }
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
MetricsService {
inner,
metrics: Arc::clone(&self.metrics),
}
}
}
#[derive(Clone)]
pub struct MetricsService<S> {
inner: S,
metrics: Arc<NetMetrics>,
}
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for MetricsService<S>
where
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ReqBody: http_body::Body + Send + 'static,
ResBody: http_body::Body + Send + 'static,
{
type Response = http::Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
let method = req.uri().path().to_owned();
let req_bytes = req
.body()
.size_hint()
.exact()
.unwrap_or_else(|| req.body().size_hint().lower());
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
let metrics = Arc::clone(&self.metrics);
let start = Instant::now();
Box::pin(async move {
metrics.add_bytes_received(req_bytes);
let _guard = metrics.enter_request();
let result = inner.call(req).await;
let elapsed_ms = start.elapsed().as_millis() as u64;
let is_error = result.is_err();
if let Ok(ref resp) = result {
let resp_bytes = resp
.body()
.size_hint()
.exact()
.unwrap_or_else(|| resp.body().size_hint().lower());
metrics.add_bytes_sent(resp_bytes);
}
metrics.record_request(&method, elapsed_ms, is_error);
metrics.record_rtt(elapsed_ms);
result
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use tower_service::Service as _;
#[derive(Clone)]
struct OkService;
impl Service<http::Request<String>> for OkService {
type Response = http::Response<String>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<String>) -> Self::Future {
Box::pin(async { Ok(http::Response::new(String::new())) })
}
}
fn make_req(path: &str) -> http::Request<String> {
http::Request::builder()
.uri(path)
.body(String::new())
.expect("request builder should not fail")
}
#[tokio::test]
async fn test_metrics_counter_increments() {
let metrics = NetMetrics::new();
let layer = MetricsLayer::new(Arc::clone(&metrics));
let mut svc = layer.layer(OkService);
for _ in 0..3 {
svc.call(make_req("/amaters.AqlService/ExecuteQuery"))
.await
.expect("service call should not error");
}
assert_eq!(
metrics.total_requests.load(Ordering::Relaxed),
3,
"total_requests should be 3 after 3 calls"
);
}
#[tokio::test]
async fn test_metrics_latency_histogram_records() {
let metrics = NetMetrics::new();
metrics.record_request("/test/Method", 10, false);
let map = metrics
.methods
.lock()
.expect("mutex should not be poisoned");
let m = map
.get("/test/Method")
.expect("method entry should exist after recording");
assert_eq!(
m.bucket(2),
1,
"le=10ms bucket should be 1 for a 10ms observation"
);
assert_eq!(m.bucket(7), 1, "+Inf bucket should be 1");
assert_eq!(
m.bucket(0),
0,
"le=1ms bucket should be 0 for a 10ms observation"
);
}
#[tokio::test]
async fn test_metrics_prometheus_text_format() {
let metrics = NetMetrics::new();
metrics.record_request("/amaters.AqlService/ExecuteQuery", 5, false);
metrics.record_request("/amaters.AqlService/ExecuteQuery", 50, false);
metrics.record_request("/amaters.AqlService/ExecuteQuery", 200, true);
let prom = metrics.to_prometheus();
assert!(
prom.contains("amaters_net_requests_total"),
"output must contain amaters_net_requests_total"
);
assert!(
prom.contains("amaters_net_errors_total"),
"output must contain amaters_net_errors_total"
);
assert!(
prom.contains("amaters_net_method_requests_total"),
"output must contain per-method counter"
);
assert!(
prom.contains("amaters_net_requests_total 3"),
"total requests should be 3"
);
assert!(
prom.contains("amaters_net_errors_total 1"),
"total errors should be 1"
);
}
#[tokio::test]
async fn test_metrics_layer_wraps_service() {
let metrics = NetMetrics::new();
let layer = MetricsLayer::new(Arc::clone(&metrics));
let mut svc = layer.layer(OkService);
svc.call(make_req("/pkg.Svc/Method"))
.await
.expect("should succeed");
let prom = metrics.to_prometheus();
assert!(
prom.contains("/pkg.Svc/Method"),
"method should appear in Prometheus output"
);
}
#[test]
fn test_latency_bucket_boundaries() {
let m = MethodMetrics::new();
m.record(1, false);
assert_eq!(m.bucket(0), 1, "le=1 should catch 1ms");
assert_eq!(m.bucket(7), 1, "+Inf must always count");
m.record(0, false);
for i in 0..8 {
let expected = 2u64;
assert_eq!(
m.bucket(i),
expected,
"all buckets should be 2 after recording 0ms and 1ms (bucket={i})"
);
}
}
#[test]
fn test_metrics_error_counting() {
let m = MethodMetrics::new();
m.record(10, true);
m.record(20, false);
m.record(30, true);
assert_eq!(m.requests(), 3);
assert_eq!(m.errors(), 2);
}
#[tokio::test]
async fn test_active_requests_gauge_increments_during_request() {
let metrics = NetMetrics::new();
let guard = metrics.enter_request();
assert_eq!(
metrics.active_requests.load(Ordering::Relaxed),
1,
"active_requests should be 1 after entering"
);
drop(guard);
}
#[tokio::test]
async fn test_active_requests_gauge_decrements_on_completion() {
let metrics = NetMetrics::new();
{
let _guard = metrics.enter_request();
assert_eq!(metrics.active_requests.load(Ordering::Relaxed), 1);
}
assert_eq!(
metrics.active_requests.load(Ordering::Relaxed),
0,
"active_requests should be 0 after guard is dropped"
);
}
#[test]
fn test_bytes_sent_counter_records() {
let metrics = NetMetrics::new();
metrics.add_bytes_sent(100);
metrics.add_bytes_sent(200);
assert_eq!(
metrics.bytes_sent_total.load(Ordering::Relaxed),
300,
"bytes_sent_total should be 300"
);
}
#[test]
fn test_bytes_received_counter_records() {
let metrics = NetMetrics::new();
metrics.add_bytes_received(512);
metrics.add_bytes_received(512);
assert_eq!(
metrics.bytes_received_total.load(Ordering::Relaxed),
1024,
"bytes_received_total should be 1024"
);
}
#[test]
fn test_rtt_histogram_records() {
let metrics = NetMetrics::new();
metrics.record_rtt(5);
assert_eq!(
metrics.rtt_histogram[0].load(Ordering::Relaxed),
0,
"le=1 bucket should be 0 for 5ms observation"
);
assert_eq!(
metrics.rtt_histogram[1].load(Ordering::Relaxed),
1,
"le=5 bucket should be 1 for 5ms observation"
);
assert_eq!(
metrics.rtt_histogram[7].load(Ordering::Relaxed),
1,
"+Inf bucket should be 1"
);
}
#[test]
fn test_prometheus_output_includes_new_metrics() {
let metrics = NetMetrics::new();
metrics.add_bytes_sent(42);
metrics.add_bytes_received(24);
metrics.record_rtt(10);
let _ = metrics.enter_request();
let prom = metrics.to_prometheus();
assert!(
prom.contains("amaters_net_active_requests"),
"output must contain active_requests"
);
assert!(
prom.contains("amaters_net_bytes_sent_total"),
"output must contain bytes_sent_total"
);
assert!(
prom.contains("amaters_net_bytes_received_total"),
"output must contain bytes_received_total"
);
assert!(
prom.contains("amaters_net_rtt_bucket"),
"output must contain rtt_bucket"
);
assert!(
prom.contains("amaters_net_bytes_sent_total 42"),
"bytes_sent_total should be 42"
);
assert!(
prom.contains("amaters_net_bytes_received_total 24"),
"bytes_received_total should be 24"
);
}
#[test]
fn test_active_requests_exception_safe() {
let metrics = NetMetrics::new();
{
let guard = metrics.enter_request();
assert_eq!(metrics.active_requests.load(Ordering::Relaxed), 1);
drop(guard);
}
assert_eq!(
metrics.active_requests.load(Ordering::Relaxed),
0,
"active_requests must be 0 after drop, even on early exit"
);
}
#[tokio::test]
async fn test_prometheus_endpoint_returns_200() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let metrics = NetMetrics::new();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("should bind to ephemeral port");
let addr = listener
.local_addr()
.expect("should have local addr after bind");
let app = axum::Router::new()
.route("/metrics", axum::routing::get(metrics_handler))
.with_state(Arc::clone(&metrics));
let _handle = tokio::spawn(async move {
if let Err(e) = axum::serve(listener, app).await {
tracing::warn!("test metrics server error: {}", e);
}
});
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
let mut stream = tokio::net::TcpStream::connect(addr)
.await
.expect("should connect to metrics server");
stream
.write_all(b"GET /metrics HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
.await
.expect("should write request");
let mut response = Vec::new();
stream
.read_to_end(&mut response)
.await
.expect("should read response");
let response_str = String::from_utf8_lossy(&response);
assert!(
response_str.starts_with("HTTP/1.1 200"),
"expected HTTP 200, got: {}",
&response_str[..response_str.find('\r').unwrap_or(response_str.len())]
);
assert!(
response_str.contains("text/plain"),
"expected text/plain Content-Type"
);
}
#[test]
fn test_prometheus_metrics_format_contains_required_families() {
let metrics = NetMetrics::new();
metrics.record_request("/amaters.AqlService/Query", 10, false);
metrics.add_bytes_sent(1024);
let _ = metrics.enter_request();
let prom = metrics.to_prometheus();
assert!(
prom.contains("amaters_net_requests_total"),
"must contain amaters_net_requests_total"
);
assert!(
prom.contains("amaters_net_active_requests"),
"must contain amaters_net_active_requests"
);
assert!(
prom.contains("amaters_net_requests_total 1"),
"must report exactly 1 request after one recording"
);
}
}