Skip to main content

jwks_cache/
metrics.rs

1//! Metrics helpers and per-provider telemetry bookkeeping.
2
3// std
4#[cfg(feature = "prometheus")] use std::sync::OnceLock;
5use std::sync::atomic::{AtomicU64, Ordering};
6// crates.io
7use metrics::Label;
8#[cfg(feature = "prometheus")]
9use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
10use smallvec::SmallVec;
11// self
12use crate::_prelude::*;
13
14type LabelSet = SmallVec<[Label; 4]>;
15
16const METRIC_REQUESTS_TOTAL: &str = "jwks_cache_requests_total";
17const METRIC_HITS_TOTAL: &str = "jwks_cache_hits_total";
18const METRIC_STALE_TOTAL: &str = "jwks_cache_stale_total";
19const METRIC_MISSES_TOTAL: &str = "jwks_cache_misses_total";
20const METRIC_REFRESH_TOTAL: &str = "jwks_cache_refresh_total";
21const METRIC_REFRESH_DURATION: &str = "jwks_cache_refresh_duration_seconds";
22const METRIC_REFRESH_ERRORS: &str = "jwks_cache_refresh_errors_total";
23
24/// Shared Prometheus handle installed by [`install_default_exporter`].
25#[cfg(feature = "prometheus")]
26static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
27
28/// Thread-safe metrics accumulator for a single provider registration.
29#[derive(Debug, Default)]
30pub struct ProviderMetrics {
31	total_requests: AtomicU64,
32	cache_hits: AtomicU64,
33	stale_serves: AtomicU64,
34	refresh_successes: AtomicU64,
35	refresh_errors: AtomicU64,
36	last_refresh_micros: AtomicU64,
37}
38impl ProviderMetrics {
39	/// Create a new metrics accumulator.
40	pub fn new() -> Arc<Self> {
41		Arc::new(Self::default())
42	}
43
44	/// Record a hit outcome.
45	pub fn record_hit(&self, stale: bool) {
46		self.total_requests.fetch_add(1, Ordering::Relaxed);
47		self.cache_hits.fetch_add(1, Ordering::Relaxed);
48		if stale {
49			self.stale_serves.fetch_add(1, Ordering::Relaxed);
50		}
51	}
52
53	/// Record a miss outcome.
54	pub fn record_miss(&self) {
55		self.total_requests.fetch_add(1, Ordering::Relaxed);
56	}
57
58	/// Record a successful refresh and latency.
59	pub fn record_refresh_success(&self, duration: Duration) {
60		self.refresh_successes.fetch_add(1, Ordering::Relaxed);
61		self.last_refresh_micros.store(duration.as_micros() as u64, Ordering::Relaxed);
62	}
63
64	/// Record refresh failure.
65	pub fn record_refresh_error(&self) {
66		self.refresh_errors.fetch_add(1, Ordering::Relaxed);
67	}
68
69	/// Take a point-in-time snapshot for status reporting.
70	pub fn snapshot(&self) -> ProviderMetricsSnapshot {
71		ProviderMetricsSnapshot {
72			total_requests: self.total_requests.load(Ordering::Relaxed),
73			cache_hits: self.cache_hits.load(Ordering::Relaxed),
74			stale_serves: self.stale_serves.load(Ordering::Relaxed),
75			refresh_successes: self.refresh_successes.load(Ordering::Relaxed),
76			refresh_errors: self.refresh_errors.load(Ordering::Relaxed),
77			last_refresh_micros: match self.last_refresh_micros.load(Ordering::Relaxed) {
78				0 => None,
79				value => Some(value),
80			},
81		}
82	}
83}
84
85/// Read-only snapshot of per-provider telemetry counters.
86#[derive(Clone, Debug)]
87pub struct ProviderMetricsSnapshot {
88	/// Total number of cache lookups observed.
89	pub total_requests: u64,
90	/// Count of lookups served from the cache.
91	pub cache_hits: u64,
92	/// Count of lookups served from stale payloads.
93	pub stale_serves: u64,
94	/// Count of successful refresh operations.
95	pub refresh_successes: u64,
96	/// Count of refresh attempts that resulted in errors.
97	pub refresh_errors: u64,
98	/// Microsecond latency of the most recent refresh.
99	pub last_refresh_micros: Option<u64>,
100}
101impl ProviderMetricsSnapshot {
102	/// Convenience method to compute the cache hit rate.
103	pub fn hit_rate(&self) -> f64 {
104		if self.total_requests == 0 {
105			0.0
106		} else {
107			self.cache_hits as f64 / self.total_requests as f64
108		}
109	}
110
111	/// Ratio of stale serves over total requests.
112	pub fn stale_ratio(&self) -> f64 {
113		if self.total_requests == 0 {
114			0.0
115		} else {
116			self.stale_serves as f64 / self.total_requests as f64
117		}
118	}
119}
120
121/// Install the default Prometheus recorder backed by `metrics`.
122///
123/// Multiple invocations are safe; subsequent calls become no-ops once the recorder is installed.
124#[cfg(feature = "prometheus")]
125pub fn install_default_exporter() -> Result<()> {
126	if PROMETHEUS_HANDLE.get().is_some() {
127		return Ok(());
128	}
129
130	let handle = PrometheusBuilder::new()
131		.install_recorder()
132		.map_err(|err| Error::Metrics(err.to_string()))?;
133	let _ = PROMETHEUS_HANDLE.set(handle);
134
135	Ok(())
136}
137
138/// Access the global Prometheus exporter handle when installed.
139#[cfg(feature = "prometheus")]
140pub fn prometheus_handle() -> Option<&'static PrometheusHandle> {
141	PROMETHEUS_HANDLE.get()
142}
143
144/// Record a cache hit, tagging whether it was served stale.
145pub fn record_resolve_hit(tenant: &str, provider: &str, stale: bool) {
146	let labels = base_labels(tenant, provider);
147
148	metrics::counter!(METRIC_REQUESTS_TOTAL, labels.iter()).increment(1);
149	metrics::counter!(METRIC_HITS_TOTAL, labels.iter()).increment(1);
150
151	if stale {
152		metrics::counter!(METRIC_STALE_TOTAL, labels.iter()).increment(1);
153	}
154}
155
156/// Record a cache miss that required an upstream fetch.
157pub fn record_resolve_miss(tenant: &str, provider: &str) {
158	let labels = base_labels(tenant, provider);
159
160	metrics::counter!(METRIC_REQUESTS_TOTAL, labels.iter()).increment(1);
161	metrics::counter!(METRIC_MISSES_TOTAL, labels.iter()).increment(1);
162}
163
164/// Record a successful refresh attempt along with its latency.
165pub fn record_refresh_success(tenant: &str, provider: &str, duration: Duration) {
166	metrics::counter!(METRIC_REFRESH_TOTAL, status_labels(tenant, provider, "success").iter())
167		.increment(1);
168	metrics::histogram!(METRIC_REFRESH_DURATION, base_labels(tenant, provider).iter())
169		.record(duration.as_secs_f64());
170}
171
172/// Record a failed refresh attempt.
173pub fn record_refresh_error(tenant: &str, provider: &str) {
174	metrics::counter!(METRIC_REFRESH_TOTAL, status_labels(tenant, provider, "error").iter())
175		.increment(1);
176	metrics::counter!(METRIC_REFRESH_ERRORS, base_labels(tenant, provider).iter()).increment(1);
177}
178
179fn base_labels(tenant: &str, provider: &str) -> LabelSet {
180	let mut labels = LabelSet::with_capacity(2);
181
182	labels.push(Label::new("tenant", tenant.to_owned()));
183	labels.push(Label::new("provider", provider.to_owned()));
184
185	labels
186}
187
188fn status_labels(tenant: &str, provider: &str, status: &'static str) -> LabelSet {
189	let mut labels = base_labels(tenant, provider);
190
191	labels.push(Label::new("status", status));
192
193	labels
194}
195
196#[cfg(test)]
197mod tests {
198	// std
199	use std::borrow::Borrow;
200	// crates.io
201	use metrics_util::{
202		CompositeKey, MetricKind,
203		debugging::{DebugValue, DebuggingRecorder},
204	};
205	// self
206	use super::*;
207
208	fn capture_metrics<F>(f: F) -> Vec<(CompositeKey, DebugValue)>
209	where
210		F: FnOnce(),
211	{
212		let recorder = DebuggingRecorder::new();
213		let snapshotter = recorder.snapshotter();
214
215		metrics::with_local_recorder(&recorder, f);
216
217		snapshotter
218			.snapshot()
219			.into_vec()
220			.into_iter()
221			.map(|(key, _, _, value)| (key, value))
222			.collect()
223	}
224
225	fn counter_value(
226		snapshot: &[(CompositeKey, DebugValue)],
227		name: &str,
228		labels: &[(&str, &str)],
229	) -> u64 {
230		snapshot
231			.iter()
232			.find_map(|(key, value)| {
233				(key.kind() == MetricKind::Counter
234					&& Borrow::<str>::borrow(key.key().name()) == name
235					&& labels_match(key, labels))
236				.then_some(match value {
237					DebugValue::Counter(value) => *value,
238					_ => 0,
239				})
240			})
241			.unwrap_or(0)
242	}
243
244	fn last_histogram_value(
245		snapshot: &[(CompositeKey, DebugValue)],
246		name: &str,
247		labels: &[(&str, &str)],
248	) -> Option<f64> {
249		snapshot.iter().find_map(|(key, value)| {
250			if key.kind() == MetricKind::Histogram
251				&& Borrow::<str>::borrow(key.key().name()) == name
252				&& labels_match(key, labels)
253			{
254				if let DebugValue::Histogram(values) = value {
255					values.last().map(|v| v.into_inner())
256				} else {
257					None
258				}
259			} else {
260				None
261			}
262		})
263	}
264
265	fn labels_match(key: &CompositeKey, expected: &[(&str, &str)]) -> bool {
266		let mut labels: Vec<_> =
267			key.key().labels().map(|label| (label.key(), label.value())).collect();
268
269		labels.sort_unstable();
270
271		let mut expected_sorted: Vec<_> = expected.to_vec();
272
273		expected_sorted.sort_unstable();
274
275		labels.len() == expected_sorted.len()
276			&& labels
277				.into_iter()
278				.zip(expected_sorted)
279				.all(|((lk, lv), (ek, ev))| lk == ek && lv == ev)
280	}
281
282	#[test]
283	fn records_hits_misses_and_stale_counts() {
284		let snapshot = capture_metrics(|| {
285			record_resolve_hit("tenant-a", "provider-1", false);
286			record_resolve_hit("tenant-a", "provider-1", true);
287			record_resolve_miss("tenant-a", "provider-1");
288		});
289		let base = [("tenant", "tenant-a"), ("provider", "provider-1")];
290
291		assert_eq!(counter_value(&snapshot, "jwks_cache_requests_total", &base), 3);
292		assert_eq!(counter_value(&snapshot, "jwks_cache_hits_total", &base), 2);
293		assert_eq!(counter_value(&snapshot, "jwks_cache_misses_total", &base), 1);
294		assert_eq!(counter_value(&snapshot, "jwks_cache_stale_total", &base), 1);
295	}
296
297	#[test]
298	#[cfg_attr(miri, ignore)]
299	fn records_refresh_success_and_errors() {
300		let snapshot = capture_metrics(|| {
301			record_refresh_success("tenant-b", "provider-2", std::time::Duration::from_millis(20));
302			record_refresh_error("tenant-b", "provider-2");
303		});
304		let base = [("tenant", "tenant-b"), ("provider", "provider-2")];
305		let success = [("tenant", "tenant-b"), ("provider", "provider-2"), ("status", "success")];
306		let error = [("tenant", "tenant-b"), ("provider", "provider-2"), ("status", "error")];
307
308		assert_eq!(counter_value(&snapshot, "jwks_cache_refresh_total", &success), 1);
309		assert_eq!(counter_value(&snapshot, "jwks_cache_refresh_total", &error), 1);
310		assert_eq!(counter_value(&snapshot, "jwks_cache_refresh_errors_total", &base), 1);
311
312		let duration =
313			last_histogram_value(&snapshot, "jwks_cache_refresh_duration_seconds", &base)
314				.expect("refresh duration recorded");
315
316		assert!((duration - 0.020).abs() < 1e-6, "expected ~20ms histogram, got {duration}");
317	}
318}