Skip to main content

entrenar/monitor/prometheus/
exporter.rs

1//! Prometheus metrics exporter implementation.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::RwLock;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use super::types::{LabelSet, MetricDef, MetricType, MetricValue};
9
10/// Prometheus metrics exporter for training monitoring
11#[derive(Debug)]
12pub struct PrometheusExporter {
13    /// Default labels applied to all metrics
14    default_labels: LabelSet,
15    /// Metric definitions
16    pub(crate) definitions: HashMap<String, MetricDef>,
17    /// Current metric values
18    values: RwLock<HashMap<String, Vec<MetricValue>>>,
19    /// Total samples counter
20    total_samples: AtomicU64,
21}
22
23impl PrometheusExporter {
24    /// Create a new exporter with experiment and run labels
25    pub fn new(experiment: &str, run: &str) -> Self {
26        let default_labels = LabelSet::from_pairs(&[("experiment", experiment), ("run", run)]);
27
28        let mut exporter = Self {
29            default_labels,
30            definitions: HashMap::new(),
31            values: RwLock::new(HashMap::new()),
32            total_samples: AtomicU64::new(0),
33        };
34
35        // Register default training metrics
36        exporter.register_default_metrics();
37        exporter
38    }
39
40    /// Create with custom default labels
41    pub fn with_labels(labels: LabelSet) -> Self {
42        let mut exporter = Self {
43            default_labels: labels,
44            definitions: HashMap::new(),
45            values: RwLock::new(HashMap::new()),
46            total_samples: AtomicU64::new(0),
47        };
48
49        exporter.register_default_metrics();
50        exporter
51    }
52
53    /// Register default training metrics
54    fn register_default_metrics(&mut self) {
55        // Training metrics
56        self.register(
57            MetricDef::gauge("entrenar_epoch_loss", "Training loss per epoch")
58                .with_labels(&["experiment", "run"]),
59        );
60        self.register(
61            MetricDef::gauge("entrenar_validation_loss", "Validation loss per epoch")
62                .with_labels(&["experiment", "run"]),
63        );
64        self.register(
65            MetricDef::gauge("entrenar_learning_rate", "Current learning rate")
66                .with_labels(&["experiment", "run"]),
67        );
68        self.register(
69            MetricDef::gauge("entrenar_batch_throughput", "Batches processed per second")
70                .with_labels(&["experiment", "run"]),
71        );
72        self.register(
73            MetricDef::gauge("entrenar_validation_accuracy", "Validation accuracy")
74                .with_labels(&["experiment", "run"]),
75        );
76        self.register(
77            MetricDef::counter("entrenar_training_steps_total", "Total training steps completed")
78                .with_labels(&["experiment", "run"]),
79        );
80
81        // GPU metrics
82        self.register(
83            MetricDef::gauge("entrenar_gpu_utilization", "GPU utilization percentage")
84                .with_labels(&["experiment", "run", "device"]),
85        );
86        self.register(
87            MetricDef::gauge("entrenar_gpu_memory_used_bytes", "GPU memory used in bytes")
88                .with_labels(&["experiment", "run", "device"]),
89        );
90        self.register(
91            MetricDef::gauge("entrenar_gpu_temperature_celsius", "GPU temperature in Celsius")
92                .with_labels(&["experiment", "run", "device"]),
93        );
94        self.register(
95            MetricDef::gauge("entrenar_gpu_power_watts", "GPU power draw in watts").with_labels(&[
96                "experiment",
97                "run",
98                "device",
99            ]),
100        );
101
102        // System metrics
103        self.register(
104            MetricDef::gauge("entrenar_memory_used_bytes", "Process memory usage in bytes")
105                .with_labels(&["experiment", "run"]),
106        );
107    }
108
109    /// Register a metric definition
110    pub fn register(&mut self, def: MetricDef) {
111        self.definitions.insert(def.name.clone(), def);
112    }
113
114    /// Record a metric value
115    pub fn record(&self, name: &str, value: f64) {
116        self.record_with_labels(name, value, self.default_labels.clone());
117    }
118
119    /// Record a metric value with additional labels
120    pub fn record_with_labels(&self, name: &str, value: f64, mut labels: LabelSet) {
121        // Merge default labels
122        for (k, v) in &self.default_labels.values {
123            if !labels.values.iter().any(|(lk, _)| lk == k) {
124                labels.values.push((k.clone(), v.clone()));
125            }
126        }
127
128        let metric_value = MetricValue { labels, value, timestamp: Some(current_timestamp_ms()) };
129
130        if let Ok(mut values) = self.values.write() {
131            values.entry(name.to_string()).or_default().push(metric_value);
132        }
133
134        self.total_samples.fetch_add(1, Ordering::Relaxed);
135    }
136
137    /// Record epoch metrics
138    pub fn record_epoch(&self, epoch: u32, loss: f64, lr: f64) {
139        self.record("entrenar_epoch_loss", loss);
140        self.record("entrenar_learning_rate", lr);
141        self.record("entrenar_training_steps_total", f64::from(epoch));
142    }
143
144    /// Record validation metrics
145    pub fn record_validation(&self, loss: f64, accuracy: f64) {
146        self.record("entrenar_validation_loss", loss);
147        self.record("entrenar_validation_accuracy", accuracy);
148    }
149
150    /// Record batch throughput
151    pub fn record_batch(&self, batches_per_second: f64) {
152        self.record("entrenar_batch_throughput", batches_per_second);
153    }
154
155    /// Record GPU metrics for a device
156    pub fn record_gpu(
157        &self,
158        device_id: u32,
159        utilization: f64,
160        memory_bytes: f64,
161        temp: f64,
162        power: f64,
163    ) {
164        let labels = self.default_labels.clone().add("device", &device_id.to_string());
165
166        self.record_with_labels("entrenar_gpu_utilization", utilization, labels.clone());
167        self.record_with_labels("entrenar_gpu_memory_used_bytes", memory_bytes, labels.clone());
168        self.record_with_labels("entrenar_gpu_temperature_celsius", temp, labels.clone());
169        self.record_with_labels("entrenar_gpu_power_watts", power, labels);
170    }
171
172    /// Record system memory usage
173    pub fn record_memory(&self, used_bytes: f64) {
174        self.record("entrenar_memory_used_bytes", used_bytes);
175    }
176
177    /// Get total samples recorded
178    pub fn total_samples(&self) -> u64 {
179        self.total_samples.load(Ordering::Relaxed)
180    }
181
182    /// Clear all recorded values
183    pub fn clear(&self) {
184        if let Ok(mut values) = self.values.write() {
185            values.clear();
186        }
187    }
188
189    /// Export metrics in Prometheus text format
190    pub fn export(&self) -> String {
191        let mut output = String::new();
192
193        let values = match self.values.read() {
194            Ok(v) => v,
195            Err(_) => return output,
196        };
197
198        for (name, def) in &self.definitions {
199            if let Some(metric_values) = values.get(name) {
200                // Only export if we have values
201                if metric_values.is_empty() {
202                    continue;
203                }
204
205                // HELP line
206                output.push_str(&format!("# HELP {} {}\n", name, def.help));
207
208                // TYPE line
209                let type_str = match def.metric_type {
210                    MetricType::Gauge => "gauge",
211                    MetricType::Counter => "counter",
212                    MetricType::Histogram => "histogram",
213                };
214                output.push_str(&format!("# TYPE {name} {type_str}\n"));
215
216                // Get latest value for each unique label set
217                let mut latest: HashMap<String, &MetricValue> = HashMap::new();
218                for mv in metric_values {
219                    let key = mv.labels.format();
220                    latest.insert(key, mv);
221                }
222
223                // Metric lines
224                for mv in latest.values() {
225                    let labels_str = mv.labels.format();
226                    if let Some(ts) = mv.timestamp {
227                        output.push_str(&format!("{}{} {} {}\n", name, labels_str, mv.value, ts));
228                    } else {
229                        output.push_str(&format!("{}{} {}\n", name, labels_str, mv.value));
230                    }
231                }
232            }
233        }
234
235        output
236    }
237
238    /// Export metrics as JSON (for programmatic access)
239    pub fn export_json(&self) -> String {
240        let values = match self.values.read() {
241            Ok(v) => v,
242            Err(_) => return "{}".to_string(),
243        };
244
245        let mut metrics: HashMap<String, Vec<serde_json::Value>> = HashMap::new();
246
247        for (name, metric_values) in values.iter() {
248            let json_values: Vec<serde_json::Value> = metric_values
249                .iter()
250                .map(|mv| {
251                    let mut obj = serde_json::Map::new();
252                    for (k, v) in &mv.labels.values {
253                        obj.insert(k.clone(), serde_json::Value::String(v.clone()));
254                    }
255                    obj.insert("value".to_string(), serde_json::json!(mv.value));
256                    if let Some(ts) = mv.timestamp {
257                        obj.insert("timestamp".to_string(), serde_json::json!(ts));
258                    }
259                    serde_json::Value::Object(obj)
260                })
261                .collect();
262
263            metrics.insert(name.clone(), json_values);
264        }
265
266        serde_json::to_string_pretty(&metrics).unwrap_or_default()
267    }
268}
269
270/// Get current timestamp in milliseconds
271fn current_timestamp_ms() -> u64 {
272    SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_millis() as u64).unwrap_or(0)
273}