use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
use tower_layer::Layer;
use tower_service::Service;
const LATENCY_BUCKETS_MS: [u64; 7] = [1, 5, 10, 50, 100, 500, 1_000];
pub struct MethodMetrics {
requests_total: AtomicU64,
errors_total: AtomicU64,
latency_buckets: [AtomicU64; 8],
}
impl MethodMetrics {
fn new() -> Self {
Self {
requests_total: AtomicU64::new(0),
errors_total: AtomicU64::new(0),
latency_buckets: [
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
],
}
}
fn record(&self, duration_ms: u64, is_error: bool) {
self.requests_total.fetch_add(1, Ordering::Relaxed);
if is_error {
self.errors_total.fetch_add(1, Ordering::Relaxed);
}
for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
if duration_ms <= bound {
self.latency_buckets[idx].fetch_add(1, Ordering::Relaxed);
}
}
self.latency_buckets[7].fetch_add(1, Ordering::Relaxed);
}
fn requests(&self) -> u64 {
self.requests_total.load(Ordering::Relaxed)
}
fn errors(&self) -> u64 {
self.errors_total.load(Ordering::Relaxed)
}
fn bucket(&self, idx: usize) -> u64 {
self.latency_buckets[idx].load(Ordering::Relaxed)
}
}
pub struct NetMetrics {
methods: Mutex<HashMap<String, Arc<MethodMetrics>>>,
total_requests: AtomicU64,
total_errors: AtomicU64,
}
impl NetMetrics {
pub fn new() -> Arc<Self> {
Arc::new(Self {
methods: Mutex::new(HashMap::new()),
total_requests: AtomicU64::new(0),
total_errors: AtomicU64::new(0),
})
}
pub fn record_request(&self, method: &str, duration_ms: u64, is_error: bool) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
if is_error {
self.total_errors.fetch_add(1, Ordering::Relaxed);
}
let method_metrics = {
let mut map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
Arc::clone(
map.entry(method.to_owned())
.or_insert_with(|| Arc::new(MethodMetrics::new())),
)
};
method_metrics.record(duration_ms, is_error);
}
pub fn to_prometheus(&self) -> String {
let mut out = String::with_capacity(4096);
let total_req = self.total_requests.load(Ordering::Relaxed);
let total_err = self.total_errors.load(Ordering::Relaxed);
out.push_str("# HELP amaters_net_requests_total Total gRPC requests\n");
out.push_str("# TYPE amaters_net_requests_total counter\n");
out.push_str(&format!("amaters_net_requests_total {total_req}\n"));
out.push_str("# HELP amaters_net_errors_total Total gRPC errors\n");
out.push_str("# TYPE amaters_net_errors_total counter\n");
out.push_str(&format!("amaters_net_errors_total {total_err}\n"));
let map = self.methods.lock().unwrap_or_else(|p| p.into_inner());
let mut methods: Vec<(&String, &Arc<MethodMetrics>)> = map.iter().collect();
methods.sort_by_key(|(k, _)| k.as_str());
for (method, m) in &methods {
let label = format!("{{method=\"{method}\"}}");
out.push_str(&format!(
"amaters_net_method_requests_total{label} {}\n",
m.requests()
));
out.push_str(&format!(
"amaters_net_method_errors_total{label} {}\n",
m.errors()
));
out.push_str(&format!(
"# HELP amaters_net_request_duration_ms{label} Request latency histogram\n"
));
out.push_str("# TYPE amaters_net_request_duration_ms histogram\n");
for (idx, &bound) in LATENCY_BUCKETS_MS.iter().enumerate() {
out.push_str(&format!(
"amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"{bound}\"}} {}\n",
m.bucket(idx)
));
}
out.push_str(&format!(
"amaters_net_request_duration_ms_bucket{{method=\"{method}\",le=\"+Inf\"}} {}\n",
m.bucket(7)
));
}
out
}
}
#[derive(Clone)]
pub struct MetricsLayer {
metrics: Arc<NetMetrics>,
}
impl MetricsLayer {
pub fn new(metrics: Arc<NetMetrics>) -> Self {
Self { metrics }
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, inner: S) -> Self::Service {
MetricsService {
inner,
metrics: Arc::clone(&self.metrics),
}
}
}
#[derive(Clone)]
pub struct MetricsService<S> {
inner: S,
metrics: Arc<NetMetrics>,
}
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for MetricsService<S>
where
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Response = http::Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
let method = req.uri().path().to_owned();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
let metrics = Arc::clone(&self.metrics);
let start = Instant::now();
Box::pin(async move {
let result = inner.call(req).await;
let elapsed_ms = start.elapsed().as_millis() as u64;
let is_error = result.is_err();
metrics.record_request(&method, elapsed_ms, is_error);
result
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use tower_service::Service as _;
#[derive(Clone)]
struct OkService;
impl Service<http::Request<String>> for OkService {
type Response = http::Response<String>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<String>) -> Self::Future {
Box::pin(async { Ok(http::Response::new(String::new())) })
}
}
fn make_req(path: &str) -> http::Request<String> {
http::Request::builder()
.uri(path)
.body(String::new())
.expect("request builder should not fail")
}
#[tokio::test]
async fn test_metrics_counter_increments() {
let metrics = NetMetrics::new();
let layer = MetricsLayer::new(Arc::clone(&metrics));
let mut svc = layer.layer(OkService);
for _ in 0..3 {
svc.call(make_req("/amaters.AqlService/ExecuteQuery"))
.await
.expect("service call should not error");
}
assert_eq!(
metrics.total_requests.load(Ordering::Relaxed),
3,
"total_requests should be 3 after 3 calls"
);
}
#[tokio::test]
async fn test_metrics_latency_histogram_records() {
let metrics = NetMetrics::new();
metrics.record_request("/test/Method", 10, false);
let map = metrics
.methods
.lock()
.expect("mutex should not be poisoned");
let m = map
.get("/test/Method")
.expect("method entry should exist after recording");
assert_eq!(
m.bucket(2),
1,
"le=10ms bucket should be 1 for a 10ms observation"
);
assert_eq!(m.bucket(7), 1, "+Inf bucket should be 1");
assert_eq!(
m.bucket(0),
0,
"le=1ms bucket should be 0 for a 10ms observation"
);
}
#[tokio::test]
async fn test_metrics_prometheus_text_format() {
let metrics = NetMetrics::new();
metrics.record_request("/amaters.AqlService/ExecuteQuery", 5, false);
metrics.record_request("/amaters.AqlService/ExecuteQuery", 50, false);
metrics.record_request("/amaters.AqlService/ExecuteQuery", 200, true);
let prom = metrics.to_prometheus();
assert!(
prom.contains("amaters_net_requests_total"),
"output must contain amaters_net_requests_total"
);
assert!(
prom.contains("amaters_net_errors_total"),
"output must contain amaters_net_errors_total"
);
assert!(
prom.contains("amaters_net_method_requests_total"),
"output must contain per-method counter"
);
assert!(
prom.contains("amaters_net_requests_total 3"),
"total requests should be 3"
);
assert!(
prom.contains("amaters_net_errors_total 1"),
"total errors should be 1"
);
}
#[tokio::test]
async fn test_metrics_layer_wraps_service() {
let metrics = NetMetrics::new();
let layer = MetricsLayer::new(Arc::clone(&metrics));
let mut svc = layer.layer(OkService);
svc.call(make_req("/pkg.Svc/Method"))
.await
.expect("should succeed");
let prom = metrics.to_prometheus();
assert!(
prom.contains("/pkg.Svc/Method"),
"method should appear in Prometheus output"
);
}
#[test]
fn test_latency_bucket_boundaries() {
let m = MethodMetrics::new();
m.record(1, false);
assert_eq!(m.bucket(0), 1, "le=1 should catch 1ms");
assert_eq!(m.bucket(7), 1, "+Inf must always count");
m.record(0, false);
for i in 0..8 {
let expected = 2u64;
assert_eq!(
m.bucket(i),
expected,
"all buckets should be 2 after recording 0ms and 1ms (bucket={i})"
);
}
}
#[test]
fn test_metrics_error_counting() {
let m = MethodMetrics::new();
m.record(10, true);
m.record(20, false);
m.record(30, true);
assert_eq!(m.requests(), 3);
assert_eq!(m.errors(), 2);
}
}