xjp_oidc/
verify.rs

1//! Resource Server JWT verification
2
3#[cfg(feature = "verifier")]
4use crate::{
5    cache::Cache,
6    discovery::discover,
7    errors::{Error, Result},
8    http::HttpClient,
9    id_token::fetch_jwks,
10    jwks::Jwks,
11    types::VerifiedClaims,
12};
13
14#[cfg(feature = "verifier")]
15use base64::{engine::general_purpose, Engine as _};
16#[cfg(feature = "verifier")]
17use josekit::{
18    jws::RS256,
19    jwt::{self, JwtPayload},
20};
21#[cfg(feature = "verifier")]
22use std::collections::HashMap;
23#[cfg(feature = "verifier")]
24use std::time::{SystemTime, UNIX_EPOCH};
25
26/// JWT Verifier for Resource Server
27#[cfg(feature = "verifier")]
28pub struct JwtVerifier<C: Cache<String, Jwks>, H: HttpClient> {
29    /// Map of tenant/host to issuer URL
30    pub issuer_map: HashMap<String, String>,
31    /// Expected audience
32    pub audience: String,
33    /// HTTP client
34    pub http: std::sync::Arc<H>,
35    /// JWKS cache
36    pub cache: std::sync::Arc<C>,
37    /// Clock skew tolerance in seconds
38    pub clock_skew_sec: i64,
39    /// Default issuer if no mapping found
40    pub default_issuer: Option<String>,
41}
42
43#[cfg(feature = "verifier")]
44impl<C: Cache<String, Jwks>, H: HttpClient> JwtVerifier<C, H> {
45    /// Create a new JWT verifier
46    pub fn new(
47        issuer_map: HashMap<String, String>,
48        audience: String,
49        http: std::sync::Arc<H>,
50        cache: std::sync::Arc<C>,
51    ) -> Self {
52        Self { issuer_map, audience, http, cache, clock_skew_sec: 60, default_issuer: None }
53    }
54
55    /// Create a verifier builder
56    pub fn builder() -> JwtVerifierBuilder<C, H> {
57        JwtVerifierBuilder::default()
58    }
59
60    /// Verify a bearer token
61    pub async fn verify(&self, bearer: &str) -> Result<VerifiedClaims> {
62        // Remove "Bearer " prefix if present
63        let token = bearer.strip_prefix("Bearer ").unwrap_or(bearer);
64
65        // Try to extract issuer from token payload (unverified)
66        let unverified_issuer = extract_unverified_issuer(token)?;
67
68        // Determine expected issuer
69        let expected_issuer = self.resolve_issuer(&unverified_issuer)?;
70
71        // Get JWKS for the issuer
72        // Use NoOpCache for metadata discovery - JWKS are the important thing to cache
73        let metadata_cache = crate::cache::NoOpCache;
74        let metadata = discover(&expected_issuer, self.http.as_ref(), &metadata_cache).await?;
75        let jwks = fetch_jwks(&metadata.jwks_uri, self.http.as_ref(), self.cache.as_ref()).await?;
76
77        // Verify token - extract kid manually
78        let kid = extract_kid(token)?.ok_or_else(|| Error::Jwt("Token missing kid".into()))?;
79
80        let jwk = jwks
81            .find_key(&kid)
82            .ok_or_else(|| Error::Jwt(format!("Key with kid '{}' not found", kid)))?;
83
84        let payload = verify_token_signature(token, jwk)?;
85
86        // Extract and validate claims
87        let claims = extract_and_validate_access_token_claims(
88            payload,
89            &expected_issuer,
90            &self.audience,
91            self.clock_skew_sec,
92        )?;
93
94        Ok(claims)
95    }
96
97    /// Resolve issuer from token or mapping
98    fn resolve_issuer(&self, token_issuer: &str) -> Result<String> {
99        // Check if token issuer is directly in our allowed list (as a value)
100        if self.issuer_map.values().any(|v| v == token_issuer) {
101            return Ok(token_issuer.to_string());
102        }
103
104        // Check if token issuer matches any mapped tenant/host key
105        // This supports cases where the key is the issuer itself
106        if self.issuer_map.contains_key(token_issuer) {
107            return Ok(self.issuer_map[token_issuer].clone());
108        }
109
110        // Otherwise use default issuer if configured
111        if let Some(default) = &self.default_issuer {
112            return Ok(default.clone());
113        }
114
115        // If no default, reject the token
116        Err(Error::Verification(format!("Issuer '{}' not in allowed list", token_issuer)))
117    }
118
119    /// Resolve issuer with tenant context
120    /// This method allows multi-tenant routing by selecting issuer based on tenant identifier
121    pub fn resolve_issuer_with_tenant(&self, tenant: &str) -> Result<String> {
122        // Look up issuer for the given tenant
123        if let Some(issuer) = self.issuer_map.get(tenant) {
124            return Ok(issuer.clone());
125        }
126
127        // Fall back to default issuer if configured
128        if let Some(default) = &self.default_issuer {
129            return Ok(default.clone());
130        }
131
132        // No issuer found for tenant
133        Err(Error::Verification(format!("No issuer configured for tenant '{}'", tenant)))
134    }
135}
136
137/// JWT Verifier builder
138#[cfg(feature = "verifier")]
139pub struct JwtVerifierBuilder<C: Cache<String, Jwks>, H: HttpClient> {
140    issuer_map: Option<HashMap<String, String>>,
141    audience: Option<String>,
142    http: Option<std::sync::Arc<H>>,
143    cache: Option<std::sync::Arc<C>>,
144    clock_skew_sec: Option<i64>,
145    default_issuer: Option<String>,
146}
147
148#[cfg(feature = "verifier")]
149impl<C: Cache<String, Jwks>, H: HttpClient> Default for JwtVerifierBuilder<C, H> {
150    fn default() -> Self {
151        Self {
152            issuer_map: None,
153            audience: None,
154            http: None,
155            cache: None,
156            clock_skew_sec: None,
157            default_issuer: None,
158        }
159    }
160}
161
162#[cfg(feature = "verifier")]
163impl<C: Cache<String, Jwks>, H: HttpClient> JwtVerifierBuilder<C, H> {
164    /// Set issuer mapping
165    pub fn issuer_map(mut self, map: HashMap<String, String>) -> Self {
166        self.issuer_map = Some(map);
167        self
168    }
169
170    /// Set audience
171    pub fn audience(mut self, audience: impl Into<String>) -> Self {
172        self.audience = Some(audience.into());
173        self
174    }
175
176    /// Set HTTP client
177    pub fn http(mut self, http: std::sync::Arc<H>) -> Self {
178        self.http = Some(http);
179        self
180    }
181
182    /// Set cache
183    pub fn cache(mut self, cache: std::sync::Arc<C>) -> Self {
184        self.cache = Some(cache);
185        self
186    }
187
188    /// Set clock skew tolerance
189    pub fn clock_skew(mut self, seconds: i64) -> Self {
190        self.clock_skew_sec = Some(seconds);
191        self
192    }
193
194    /// Set default issuer
195    pub fn default_issuer(mut self, issuer: impl Into<String>) -> Self {
196        self.default_issuer = Some(issuer.into());
197        self
198    }
199
200    /// Build the verifier
201    pub fn build(self) -> Result<JwtVerifier<C, H>> {
202        Ok(JwtVerifier {
203            issuer_map: self.issuer_map.unwrap_or_default(),
204            audience: self.audience.ok_or(Error::MissingConfig("audience"))?,
205            http: self.http.ok_or(Error::MissingConfig("http client"))?,
206            cache: self.cache.ok_or(Error::MissingConfig("cache"))?,
207            clock_skew_sec: self.clock_skew_sec.unwrap_or(60),
208            default_issuer: self.default_issuer,
209        })
210    }
211}
212
213/// Extract kid from JWT header without full verification
214#[cfg(feature = "verifier")]
215fn extract_kid(jwt: &str) -> Result<Option<String>> {
216    let parts: Vec<&str> = jwt.split('.').collect();
217    if parts.len() != 3 {
218        return Err(Error::Jwt("Invalid JWT format".into()));
219    }
220
221    let header_bytes = general_purpose::URL_SAFE_NO_PAD
222        .decode(parts[0])
223        .map_err(|e| Error::Base64(format!("Failed to decode header: {}", e)))?;
224
225    let header_value: serde_json::Value = serde_json::from_slice(&header_bytes)
226        .map_err(|e| Error::Jwt(format!("Failed to parse header JSON: {}", e)))?;
227
228    Ok(header_value.get("kid").and_then(|v| v.as_str()).map(|s| s.to_string()))
229}
230
231/// Extract issuer from unverified token
232#[cfg(feature = "verifier")]
233fn extract_unverified_issuer(token: &str) -> Result<String> {
234    let parts: Vec<&str> = token.split('.').collect();
235    if parts.len() != 3 {
236        return Err(Error::Jwt("Invalid JWT format".into()));
237    }
238
239    let payload_json = general_purpose::URL_SAFE_NO_PAD
240        .decode(parts[1])
241        .map_err(|e| Error::Base64(e.to_string()))?;
242
243    let payload: serde_json::Value = serde_json::from_slice(&payload_json)?;
244
245    payload["iss"]
246        .as_str()
247        .ok_or_else(|| Error::Jwt("Token missing issuer".into()))
248        .map(|s| s.to_string())
249}
250
251/// Verify token signature
252#[cfg(feature = "verifier")]
253fn verify_token_signature(token: &str, jwk: &crate::jwks::Jwk) -> Result<JwtPayload> {
254    // Convert JWK to josekit format
255    let key = josekit::jwk::Jwk::from_map(serde_json::to_value(jwk)?.as_object().unwrap().clone())
256        .map_err(|e| Error::Jwt(format!("Invalid JWK: {}", e)))?;
257
258    // Verify based on algorithm (default to RS256 if not specified)
259    let alg = jwk.alg.as_deref().unwrap_or("RS256");
260    let verifier = match alg {
261        "RS256" => RS256.verifier_from_jwk(&key),
262        alg => return Err(Error::Jwt(format!("Unsupported algorithm: {}", alg))),
263    }
264    .map_err(|e| Error::Jwt(format!("Failed to create verifier: {}", e)))?;
265
266    let (payload, _header) = jwt::decode_with_verifier(token, &verifier)
267        .map_err(|e| Error::Jwt(format!("Token verification failed: {}", e)))?;
268
269    Ok(payload)
270}
271
272/// Extract and validate access token claims
273#[cfg(feature = "verifier")]
274fn extract_and_validate_access_token_claims(
275    payload: JwtPayload,
276    expected_issuer: &str,
277    expected_audience: &str,
278    clock_skew: i64,
279) -> Result<VerifiedClaims> {
280    let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
281
282    // Extract standard claims
283    let iss = payload.issuer().ok_or_else(|| Error::Verification("Missing iss claim".into()))?;
284    let sub = payload.subject().ok_or_else(|| Error::Verification("Missing sub claim".into()))?;
285    let exp = payload
286        .expires_at()
287        .ok_or_else(|| Error::Verification("Missing exp claim".into()))?
288        .duration_since(UNIX_EPOCH)
289        .map_err(|_| Error::Verification("Invalid exp time".into()))?
290        .as_secs() as i64;
291    let iat = payload
292        .issued_at()
293        .ok_or_else(|| Error::Verification("Missing iat claim".into()))?
294        .duration_since(UNIX_EPOCH)
295        .map_err(|_| Error::Verification("Invalid iat time".into()))?
296        .as_secs() as i64;
297
298    // Validate issuer
299    if iss != expected_issuer {
300        return Err(Error::Verification(format!(
301            "Invalid issuer: expected '{}', got '{}'",
302            expected_issuer, iss
303        )));
304    }
305
306    // Validate audience
307    let aud = if let Some(audiences) = payload.audience() {
308        if !audiences.iter().any(|a| *a == expected_audience) {
309            return Err(Error::Verification(format!(
310                "Invalid audience: expected '{}'",
311                expected_audience
312            )));
313        }
314        expected_audience.to_string()
315    } else {
316        return Err(Error::Verification("Missing aud claim".into()));
317    };
318
319    // Validate expiration
320    if exp < now - clock_skew {
321        return Err(Error::Verification("Token expired".into()));
322    }
323
324    // Validate issued at
325    if iat > now + clock_skew {
326        return Err(Error::Verification("Token issued in the future".into()));
327    }
328
329    // Extract custom claims
330    let claims_map = payload.claims_set();
331
332    let jti = claims_map.get("jti").and_then(|v| v.as_str()).unwrap_or("").to_string();
333
334    let scope = claims_map.get("scope").and_then(|v| v.as_str()).map(|s| s.to_string());
335
336    let xjp_admin = claims_map.get("xjp_admin").and_then(|v| v.as_bool());
337
338    let amr = claims_map.get("amr").and_then(|v| {
339        v.as_array()?
340            .iter()
341            .map(|item| item.as_str().map(|s| s.to_string()))
342            .collect::<Option<Vec<String>>>()
343    });
344
345    let auth_time = claims_map.get("auth_time").and_then(|v| v.as_i64());
346
347    Ok(VerifiedClaims {
348        iss: iss.to_string(),
349        sub: sub.to_string(),
350        aud: aud.to_string(),
351        exp,
352        iat,
353        jti,
354        scope,
355        xjp_admin,
356        amr,
357        auth_time,
358    })
359}
360
361#[cfg(all(test, feature = "verifier"))]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_extract_unverified_issuer() {
367        // This is a dummy JWT for testing - not a real token
368        let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5In0.eyJpc3MiOiJodHRwczovL2F1dGguZXhhbXBsZS5jb20ifQ.dummy";
369
370        let issuer = extract_unverified_issuer(token).unwrap();
371        assert_eq!(issuer, "https://auth.example.com");
372    }
373}