use std::{sync::Arc, time::Duration};
use prometheus_client::{
encoding::{EncodeLabelSet, text::encode},
metrics::{counter::Counter, family::Family, histogram::Histogram},
registry::Registry,
};
use crate::observability::metrics::DURATION_BUCKETS;
#[derive(Debug, Clone)]
pub struct PrometheusClientMetricsRegistry {
registry: Arc<Registry>,
http_requests: Family<HttpLabels, Counter>,
http_duration: Family<HttpLabels, Histogram, HistogramConstructor>,
rpc_requests: Family<RpcLabels, Counter>,
rpc_duration: Family<RpcLabels, Histogram, HistogramConstructor>,
sql_queries: Family<SqlLabels, Counter>,
sql_duration: Family<SqlLabels, Histogram, HistogramConstructor>,
redis_commands: Family<RedisLabels, Counter>,
redis_duration: Family<RedisLabels, Histogram, HistogramConstructor>,
}
impl PrometheusClientMetricsRegistry {
pub fn new() -> Self {
let mut registry = Registry::default();
let http_requests = Family::<HttpLabels, Counter>::default();
let http_duration = duration_family();
let rpc_requests = Family::<RpcLabels, Counter>::default();
let rpc_duration = duration_family();
let sql_queries = Family::<SqlLabels, Counter>::default();
let sql_duration = duration_family();
let redis_commands = Family::<RedisLabels, Counter>::default();
let redis_duration = duration_family();
registry.register(
"rs_zero_http_requests",
"Total number of HTTP requests.",
http_requests.clone(),
);
registry.register(
"rs_zero_http_request_duration_seconds",
"HTTP request duration.",
http_duration.clone(),
);
registry.register(
"rs_zero_rpc_requests",
"Total number of gRPC requests.",
rpc_requests.clone(),
);
registry.register(
"rs_zero_rpc_request_duration_seconds",
"gRPC request duration.",
rpc_duration.clone(),
);
registry.register(
"rs_zero_sql_queries",
"Total number of SQL queries.",
sql_queries.clone(),
);
registry.register(
"rs_zero_sql_query_duration_seconds",
"SQL query duration.",
sql_duration.clone(),
);
registry.register(
"rs_zero_redis_commands",
"Total number of Redis commands.",
redis_commands.clone(),
);
registry.register(
"rs_zero_redis_command_duration_seconds",
"Redis command duration.",
redis_duration.clone(),
);
Self {
registry: Arc::new(registry),
http_requests,
http_duration,
rpc_requests,
rpc_duration,
sql_queries,
sql_duration,
redis_commands,
redis_duration,
}
}
pub fn record_http_request(&self, method: &str, route: &str, status: u16, duration: Duration) {
let labels = HttpLabels {
method: method.to_string(),
route: route.to_string(),
status,
};
self.http_requests.get_or_create(&labels).inc();
self.http_duration
.get_or_create(&labels)
.observe(duration.as_secs_f64());
}
pub fn record_rpc_request(&self, service: &str, method: &str, code: &str, duration: Duration) {
let labels = RpcLabels {
service: service.to_string(),
method: method.to_string(),
code: code.to_string(),
};
self.rpc_requests.get_or_create(&labels).inc();
self.rpc_duration
.get_or_create(&labels)
.observe(duration.as_secs_f64());
}
pub fn record_sql_query(
&self,
db_kind: &str,
repository: &str,
method: &str,
operation: &str,
result: &str,
duration: Duration,
) {
let labels = SqlLabels {
db_kind: db_kind.to_string(),
repository: repository.to_string(),
method: method.to_string(),
operation: operation.to_string(),
result: result.to_string(),
};
self.sql_queries.get_or_create(&labels).inc();
self.sql_duration
.get_or_create(&labels)
.observe(duration.as_secs_f64());
}
pub fn record_redis_command(
&self,
command: &str,
shard: &str,
result: &str,
duration: Duration,
) {
let labels = RedisLabels {
command: command.to_string(),
shard: shard.to_string(),
result: result.to_string(),
};
self.redis_commands.get_or_create(&labels).inc();
self.redis_duration
.get_or_create(&labels)
.observe(duration.as_secs_f64());
}
pub fn render(&self) -> String {
let mut output = String::new();
encode(&mut output, &self.registry).expect("prometheus-client text encoding");
output
}
}
impl Default for PrometheusClientMetricsRegistry {
fn default() -> Self {
Self::new()
}
}
type HistogramConstructor = fn() -> Histogram;
fn duration_family<L>() -> Family<L, Histogram, HistogramConstructor>
where
L: Clone + std::hash::Hash + Eq,
{
Family::new_with_constructor(duration_histogram)
}
fn duration_histogram() -> Histogram {
Histogram::new(DURATION_BUCKETS)
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct HttpLabels {
method: String,
route: String,
status: u16,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct RpcLabels {
service: String,
method: String,
code: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct SqlLabels {
db_kind: String,
repository: String,
method: String,
operation: String,
result: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct RedisLabels {
command: String,
shard: String,
result: String,
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::PrometheusClientMetricsRegistry;
#[test]
fn prometheus_client_registry_records_low_cardinality_metrics() {
let registry = PrometheusClientMetricsRegistry::new();
registry.record_http_request("GET", "/users/{id}", 200, Duration::from_millis(5));
registry.record_rpc_request("hello.Hello", "Say", "Ok", Duration::from_millis(3));
registry.record_sql_query(
"sqlite",
"users",
"find_by_id",
"select",
"success",
Duration::from_millis(1),
);
registry.record_redis_command("GET", "primary", "success", Duration::from_millis(1));
let text = registry.render();
assert!(text.contains("rs_zero_http_requests_total"));
assert!(text.contains("route=\"/users/{id}\""));
assert!(text.contains("rs_zero_rpc_requests_total"));
assert!(text.contains("rs_zero_sql_queries_total"));
assert!(text.contains("rs_zero_redis_commands_total"));
assert!(!text.contains("/users/42"));
assert!(!text.contains("users:42"));
assert!(!text.contains("password"));
}
}