1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use tokio::sync::{Mutex, RwLock};
10
11use crate::jwks::validate_https_public_uri;
12use crate::types::AuthError;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct IntrospectionResult {
16 pub active: bool,
17 #[serde(default)]
18 pub sub: Option<String>,
19 #[serde(default)]
20 pub exp: Option<u64>,
21 #[serde(default)]
22 pub iat: Option<u64>,
23 #[serde(default)]
24 pub nbf: Option<u64>,
25 #[serde(default)]
26 pub scope: Option<String>,
27 #[serde(default)]
28 pub client_id: Option<String>,
29 #[serde(default)]
30 pub token_type: Option<String>,
31 #[serde(default)]
32 pub iss: Option<String>,
33 #[serde(default)]
34 pub aud: Option<serde_json::Value>,
35 #[serde(flatten)]
36 pub extra: serde_json::Map<String, serde_json::Value>,
37}
38
39#[derive(Debug, Clone)]
40pub struct IntrospectionCacheOptions {
41 pub max_entries: usize,
42 pub default_ttl: Duration,
43 pub negative_ttl: Duration,
44}
45
46impl Default for IntrospectionCacheOptions {
47 fn default() -> Self {
48 Self {
49 max_entries: 10_000,
50 default_ttl: Duration::from_secs(60),
51 negative_ttl: Duration::from_secs(5),
52 }
53 }
54}
55
56#[async_trait]
57pub trait TokenIntrospector: Send + Sync {
58 async fn introspect(&self, token: &str) -> Result<IntrospectionResult, AuthError>;
59}
60
61pub(crate) struct CachedEntry {
62 result: IntrospectionResult,
63 expires_at: Instant,
64}
65
66pub struct CachingTokenIntrospector {
67 endpoint: String,
68 client_id: String,
69 client_secret: String,
70 http: reqwest::Client,
71 pub(crate) cache: Arc<RwLock<HashMap<String, CachedEntry>>>,
72 in_flight: Mutex<()>,
73 max_cache_size: usize,
74 default_ttl: Duration,
75 negative_ttl: Duration,
76}
77
78impl CachingTokenIntrospector {
79 pub fn new(
80 endpoint: String,
81 client_id: String,
82 client_secret: String,
83 options: IntrospectionCacheOptions,
84 ) -> Result<Self, AuthError> {
85 validate_https_public_uri(&endpoint, "introspection endpoint URI")?;
86 let http = reqwest::Client::builder()
87 .connect_timeout(Duration::from_secs(5))
88 .timeout(Duration::from_secs(10))
89 .build()
90 .map_err(|e| AuthError::ConfigError(format!("failed to build HTTP client: {e}")))?;
91 Ok(Self::with_client(
92 endpoint,
93 client_id,
94 client_secret,
95 options,
96 http,
97 ))
98 }
99
100 #[doc(hidden)]
101 pub fn new_unchecked_for_test(
102 endpoint: String,
103 client_id: String,
104 client_secret: String,
105 options: IntrospectionCacheOptions,
106 ) -> Self {
107 let http = reqwest::Client::builder()
108 .connect_timeout(Duration::from_secs(5))
109 .timeout(Duration::from_secs(10))
110 .build()
111 .unwrap_or_default();
112 Self::with_client(endpoint, client_id, client_secret, options, http)
113 }
114
115 fn with_client(
116 endpoint: String,
117 client_id: String,
118 client_secret: String,
119 options: IntrospectionCacheOptions,
120 http: reqwest::Client,
121 ) -> Self {
122 Self {
123 endpoint,
124 client_id,
125 client_secret,
126 http,
127 cache: Arc::new(RwLock::new(HashMap::new())),
128 in_flight: Mutex::new(()),
129 max_cache_size: options.max_entries,
130 default_ttl: options.default_ttl,
131 negative_ttl: options.negative_ttl,
132 }
133 }
134
135 fn token_hash(token: &str) -> String {
136 let mut hasher = Sha256::new();
137 hasher.update(token.as_bytes());
138 hex::encode(hasher.finalize())
139 }
140
141 fn compute_ttl(&self, result: &IntrospectionResult) -> Duration {
142 if !result.active {
143 return self.negative_ttl;
144 }
145 if let Some(exp) = result.exp {
146 let now = std::time::SystemTime::now()
147 .duration_since(std::time::UNIX_EPOCH)
148 .unwrap_or_default()
149 .as_secs();
150 if exp > now {
151 let remaining = Duration::from_secs(exp - now);
152 return remaining.min(self.default_ttl);
153 }
154 }
155 self.default_ttl
156 }
157
158 async fn evict_if_needed(&self) {
159 let mut cache = self.cache.write().await;
160 if cache.len() < self.max_cache_size {
161 return;
162 }
163 let now = Instant::now();
164 cache.retain(|_, entry| entry.expires_at > now);
165 if cache.len() >= self.max_cache_size {
166 let oldest_key = cache
167 .iter()
168 .min_by_key(|(_, e)| e.expires_at)
169 .map(|(k, _)| k.clone());
170 if let Some(key) = oldest_key {
171 cache.remove(&key);
172 }
173 }
174 }
175}
176
177impl fmt::Debug for CachingTokenIntrospector {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 f.debug_struct("CachingTokenIntrospector")
180 .field("endpoint", &self.endpoint)
181 .field("client_id", &self.client_id)
182 .field("client_secret", &"[REDACTED]")
183 .field("max_cache_size", &self.max_cache_size)
184 .field("default_ttl", &self.default_ttl)
185 .field("negative_ttl", &self.negative_ttl)
186 .finish_non_exhaustive()
187 }
188}
189
190#[async_trait]
191impl TokenIntrospector for CachingTokenIntrospector {
192 async fn introspect(&self, token: &str) -> Result<IntrospectionResult, AuthError> {
193 let key = Self::token_hash(token);
194
195 {
196 let cache = self.cache.read().await;
197 if let Some(entry) = cache.get(&key)
198 && entry.expires_at > Instant::now()
199 {
200 tracing::debug!(target: "camel_auth::introspection", cache_outcome = "hit");
201 return Ok(entry.result.clone());
202 }
203 }
204
205 let _guard = self.in_flight.lock().await;
206
207 {
208 let cache = self.cache.read().await;
209 if let Some(entry) = cache.get(&key)
210 && entry.expires_at > Instant::now()
211 {
212 tracing::debug!(target: "camel_auth::introspection", cache_outcome = "hit_after_wait");
213 return Ok(entry.result.clone());
214 }
215 }
216
217 tracing::debug!(
218 target: "camel_auth::introspection",
219 cache_outcome = "miss"
220 );
221
222 let response = self
223 .http
224 .post(&self.endpoint)
225 .form(&[
226 ("token", token),
227 ("client_id", &self.client_id),
228 ("client_secret", &self.client_secret),
229 ])
230 .send()
231 .await
232 .map_err(|e| {
233 AuthError::ProviderUnavailable(format!("introspection request failed: {e}"))
234 })?;
235
236 let status = response.status();
237 if status.as_u16() == 401 || status.as_u16() == 403 {
238 return Err(AuthError::ProviderUnavailable(
239 "introspection client unauthorized".into(),
240 ));
241 }
242 if status.is_server_error() {
243 return Err(AuthError::ProviderUnavailable(format!(
244 "introspection endpoint returned {}",
245 status
246 )));
247 }
248 if status.is_client_error() {
249 return Err(AuthError::TokenInvalid(format!(
250 "introspection endpoint returned client error {}",
251 status
252 )));
253 }
254
255 let result: IntrospectionResult = response.json().await.map_err(|e| {
256 AuthError::ProviderUnavailable(format!("invalid introspection response: {e}"))
257 })?;
258
259 let ttl = self.compute_ttl(&result);
260 let entry = CachedEntry {
261 result: result.clone(),
262 expires_at: Instant::now() + ttl,
263 };
264
265 self.evict_if_needed().await;
266 {
267 let mut cache = self.cache.write().await;
268 cache.insert(key, entry);
269 }
270
271 Ok(result)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use wiremock::matchers::{body_string_contains, method};
279 use wiremock::{Mock, MockServer, ResponseTemplate};
280
281 #[test]
282 fn deserialize_minimal_active() {
283 let json = r#"{"active": true}"#;
284 let result: IntrospectionResult = serde_json::from_str(json).unwrap();
285 assert!(result.active);
286 assert!(result.sub.is_none());
287 assert!(result.extra.is_empty());
288 }
289
290 #[test]
291 fn deserialize_full_rfc7662() {
292 let json = r#"{
293 "active": true,
294 "sub": "user-1",
295 "exp": 1700000000,
296 "iat": 1699999999,
297 "nbf": 1699999900,
298 "scope": "read write",
299 "client_id": "my-client",
300 "token_type": "Bearer",
301 "iss": "https://kc.example.com/realms/test",
302 "aud": ["my-api"],
303 "realm_access": {"roles": ["admin", "user"]},
304 "resource_access": {"my-client": {"roles": ["client-role"]}}
305 }"#;
306 let result: IntrospectionResult = serde_json::from_str(json).unwrap();
307 assert!(result.active);
308 assert_eq!(result.sub.as_deref(), Some("user-1"));
309 assert_eq!(result.exp, Some(1700000000));
310 assert_eq!(result.scope.as_deref(), Some("read write"));
311 assert_eq!(result.client_id.as_deref(), Some("my-client"));
312 assert_eq!(result.token_type.as_deref(), Some("Bearer"));
313 assert_eq!(
314 result.iss.as_deref(),
315 Some("https://kc.example.com/realms/test")
316 );
317 assert!(result.extra.contains_key("realm_access"));
318 assert!(result.extra.contains_key("resource_access"));
319 }
320
321 #[test]
322 fn deserialize_inactive() {
323 let json = r#"{"active": false}"#;
324 let result: IntrospectionResult = serde_json::from_str(json).unwrap();
325 assert!(!result.active);
326 }
327
328 #[test]
329 fn deserialize_unknown_fields_go_to_extra() {
330 let json = r#"{"active": true, "custom_field": "hello"}"#;
331 let result: IntrospectionResult = serde_json::from_str(json).unwrap();
332 assert_eq!(result.extra["custom_field"], "hello");
333 }
334
335 #[test]
336 fn cache_options_defaults() {
337 let opts = IntrospectionCacheOptions::default();
338 assert_eq!(opts.max_entries, 10_000);
339 assert_eq!(opts.default_ttl, Duration::from_secs(60));
340 assert_eq!(opts.negative_ttl, Duration::from_secs(5));
341 }
342
343 fn test_cache_opts() -> IntrospectionCacheOptions {
344 IntrospectionCacheOptions {
345 max_entries: 100,
346 default_ttl: Duration::from_secs(60),
347 negative_ttl: Duration::from_secs(2),
348 }
349 }
350
351 #[tokio::test]
352 async fn cache_hit_returns_cached_result_without_http_call() {
353 let server = MockServer::start().await;
354 Mock::given(method("POST"))
355 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
356 "active": true,
357 "sub": "cached-user"
358 })))
359 .expect(1)
360 .mount(&server)
361 .await;
362
363 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
364 server.uri(),
365 "client-id".into(),
366 "client-secret".into(),
367 test_cache_opts(),
368 );
369
370 let r1 = introspector.introspect("token-a").await.unwrap();
371 let r2 = introspector.introspect("token-a").await.unwrap();
372 assert_eq!(r1.sub, r2.sub);
373 assert!(r1.active);
374 }
375
376 #[tokio::test]
377 async fn expired_entry_re_introspects() {
378 let server = MockServer::start().await;
379 Mock::given(method("POST"))
380 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
381 "active": true, "sub": "user"
382 })))
383 .expect(2)
384 .mount(&server)
385 .await;
386
387 let opts = IntrospectionCacheOptions {
388 max_entries: 100,
389 default_ttl: Duration::from_millis(50),
390 negative_ttl: Duration::from_secs(2),
391 };
392 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
393 server.uri(),
394 "cid".into(),
395 "cs".into(),
396 opts,
397 );
398
399 introspector.introspect("tok").await.unwrap();
400 tokio::time::sleep(Duration::from_millis(80)).await;
401 introspector.introspect("tok").await.unwrap();
402 }
403
404 #[tokio::test]
405 async fn inactive_token_cached_with_negative_ttl() {
406 let server = MockServer::start().await;
407 Mock::given(method("POST"))
408 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
409 "active": false
410 })))
411 .expect(1)
412 .mount(&server)
413 .await;
414
415 let opts = IntrospectionCacheOptions {
416 max_entries: 100,
417 default_ttl: Duration::from_secs(60),
418 negative_ttl: Duration::from_secs(10),
419 };
420 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
421 server.uri(),
422 "cid".into(),
423 "cs".into(),
424 opts,
425 );
426
427 let r = introspector.introspect("dead-token").await.unwrap();
428 assert!(!r.active);
429 let r2 = introspector.introspect("dead-token").await.unwrap();
430 assert!(!r2.active);
431 }
432
433 #[tokio::test]
434 async fn cache_key_does_not_contain_raw_token() {
435 let server = MockServer::start().await;
436 Mock::given(method("POST"))
437 .respond_with(
438 ResponseTemplate::new(200).set_body_json(serde_json::json!({"active": true})),
439 )
440 .mount(&server)
441 .await;
442
443 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
444 server.uri(),
445 "cid".into(),
446 "cs".into(),
447 test_cache_opts(),
448 );
449
450 introspector.introspect("secret-token-value").await.unwrap();
451 let cache = introspector.cache.read().await;
452 for key in cache.keys() {
453 assert!(
454 !key.contains("secret-token-value"),
455 "cache key must not contain raw token"
456 );
457 }
458 }
459
460 #[tokio::test]
461 async fn eviction_removes_oldest_when_over_capacity() {
462 let server = MockServer::start().await;
463 for i in 0..5 {
464 Mock::given(method("POST"))
465 .and(body_string_contains(format!("token-{i}"))) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
467 "active": true, "sub": format!("user-{i}")
468 })))
469 .mount(&server)
470 .await;
471 }
472
473 let opts = IntrospectionCacheOptions {
474 max_entries: 2,
475 default_ttl: Duration::from_secs(600),
476 negative_ttl: Duration::from_secs(5),
477 };
478 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
479 server.uri(),
480 "cid".into(),
481 "cs".into(),
482 opts,
483 );
484
485 introspector.introspect("token-0").await.unwrap();
486 introspector.introspect("token-1").await.unwrap();
487 introspector.introspect("token-2").await.unwrap();
488 introspector.introspect("token-3").await.unwrap();
489 introspector.introspect("token-4").await.unwrap();
490
491 let cache = introspector.cache.read().await;
492 assert!(cache.len() <= 2, "cache must respect max_entries");
493 }
494
495 #[test]
496 fn debug_redacts_client_secret() {
497 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
498 "https://example.com".into(),
499 "cid".into(),
500 "super-secret-value".into(),
501 test_cache_opts(),
502 );
503 let debug = format!("{introspector:?}");
504 assert!(
505 !debug.contains("super-secret-value"),
506 "Debug must not leak client_secret"
507 );
508 assert!(debug.contains("REDACTED"));
509 }
510
511 #[tokio::test]
512 async fn production_constructor_rejects_http_endpoint() {
513 let result = CachingTokenIntrospector::new(
514 "http://insecure.example.com/introspect".into(),
515 "cid".into(),
516 "cs".into(),
517 test_cache_opts(),
518 );
519 assert!(result.is_err());
520 let err = result.unwrap_err();
521 assert!(matches!(err, AuthError::ConfigError(ref s) if s.contains("HTTPS")));
522 }
523
524 #[tokio::test]
525 async fn production_constructor_rejects_localhost() {
526 let result = CachingTokenIntrospector::new(
527 "https://localhost:8080/introspect".into(),
528 "cid".into(),
529 "cs".into(),
530 test_cache_opts(),
531 );
532 assert!(result.is_err());
533 let err = result.unwrap_err();
534 assert!(
535 matches!(err, AuthError::ConfigError(ref s) if s.contains("private") || s.contains("loopback"))
536 );
537 }
538
539 #[tokio::test]
540 async fn http_error_500_returns_provider_unavailable() {
541 let server = MockServer::start().await;
542 Mock::given(method("POST"))
543 .respond_with(ResponseTemplate::new(500))
544 .mount(&server)
545 .await;
546
547 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
548 server.uri(),
549 "cid".into(),
550 "cs".into(),
551 test_cache_opts(),
552 );
553 let result = introspector.introspect("tok").await;
554 assert!(result.is_err());
555 let err = result.unwrap_err();
556 assert!(matches!(err, AuthError::ProviderUnavailable(_)));
557 }
558
559 #[tokio::test]
560 async fn http_401_returns_provider_unavailable() {
561 let server = MockServer::start().await;
562 Mock::given(method("POST"))
563 .respond_with(ResponseTemplate::new(401))
564 .mount(&server)
565 .await;
566
567 let introspector = CachingTokenIntrospector::new_unchecked_for_test(
568 server.uri(),
569 "cid".into(),
570 "cs".into(),
571 test_cache_opts(),
572 );
573 let result = introspector.introspect("tok").await;
574 assert!(result.is_err());
575 let err = result.unwrap_err();
576 assert!(matches!(err, AuthError::ProviderUnavailable(ref s) if s.contains("unauthorized")));
577 }
578}