promwrite 0.1.1

a simple prometheus remote write client library
Documentation
mod types;

use crate::types::{Request, Sample as ProtoSample, Timeseries};
use chrono::Utc;
use reqwest::header::{
    CONTENT_ENCODING, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue,
    USER_AGENT,
};
use std::collections::BTreeMap;
use std::time::Duration;
use tokio::sync::mpsc;

pub const DEFAULT_BUCKETS: &[f64] = &[
    0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5,
    10.0,
];

struct MetricEvent {
    name: String,
    labels: Vec<(String, String)>,
    value: f64,
    timestamp: i64,
}

#[derive(Clone)]
pub struct Metric {
    tx: mpsc::Sender<MetricEvent>,
}

impl Metric {
    pub fn new(url: impl Into<String>, buffer_size: usize) -> Self {
        let (tx, rx) = mpsc::channel(buffer_size);
        let url_str = url.into();

        tokio::spawn(async move {
            Self::background_worker(url_str, rx).await;
        });

        Self { tx }
    }

    pub fn name(&self, name: impl Into<String>) -> MetricBuilder {
        MetricBuilder {
            manager: self.clone(),
            name: name.into(),
            labels: Vec::new(),
        }
    }

    async fn background_worker(
        url: String,
        mut rx: mpsc::Receiver<MetricEvent>,
    ) {
        let client = reqwest::Client::builder()
            .pool_max_idle_per_host(10)
            .tcp_keepalive(Duration::from_secs(90))
            .build()
            .unwrap_or_default();

        let mut headers = HeaderMap::new();
        headers.insert(
            CONTENT_TYPE,
            HeaderValue::from_static(
                "application/x-protobuf;proto=io.prometheus.write.v2.Request",
            ),
        );
        headers.insert(CONTENT_ENCODING, HeaderValue::from_static("snappy"));
        headers.insert(
            HeaderName::from_static("x-prometheus-remote-write-version"),
            HeaderValue::from_static("2.0.0"),
        );
        headers.insert(USER_AGENT, HeaderValue::from_static("promwrite/1.0.0"));

        let mut buffer = Vec::with_capacity(2000);
        let mut interval = tokio::time::interval(Duration::from_millis(50));

        loop {
            tokio::select! {
                Some(event) = rx.recv() => {
                    buffer.push(event);
                    if buffer.len() >= 1000 {
                        Self::flush_batch(&client, &url, &headers, &mut buffer).await;
                    }
                }
                _ = interval.tick() => {
                    if !buffer.is_empty() {
                        Self::flush_batch(&client, &url, &headers, &mut buffer).await;
                    }
                }
            }
        }
    }

    async fn flush_batch(
        client: &reqwest::Client,
        url: &str,
        headers: &HeaderMap,
        buffer: &mut Vec<MetricEvent>,
    ) {
        let mut req = Request {
            symbols: vec![],
            timeseries: vec![],
        };
        let mut symbol_map: BTreeMap<String, u32> = BTreeMap::new();

        let mut get_or_insert_symbol = |s: &str, req: &mut Request| -> u32 {
            if let Some(&idx) = symbol_map.get(s) {
                idx
            } else {
                let idx = req.symbols.len() as u32;
                req.symbols.push(s.to_string());
                symbol_map.insert(s.to_string(), idx);
                idx
            }
        };

        for event in buffer.drain(..) {
            let mut labels_map = BTreeMap::new();
            labels_map.insert("__name__".to_string(), event.name);
            for (k, v) in event.labels {
                labels_map.insert(k, v);
            }

            let mut labels_refs = Vec::with_capacity(labels_map.len() * 2);
            for (k, v) in &labels_map {
                labels_refs.push(get_or_insert_symbol(k, &mut req));
                labels_refs.push(get_or_insert_symbol(v, &mut req));
            }

            req.timeseries.push(Timeseries {
                labels_refs,
                samples: vec![ProtoSample {
                    value: event.value,
                    timestamp: event.timestamp,
                }],
            });
        }

        let encoded = prost::Message::encode_to_vec(&req);
        if let Ok(compressed) = snap::raw::Encoder::new().compress_vec(&encoded)
        {
            let mut request_builder =
                client.post(url).headers(headers.clone()).body(compressed);

            if let Ok(parsed_url) = reqwest::Url::parse(url) {
                let username = parsed_url.username();
                let password = parsed_url.password();
                if !username.is_empty() {
                    request_builder =
                        request_builder.basic_auth(username, password);
                }
            }

            match request_builder.send().await {
                Ok(resp) => {
                    if !resp.status().is_success() {
                        tracing::error!(
                            status = resp.status().as_u16(),
                            "Remote-write response error"
                        );
                    }
                }
                Err(err) => {
                    tracing::error!(error = %err, "Remote-write network request failed");
                }
            }
        }
    }
}

#[derive(Clone)]
pub struct MetricBuilder {
    pub manager: Metric,
    pub name: String,
    pub labels: Vec<(String, String)>,
}

impl MetricBuilder {
    pub fn label(
        mut self,
        name: impl Into<String>,
        value: impl Into<String>,
    ) -> Self {
        self.labels.push((name.into(), value.into()));
        self
    }

