1use std::collections::HashMap;
39
40use super::{Role, User};
41
42#[derive(Debug, Clone)]
46pub struct OAuthConfig {
47 pub enabled: bool,
49 pub issuer: String,
52 pub audience: String,
55 pub jwks_url: String,
58 pub identity_mode: OAuthIdentityMode,
60 pub role_claim: Option<String>,
62 pub tenant_claim: Option<String>,
66 pub default_role: Role,
67 pub map_to_existing_users: bool,
68 pub accept_bearer: bool,
72}
73
74impl Default for OAuthConfig {
75 fn default() -> Self {
76 Self {
77 enabled: false,
78 issuer: String::new(),
79 audience: String::new(),
80 jwks_url: String::new(),
81 identity_mode: OAuthIdentityMode::SubClaim,
82 role_claim: None,
83 tenant_claim: None,
84 default_role: Role::Read,
85 map_to_existing_users: true,
86 accept_bearer: true,
87 }
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum OAuthIdentityMode {
93 SubClaim,
95 ClaimField(String),
97}
98
99#[derive(Debug, Clone)]
103pub struct DecodedJwt {
104 pub header: JwtHeader,
105 pub claims: JwtClaims,
106 pub signature: Vec<u8>,
109 pub signing_input: Vec<u8>,
112}
113
114#[derive(Debug, Clone)]
115pub struct JwtHeader {
116 pub alg: String,
117 pub kid: Option<String>,
119}
120
121#[derive(Debug, Clone, Default)]
122pub struct JwtClaims {
123 pub iss: Option<String>,
124 pub sub: Option<String>,
125 pub aud: Vec<String>,
127 pub exp: Option<i64>,
128 pub nbf: Option<i64>,
129 pub iat: Option<i64>,
130 pub extra: HashMap<String, String>,
132}
133
134impl JwtClaims {
135 pub fn claim(&self, key: &str) -> Option<&str> {
136 match key {
137 "iss" => self.iss.as_deref(),
138 "sub" => self.sub.as_deref(),
139 _ => self.extra.get(key).map(|s| s.as_str()),
140 }
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct OAuthIdentity {
147 pub username: String,
148 pub tenant: Option<String>,
151 pub role: Role,
152 pub issuer: String,
153 pub subject: Option<String>,
154 pub expires_at_unix_secs: Option<i64>,
155}
156
157#[derive(Debug, Clone)]
158pub enum OAuthError {
159 Disabled,
160 MissingToken,
162 Malformed(String),
164 WrongIssuer {
166 expected: String,
167 actual: String,
168 },
169 WrongAudience {
171 expected: String,
172 actual: Vec<String>,
173 },
174 Expired {
176 exp: i64,
177 },
178 NotYetValid {
180 nbf: i64,
181 },
182 BadSignature(String),
184 MissingIdentityClaim(OAuthIdentityMode),
186 MissingOrInvalidRole(String),
188 UnknownUser(String),
190 JwksFetch(String),
192}
193
194impl std::fmt::Display for OAuthError {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 match self {
197 OAuthError::Disabled => write!(f, "OAuth disabled on this listener"),
198 OAuthError::MissingToken => write!(f, "no Bearer token"),
199 OAuthError::Malformed(m) => write!(f, "malformed JWT: {m}"),
200 OAuthError::WrongIssuer { expected, actual } => {
201 write!(f, "issuer mismatch: expected {expected}, got {actual}")
202 }
203 OAuthError::WrongAudience { expected, actual } => {
204 write!(
205 f,
206 "audience mismatch: expected {expected}, got {:?}",
207 actual
208 )
209 }
210 OAuthError::Expired { exp } => write!(f, "token expired at unix {exp}"),
211 OAuthError::NotYetValid { nbf } => {
212 write!(f, "token not valid before unix {nbf}")
213 }
214 OAuthError::BadSignature(m) => write!(f, "signature verification failed: {m}"),
215 OAuthError::MissingIdentityClaim(mode) => {
216 write!(f, "identity claim missing for mode {:?}", mode)
217 }
218 OAuthError::MissingOrInvalidRole(c) => {
219 write!(f, "role claim '{c}' missing or not a valid Role")
220 }
221 OAuthError::UnknownUser(u) => write!(f, "OAuth user '{u}' not in auth store"),
222 OAuthError::JwksFetch(m) => write!(f, "JWKS fetch failed: {m}"),
223 }
224 }
225}
226
227impl std::error::Error for OAuthError {}
228
229#[derive(Debug, Clone)]
235pub struct Jwk {
236 pub kid: String,
237 pub alg: String,
238 pub key_bytes: Vec<u8>,
242}
243
244pub type JwtVerifier = Box<dyn Fn(&Jwk, &[u8], &[u8]) -> Result<(), String> + Send + Sync>;
248
249pub struct OAuthValidator {
250 config: OAuthConfig,
251 jwks: parking_lot::RwLock<Vec<Jwk>>,
252 verifier: JwtVerifier,
253}
254
255impl OAuthValidator {
256 pub fn with_verifier(config: OAuthConfig, verifier: JwtVerifier) -> Self {
260 Self {
261 config,
262 jwks: parking_lot::RwLock::new(Vec::new()),
263 verifier,
264 }
265 }
266
267 pub fn set_jwks(&self, keys: Vec<Jwk>) {
270 *self.jwks.write() = keys;
271 }
272
273 pub fn config(&self) -> &OAuthConfig {
274 &self.config
275 }
276
277 pub fn extract_bearer(&self, header_value: &str) -> Option<String> {
280 if !self.config.accept_bearer {
281 return None;
282 }
283 let trimmed = header_value.trim();
284 let prefix = "Bearer ";
285 if trimmed.len() > prefix.len() && trimmed[..prefix.len()].eq_ignore_ascii_case(prefix) {
286 Some(trimmed[prefix.len()..].trim().to_string())
287 } else {
288 None
289 }
290 }
291
292 pub fn validate<F>(
296 &self,
297 token: &DecodedJwt,
298 now_unix_secs: i64,
299 lookup_user: F,
300 ) -> Result<OAuthIdentity, OAuthError>
301 where
302 F: Fn(&str) -> Option<User>,
303 {
304 if !self.config.enabled {
305 return Err(OAuthError::Disabled);
306 }
307
308 let jwk = {
310 let jwks = self.jwks.read();
311 let kid = token.header.kid.as_deref();
312 jwks.iter()
313 .find(|j| kid.map(|k| k == j.kid).unwrap_or(false) && j.alg == token.header.alg)
314 .cloned()
315 };
316 let Some(jwk) = jwk else {
317 return Err(OAuthError::BadSignature(format!(
318 "no JWK for kid {:?} alg {}",
319 token.header.kid, token.header.alg
320 )));
321 };
322 (self.verifier)(&jwk, &token.signing_input, &token.signature)
323 .map_err(OAuthError::BadSignature)?;
324
325 match &token.claims.iss {
327 Some(iss) if iss == &self.config.issuer => {}
328 Some(iss) => {
329 return Err(OAuthError::WrongIssuer {
330 expected: self.config.issuer.clone(),
331 actual: iss.clone(),
332 });
333 }
334 None => {
335 return Err(OAuthError::Malformed("missing iss".into()));
336 }
337 }
338 if !token.claims.aud.iter().any(|a| a == &self.config.audience) {
339 return Err(OAuthError::WrongAudience {
340 expected: self.config.audience.clone(),
341 actual: token.claims.aud.clone(),
342 });
343 }
344 if let Some(exp) = token.claims.exp {
345 if exp <= now_unix_secs {
346 return Err(OAuthError::Expired { exp });
347 }
348 }
349 if let Some(nbf) = token.claims.nbf {
350 if nbf > now_unix_secs {
351 return Err(OAuthError::NotYetValid { nbf });
352 }
353 }
354
355 let username = match &self.config.identity_mode {
357 OAuthIdentityMode::SubClaim => {
358 token
359 .claims
360 .sub
361 .clone()
362 .ok_or(OAuthError::MissingIdentityClaim(
363 OAuthIdentityMode::SubClaim,
364 ))?
365 }
366 OAuthIdentityMode::ClaimField(name) => token
367 .claims
368 .claim(name)
369 .map(|s| s.to_string())
370 .ok_or_else(|| {
371 OAuthError::MissingIdentityClaim(OAuthIdentityMode::ClaimField(name.clone()))
372 })?,
373 };
374
375 let role = if self.config.map_to_existing_users {
377 match lookup_user(&username) {
378 Some(user) => user.role,
379 None => self.derive_role_from_claims(&token.claims)?,
380 }
381 } else {
382 self.derive_role_from_claims(&token.claims)?
383 };
384
385 let tenant = self
389 .config
390 .tenant_claim
391 .as_deref()
392 .and_then(|name| token.claims.claim(name).map(|s| s.to_string()))
393 .filter(|s| !s.is_empty());
394
395 Ok(OAuthIdentity {
396 username,
397 tenant,
398 role,
399 issuer: self.config.issuer.clone(),
400 subject: token.claims.sub.clone(),
401 expires_at_unix_secs: token.claims.exp,
402 })
403 }
404
405 fn derive_role_from_claims(&self, claims: &JwtClaims) -> Result<Role, OAuthError> {
406 let Some(name) = &self.config.role_claim else {
407 return Ok(self.config.default_role);
408 };
409 let Some(raw) = claims.claim(name) else {
414 return Ok(self.config.default_role);
415 };
416 Role::from_str(raw.trim()).ok_or_else(|| OAuthError::MissingOrInvalidRole(name.clone()))
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 fn noop_verifier() -> JwtVerifier {
425 Box::new(|_jwk, _input, _sig| Ok(()))
426 }
427
428 fn base_config() -> OAuthConfig {
429 OAuthConfig {
430 enabled: true,
431 issuer: "https://id.example.com".to_string(),
432 audience: "reddb".to_string(),
433 jwks_url: String::new(),
434 identity_mode: OAuthIdentityMode::SubClaim,
435 role_claim: None,
436 tenant_claim: None,
437 default_role: Role::Read,
438 map_to_existing_users: false,
439 accept_bearer: true,
440 }
441 }
442
443 fn base_token(now: i64) -> DecodedJwt {
444 DecodedJwt {
445 header: JwtHeader {
446 alg: "RS256".to_string(),
447 kid: Some("k1".to_string()),
448 },
449 claims: JwtClaims {
450 iss: Some("https://id.example.com".to_string()),
451 sub: Some("alice".to_string()),
452 aud: vec!["reddb".to_string()],
453 exp: Some(now + 3600),
454 nbf: Some(now - 60),
455 iat: Some(now),
456 extra: HashMap::new(),
457 },
458 signature: vec![0u8; 8],
459 signing_input: b"header.payload".to_vec(),
460 }
461 }
462
463 fn seeded_validator() -> OAuthValidator {
464 let v = OAuthValidator::with_verifier(base_config(), noop_verifier());
465 v.set_jwks(vec![Jwk {
466 kid: "k1".to_string(),
467 alg: "RS256".to_string(),
468 key_bytes: Vec::new(),
469 }]);
470 v
471 }
472
473 #[test]
474 fn extract_bearer_case_insensitive() {
475 let v = seeded_validator();
476 assert_eq!(
477 v.extract_bearer("Bearer abc.def.ghi").as_deref(),
478 Some("abc.def.ghi")
479 );
480 assert_eq!(v.extract_bearer("bearer xyz").as_deref(), Some("xyz"));
481 assert!(v.extract_bearer("Basic QQ==").is_none());
482 }
483
484 #[test]
485 fn valid_token_yields_sub_identity() {
486 let v = seeded_validator();
487 let token = base_token(1_700_000_000);
488 let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
489 assert_eq!(id.username, "alice");
490 assert_eq!(id.role, Role::Read);
491 }
492
493 #[test]
494 fn issuer_mismatch_rejected() {
495 let v = seeded_validator();
496 let mut token = base_token(1_700_000_000);
497 token.claims.iss = Some("https://evil.example.com".to_string());
498 assert!(matches!(
499 v.validate(&token, 1_700_000_000, |_| None),
500 Err(OAuthError::WrongIssuer { .. })
501 ));
502 }
503
504 #[test]
505 fn audience_mismatch_rejected() {
506 let v = seeded_validator();
507 let mut token = base_token(1_700_000_000);
508 token.claims.aud = vec!["other".to_string()];
509 assert!(matches!(
510 v.validate(&token, 1_700_000_000, |_| None),
511 Err(OAuthError::WrongAudience { .. })
512 ));
513 }
514
515 #[test]
516 fn expired_token_rejected() {
517 let v = seeded_validator();
518 let mut token = base_token(1_700_000_000);
519 token.claims.exp = Some(1_600_000_000);
520 assert!(matches!(
521 v.validate(&token, 1_700_000_000, |_| None),
522 Err(OAuthError::Expired { .. })
523 ));
524 }
525
526 #[test]
527 fn not_yet_valid_rejected() {
528 let v = seeded_validator();
529 let mut token = base_token(1_700_000_000);
530 token.claims.nbf = Some(1_800_000_000);
531 assert!(matches!(
532 v.validate(&token, 1_700_000_000, |_| None),
533 Err(OAuthError::NotYetValid { .. })
534 ));
535 }
536
537 #[test]
538 fn missing_jwk_fails_signature() {
539 let v = OAuthValidator::with_verifier(base_config(), noop_verifier());
540 let token = base_token(1_700_000_000);
542 assert!(matches!(
543 v.validate(&token, 1_700_000_000, |_| None),
544 Err(OAuthError::BadSignature(_))
545 ));
546 }
547
548 #[test]
549 fn role_claim_parses_from_extra() {
550 let mut config = base_config();
551 config.role_claim = Some("role".to_string());
552 let v = OAuthValidator::with_verifier(config, noop_verifier());
553 v.set_jwks(vec![Jwk {
554 kid: "k1".to_string(),
555 alg: "RS256".to_string(),
556 key_bytes: Vec::new(),
557 }]);
558 let mut token = base_token(1_700_000_000);
559 token
560 .claims
561 .extra
562 .insert("role".to_string(), "admin".to_string());
563 let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
564 assert_eq!(id.role, Role::Admin);
565 }
566
567 #[test]
568 fn claim_field_identity_mode() {
569 let mut config = base_config();
570 config.identity_mode = OAuthIdentityMode::ClaimField("preferred_username".into());
571 let v = OAuthValidator::with_verifier(config, noop_verifier());
572 v.set_jwks(vec![Jwk {
573 kid: "k1".to_string(),
574 alg: "RS256".to_string(),
575 key_bytes: Vec::new(),
576 }]);
577 let mut token = base_token(1_700_000_000);
578 token
579 .claims
580 .extra
581 .insert("preferred_username".into(), "alice.smith".into());
582 let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
583 assert_eq!(id.username, "alice.smith");
584 }
585
586 #[test]
587 fn tenant_claim_extracted_when_configured() {
588 let mut config = base_config();
589 config.tenant_claim = Some("tenant".into());
590 let v = OAuthValidator::with_verifier(config, noop_verifier());
591 v.set_jwks(vec![Jwk {
592 kid: "k1".to_string(),
593 alg: "RS256".to_string(),
594 key_bytes: Vec::new(),
595 }]);
596 let mut token = base_token(1_700_000_000);
597 token.claims.extra.insert("tenant".into(), "acme".into());
598 let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
599 assert_eq!(id.tenant.as_deref(), Some("acme"));
600 }
601
602 #[test]
603 fn tenant_absent_when_claim_unconfigured() {
604 let v = seeded_validator();
605 let mut token = base_token(1_700_000_000);
606 token.claims.extra.insert("tenant".into(), "acme".into());
607 let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
608 assert!(id.tenant.is_none());
610 }
611
612 #[test]
613 fn tenant_claim_custom_name() {
614 let mut config = base_config();
615 config.tenant_claim = Some("org_id".into());
616 let v = OAuthValidator::with_verifier(config, noop_verifier());
617 v.set_jwks(vec![Jwk {
618 kid: "k1".to_string(),
619 alg: "RS256".to_string(),
620 key_bytes: Vec::new(),
621 }]);
622 let mut token = base_token(1_700_000_000);
623 token.claims.extra.insert("org_id".into(), "globex".into());
624 let id = v.validate(&token, 1_700_000_000, |_| None).unwrap();
625 assert_eq!(id.tenant.as_deref(), Some("globex"));
626 }
627}