rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{sync::Arc, time::Duration};

use prometheus_client::{
    encoding::{EncodeLabelSet, text::encode},
    metrics::{counter::Counter, family::Family, histogram::Histogram},
    registry::Registry,
};

use crate::observability::metrics::DURATION_BUCKETS;

/// Optional metrics registry backed by the mature `prometheus-client` crate.
///
/// It is feature-gated behind `observability-prometheus-client` so the default
/// runtime keeps its small in-process registry and dependency set.
#[derive(Debug, Clone)]
pub struct PrometheusClientMetricsRegistry {
    registry: Arc<Registry>,
    http_requests: Family<HttpLabels, Counter>,
    http_duration: Family<HttpLabels, Histogram, HistogramConstructor>,
    rpc_requests: Family<RpcLabels, Counter>,
    rpc_duration: Family<RpcLabels, Histogram, HistogramConstructor>,
    sql_queries: Family<SqlLabels, Counter>,
    sql_duration: Family<SqlLabels, Histogram, HistogramConstructor>,
    redis_commands: Family<RedisLabels, Counter>,
    redis_duration: Family<RedisLabels, Histogram, HistogramConstructor>,
}

impl PrometheusClientMetricsRegistry {
    /// Creates and registers the supported rs-zero metric families.
    pub fn new() -> Self {
        let mut registry = Registry::default();
        let http_requests = Family::<HttpLabels, Counter>::default();
        let http_duration = duration_family();
        let rpc_requests = Family::<RpcLabels, Counter>::default();
        let rpc_duration = duration_family();
        let sql_queries = Family::<SqlLabels, Counter>::default();
        let sql_duration = duration_family();
        let redis_commands = Family::<RedisLabels, Counter>::default();
        let redis_duration = duration_family();

        registry.register(
            "rs_zero_http_requests",
            "Total number of HTTP requests.",
            http_requests.clone(),
        );
        registry.register(
            "rs_zero_http_request_duration_seconds",
            "HTTP request duration.",
            http_duration.clone(),
        );
        registry.register(
            "rs_zero_rpc_requests",
            "Total number of gRPC requests.",
            rpc_requests.clone(),
        );
        registry.register(
            "rs_zero_rpc_request_duration_seconds",
            "gRPC request duration.",
            rpc_duration.clone(),
        );
        registry.register(
            "rs_zero_sql_queries",
            "Total number of SQL queries.",
            sql_queries.clone(),
        );
        registry.register(
            "rs_zero_sql_query_duration_seconds",
            "SQL query duration.",
            sql_duration.clone(),
        );
        registry.register(
            "rs_zero_redis_commands",
            "Total number of Redis commands.",
            redis_commands.clone(),
        );
        registry.register(
            "rs_zero_redis_command_duration_seconds",
            "Redis command duration.",
            redis_duration.clone(),
        );

        Self {
            registry: Arc::new(registry),
            http_requests,
            http_duration,
            rpc_requests,
            rpc_duration,
            sql_queries,
            sql_duration,
            redis_commands,
            redis_duration,
        }
    }

    /// Records one HTTP request using low-cardinality labels.
    pub fn record_http_request(&self, method: &str, route: &str, status: u16, duration: Duration) {
        let labels = HttpLabels {
            method: method.to_string(),
            route: route.to_string(),
            status,
        };
        self.http_requests.get_or_create(&labels).inc();
        self.http_duration
            .get_or_create(&labels)
            .observe(duration.as_secs_f64());
    }

    /// Records one gRPC request using low-cardinality labels.
    pub fn record_rpc_request(&self, service: &str, method: &str, code: &str, duration: Duration) {
        let labels = RpcLabels {
            service: service.to_string(),
            method: method.to_string(),
            code: code.to_string(),
        };
        self.rpc_requests.get_or_create(&labels).inc();
        self.rpc_duration
            .get_or_create(&labels)
            .observe(duration.as_secs_f64());
    }

    /// Records one SQL query without SQL text or parameters.
    pub fn record_sql_query(
        &self,
        db_kind: &str,
        repository: &str,
        method: &str,
        operation: &str,
        result: &str,
        duration: Duration,
    ) {
        let labels = SqlLabels {
            db_kind: db_kind.to_string(),
            repository: repository.to_string(),
            method: method.to_string(),
            operation: operation.to_string(),
            result: result.to_string(),
        };
        self.sql_queries.get_or_create(&labels).inc();
        self.sql_duration
            .get_or_create(&labels)
            .observe(duration.as_secs_f64());
    }

    /// Records one Redis command without Redis keys or URLs.
    pub fn record_redis_command(
        &self,
        command: &str,
        shard: &str,
        result: &str,
        duration: Duration,
    ) {
        let labels = RedisLabels {
            command: command.to_string(),
            shard: shard.to_string(),
            result: result.to_string(),
        };
        self.redis_commands.get_or_create(&labels).inc();
        self.redis_duration
            .get_or_create(&labels)
            .observe(duration.as_secs_f64());
    }

    /// Renders OpenMetrics text exposition.
    pub fn render(&self) -> String {
        let mut output = String::new();
        encode(&mut output, &self.registry).expect("prometheus-client text encoding");
        output
    }
}

impl Default for PrometheusClientMetricsRegistry {
    fn default() -> Self {
        Self::new()
    }
}

type HistogramConstructor = fn() -> Histogram;

fn duration_family<L>() -> Family<L, Histogram, HistogramConstructor>
where
    L: Clone + std::hash::Hash + Eq,
{
    Family::new_with_constructor(duration_histogram)
}

fn duration_histogram() -> Histogram {
    Histogram::new(DURATION_BUCKETS)
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct HttpLabels {
    method: String,
    route: String,
    status: u16,
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct RpcLabels {
    service: String,
    method: String,
    code: String,
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct SqlLabels {
    db_kind: String,
    repository: String,
    method: String,
    operation: String,
    result: String,
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct RedisLabels {
    command: String,
    shard: String,
    result: String,
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use super::PrometheusClientMetricsRegistry;

    #[test]
    fn prometheus_client_registry_records_low_cardinality_metrics() {
        let registry = PrometheusClientMetricsRegistry::new();
        registry.record_http_request("GET", "/users/{id}", 200, Duration::from_millis(5));
        registry.record_rpc_request("hello.Hello", "Say", "Ok", Duration::from_millis(3));
        registry.record_sql_query(
            "sqlite",
            "users",
            "find_by_id",
            "select",
            "success",
            Duration::from_millis(1),
        );
        registry.record_redis_command("GET", "primary", "success", Duration::from_millis(1));

        let text = registry.render();
        assert!(text.contains("rs_zero_http_requests_total"));
        assert!(text.contains("route=\"/users/{id}\""));
        assert!(text.contains("rs_zero_rpc_requests_total"));
        assert!(text.contains("rs_zero_sql_queries_total"));
        assert!(text.contains("rs_zero_redis_commands_total"));
        assert!(!text.contains("/users/42"));
        assert!(!text.contains("users:42"));
        assert!(!text.contains("password"));
    }
}