Skip to main content

cdk_prometheus/
server.rs

1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Duration;
5
6use prometheus::{Registry, TextEncoder};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9
10use crate::metrics::METRICS;
11#[cfg(feature = "system-metrics")]
12use crate::process::SystemMetrics;
13
14const MAX_REQUEST_BYTES: usize = 4096;
15const READ_TIMEOUT: Duration = Duration::from_secs(5);
16const WRITE_TIMEOUT: Duration = Duration::from_secs(5);
17
18type MetricsHandler = Arc<dyn Fn() -> String + Send + Sync + 'static>;
19
20/// Configuration for the Prometheus server
21#[derive(Debug, Clone)]
22pub struct PrometheusConfig {
23    /// Address to bind the server to (default: "127.0.0.1:9090")
24    pub bind_address: SocketAddr,
25    /// Path to serve metrics on (default: "/metrics")
26    pub metrics_path: String,
27    /// Whether to include system metrics (default: true if feature enabled)
28    #[cfg(feature = "system-metrics")]
29    pub include_system_metrics: bool,
30    /// How often to update system metrics in seconds (default: 15)
31    #[cfg(feature = "system-metrics")]
32    pub system_metrics_interval: u64,
33}
34
35impl Default for PrometheusConfig {
36    fn default() -> Self {
37        Self {
38            bind_address: "127.0.0.1:9090".parse().expect("Invalid default address"),
39            metrics_path: "/metrics".to_string(),
40            #[cfg(feature = "system-metrics")]
41            include_system_metrics: true,
42            #[cfg(feature = "system-metrics")]
43            system_metrics_interval: 15,
44        }
45    }
46}
47
48/// Prometheus metrics server
49#[derive(Debug)]
50pub struct PrometheusServer {
51    config: PrometheusConfig,
52    registry: Arc<Registry>,
53    #[cfg(feature = "system-metrics")]
54    system_metrics: Option<SystemMetrics>,
55}
56
57fn request_matches_path(request: &str, metrics_path: &str) -> bool {
58    let Some(request_line) = request.lines().next() else {
59        return false;
60    };
61
62    let mut parts = request_line.split_whitespace();
63    let method = parts.next();
64    let target = parts.next();
65    let version = parts.next();
66
67    if method != Some("GET") || target.is_none() || version.is_none() {
68        return false;
69    }
70
71    let target_path = target
72        .and_then(|target| target.split('?').next())
73        .unwrap_or_default();
74
75    target_path == metrics_path
76}
77
78async fn read_request(stream: &mut TcpStream) -> io::Result<Option<String>> {
79    let mut request = Vec::with_capacity(1024);
80    let mut buffer = [0_u8; 1024];
81
82    loop {
83        let bytes_read = tokio::time::timeout(READ_TIMEOUT, stream.read(&mut buffer))
84            .await
85            .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "timed out reading request"))??;
86
87        if bytes_read == 0 {
88            if request.is_empty() {
89                return Ok(None);
90            }
91            break;
92        }
93
94        request.extend_from_slice(&buffer[..bytes_read]);
95
96        if request.windows(2).any(|window| window == b"\r\n")
97            || request.contains(&b'\n')
98            || request.len() >= MAX_REQUEST_BYTES
99        {
100            break;
101        }
102    }
103
104    Ok(Some(String::from_utf8_lossy(&request).to_string()))
105}
106
107async fn write_response(
108    stream: &mut TcpStream,
109    status: &str,
110    content_type: &str,
111    body: &str,
112) -> io::Result<()> {
113    let response = format!(
114        "HTTP/1.1 {status}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
115        body.len(),
116    );
117
118    tokio::time::timeout(WRITE_TIMEOUT, stream.write_all(response.as_bytes()))
119        .await
120        .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "timed out writing response"))??;
121
122    Ok(())
123}
124
125async fn handle_connection(
126    mut stream: TcpStream,
127    metrics_path: String,
128    metrics_handler: MetricsHandler,
129) -> io::Result<()> {
130    let Some(request) = read_request(&mut stream).await? else {
131        return Ok(());
132    };
133
134    if request_matches_path(&request, &metrics_path) {
135        let metrics = metrics_handler();
136        write_response(
137            &mut stream,
138            "200 OK",
139            "text/plain; version=0.0.4; charset=utf-8",
140            &metrics,
141        )
142        .await
143    } else {
144        write_response(&mut stream, "404 Not Found", "text/plain", "Not Found").await
145    }
146}
147
148impl PrometheusServer {
149    /// Create a new Prometheus server with CDK metrics
150    ///
151    /// # Errors
152    /// Returns an error if system metrics cannot be created (when enabled)
153    pub fn new(config: PrometheusConfig) -> crate::Result<Self> {
154        let registry = METRICS.registry();
155
156        #[cfg(feature = "system-metrics")]
157        let system_metrics = if config.include_system_metrics {
158            let sys_metrics = SystemMetrics::new()?;
159            Some(sys_metrics)
160        } else {
161            None
162        };
163
164        Ok(Self {
165            config,
166            registry,
167            #[cfg(feature = "system-metrics")]
168            system_metrics,
169        })
170    }
171
172    /// Create a new Prometheus server with custom registry
173    #[must_use]
174    pub const fn with_registry(config: PrometheusConfig, registry: Arc<Registry>) -> Self {
175        Self {
176            config,
177            registry,
178            #[cfg(feature = "system-metrics")]
179            system_metrics: None,
180        }
181    }
182
183    /// Create a metrics handler function that gathers and encodes metrics
184    fn create_metrics_handler(
185        registry: Arc<Registry>,
186        #[cfg(feature = "system-metrics")] system_metrics: Option<SystemMetrics>,
187    ) -> MetricsHandler {
188        Arc::new(move || {
189            let encoder = TextEncoder::new();
190
191            // Collect metrics from our registry
192            #[cfg(feature = "system-metrics")]
193            let mut metric_families = registry.gather();
194            #[cfg(not(feature = "system-metrics"))]
195            let metric_families = registry.gather();
196
197            // Add system metrics if available
198            #[cfg(feature = "system-metrics")]
199            if let Some(ref sys_metrics) = system_metrics {
200                // Update system metrics before collection
201                if let Err(e) = sys_metrics.update_metrics() {
202                    tracing::warn!("Failed to update system metrics: {e}");
203                }
204
205                let sys_registry = sys_metrics.registry();
206                let mut sys_families = sys_registry.gather();
207                metric_families.append(&mut sys_families);
208            }
209
210            // Encode metrics to string
211            encoder
212                .encode_to_string(&metric_families)
213                .unwrap_or_else(|e| {
214                    tracing::error!("Failed to encode metrics: {e}");
215                    format!("Failed to encode metrics: {e}")
216                })
217        })
218    }
219
220    /// Start the Prometheus HTTP server
221    ///
222    /// # Errors
223    /// Returns an error if the server cannot bind to the configured address
224    pub async fn start(
225        self,
226        shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
227    ) -> crate::Result<()> {
228        let binding = self.config.bind_address;
229        let registry_clone = Arc::<Registry>::clone(&self.registry);
230        let path = self.config.metrics_path.clone();
231
232        #[cfg(feature = "system-metrics")]
233        let metrics_handler =
234            Self::create_metrics_handler(registry_clone, self.system_metrics.clone());
235
236        #[cfg(not(feature = "system-metrics"))]
237        let metrics_handler = Self::create_metrics_handler(registry_clone);
238
239        let listener = TcpListener::bind(binding).await.map_err(|source| {
240            crate::error::PrometheusError::ServerBind {
241                address: binding.to_string(),
242                source,
243            }
244        })?;
245
246        tracing::info!("Started Prometheus server on {} at path {}", binding, path);
247
248        tokio::pin!(shutdown_signal);
249
250        loop {
251            tokio::select! {
252                _ = &mut shutdown_signal => {
253                    tracing::info!("Shutdown signal received, stopping Prometheus server");
254                    break;
255                }
256                accept_result = listener.accept() => {
257                    match accept_result {
258                        Ok((stream, _peer_addr)) => {
259                            let metrics_path = path.clone();
260                            let metrics_handler = Arc::clone(&metrics_handler);
261
262                            tokio::spawn(async move {
263                                if let Err(e) =
264                                    handle_connection(stream, metrics_path, metrics_handler).await
265                                {
266                                    tracing::warn!("Failed to serve Prometheus scrape: {e}");
267                                }
268                            });
269                        }
270                        Err(e) => {
271                            tracing::error!("Failed to accept connection: {e}");
272                            tokio::time::sleep(Duration::from_millis(100)).await;
273                        }
274                    }
275                }
276            }
277        }
278
279        tracing::info!("Prometheus server stopped");
280
281        Ok(())
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::request_matches_path;
288
289    #[test]
290    fn request_matching_requires_exact_request_target() {
291        assert!(request_matches_path(
292            "GET /metrics HTTP/1.1\r\n\r\n",
293            "/metrics"
294        ));
295        assert!(request_matches_path(
296            "GET /metrics?name=value HTTP/1.1\r\n\r\n",
297            "/metrics"
298        ));
299        assert!(!request_matches_path(
300            "GET /not-metrics HTTP/1.1\r\nX-Path: /metrics\r\n\r\n",
301            "/metrics"
302        ));
303        assert!(!request_matches_path(
304            "POST /metrics HTTP/1.1\r\n\r\n",
305            "/metrics"
306        ));
307    }
308}
309
310/// Builder for easy Prometheus server setup
311#[derive(Debug)]
312pub struct PrometheusBuilder {
313    config: PrometheusConfig,
314}
315
316impl PrometheusBuilder {
317    /// Create a new builder with default configuration
318    #[must_use]
319    pub fn new() -> Self {
320        Self {
321            config: PrometheusConfig::default(),
322        }
323    }
324
325    /// Set the bind address
326    #[must_use]
327    pub const fn bind_address(mut self, addr: SocketAddr) -> Self {
328        self.config.bind_address = addr;
329        self
330    }
331
332    /// Set the metrics path
333    #[must_use]
334    pub fn metrics_path<S: Into<String>>(mut self, path: S) -> Self {
335        self.config.metrics_path = path.into();
336        self
337    }
338
339    /// Enable or disable system metrics
340    #[cfg(feature = "system-metrics")]
341    #[must_use]
342    pub const fn system_metrics(mut self, enabled: bool) -> Self {
343        self.config.include_system_metrics = enabled;
344        self
345    }
346
347    /// Set system metrics update interval
348    #[cfg(feature = "system-metrics")]
349    #[must_use]
350    pub const fn system_metrics_interval(mut self, seconds: u64) -> Self {
351        self.config.system_metrics_interval = seconds;
352        self
353    }
354
355    /// Build the server with specific CDK metrics instance
356    ///
357    /// # Errors
358    /// Returns an error if system metrics cannot be created (when enabled)
359    pub fn build_with_cdk_metrics(self) -> crate::Result<PrometheusServer> {
360        PrometheusServer::new(self.config)
361    }
362
363    /// Build the server with custom registry
364    #[must_use]
365    pub fn build_with_registry(self, registry: Arc<Registry>) -> PrometheusServer {
366        PrometheusServer::with_registry(self.config, registry)
367    }
368}
369
370impl Default for PrometheusBuilder {
371    fn default() -> Self {
372        Self::new()
373    }
374}