1#[cfg(feature = "prometheus")] use std::sync::OnceLock;
5use std::sync::atomic::{AtomicU64, Ordering};
6use metrics::Label;
8#[cfg(feature = "prometheus")]
9use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
10use smallvec::SmallVec;
11use crate::_prelude::*;
13
14type LabelSet = SmallVec<[Label; 4]>;
15
16const METRIC_REQUESTS_TOTAL: &str = "jwks_cache_requests_total";
17const METRIC_HITS_TOTAL: &str = "jwks_cache_hits_total";
18const METRIC_STALE_TOTAL: &str = "jwks_cache_stale_total";
19const METRIC_MISSES_TOTAL: &str = "jwks_cache_misses_total";
20const METRIC_REFRESH_TOTAL: &str = "jwks_cache_refresh_total";
21const METRIC_REFRESH_DURATION: &str = "jwks_cache_refresh_duration_seconds";
22const METRIC_REFRESH_ERRORS: &str = "jwks_cache_refresh_errors_total";
23
24#[cfg(feature = "prometheus")]
26static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
27
28#[derive(Debug, Default)]
30pub struct ProviderMetrics {
31 total_requests: AtomicU64,
32 cache_hits: AtomicU64,
33 stale_serves: AtomicU64,
34 refresh_successes: AtomicU64,
35 refresh_errors: AtomicU64,
36 last_refresh_micros: AtomicU64,
37}
38impl ProviderMetrics {
39 pub fn new() -> Arc<Self> {
41 Arc::new(Self::default())
42 }
43
44 pub fn record_hit(&self, stale: bool) {
46 self.total_requests.fetch_add(1, Ordering::Relaxed);
47 self.cache_hits.fetch_add(1, Ordering::Relaxed);
48 if stale {
49 self.stale_serves.fetch_add(1, Ordering::Relaxed);
50 }
51 }
52
53 pub fn record_miss(&self) {
55 self.total_requests.fetch_add(1, Ordering::Relaxed);
56 }
57
58 pub fn record_refresh_success(&self, duration: Duration) {
60 self.refresh_successes.fetch_add(1, Ordering::Relaxed);
61 self.last_refresh_micros.store(duration.as_micros() as u64, Ordering::Relaxed);
62 }
63
64 pub fn record_refresh_error(&self) {
66 self.refresh_errors.fetch_add(1, Ordering::Relaxed);
67 }
68
69 pub fn snapshot(&self) -> ProviderMetricsSnapshot {
71 ProviderMetricsSnapshot {
72 total_requests: self.total_requests.load(Ordering::Relaxed),
73 cache_hits: self.cache_hits.load(Ordering::Relaxed),
74 stale_serves: self.stale_serves.load(Ordering::Relaxed),
75 refresh_successes: self.refresh_successes.load(Ordering::Relaxed),
76 refresh_errors: self.refresh_errors.load(Ordering::Relaxed),
77 last_refresh_micros: match self.last_refresh_micros.load(Ordering::Relaxed) {
78 0 => None,
79 value => Some(value),
80 },
81 }
82 }
83}
84
85#[derive(Clone, Debug)]
87pub struct ProviderMetricsSnapshot {
88 pub total_requests: u64,
90 pub cache_hits: u64,
92 pub stale_serves: u64,
94 pub refresh_successes: u64,
96 pub refresh_errors: u64,
98 pub last_refresh_micros: Option<u64>,
100}
101impl ProviderMetricsSnapshot {
102 pub fn hit_rate(&self) -> f64 {
104 if self.total_requests == 0 {
105 0.0
106 } else {
107 self.cache_hits as f64 / self.total_requests as f64
108 }
109 }
110
111 pub fn stale_ratio(&self) -> f64 {
113 if self.total_requests == 0 {
114 0.0
115 } else {
116 self.stale_serves as f64 / self.total_requests as f64
117 }
118 }
119}
120
121#[cfg(feature = "prometheus")]
125pub fn install_default_exporter() -> Result<()> {
126 if PROMETHEUS_HANDLE.get().is_some() {
127 return Ok(());
128 }
129
130 let handle = PrometheusBuilder::new()
131 .install_recorder()
132 .map_err(|err| Error::Metrics(err.to_string()))?;
133 let _ = PROMETHEUS_HANDLE.set(handle);
134
135 Ok(())
136}
137
138#[cfg(feature = "prometheus")]
140pub fn prometheus_handle() -> Option<&'static PrometheusHandle> {
141 PROMETHEUS_HANDLE.get()
142}
143
144pub fn record_resolve_hit(tenant: &str, provider: &str, stale: bool) {
146 let labels = base_labels(tenant, provider);
147
148 metrics::counter!(METRIC_REQUESTS_TOTAL, labels.iter()).increment(1);
149 metrics::counter!(METRIC_HITS_TOTAL, labels.iter()).increment(1);
150
151 if stale {
152 metrics::counter!(METRIC_STALE_TOTAL, labels.iter()).increment(1);
153 }
154}
155
156pub fn record_resolve_miss(tenant: &str, provider: &str) {
158 let labels = base_labels(tenant, provider);
159
160 metrics::counter!(METRIC_REQUESTS_TOTAL, labels.iter()).increment(1);
161 metrics::counter!(METRIC_MISSES_TOTAL, labels.iter()).increment(1);
162}
163
164pub fn record_refresh_success(tenant: &str, provider: &str, duration: Duration) {
166 metrics::counter!(METRIC_REFRESH_TOTAL, status_labels(tenant, provider, "success").iter())
167 .increment(1);
168 metrics::histogram!(METRIC_REFRESH_DURATION, base_labels(tenant, provider).iter())
169 .record(duration.as_secs_f64());
170}
171
172pub fn record_refresh_error(tenant: &str, provider: &str) {
174 metrics::counter!(METRIC_REFRESH_TOTAL, status_labels(tenant, provider, "error").iter())
175 .increment(1);
176 metrics::counter!(METRIC_REFRESH_ERRORS, base_labels(tenant, provider).iter()).increment(1);
177}
178
179fn base_labels(tenant: &str, provider: &str) -> LabelSet {
180 let mut labels = LabelSet::with_capacity(2);
181
182 labels.push(Label::new("tenant", tenant.to_owned()));
183 labels.push(Label::new("provider", provider.to_owned()));
184
185 labels
186}
187
188fn status_labels(tenant: &str, provider: &str, status: &'static str) -> LabelSet {
189 let mut labels = base_labels(tenant, provider);
190
191 labels.push(Label::new("status", status));
192
193 labels
194}
195
196#[cfg(test)]
197mod tests {
198 use std::borrow::Borrow;
200 use metrics_util::{
202 CompositeKey, MetricKind,
203 debugging::{DebugValue, DebuggingRecorder},
204 };
205 use super::*;
207
208 fn capture_metrics<F>(f: F) -> Vec<(CompositeKey, DebugValue)>
209 where
210 F: FnOnce(),
211 {
212 let recorder = DebuggingRecorder::new();
213 let snapshotter = recorder.snapshotter();
214
215 metrics::with_local_recorder(&recorder, f);
216
217 snapshotter
218 .snapshot()
219 .into_vec()
220 .into_iter()
221 .map(|(key, _, _, value)| (key, value))
222 .collect()
223 }
224
225 fn counter_value(
226 snapshot: &[(CompositeKey, DebugValue)],
227 name: &str,
228 labels: &[(&str, &str)],
229 ) -> u64 {
230 snapshot
231 .iter()
232 .find_map(|(key, value)| {
233 (key.kind() == MetricKind::Counter
234 && Borrow::<str>::borrow(key.key().name()) == name
235 && labels_match(key, labels))
236 .then_some(match value {
237 DebugValue::Counter(value) => *value,
238 _ => 0,
239 })
240 })
241 .unwrap_or(0)
242 }
243
244 fn last_histogram_value(
245 snapshot: &[(CompositeKey, DebugValue)],
246 name: &str,
247 labels: &[(&str, &str)],
248 ) -> Option<f64> {
249 snapshot.iter().find_map(|(key, value)| {
250 if key.kind() == MetricKind::Histogram
251 && Borrow::<str>::borrow(key.key().name()) == name
252 && labels_match(key, labels)
253 {
254 if let DebugValue::Histogram(values) = value {
255 values.last().map(|v| v.into_inner())
256 } else {
257 None
258 }
259 } else {
260 None
261 }
262 })
263 }
264
265 fn labels_match(key: &CompositeKey, expected: &[(&str, &str)]) -> bool {
266 let mut labels: Vec<_> =
267 key.key().labels().map(|label| (label.key(), label.value())).collect();
268
269 labels.sort_unstable();
270
271 let mut expected_sorted: Vec<_> = expected.to_vec();
272
273 expected_sorted.sort_unstable();
274
275 labels.len() == expected_sorted.len()
276 && labels
277 .into_iter()
278 .zip(expected_sorted)
279 .all(|((lk, lv), (ek, ev))| lk == ek && lv == ev)
280 }
281
282 #[test]
283 fn records_hits_misses_and_stale_counts() {
284 let snapshot = capture_metrics(|| {
285 record_resolve_hit("tenant-a", "provider-1", false);
286 record_resolve_hit("tenant-a", "provider-1", true);
287 record_resolve_miss("tenant-a", "provider-1");
288 });
289 let base = [("tenant", "tenant-a"), ("provider", "provider-1")];
290
291 assert_eq!(counter_value(&snapshot, "jwks_cache_requests_total", &base), 3);
292 assert_eq!(counter_value(&snapshot, "jwks_cache_hits_total", &base), 2);
293 assert_eq!(counter_value(&snapshot, "jwks_cache_misses_total", &base), 1);
294 assert_eq!(counter_value(&snapshot, "jwks_cache_stale_total", &base), 1);
295 }
296
297 #[test]
298 #[cfg_attr(miri, ignore)]
299 fn records_refresh_success_and_errors() {
300 let snapshot = capture_metrics(|| {
301 record_refresh_success("tenant-b", "provider-2", std::time::Duration::from_millis(20));
302 record_refresh_error("tenant-b", "provider-2");
303 });
304 let base = [("tenant", "tenant-b"), ("provider", "provider-2")];
305 let success = [("tenant", "tenant-b"), ("provider", "provider-2"), ("status", "success")];
306 let error = [("tenant", "tenant-b"), ("provider", "provider-2"), ("status", "error")];
307
308 assert_eq!(counter_value(&snapshot, "jwks_cache_refresh_total", &success), 1);
309 assert_eq!(counter_value(&snapshot, "jwks_cache_refresh_total", &error), 1);
310 assert_eq!(counter_value(&snapshot, "jwks_cache_refresh_errors_total", &base), 1);
311
312 let duration =
313 last_histogram_value(&snapshot, "jwks_cache_refresh_duration_seconds", &base)
314 .expect("refresh duration recorded");
315
316 assert!((duration - 0.020).abs() < 1e-6, "expected ~20ms histogram, got {duration}");
317 }
318}