Skip to main content

camel_auth/
jwt.rs

1use async_trait::async_trait;
2use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
3use std::sync::Arc;
4
5use crate::claims::ClaimsMapper;
6use crate::jwks::JwksProvider;
7use crate::types::AuthError;
8use camel_api::security_policy::Principal;
9
10/// Validates JWT tokens and extracts a [`Principal`].
11#[async_trait]
12pub trait JwtValidator: Send + Sync {
13    async fn validate(&self, token: &str) -> Result<Principal, AuthError>;
14}
15
16/// Production JWT validator backed by a dynamic JWKS provider.
17///
18/// Delegates Principal construction to a configurable [`ClaimsMapper`],
19/// allowing provider-specific claim shapes without hardcoding extraction logic.
20pub struct LocalJwtValidator {
21    audience: Vec<String>,
22    issuer: String,
23    jwks: Arc<dyn JwksProvider>,
24    mapper: Arc<dyn ClaimsMapper>,
25}
26
27impl LocalJwtValidator {
28    pub fn new(
29        audience: Vec<String>,
30        issuer: String,
31        jwks: Arc<dyn JwksProvider>,
32        mapper: Arc<dyn ClaimsMapper>,
33    ) -> Self {
34        Self {
35            audience,
36            issuer,
37            jwks,
38            mapper,
39        }
40    }
41}
42
43/// Convert a JWK to a [`DecodingKey`].
44///
45/// Supports both PEM-encoded public keys (stored in `n` with a `-----BEGIN` prefix,
46/// useful for testing) and standard JWKS base64url components (production).
47fn jwk_to_decoding_key(n: &str, e: &str) -> Result<DecodingKey, AuthError> {
48    if n.starts_with("-----BEGIN") {
49        DecodingKey::from_rsa_pem(n.as_bytes())
50            .map_err(|e| AuthError::TokenInvalid(format!("invalid RSA PEM: {e}"))) // allow-secret
51    } else {
52        DecodingKey::from_rsa_components(n, e)
53            .map_err(|e| AuthError::TokenInvalid(format!("invalid JWK components: {e}"))) // allow-secret
54    }
55}
56
57#[async_trait]
58impl JwtValidator for LocalJwtValidator {
59    async fn validate(&self, token: &str) -> Result<Principal, AuthError> {
60        // Decode header to extract kid
61        let header = decode_header(token)
62            .map_err(|e| AuthError::TokenInvalid(format!("invalid JWT header: {e}")))?;
63
64        let kid = header
65            .kid
66            .ok_or_else(|| AuthError::TokenInvalid("JWT missing kid".into()))?;
67
68        // Fetch signing keys; on kid miss, force a JWKS refresh (handles key rotation)
69        let keys = self.jwks.get_signing_keys().await?;
70        let jwk = if let Some(k) = keys.iter().find(|k| k.kid == kid) {
71            k.clone()
72        } else {
73            // Key not in cache — might be a newly rotated key; refresh once and retry
74            self.jwks.refresh().await?;
75            self.jwks
76                .get_signing_keys()
77                .await?
78                .into_iter()
79                .find(|k| k.kid == kid)
80                .ok_or_else(|| {
81                    AuthError::TokenInvalid(format!("no key for kid={kid} after refresh"))
82                })?
83        };
84
85        let decoding_key = jwk_to_decoding_key(&jwk.n, &jwk.e)?;
86
87        // Configure validation
88        let mut validation = Validation::new(Algorithm::RS256);
89        validation.set_audience(&self.audience);
90        validation.set_issuer(&[&self.issuer]);
91
92        // Decode and verify
93        let token_data =
94            decode::<serde_json::Value>(token, &decoding_key, &validation).map_err(|e| match e
95                .kind()
96            {
97                jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
98                _ => AuthError::TokenInvalid(e.to_string()),
99            })?;
100
101        let claims = token_data.claims;
102
103        // Delegate Principal construction to the configured ClaimsMapper
104        self.mapper.to_principal(&claims)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::claims::{ClaimPaths, JsonPointerClaimsMapper};
112    use crate::jwks::Jwk;
113    use jsonwebtoken::{EncodingKey, Header, encode};
114    use serde_json::json;
115
116    static TEST_RSA_PRIVATE_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_private.pem");
117    static TEST_RSA_PUBLIC_PEM: &[u8] = include_bytes!("../tests/fixtures/test_rsa_public.pem");
118
119    /// Mock JWKS provider that returns a PEM-encoded public key.
120    struct MockJwks {
121        kid: String,
122        public_pem: &'static [u8],
123    }
124
125    #[async_trait]
126    impl JwksProvider for MockJwks {
127        async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError> {
128            Ok(vec![Jwk {
129                kid: self.kid.clone(),
130                kty: "RSA".into(),
131                alg: Some("RS256".into()),
132                r#use: None,
133                n: String::from_utf8_lossy(self.public_pem).into_owned(),
134                e: "AQAB".into(),
135            }])
136        }
137
138        async fn refresh(&self) -> Result<(), AuthError> {
139            Ok(())
140        }
141    }
142
143    /// Mock JWKS that starts empty and gains a key after refresh (simulates rotation).
144    struct RotatingMockJwks {
145        kid: String,
146        public_pem: &'static [u8],
147        refreshed: std::sync::atomic::AtomicBool,
148    }
149
150    #[async_trait]
151    impl JwksProvider for RotatingMockJwks {
152        async fn get_signing_keys(&self) -> Result<Vec<Jwk>, AuthError> {
153            if self.refreshed.load(std::sync::atomic::Ordering::SeqCst) {
154                Ok(vec![Jwk {
155                    kid: self.kid.clone(),
156                    kty: "RSA".into(),
157                    alg: Some("RS256".into()),
158                    r#use: None,
159                    n: String::from_utf8_lossy(self.public_pem).into_owned(),
160                    e: "AQAB".into(),
161                }])
162            } else {
163                Ok(vec![]) // key not yet known
164            }
165        }
166
167        async fn refresh(&self) -> Result<(), AuthError> {
168            self.refreshed
169                .store(true, std::sync::atomic::Ordering::SeqCst);
170            Ok(())
171        }
172    }
173
174    /// Build a mapper configured for multiple role paths.
175    fn multi_role_mapper(role_paths: Vec<String>) -> Arc<JsonPointerClaimsMapper> {
176        Arc::new(JsonPointerClaimsMapper::new(ClaimPaths {
177            subject: "/sub".into(),
178            roles: role_paths,
179            scopes: Some("/scope".into()),
180        }))
181    }
182
183    fn validator(audience: Vec<&str>, mapper: Arc<dyn ClaimsMapper>) -> LocalJwtValidator {
184        LocalJwtValidator::new(
185            audience.iter().map(|s| s.to_string()).collect(),
186            "http://localhost:8080/realms/test".into(),
187            Arc::new(MockJwks {
188                kid: "test-key".into(),
189                public_pem: TEST_RSA_PUBLIC_PEM,
190            }),
191            mapper,
192        )
193    }
194
195    fn make_token(kid: &str, claims: &serde_json::Value) -> String {
196        let mut header = Header::new(Algorithm::RS256);
197        header.kid = Some(kid.to_string());
198        let encoding_key = EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM).unwrap();
199        encode(&header, claims, &encoding_key).unwrap()
200    }
201
202    #[tokio::test]
203    async fn validates_valid_token() {
204        let v = validator(vec!["my-api"], multi_role_mapper(vec!["/groups".into()]));
205        let now = chrono::Utc::now().timestamp() as u64;
206        let claims = json!({
207            "sub": "user-123",
208            "iss": "http://localhost:8080/realms/test",
209            "aud": "my-api",
210            "exp": now + 3600,
211            "iat": now,
212        });
213        let token = make_token("test-key", &claims);
214        let principal = v.validate(&token).await.unwrap();
215        assert_eq!(principal.subject, "user-123");
216    }
217
218    #[tokio::test]
219    async fn rejects_expired_token() {
220        let v = validator(vec!["my-api"], multi_role_mapper(vec!["/groups".into()]));
221        let now = chrono::Utc::now().timestamp() as u64;
222        let claims = json!({
223            "sub": "user-123",
224            "iss": "http://localhost:8080/realms/test",
225            "aud": "my-api",
226            "exp": now - 3600,
227            "iat": now - 7200,
228        });
229        let token = make_token("test-key", &claims);
230        assert!(matches!(
231            v.validate(&token).await,
232            Err(AuthError::TokenExpired)
233        ));
234    }
235
236    #[tokio::test]
237    async fn rejects_wrong_audience() {
238        let v = validator(vec!["my-api"], multi_role_mapper(vec!["/groups".into()]));
239        let now = chrono::Utc::now().timestamp() as u64;
240        let claims = json!({
241            "sub": "user-123",
242            "iss": "http://localhost:8080/realms/test",
243            "aud": "wrong-audience",
244            "exp": now + 3600,
245            "iat": now,
246        });
247        let token = make_token("test-key", &claims);
248        assert!(matches!(
249            v.validate(&token).await,
250            Err(AuthError::TokenInvalid(_))
251        ));
252    }
253
254    #[tokio::test]
255    async fn extracts_resource_access_roles() {
256        let mapper = multi_role_mapper(vec![
257            "/realm_access/roles".into(),
258            "/resource_access/my-client/roles".into(),
259        ]);
260        let v = validator(vec!["my-client"], mapper);
261        let now = chrono::Utc::now().timestamp() as u64;
262        let claims = json!({
263            "sub": "user-123",
264            "iss": "http://localhost:8080/realms/test",
265            "aud": "my-client",
266            "exp": now + 3600,
267            "iat": now,
268            "realm_access": { "roles": ["realm-role"] },
269            "resource_access": {
270                "my-client": { "roles": ["client-role-a"] }
271            },
272        });
273        let token = make_token("test-key", &claims);
274        let principal = v.validate(&token).await.unwrap();
275        assert!(principal.has_role("realm-role"));
276        assert!(principal.has_role("client-role-a"));
277    }
278
279    #[tokio::test]
280    async fn rejects_missing_sub() {
281        let v = validator(vec!["my-api"], multi_role_mapper(vec!["/groups".into()]));
282        let now = chrono::Utc::now().timestamp() as u64;
283        let claims = json!({
284            // "sub" intentionally absent
285            "iss": "http://localhost:8080/realms/test",
286            "aud": "my-api",
287            "exp": now + 3600,
288            "iat": now,
289        });
290        let token = make_token("test-key", &claims);
291        assert!(matches!(
292            v.validate(&token).await,
293            Err(AuthError::TokenInvalid(_))
294        ));
295    }
296
297    #[tokio::test]
298    async fn refreshes_on_unknown_kid() {
299        let now = chrono::Utc::now().timestamp() as u64;
300        let claims = json!({
301            "sub": "user-123",
302            "iss": "http://localhost:8080/realms/test",
303            "aud": "my-api",
304            "exp": now + 3600,
305            "iat": now,
306        });
307        let token = make_token("test-key", &claims);
308
309        // Validator backed by a JWKS that returns the key only after refresh
310        let v = LocalJwtValidator::new(
311            vec!["my-api".into()],
312            "http://localhost:8080/realms/test".into(),
313            Arc::new(RotatingMockJwks {
314                kid: "test-key".into(),
315                public_pem: TEST_RSA_PUBLIC_PEM,
316                refreshed: std::sync::atomic::AtomicBool::new(false),
317            }),
318            multi_role_mapper(vec!["/groups".into()]),
319        );
320
321        // Token should validate after the forced JWKS refresh
322        let principal = v.validate(&token).await.unwrap();
323        assert_eq!(principal.subject, "user-123");
324    }
325
326    #[tokio::test]
327    async fn mapper_configures_role_paths_independently_of_audience() {
328        // Mapper is configured with explicit role paths — no audience heuristic needed.
329        // Token audience is "other-audience" but mapper looks up roles under "my-service".
330        let mapper = multi_role_mapper(vec![
331            "/realm_access/roles".into(),
332            "/resource_access/my-service/roles".into(),
333        ]);
334        let v = validator(vec!["other-audience"], mapper);
335
336        let now = chrono::Utc::now().timestamp() as u64;
337        let claims = json!({
338            "sub": "user-123",
339            "iss": "http://localhost:8080/realms/test",
340            "aud": "other-audience",
341            "exp": now + 3600,
342            "iat": now,
343            "resource_access": {
344                "my-service": { "roles": ["svc-role"] },
345                "other-audience": { "roles": ["aud-role"] },
346            },
347        });
348        let token = make_token("test-key", &claims);
349        let principal = v.validate(&token).await.unwrap();
350
351        // Mapper finds "svc-role" under "my-service" via configured path,
352        // NOT "aud-role" under "other-audience".
353        assert!(
354            principal.has_role("svc-role"),
355            "expected svc-role from my-service path"
356        );
357        assert!(
358            !principal.has_role("aud-role"),
359            "must not pick aud-role when mapper path targets my-service"
360        );
361    }
362
363    #[tokio::test]
364    async fn extracts_scopes_from_scope_claim() {
365        let mapper = multi_role_mapper(vec!["/groups".into()]);
366        let v = validator(vec!["my-api"], mapper);
367        let now = chrono::Utc::now().timestamp() as u64;
368        let claims = json!({
369            "sub": "user-123",
370            "iss": "http://localhost:8080/realms/test",
371            "aud": "my-api",
372            "exp": now + 3600,
373            "iat": now,
374            "scope": "read write admin",
375        });
376        let token = make_token("test-key", &claims);
377        let principal = v.validate(&token).await.unwrap();
378        assert_eq!(principal.scopes, vec!["read", "write", "admin"]);
379    }
380
381    #[tokio::test]
382    async fn extracts_generic_groups_roles() {
383        // Test with generic /groups claim path
384        let mapper = multi_role_mapper(vec!["/groups".into()]);
385        let v = validator(vec!["my-api"], mapper);
386        let now = chrono::Utc::now().timestamp() as u64;
387        let claims = json!({
388            "sub": "user-123",
389            "iss": "http://localhost:8080/realms/test",
390            "aud": "my-api",
391            "exp": now + 3600,
392            "iat": now,
393            "groups": ["admin", "editor", "viewer"],
394        });
395        let token = make_token("test-key", &claims);
396        let principal = v.validate(&token).await.unwrap();
397        assert!(principal.has_role("admin"));
398        assert!(principal.has_role("editor"));
399        assert!(principal.has_role("viewer"));
400    }
401}