1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Duration;
5
6use prometheus::{Registry, TextEncoder};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9
10use crate::metrics::METRICS;
11#[cfg(feature = "system-metrics")]
12use crate::process::SystemMetrics;
13
14const MAX_REQUEST_BYTES: usize = 4096;
15const READ_TIMEOUT: Duration = Duration::from_secs(5);
16const WRITE_TIMEOUT: Duration = Duration::from_secs(5);
17
18type MetricsHandler = Arc<dyn Fn() -> String + Send + Sync + 'static>;
19
20#[derive(Debug, Clone)]
22pub struct PrometheusConfig {
23 pub bind_address: SocketAddr,
25 pub metrics_path: String,
27 #[cfg(feature = "system-metrics")]
29 pub include_system_metrics: bool,
30 #[cfg(feature = "system-metrics")]
32 pub system_metrics_interval: u64,
33}
34
35impl Default for PrometheusConfig {
36 fn default() -> Self {
37 Self {
38 bind_address: "127.0.0.1:9090".parse().expect("Invalid default address"),
39 metrics_path: "/metrics".to_string(),
40 #[cfg(feature = "system-metrics")]
41 include_system_metrics: true,
42 #[cfg(feature = "system-metrics")]
43 system_metrics_interval: 15,
44 }
45 }
46}
47
48#[derive(Debug)]
50pub struct PrometheusServer {
51 config: PrometheusConfig,
52 registry: Arc<Registry>,
53 #[cfg(feature = "system-metrics")]
54 system_metrics: Option<SystemMetrics>,
55}
56
57fn request_matches_path(request: &str, metrics_path: &str) -> bool {
58 let Some(request_line) = request.lines().next() else {
59 return false;
60 };
61
62 let mut parts = request_line.split_whitespace();
63 let method = parts.next();
64 let target = parts.next();
65 let version = parts.next();
66
67 if method != Some("GET") || target.is_none() || version.is_none() {
68 return false;
69 }
70
71 let target_path = target
72 .and_then(|target| target.split('?').next())
73 .unwrap_or_default();
74
75 target_path == metrics_path
76}
77
78async fn read_request(stream: &mut TcpStream) -> io::Result<Option<String>> {
79 let mut request = Vec::with_capacity(1024);
80 let mut buffer = [0_u8; 1024];
81
82 loop {
83 let bytes_read = tokio::time::timeout(READ_TIMEOUT, stream.read(&mut buffer))
84 .await
85 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "timed out reading request"))??;
86
87 if bytes_read == 0 {
88 if request.is_empty() {
89 return Ok(None);
90 }
91 break;
92 }
93
94 request.extend_from_slice(&buffer[..bytes_read]);
95
96 if request.windows(2).any(|window| window == b"\r\n")
97 || request.contains(&b'\n')
98 || request.len() >= MAX_REQUEST_BYTES
99 {
100 break;
101 }
102 }
103
104 Ok(Some(String::from_utf8_lossy(&request).to_string()))
105}
106
107async fn write_response(
108 stream: &mut TcpStream,
109 status: &str,
110 content_type: &str,
111 body: &str,
112) -> io::Result<()> {
113 let response = format!(
114 "HTTP/1.1 {status}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
115 body.len(),
116 );
117
118 tokio::time::timeout(WRITE_TIMEOUT, stream.write_all(response.as_bytes()))
119 .await
120 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "timed out writing response"))??;
121
122 Ok(())
123}
124
125async fn handle_connection(
126 mut stream: TcpStream,
127 metrics_path: String,
128 metrics_handler: MetricsHandler,
129) -> io::Result<()> {
130 let Some(request) = read_request(&mut stream).await? else {
131 return Ok(());
132 };
133
134 if request_matches_path(&request, &metrics_path) {
135 let metrics = metrics_handler();
136 write_response(
137 &mut stream,
138 "200 OK",
139 "text/plain; version=0.0.4; charset=utf-8",
140 &metrics,
141 )
142 .await
143 } else {
144 write_response(&mut stream, "404 Not Found", "text/plain", "Not Found").await
145 }
146}
147
148impl PrometheusServer {
149 pub fn new(config: PrometheusConfig) -> crate::Result<Self> {
154 let registry = METRICS.registry();
155
156 #[cfg(feature = "system-metrics")]
157 let system_metrics = if config.include_system_metrics {
158 let sys_metrics = SystemMetrics::new()?;
159 Some(sys_metrics)
160 } else {
161 None
162 };
163
164 Ok(Self {
165 config,
166 registry,
167 #[cfg(feature = "system-metrics")]
168 system_metrics,
169 })
170 }
171
172 #[must_use]
174 pub const fn with_registry(config: PrometheusConfig, registry: Arc<Registry>) -> Self {
175 Self {
176 config,
177 registry,
178 #[cfg(feature = "system-metrics")]
179 system_metrics: None,
180 }
181 }
182
183 fn create_metrics_handler(
185 registry: Arc<Registry>,
186 #[cfg(feature = "system-metrics")] system_metrics: Option<SystemMetrics>,
187 ) -> MetricsHandler {
188 Arc::new(move || {
189 let encoder = TextEncoder::new();
190
191 #[cfg(feature = "system-metrics")]
193 let mut metric_families = registry.gather();
194 #[cfg(not(feature = "system-metrics"))]
195 let metric_families = registry.gather();
196
197 #[cfg(feature = "system-metrics")]
199 if let Some(ref sys_metrics) = system_metrics {
200 if let Err(e) = sys_metrics.update_metrics() {
202 tracing::warn!("Failed to update system metrics: {e}");
203 }
204
205 let sys_registry = sys_metrics.registry();
206 let mut sys_families = sys_registry.gather();
207 metric_families.append(&mut sys_families);
208 }
209
210 encoder
212 .encode_to_string(&metric_families)
213 .unwrap_or_else(|e| {
214 tracing::error!("Failed to encode metrics: {e}");
215 format!("Failed to encode metrics: {e}")
216 })
217 })
218 }
219
220 pub async fn start(
225 self,
226 shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
227 ) -> crate::Result<()> {
228 let binding = self.config.bind_address;
229 let registry_clone = Arc::<Registry>::clone(&self.registry);
230 let path = self.config.metrics_path.clone();
231
232 #[cfg(feature = "system-metrics")]
233 let metrics_handler =
234 Self::create_metrics_handler(registry_clone, self.system_metrics.clone());
235
236 #[cfg(not(feature = "system-metrics"))]
237 let metrics_handler = Self::create_metrics_handler(registry_clone);
238
239 let listener = TcpListener::bind(binding).await.map_err(|source| {
240 crate::error::PrometheusError::ServerBind {
241 address: binding.to_string(),
242 source,
243 }
244 })?;
245
246 tracing::info!("Started Prometheus server on {} at path {}", binding, path);
247
248 tokio::pin!(shutdown_signal);
249
250 loop {
251 tokio::select! {
252 _ = &mut shutdown_signal => {
253 tracing::info!("Shutdown signal received, stopping Prometheus server");
254 break;
255 }
256 accept_result = listener.accept() => {
257 match accept_result {
258 Ok((stream, _peer_addr)) => {
259 let metrics_path = path.clone();
260 let metrics_handler = Arc::clone(&metrics_handler);
261
262 tokio::spawn(async move {
263 if let Err(e) =
264 handle_connection(stream, metrics_path, metrics_handler).await
265 {
266 tracing::warn!("Failed to serve Prometheus scrape: {e}");
267 }
268 });
269 }
270 Err(e) => {
271 tracing::error!("Failed to accept connection: {e}");
272 tokio::time::sleep(Duration::from_millis(100)).await;
273 }
274 }
275 }
276 }
277 }
278
279 tracing::info!("Prometheus server stopped");
280
281 Ok(())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::request_matches_path;
288
289 #[test]
290 fn request_matching_requires_exact_request_target() {
291 assert!(request_matches_path(
292 "GET /metrics HTTP/1.1\r\n\r\n",
293 "/metrics"
294 ));
295 assert!(request_matches_path(
296 "GET /metrics?name=value HTTP/1.1\r\n\r\n",
297 "/metrics"
298 ));
299 assert!(!request_matches_path(
300 "GET /not-metrics HTTP/1.1\r\nX-Path: /metrics\r\n\r\n",
301 "/metrics"
302 ));
303 assert!(!request_matches_path(
304 "POST /metrics HTTP/1.1\r\n\r\n",
305 "/metrics"
306 ));
307 }
308}
309
310#[derive(Debug)]
312pub struct PrometheusBuilder {
313 config: PrometheusConfig,
314}
315
316impl PrometheusBuilder {
317 #[must_use]
319 pub fn new() -> Self {
320 Self {
321 config: PrometheusConfig::default(),
322 }
323 }
324
325 #[must_use]
327 pub const fn bind_address(mut self, addr: SocketAddr) -> Self {
328 self.config.bind_address = addr;
329 self
330 }
331
332 #[must_use]
334 pub fn metrics_path<S: Into<String>>(mut self, path: S) -> Self {
335 self.config.metrics_path = path.into();
336 self
337 }
338
339 #[cfg(feature = "system-metrics")]
341 #[must_use]
342 pub const fn system_metrics(mut self, enabled: bool) -> Self {
343 self.config.include_system_metrics = enabled;
344 self
345 }
346
347 #[cfg(feature = "system-metrics")]
349 #[must_use]
350 pub const fn system_metrics_interval(mut self, seconds: u64) -> Self {
351 self.config.system_metrics_interval = seconds;
352 self
353 }
354
355 pub fn build_with_cdk_metrics(self) -> crate::Result<PrometheusServer> {
360 PrometheusServer::new(self.config)
361 }
362
363 #[must_use]
365 pub fn build_with_registry(self, registry: Arc<Registry>) -> PrometheusServer {
366 PrometheusServer::with_registry(self.config, registry)
367 }
368}
369
370impl Default for PrometheusBuilder {
371 fn default() -> Self {
372 Self::new()
373 }
374}