1pub mod peer_keys;
4
5pub use peer_keys::{
6 GatewayPeerKeyError, GatewayPeerKeys, KEY_FILE_NAME as GATEWAY_PEER_KEY_FILE,
7 PubkeyDecodeError, decode_pubkey_b64,
8};
9
10use std::fmt::{Display, Formatter};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14use base64::Engine;
15use base64::engine::general_purpose::URL_SAFE_NO_PAD;
16use hmac::{Hmac, Mac};
17use ring::signature::{self, RsaPublicKeyComponents, UnparsedPublicKey};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use sha2::Sha256;
21use tokio::sync::RwLock;
22
23type HmacSha256 = Hmac<Sha256>;
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct JwtValidationConfig {
27 pub shared_secret: String,
28 pub issuer: Option<String>,
29 pub audience: Option<String>,
30 pub now_epoch_seconds: u64,
31 pub leeway_seconds: u64,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct JwtClaimsValidationConfig {
36 pub issuer: Option<String>,
37 pub audience: Option<String>,
38 pub now_epoch_seconds: u64,
39 pub leeway_seconds: u64,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum JwtVerificationKey {
44 Hs256(Vec<u8>),
45 Rs256 { modulus: Vec<u8>, exponent: Vec<u8> },
46 Es256P256 { public_key: Vec<u8> },
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum JwtValidationError {
51 InvalidFormat,
52 InvalidBase64,
53 InvalidJson,
54 UnsupportedAlgorithm(String),
55 InvalidSignature,
56 Expired,
57 NotYetValid,
58 IssuerMismatch,
59 AudienceMismatch,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63pub struct ValidatedJwt {
64 pub subject: Option<String>,
65 pub email: Option<String>,
66 pub provider: Option<String>,
67 pub actor_type: Option<String>,
68 pub issuer: Option<String>,
69 pub audience: Vec<String>,
70 pub expires_at_epoch_seconds: Option<u64>,
71 pub not_before_epoch_seconds: Option<u64>,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75struct JwtHeader {
76 alg: String,
77 #[serde(default)]
78 kid: Option<String>,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct OidcDiscoveryDocument {
83 pub issuer: String,
84 pub jwks_uri: String,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct JwksDocument {
89 pub keys: Vec<Jwk>,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93pub struct Jwk {
94 pub kid: Option<String>,
95 pub kty: String,
96 #[serde(default)]
97 pub alg: Option<String>,
98 #[serde(default)]
99 pub k: Option<String>,
100 #[serde(default)]
101 pub n: Option<String>,
102 #[serde(default)]
103 pub e: Option<String>,
104 #[serde(default)]
105 pub crv: Option<String>,
106 #[serde(default)]
107 pub x: Option<String>,
108 #[serde(default)]
109 pub y: Option<String>,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub enum OidcContractError {
114 InvalidJson,
115 MissingIssuer,
116 MissingJwksUri,
117 MissingKeys,
118 NoMatchingKey,
119 UnsupportedKeyType(String),
120 UnsupportedJwtAlgorithm(String),
121 MissingSymmetricKeyMaterial,
122 MissingRsaKeyMaterial,
123 MissingEcKeyMaterial,
124 InvalidKeyEncoding,
125 InvalidSymmetricKeyMaterial,
126 InvalidRsaKeyMaterial,
127 InvalidEcKeyMaterial,
128 UnsupportedEllipticCurve(String),
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
132pub struct JwtHeaderView {
133 pub alg: String,
134 pub kid: Option<String>,
135}
136
137pub fn validate_jwt_locally(
138 token: &str,
139 config: &JwtValidationConfig,
140) -> Result<ValidatedJwt, JwtValidationError> {
141 let claims_config = JwtClaimsValidationConfig {
142 issuer: config.issuer.clone(),
143 audience: config.audience.clone(),
144 now_epoch_seconds: config.now_epoch_seconds,
145 leeway_seconds: config.leeway_seconds,
146 };
147 let key = JwtVerificationKey::Hs256(config.shared_secret.as_bytes().to_vec());
148 validate_jwt_with_verification_key(token, &key, &claims_config)
149}
150
151pub fn validate_jwt_with_verification_key(
152 token: &str,
153 verification_key: &JwtVerificationKey,
154 config: &JwtClaimsValidationConfig,
155) -> Result<ValidatedJwt, JwtValidationError> {
156 let parsed = parse_jwt(token)?;
157 if !key_supports_algorithm(verification_key, parsed.header.alg.as_str()) {
158 return Err(JwtValidationError::UnsupportedAlgorithm(parsed.header.alg));
159 }
160 verify_signature(
161 verification_key,
162 parsed.header.alg.as_str(),
163 parsed.signing_input.as_bytes(),
164 &parsed.signature,
165 )?;
166 validate_claims(&parsed.claims, config)
167}
168
169pub fn build_jwt_verification_key(
170 key: &Jwk,
171 alg: &str,
172) -> Result<JwtVerificationKey, OidcContractError> {
173 match alg {
174 "HS256" => build_hs256_key(key),
175 "RS256" => build_rs256_key(key),
176 "ES256" => build_es256_key(key),
177 unsupported => Err(OidcContractError::UnsupportedJwtAlgorithm(
178 unsupported.to_string(),
179 )),
180 }
181}
182
183fn build_hs256_key(key: &Jwk) -> Result<JwtVerificationKey, OidcContractError> {
184 if key.kty != "oct" {
185 return Err(OidcContractError::UnsupportedKeyType(key.kty.clone()));
186 }
187 let encoded = key
188 .k
189 .as_deref()
190 .ok_or(OidcContractError::MissingSymmetricKeyMaterial)?;
191 let bytes = URL_SAFE_NO_PAD
192 .decode(encoded)
193 .map_err(|_| OidcContractError::InvalidKeyEncoding)?;
194 Ok(JwtVerificationKey::Hs256(bytes))
195}
196
197fn build_rs256_key(key: &Jwk) -> Result<JwtVerificationKey, OidcContractError> {
198 if key.kty != "RSA" {
199 return Err(OidcContractError::UnsupportedKeyType(key.kty.clone()));
200 }
201 let modulus = key
202 .n
203 .as_deref()
204 .ok_or(OidcContractError::MissingRsaKeyMaterial)
205 .and_then(decode_key_component)?;
206 let exponent = key
207 .e
208 .as_deref()
209 .ok_or(OidcContractError::MissingRsaKeyMaterial)
210 .and_then(decode_key_component)?;
211 if modulus.is_empty() || exponent.is_empty() {
212 return Err(OidcContractError::InvalidRsaKeyMaterial);
213 }
214 Ok(JwtVerificationKey::Rs256 { modulus, exponent })
215}
216
217fn build_es256_key(key: &Jwk) -> Result<JwtVerificationKey, OidcContractError> {
218 if key.kty != "EC" {
219 return Err(OidcContractError::UnsupportedKeyType(key.kty.clone()));
220 }
221 let curve = key
222 .crv
223 .as_deref()
224 .ok_or(OidcContractError::MissingEcKeyMaterial)?;
225 if curve != "P-256" {
226 return Err(OidcContractError::UnsupportedEllipticCurve(
227 curve.to_string(),
228 ));
229 }
230
231 let x = key
232 .x
233 .as_deref()
234 .ok_or(OidcContractError::MissingEcKeyMaterial)
235 .and_then(decode_key_component)?;
236 let y = key
237 .y
238 .as_deref()
239 .ok_or(OidcContractError::MissingEcKeyMaterial)
240 .and_then(decode_key_component)?;
241 if x.len() != 32 || y.len() != 32 {
242 return Err(OidcContractError::InvalidEcKeyMaterial);
243 }
244
245 let mut public_key = Vec::with_capacity(65);
246 public_key.push(0x04);
247 public_key.extend_from_slice(&x);
248 public_key.extend_from_slice(&y);
249 Ok(JwtVerificationKey::Es256P256 { public_key })
250}
251
252fn decode_key_component(encoded: &str) -> Result<Vec<u8>, OidcContractError> {
253 URL_SAFE_NO_PAD
254 .decode(encoded)
255 .map_err(|_| OidcContractError::InvalidKeyEncoding)
256}
257
258struct ParsedJwt {
259 header: JwtHeader,
260 claims: Value,
261 signing_input: String,
262 signature: Vec<u8>,
263}
264
265fn parse_jwt(token: &str) -> Result<ParsedJwt, JwtValidationError> {
266 let parts: Vec<&str> = token.split('.').collect();
267 if parts.len() != 3 {
268 return Err(JwtValidationError::InvalidFormat);
269 }
270
271 let header_bytes = URL_SAFE_NO_PAD
272 .decode(parts[0])
273 .map_err(|_| JwtValidationError::InvalidBase64)?;
274 let payload_bytes = URL_SAFE_NO_PAD
275 .decode(parts[1])
276 .map_err(|_| JwtValidationError::InvalidBase64)?;
277 let signature = URL_SAFE_NO_PAD
278 .decode(parts[2])
279 .map_err(|_| JwtValidationError::InvalidBase64)?;
280
281 let header: JwtHeader =
282 serde_json::from_slice(&header_bytes).map_err(|_| JwtValidationError::InvalidJson)?;
283 let claims: Value =
284 serde_json::from_slice(&payload_bytes).map_err(|_| JwtValidationError::InvalidJson)?;
285 let signing_input = format!("{}.{}", parts[0], parts[1]);
286
287 Ok(ParsedJwt {
288 header,
289 claims,
290 signing_input,
291 signature,
292 })
293}
294
295fn key_supports_algorithm(key: &JwtVerificationKey, alg: &str) -> bool {
296 matches!(
297 (key, alg),
298 (JwtVerificationKey::Hs256(_), "HS256")
299 | (JwtVerificationKey::Rs256 { .. }, "RS256")
300 | (JwtVerificationKey::Es256P256 { .. }, "ES256")
301 )
302}
303
304fn verify_signature(
305 verification_key: &JwtVerificationKey,
306 alg: &str,
307 signing_input: &[u8],
308 signature: &[u8],
309) -> Result<(), JwtValidationError> {
310 match (verification_key, alg) {
311 (JwtVerificationKey::Hs256(secret), "HS256") => {
312 let mut mac = HmacSha256::new_from_slice(secret)
313 .map_err(|_| JwtValidationError::InvalidSignature)?;
314 mac.update(signing_input);
315 mac.verify_slice(signature)
320 .map_err(|_| JwtValidationError::InvalidSignature)?;
321 Ok(())
322 }
323 (JwtVerificationKey::Rs256 { modulus, exponent }, "RS256") => {
324 let components = RsaPublicKeyComponents {
325 n: modulus.as_slice(),
326 e: exponent.as_slice(),
327 };
328 components
329 .verify(
330 &signature::RSA_PKCS1_2048_8192_SHA256,
331 signing_input,
332 signature,
333 )
334 .map_err(|_| JwtValidationError::InvalidSignature)
335 }
336 (JwtVerificationKey::Es256P256 { public_key }, "ES256") => {
337 UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_FIXED, public_key.as_slice())
338 .verify(signing_input, signature)
339 .map_err(|_| JwtValidationError::InvalidSignature)
340 }
341 (_, unsupported) => Err(JwtValidationError::UnsupportedAlgorithm(
342 unsupported.to_string(),
343 )),
344 }
345}
346
347fn validate_claims(
348 claims: &Value,
349 config: &JwtClaimsValidationConfig,
350) -> Result<ValidatedJwt, JwtValidationError> {
351 let exp = claims.get("exp").and_then(Value::as_u64);
352 let nbf = claims.get("nbf").and_then(Value::as_u64);
353 let iss = claims
354 .get("iss")
355 .and_then(Value::as_str)
356 .map(ToString::to_string);
357 let aud = extract_audiences(claims);
358
359 if let Some(exp) = exp {
360 let threshold = config
361 .now_epoch_seconds
362 .saturating_sub(config.leeway_seconds);
363 if exp < threshold {
364 return Err(JwtValidationError::Expired);
365 }
366 }
367 if let Some(nbf) = nbf {
368 let threshold = config
369 .now_epoch_seconds
370 .saturating_add(config.leeway_seconds);
371 if nbf > threshold {
372 return Err(JwtValidationError::NotYetValid);
373 }
374 }
375
376 if let Some(expected_issuer) = &config.issuer
377 && iss.as_deref() != Some(expected_issuer.as_str())
378 {
379 return Err(JwtValidationError::IssuerMismatch);
380 }
381 if let Some(expected_audience) = &config.audience
382 && !aud.iter().any(|entry| entry == expected_audience)
383 {
384 return Err(JwtValidationError::AudienceMismatch);
385 }
386
387 Ok(ValidatedJwt {
388 subject: claims
389 .get("sub")
390 .and_then(Value::as_str)
391 .map(ToString::to_string),
392 email: claims
393 .get("email")
394 .and_then(Value::as_str)
395 .map(ToString::to_string),
396 provider: claims
397 .get("provider")
398 .and_then(Value::as_str)
399 .map(ToString::to_string),
400 actor_type: claims
401 .get("actor_type")
402 .and_then(Value::as_str)
403 .map(ToString::to_string),
404 issuer: iss,
405 audience: aud,
406 expires_at_epoch_seconds: exp,
407 not_before_epoch_seconds: nbf,
408 })
409}
410
411pub fn inspect_jwt_header(token: &str) -> Result<JwtHeaderView, JwtValidationError> {
412 let parts: Vec<&str> = token.split('.').collect();
413 if parts.len() != 3 {
414 return Err(JwtValidationError::InvalidFormat);
415 }
416 let header_bytes = URL_SAFE_NO_PAD
417 .decode(parts[0])
418 .map_err(|_| JwtValidationError::InvalidBase64)?;
419 let header: JwtHeader =
420 serde_json::from_slice(&header_bytes).map_err(|_| JwtValidationError::InvalidJson)?;
421 Ok(JwtHeaderView {
422 alg: header.alg,
423 kid: header.kid,
424 })
425}
426
427pub fn parse_oidc_discovery_json(
428 json_text: &str,
429) -> Result<OidcDiscoveryDocument, OidcContractError> {
430 let doc: OidcDiscoveryDocument =
431 serde_json::from_str(json_text).map_err(|_| OidcContractError::InvalidJson)?;
432 if doc.issuer.trim().is_empty() {
433 return Err(OidcContractError::MissingIssuer);
434 }
435 if doc.jwks_uri.trim().is_empty() {
436 return Err(OidcContractError::MissingJwksUri);
437 }
438 Ok(doc)
439}
440
441pub fn parse_jwks_json(json_text: &str) -> Result<JwksDocument, OidcContractError> {
442 let doc: JwksDocument =
443 serde_json::from_str(json_text).map_err(|_| OidcContractError::InvalidJson)?;
444 if doc.keys.is_empty() {
445 return Err(OidcContractError::MissingKeys);
446 }
447 Ok(doc)
448}
449
450pub fn select_jwk_for_token<'a>(
451 jwks: &'a JwksDocument,
452 kid: Option<&str>,
453 alg: &str,
454) -> Result<&'a Jwk, OidcContractError> {
455 if let Some(kid) = kid {
456 return jwks
457 .keys
458 .iter()
459 .find(|key| {
460 key.kid.as_deref() == Some(kid)
461 && key.alg.as_deref().is_none_or(|key_alg| key_alg == alg)
462 })
463 .ok_or(OidcContractError::NoMatchingKey);
464 }
465
466 jwks.keys
467 .iter()
468 .find(|key| key.alg.as_deref().is_none_or(|key_alg| key_alg == alg))
469 .ok_or(OidcContractError::NoMatchingKey)
470}
471
472pub fn extract_hs256_shared_secret(key: &Jwk) -> Result<String, OidcContractError> {
473 let bytes = match build_hs256_key(key)? {
474 JwtVerificationKey::Hs256(bytes) => bytes,
475 _ => return Err(OidcContractError::InvalidSymmetricKeyMaterial),
476 };
477 String::from_utf8(bytes).map_err(|_| OidcContractError::InvalidSymmetricKeyMaterial)
478}
479
480fn extract_audiences(claims: &Value) -> Vec<String> {
481 match claims.get("aud") {
482 Some(Value::String(aud)) => vec![aud.clone()],
483 Some(Value::Array(values)) => values
484 .iter()
485 .filter_map(Value::as_str)
486 .map(ToString::to_string)
487 .collect(),
488 _ => Vec::new(),
489 }
490}
491
492#[derive(Debug)]
497pub enum JwksCacheError {
498 Discovery(OidcContractError),
499 Http(String),
500 Validation(JwtValidationError),
501 NoMatchingKey,
502 NotInitialized,
503}
504
505impl Display for JwksCacheError {
506 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
507 match self {
508 Self::Discovery(err) => write!(f, "OIDC discovery error: {err:?}"),
509 Self::Http(msg) => write!(f, "HTTP fetch error: {msg}"),
510 Self::Validation(err) => write!(f, "JWT validation error: {err:?}"),
511 Self::NoMatchingKey => write!(f, "no matching JWK for token"),
512 Self::NotInitialized => write!(f, "JWKS cache not initialized"),
513 }
514 }
515}
516
517impl std::error::Error for JwksCacheError {}
518
519#[derive(Debug, Clone)]
520pub struct JwksCacheConfig {
521 pub discovery_url: String,
522 pub refresh_interval: Duration,
523 pub http_timeout: Duration,
524 pub issuer: Option<String>,
525 pub audience: Option<String>,
526 pub leeway_seconds: u64,
527}
528
529impl JwksCacheConfig {
530 pub fn new(discovery_url: String) -> Self {
531 Self {
532 discovery_url,
533 refresh_interval: Duration::from_hours(1),
534 http_timeout: Duration::from_secs(10),
535 issuer: None,
536 audience: None,
537 leeway_seconds: 60,
538 }
539 }
540}
541
542struct JwksCacheInner {
543 jwks: Option<JwksDocument>,
544 last_refresh: Option<Instant>,
545}
546
547#[derive(Clone)]
548pub struct JwksCache {
549 inner: Arc<RwLock<JwksCacheInner>>,
550 config: Arc<JwksCacheConfig>,
551 http_client: reqwest::Client,
552}
553
554impl JwksCache {
555 pub fn new(config: JwksCacheConfig) -> Self {
556 let http_client = reqwest::Client::builder()
557 .timeout(config.http_timeout)
558 .build()
559 .unwrap_or_else(|_| reqwest::Client::new());
560 Self {
561 inner: Arc::new(RwLock::new(JwksCacheInner {
562 jwks: None,
563 last_refresh: None,
564 })),
565 config: Arc::new(config),
566 http_client,
567 }
568 }
569
570 pub async fn refresh(&self) -> Result<(), JwksCacheError> {
572 let discovery_json = fetch_json(&self.http_client, &self.config.discovery_url).await?;
573 let discovery =
574 parse_oidc_discovery_json(&discovery_json).map_err(JwksCacheError::Discovery)?;
575
576 let jwks_json = fetch_json(&self.http_client, &discovery.jwks_uri).await?;
577 let jwks = parse_jwks_json(&jwks_json).map_err(JwksCacheError::Discovery)?;
578
579 let mut inner = self.inner.write().await;
580 inner.jwks = Some(jwks);
581 inner.last_refresh = Some(Instant::now());
582 Ok(())
583 }
584
585 pub async fn validate_token(&self, token: &str) -> Result<ValidatedJwt, JwksCacheError> {
587 self.maybe_refresh().await?;
588
589 let header = inspect_jwt_header(token).map_err(JwksCacheError::Validation)?;
590
591 match self.try_validate(token, &header).await {
593 Ok(jwt) => return Ok(jwt),
594 Err(JwksCacheError::NoMatchingKey) => {
595 }
597 Err(err) => return Err(err),
598 }
599
600 self.refresh().await?;
601 self.try_validate(token, &header).await
602 }
603
604 async fn try_validate(
605 &self,
606 token: &str,
607 header: &JwtHeaderView,
608 ) -> Result<ValidatedJwt, JwksCacheError> {
609 let inner = self.inner.read().await;
610 let jwks = inner.jwks.as_ref().ok_or(JwksCacheError::NotInitialized)?;
611
612 let jwk = select_jwk_for_token(jwks, header.kid.as_deref(), &header.alg)
613 .map_err(|_| JwksCacheError::NoMatchingKey)?;
614
615 let verification_key =
616 build_jwt_verification_key(jwk, &header.alg).map_err(JwksCacheError::Discovery)?;
617
618 let now = std::time::SystemTime::now()
619 .duration_since(std::time::UNIX_EPOCH)
620 .unwrap_or_default()
621 .as_secs();
622
623 let claims_config = JwtClaimsValidationConfig {
624 issuer: self.config.issuer.clone(),
625 audience: self.config.audience.clone(),
626 now_epoch_seconds: now,
627 leeway_seconds: self.config.leeway_seconds,
628 };
629
630 validate_jwt_with_verification_key(token, &verification_key, &claims_config)
631 .map_err(JwksCacheError::Validation)
632 }
633
634 async fn maybe_refresh(&self) -> Result<(), JwksCacheError> {
636 let needs_refresh = {
637 let inner = self.inner.read().await;
638 match inner.last_refresh {
639 Some(last) => last.elapsed() >= self.config.refresh_interval,
640 None => true,
641 }
642 };
643 if needs_refresh {
644 self.refresh().await?;
645 }
646 Ok(())
647 }
648}
649
650async fn fetch_json(client: &reqwest::Client, url: &str) -> Result<String, JwksCacheError> {
651 let response = client
652 .get(url)
653 .send()
654 .await
655 .map_err(|err| JwksCacheError::Http(format!("{err}")))?;
656 let status = response.status();
657 if !status.is_success() {
658 return Err(JwksCacheError::Http(format!("HTTP {status} from {url}")));
659 }
660 response
661 .text()
662 .await
663 .map_err(|err| JwksCacheError::Http(format!("{err}")))
664}