Skip to main content

mcp_postgres/
metrics.rs

1use anyhow::Result;
2use once_cell::sync::Lazy;
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::sync::atomic::{AtomicU64, Ordering};
5
6/// Per-CPU-shard counters (data-oriented: no queue allocations on hot path).
7/// Each shard is on its own cache line to prevent false sharing.
8/// Producers increment atomics directly; consumers sum all shards.
9const NUM_SHARDS: usize = 16;
10
11#[repr(align(64))]
12struct MetricShard {
13    requests: AtomicU64,
14    errors: AtomicU64,
15}
16
17static SHARDS: Lazy<[MetricShard; NUM_SHARDS]> = Lazy::new(|| {
18    [0; NUM_SHARDS].map(|_| MetricShard {
19        requests: AtomicU64::new(0),
20        errors: AtomicU64::new(0),
21    })
22});
23
24thread_local! {
25    static THREAD_SHARD: usize = calc_thread_shard();
26}
27
28fn calc_thread_shard() -> usize {
29    const MASK: usize = NUM_SHARDS - 1;
30    let tid = std::thread::current().id();
31    let mut hasher = DefaultHasher::new();
32    tid.hash(&mut hasher);
33    hasher.finish() as usize & MASK
34}
35
36/// Increment request count on the calling thread's shard (cheap: one atomic add).
37#[inline]
38pub fn inc_requests() {
39    let shard = THREAD_SHARD.with(|s| *s);
40    SHARDS[shard].requests.fetch_add(1, Ordering::Relaxed);
41}
42
43/// Increment error count on the calling thread's shard.
44#[inline]
45pub fn inc_errors() {
46    let shard = THREAD_SHARD.with(|s| *s);
47    SHARDS[shard].errors.fetch_add(1, Ordering::Relaxed);
48}
49
50/// Read-and-reset all shard counters, returning totals.
51pub fn drain_counters() -> (u64, u64) {
52    let mut total_reqs = 0u64;
53    let mut total_errs = 0u64;
54    for shard in SHARDS.iter() {
55        total_reqs += shard.requests.swap(0, Ordering::Relaxed);
56        total_errs += shard.errors.swap(0, Ordering::Relaxed);
57    }
58    (total_reqs, total_errs)
59}
60
61pub fn init_metrics(port: u16) -> Result<()> {
62    use prometheus::{Encoder, IntCounter, Registry, TextEncoder};
63    use std::net::SocketAddr;
64    use std::sync::Arc;
65    use tokio::io::{AsyncReadExt, AsyncWriteExt};
66
67    let addr: SocketAddr = format!("127.0.0.1:{}", port).parse()?;
68
69    let registry = Arc::new(Registry::new());
70    let request_total = IntCounter::new("requests_total", "Total requests processed")?;
71    let error_total = IntCounter::new("request_errors_total", "Total request errors")?;
72    registry.register(Box::new(request_total.clone()))?;
73    registry.register(Box::new(error_total.clone()))?;
74
75    // Background task: drain shard counters into Prometheus counters
76    tokio::spawn(async move {
77        loop {
78            tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
79            let (reqs, errs) = drain_counters();
80            if reqs > 0 {
81                request_total.inc_by(reqs);
82            }
83            if errs > 0 {
84                error_total.inc_by(errs);
85            }
86        }
87    });
88
89    // Metrics HTTP endpoint
90    tokio::spawn(async move {
91        let listener = tokio::net::TcpListener::bind(&addr)
92            .await
93            .expect("Failed to bind metrics server");
94
95        loop {
96            match listener.accept().await {
97                Ok((stream, _)) => {
98                    let reg = Arc::clone(&registry);
99                    tokio::spawn(async move {
100                        let (mut reader, mut writer) = tokio::io::split(stream);
101                        let mut buf = vec![0; 1024];
102
103                        if let Ok(n) = reader.read(&mut buf).await {
104                            let request = String::from_utf8_lossy(&buf[..n]);
105                            if request.starts_with("GET /metrics") {
106                                let encoder = TextEncoder::new();
107                                let metric_families = reg.gather();
108                                let mut metrics_buf = Vec::new();
109                                if encoder.encode(&metric_families, &mut metrics_buf).is_ok() {
110                                    let response = format!(
111                                        "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n{}",
112                                        String::from_utf8_lossy(&metrics_buf)
113                                    );
114                                    let _ = writer.write_all(response.as_bytes()).await;
115                                }
116                            }
117                        }
118                    });
119                }
120                Err(e) => eprintln!("Metrics server error: {}", e),
121            }
122        }
123    });
124
125    Ok(())
126}