cdk_prometheus/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use prometheus::{Registry, TextEncoder};
6
7use crate::metrics::METRICS;
8#[cfg(feature = "system-metrics")]
9use crate::process::SystemMetrics;
10
11/// Configuration for the Prometheus server
12#[derive(Debug, Clone)]
13pub struct PrometheusConfig {
14    /// Address to bind the server to (default: "127.0.0.1:9090")
15    pub bind_address: SocketAddr,
16    /// Path to serve metrics on (default: "/metrics")
17    pub metrics_path: String,
18    /// Whether to include system metrics (default: true if feature enabled)
19    #[cfg(feature = "system-metrics")]
20    pub include_system_metrics: bool,
21    /// How often to update system metrics in seconds (default: 15)
22    #[cfg(feature = "system-metrics")]
23    pub system_metrics_interval: u64,
24}
25
26impl Default for PrometheusConfig {
27    fn default() -> Self {
28        Self {
29            bind_address: "127.0.0.1:9090".parse().expect("Invalid default address"),
30            metrics_path: "/metrics".to_string(),
31            #[cfg(feature = "system-metrics")]
32            include_system_metrics: true,
33            #[cfg(feature = "system-metrics")]
34            system_metrics_interval: 15,
35        }
36    }
37}
38
39/// Prometheus metrics server
40#[derive(Debug)]
41pub struct PrometheusServer {
42    config: PrometheusConfig,
43    registry: Arc<Registry>,
44    #[cfg(feature = "system-metrics")]
45    system_metrics: Option<SystemMetrics>,
46}
47
48impl PrometheusServer {
49    /// Create a new Prometheus server with CDK metrics
50    ///
51    /// # Errors
52    /// Returns an error if system metrics cannot be created (when enabled)
53    pub fn new(config: PrometheusConfig) -> crate::Result<Self> {
54        let registry = METRICS.registry();
55
56        #[cfg(feature = "system-metrics")]
57        let system_metrics = if config.include_system_metrics {
58            let sys_metrics = SystemMetrics::new()?;
59            Some(sys_metrics)
60        } else {
61            None
62        };
63
64        Ok(Self {
65            config,
66            registry,
67            #[cfg(feature = "system-metrics")]
68            system_metrics,
69        })
70    }
71
72    /// Create a new Prometheus server with custom registry
73    #[must_use]
74    pub const fn with_registry(config: PrometheusConfig, registry: Arc<Registry>) -> Self {
75        Self {
76            config,
77            registry,
78            #[cfg(feature = "system-metrics")]
79            system_metrics: None,
80        }
81    }
82
83    /// Create a metrics handler function that gathers and encodes metrics
84    fn create_metrics_handler(
85        registry: Arc<Registry>,
86        #[cfg(feature = "system-metrics")] system_metrics: Option<SystemMetrics>,
87    ) -> impl Fn() -> String {
88        move || {
89            let encoder = TextEncoder::new();
90
91            // Collect metrics from our registry
92            #[cfg(feature = "system-metrics")]
93            let mut metric_families = registry.gather();
94            #[cfg(not(feature = "system-metrics"))]
95            let metric_families = registry.gather();
96
97            // Add system metrics if available
98            #[cfg(feature = "system-metrics")]
99            if let Some(ref sys_metrics) = system_metrics {
100                // Update system metrics before collection
101                if let Err(e) = sys_metrics.update_metrics() {
102                    tracing::warn!("Failed to update system metrics: {e}");
103                }
104
105                let sys_registry = sys_metrics.registry();
106                let mut sys_families = sys_registry.gather();
107                metric_families.append(&mut sys_families);
108            }
109
110            // Encode metrics to string
111            encoder
112                .encode_to_string(&metric_families)
113                .unwrap_or_else(|e| {
114                    tracing::error!("Failed to encode metrics: {e}");
115                    format!("Failed to encode metrics: {e}")
116                })
117        }
118    }
119
120    /// Start the Prometheus HTTP server
121    ///
122    /// # Errors
123    /// This function always returns Ok as errors are handled internally
124    pub async fn start(
125        self,
126        shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
127    ) -> crate::Result<()> {
128        // Create and start the exporter
129        let binding = self.config.bind_address;
130        let registry_clone = Arc::<Registry>::clone(&self.registry);
131
132        // Create a handler that exposes our registry
133        #[cfg(feature = "system-metrics")]
134        let metrics_handler =
135            Self::create_metrics_handler(registry_clone, self.system_metrics.clone());
136
137        #[cfg(not(feature = "system-metrics"))]
138        let metrics_handler = Self::create_metrics_handler(registry_clone);
139
140        // Start the exporter in a background task
141        let path = self.config.metrics_path.clone();
142
143        // Create a channel for signaling the server task to shutdown
144        let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
145
146        // Spawn the server task
147        let server_handle = tokio::spawn(async move {
148            // We're using a simple HTTP server to expose our metrics
149            use std::io::{Read, Write};
150            use std::net::TcpListener;
151
152            // Create a TCP listener
153            let listener = match TcpListener::bind(binding) {
154                Ok(listener) => {
155                    // Set non-blocking mode to allow for shutdown checking
156                    if let Err(e) = listener.set_nonblocking(true) {
157                        tracing::error!("Failed to set non-blocking mode: {e}");
158                        return;
159                    }
160                    listener
161                }
162                Err(e) => {
163                    tracing::error!("Failed to bind TCP listener: {e}");
164                    return;
165                }
166            };
167            tracing::info!("Started Prometheus server on {} at path {}", binding, path);
168
169            // Accept connections with shutdown signal handling
170            loop {
171                // Check for shutdown signal
172                if shutdown_rx.try_recv().is_ok() {
173                    tracing::info!("Shutdown signal received, stopping Prometheus server");
174                    break;
175                }
176
177                // Try to accept a connection (non-blocking)
178                match listener.accept() {
179                    Ok((mut stream, _)) => {
180                        // Handle the connection
181                        let mut buffer = [0; 1024];
182                        match stream.read(&mut buffer) {
183                            Ok(0) => {}
184                            Ok(bytes_read) => {
185                                // Convert the buffer to a string
186                                let request = String::from_utf8_lossy(&buffer[..bytes_read]);
187
188                                // Check if the request is for our metrics path
189                                if request.contains(&format!("GET {path} HTTP")) {
190                                    // Get the metrics
191                                    let metrics = metrics_handler();
192
193                                    // Write the response
194                                    let response = format!(
195                                        "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}",
196                                        metrics.len(),
197                                        metrics
198                                    );
199
200                                    if let Err(e) = stream.write_all(response.as_bytes()) {
201                                        tracing::error!("Failed to write response: {e}");
202                                    }
203                                } else {
204                                    // Write a 404 response
205                                    let response = "HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\nContent-Length: 9\r\n\r\nNot Found";
206                                    if let Err(e) = stream.write_all(response.as_bytes()) {
207                                        tracing::error!("Failed to write response: {e}");
208                                    }
209                                }
210                            }
211                            Err(e) => {
212                                tracing::error!("Failed to read from stream: {e}");
213                            }
214                        }
215                    }
216                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
217                        // No connection available, continue the loop
218                        tokio::time::sleep(Duration::from_millis(10)).await;
219                    }
220                    Err(e) => {
221                        tracing::error!("Failed to accept connection: {e}");
222                        // Add a small delay to prevent busy looping on persistent errors
223                        tokio::time::sleep(Duration::from_millis(100)).await;
224                    }
225                }
226            }
227
228            tracing::info!("Prometheus server stopped");
229        });
230
231        // Wait for the shutdown signal
232        shutdown_signal.await;
233
234        // Signal the server to shutdown
235        let _ = shutdown_tx.send(());
236
237        // Wait for the server task to complete (with a timeout)
238        match tokio::time::timeout(Duration::from_secs(5), server_handle).await {
239            Ok(result) => {
240                if let Err(e) = result {
241                    tracing::error!("Server task failed: {e}");
242                }
243            }
244            Err(_) => {
245                tracing::warn!("Server shutdown timed out after 5 seconds");
246            }
247        }
248
249        Ok(())
250    }
251}
252
253/// Builder for easy Prometheus server setup
254#[derive(Debug)]
255pub struct PrometheusBuilder {
256    config: PrometheusConfig,
257}
258
259impl PrometheusBuilder {
260    /// Create a new builder with default configuration
261    #[must_use]
262    pub fn new() -> Self {
263        Self {
264            config: PrometheusConfig::default(),
265        }
266    }
267
268    /// Set the bind address
269    #[must_use]
270    pub const fn bind_address(mut self, addr: SocketAddr) -> Self {
271        self.config.bind_address = addr;
272        self
273    }
274
275    /// Set the metrics path
276    #[must_use]
277    pub fn metrics_path<S: Into<String>>(mut self, path: S) -> Self {
278        self.config.metrics_path = path.into();
279        self
280    }
281
282    /// Enable or disable system metrics
283    #[cfg(feature = "system-metrics")]
284    #[must_use]
285    pub const fn system_metrics(mut self, enabled: bool) -> Self {
286        self.config.include_system_metrics = enabled;
287        self
288    }
289
290    /// Set system metrics update interval
291    #[cfg(feature = "system-metrics")]
292    #[must_use]
293    pub const fn system_metrics_interval(mut self, seconds: u64) -> Self {
294        self.config.system_metrics_interval = seconds;
295        self
296    }
297
298    /// Build the server with specific CDK metrics instance
299    ///
300    /// # Errors
301    /// Returns an error if system metrics cannot be created (when enabled)
302    pub fn build_with_cdk_metrics(self) -> crate::Result<PrometheusServer> {
303        PrometheusServer::new(self.config)
304    }
305
306    /// Build the server with custom registry
307    #[must_use]
308    pub fn build_with_registry(self, registry: Arc<Registry>) -> PrometheusServer {
309        PrometheusServer::with_registry(self.config, registry)
310    }
311}
312
313impl Default for PrometheusBuilder {
314    fn default() -> Self {
315        Self::new()
316    }
317}