    pub fn set_with_ts(&self, value: f64, timestamp_ms: i64) {
        let event = MetricEvent {
            name: self.name.clone(),
            labels: self.labels.clone(),
            value,
            timestamp: timestamp_ms,
        };
        if let Err(_) = self.manager.tx.try_send(event) {
            tracing::warn!(metric = %self.name, "Buffer queue full, event dropped");
        }
    }

    pub fn set(&self, value: f64) {
        self.set_with_ts(value, Utc::now().timestamp_millis());
    }

    pub fn inc(&self) {
        self.set(1.0);
    }

    pub fn add(&self, value: f64) {
        if value > 0.0 {
            self.set(value);
        }
    }
}

pub struct Histogram {
    sum_builder: MetricBuilder,
    count_builder: MetricBuilder,
    bucket_builders: Vec<(f64, MetricBuilder)>,
    inf_builder: MetricBuilder,
}

impl Histogram {
    pub fn new_with_default(builder: MetricBuilder) -> Self {
        Self::new(builder, DEFAULT_BUCKETS.to_vec())
    }

    pub fn new(builder: MetricBuilder, buckets: Vec<f64>) -> Self {
        let mut sorted_buckets = buckets;
        sorted_buckets.sort_by(|a, b| a.partial_cmp(b).unwrap());

        let mut sum_builder =
            builder.manager.name(format!("{}_sum", builder.name));
        for (k, v) in &builder.labels {
            sum_builder = sum_builder.label(k, v);
        }

        let mut count_builder =
            builder.manager.name(format!("{}_count", builder.name));
        for (k, v) in &builder.labels {
            count_builder = count_builder.label(k, v);
        }

        let mut bucket_builders = Vec::with_capacity(sorted_buckets.len());
        for &bucket in &sorted_buckets {
            let b_builder = builder.clone().label("le", bucket.to_string());
            bucket_builders.push((bucket, b_builder));
        }

        let inf_builder = builder.clone().label("le", "+Inf");

        Self {
            sum_builder,
            count_builder,
            bucket_builders,
            inf_builder,
        }
    }

    pub fn observe_with_ts(&self, value: f64, timestamp_ms: i64) {
        let _ = self.sum_builder.manager.tx.try_send(MetricEvent {
            name: self.sum_builder.name.clone(),
            labels: self.sum_builder.labels.clone(),
            value,
            timestamp: timestamp_ms,
        });

        let _ = self.count_builder.manager.tx.try_send(MetricEvent {
            name: self.count_builder.name.clone(),
            labels: self.count_builder.labels.clone(),
            value: 1.0,
            timestamp: timestamp_ms,
        });

        for (bucket, b_builder) in &self.bucket_builders {
            if value <= *bucket {
                let _ = b_builder.manager.tx.try_send(MetricEvent {
                    name: b_builder.name.clone(),
                    labels: b_builder.labels.clone(),
                    value: 1.0,
                    timestamp: timestamp_ms,
                });
            }
        }

        let _ = self.inf_builder.manager.tx.try_send(MetricEvent {
            name: self.inf_builder.name.clone(),
            labels: self.inf_builder.labels.clone(),
            value: 1.0,
            timestamp: timestamp_ms,
        });
    }

    pub fn observe(&self, value: f64) {
        self.observe_with_ts(value, Utc::now().timestamp_millis());
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;
    use tokio::time::sleep;

    #[tokio::test]
    async fn test_all_prometheus_metric_types() {
        let metric = Metric::new(
            "http://admin:admin@127.0.0.1:9090/api/v1/write",
            50000,
        );

        let qps_counter = metric
            .name("rust_sdk_requests_total")
            .label("path", "/api/v1/user");
        qps_counter.inc();
        qps_counter.inc();

        let cpu_gauge =
            metric.name("rust_sdk_cpu_usage_ratio").label("core", "0");
        cpu_gauge.set(0.12);
        cpu_gauge.set(0.45);

        let latency_histogram = Histogram::new_with_default(
            metric
                .name("rust_sdk_http_request_duration_seconds")
                .label("method", "POST"),
        );

        latency_histogram.observe(0.025);
        latency_histogram.observe(0.080);
        latency_histogram.observe(0.350);

        let now_ms = Utc::now().timestamp_millis();
        let history_gauge =
            metric.name("rust_sdk_history_gauge").label("host", "db-01");
        history_gauge.set_with_ts(85.5, now_ms);

        let mut handles = vec![];
        let stress_gauge = metric
            .name("rust_sdk_high_qps_gauge")
            .label("type", "stress");

        for i in 0..10 {
            let task_metric =
                stress_gauge.clone().label("thread_id", i.to_string());
            let handle = tokio::spawn(async move {
                for count in 0..100 {
                    task_metric.set(count as f64);
                    tokio::time::sleep(Duration::from_millis(1)).await;
                }
            });
            handles.push(handle);
        }

        for handle in handles {
            handle.await.unwrap();
        }

        sleep(Duration::from_secs(2)).await;
    }
}