use once_cell::sync::Lazy;
use prometheus::{
register_histogram_vec, register_int_counter_vec, Encoder, HistogramVec, IntCounterVec,
TextEncoder,
};
use std::time::Instant;
const LATENCY_BUCKETS: &[f64] = &[
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
];
static METRICS: Lazy<GatewayMetrics> = Lazy::new(GatewayMetrics::new);
#[derive(Clone)]
pub struct GatewayMetrics {
pub graphql_requests: IntCounterVec,
pub graphql_duration: HistogramVec,
pub graphql_errors: IntCounterVec,
pub grpc_requests: IntCounterVec,
pub grpc_duration: HistogramVec,
pub grpc_errors: IntCounterVec,
}
impl GatewayMetrics {
pub fn new() -> Self {
Self {
graphql_requests: register_int_counter_vec!(
"graphql_requests_total",
"Total number of GraphQL requests",
&["operation"]
)
.expect("metric can be created"),
graphql_duration: register_histogram_vec!(
"graphql_request_duration_seconds",
"GraphQL request duration in seconds",
&["operation"],
LATENCY_BUCKETS.to_vec()
)
.expect("metric can be created"),
graphql_errors: register_int_counter_vec!(
"graphql_errors_total",
"Total number of GraphQL errors",
&["error_type"]
)
.expect("metric can be created"),
grpc_requests: register_int_counter_vec!(
"grpc_backend_requests_total",
"Total number of gRPC backend requests",
&["service", "method"]
)
.expect("metric can be created"),
grpc_duration: register_histogram_vec!(
"grpc_backend_duration_seconds",
"gRPC backend request duration in seconds",
&["service", "method"],
LATENCY_BUCKETS.to_vec()
)
.expect("metric can be created"),
grpc_errors: register_int_counter_vec!(
"grpc_backend_errors_total",
"Total number of gRPC backend errors",
&["service", "method", "code"]
)
.expect("metric can be created"),
}
}
pub fn global() -> &'static Self {
&METRICS
}
pub fn record_graphql_request(&self, operation: &str) {
self.graphql_requests.with_label_values(&[operation]).inc();
}
pub fn record_graphql_duration(&self, operation: &str, duration_secs: f64) {
self.graphql_duration
.with_label_values(&[operation])
.observe(duration_secs);
}
pub fn record_graphql_error(&self, error_type: &str) {
self.graphql_errors.with_label_values(&[error_type]).inc();
}
pub fn record_grpc_request(&self, service: &str, method: &str) {
self.grpc_requests
.with_label_values(&[service, method])
.inc();
}
pub fn record_grpc_duration(&self, service: &str, method: &str, duration_secs: f64) {
self.grpc_duration
.with_label_values(&[service, method])
.observe(duration_secs);
}
pub fn record_grpc_error(&self, service: &str, method: &str, code: &str) {
self.grpc_errors
.with_label_values(&[service, method, code])
.inc();
}
pub fn requests_total(&self) -> u64 {
let mut total = 0;
for operation in &["query", "mutation", "subscription"] {
total += self.graphql_requests.with_label_values(&[operation]).get();
}
total
}
pub fn render(&self) -> String {
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
let mut buffer = Vec::new();
encoder
.encode(&metric_families, &mut buffer)
.expect("encoding metrics");
String::from_utf8(buffer).expect("valid utf8")
}
}
impl Default for GatewayMetrics {
fn default() -> Self {
Self::new()
}
}
pub struct RequestTimer {
start: Instant,
operation: String,
metrics: &'static GatewayMetrics,
}
impl RequestTimer {
pub fn new(operation: impl Into<String>) -> Self {
let operation = operation.into();
let metrics = GatewayMetrics::global();
metrics.record_graphql_request(&operation);
Self {
start: Instant::now(),
operation,
metrics,
}
}
pub fn record_error(&self, error_type: &str) {
self.metrics.record_graphql_error(error_type);
}
}
impl Drop for RequestTimer {
fn drop(&mut self) {
let duration = self.start.elapsed().as_secs_f64();
self.metrics
.record_graphql_duration(&self.operation, duration);
}
}
pub struct GrpcTimer {
start: Instant,
service: String,
method: String,
metrics: &'static GatewayMetrics,
}
impl GrpcTimer {
pub fn new(service: impl Into<String>, method: impl Into<String>) -> Self {
let service = service.into();
let method = method.into();
let metrics = GatewayMetrics::global();
metrics.record_grpc_request(&service, &method);
Self {
start: Instant::now(),
service,
method,
metrics,
}
}
pub fn record_error(&self, code: &str) {
self.metrics
.record_grpc_error(&self.service, &self.method, code);
}
}
impl Drop for GrpcTimer {
fn drop(&mut self) {
let duration = self.start.elapsed().as_secs_f64();
self.metrics
.record_grpc_duration(&self.service, &self.method, duration);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metrics_registry() {
let metrics = GatewayMetrics::global();
let initial_count = metrics.graphql_requests.with_label_values(&["query"]).get();
metrics.record_graphql_request("query");
let new_count = metrics.graphql_requests.with_label_values(&["query"]).get();
assert!(
new_count > initial_count,
"Expected query count to increase"
);
}
#[test]
fn test_graphql_metrics() {
let metrics = GatewayMetrics::global();
let initial_q = metrics.graphql_requests.with_label_values(&["query"]).get();
let initial_m = metrics
.graphql_requests
.with_label_values(&["mutation"])
.get();
metrics.record_graphql_request("query");
metrics.record_graphql_request("mutation");
assert!(metrics.graphql_requests.with_label_values(&["query"]).get() > initial_q);
assert!(
metrics
.graphql_requests
.with_label_values(&["mutation"])
.get()
> initial_m
);
let initial_err_count = metrics
.graphql_errors
.with_label_values(&["validation"])
.get();
metrics.record_graphql_error("validation");
assert!(
metrics
.graphql_errors
.with_label_values(&["validation"])
.get()
> initial_err_count
);
metrics.record_graphql_duration("query", 0.5);
let count = metrics
.graphql_duration
.with_label_values(&["query"])
.get_sample_count();
assert!(count > 0);
}
#[test]
fn test_grpc_metrics() {
let metrics = GatewayMetrics::global();
let svc = "TestService";
let method = "TestMethod";
let initial_reqs = metrics
.grpc_requests
.with_label_values(&[svc, method])
.get();
metrics.record_grpc_request(svc, method);
assert!(
metrics
.grpc_requests
.with_label_values(&[svc, method])
.get()
> initial_reqs
);
metrics.record_grpc_duration(svc, method, 0.1);
let initial_errs = metrics
.grpc_errors
.with_label_values(&[svc, method, "INTERNAL"])
.get();
metrics.record_grpc_error(svc, method, "INTERNAL");
assert!(
metrics
.grpc_errors
.with_label_values(&[svc, method, "INTERNAL"])
.get()
> initial_errs
);
}
#[test]
fn test_request_timer_workflow() {
let metrics = GatewayMetrics::global();
let start_count = metrics
.graphql_requests
.with_label_values(&["subscription"])
.get();
{
let timer = RequestTimer::new("subscription");
timer.record_error("timeout");
}
let end_count = metrics
.graphql_requests
.with_label_values(&["subscription"])
.get();
assert!(end_count > start_count);
assert!(metrics.graphql_errors.with_label_values(&["timeout"]).get() >= 1);
}
#[test]
fn test_grpc_timer_workflow() {
let metrics = GatewayMetrics::global();
let svc = "TimerService";
let method = "TimerMethod";
let start_count = metrics
.grpc_requests
.with_label_values(&[svc, method])
.get();
{
let timer = GrpcTimer::new(svc, method);
timer.record_error("UNAVAILABLE");
}
assert!(
metrics
.grpc_requests
.with_label_values(&[svc, method])
.get()
> start_count
);
assert!(
metrics
.grpc_errors
.with_label_values(&[svc, method, "UNAVAILABLE"])
.get()
>= 1
);
}
#[test]
fn test_metrics_rendering_output() {
let metrics = GatewayMetrics::global();
metrics.record_graphql_request("render_test");
let output = metrics.render();
assert!(output.contains("graphql_requests_total"));
assert!(output.contains("render_test"));
assert!(output.contains("TYPE graphql_requests_total counter"));
}
#[test]
fn test_requests_total_aggregator() {
let metrics = GatewayMetrics::global();
let before_total = metrics.requests_total();
metrics.record_graphql_request("query");
metrics.record_graphql_request("mutation");
metrics.record_graphql_request("subscription");
let after_total = metrics.requests_total();
assert!(after_total >= before_total + 3);
}
}