1pub mod authz;
9pub mod forward;
10pub mod jwks;
11pub mod policy;
12
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16
17use axum::extract::State;
18use axum::http::header::{HeaderName, HeaderValue};
19use axum::http::{HeaderMap, StatusCode};
20use axum::middleware::Next;
21use axum::response::{IntoResponse, Response};
22use axum::Json;
23use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
24use serde_json::Value;
25
26use crate::config::AuthConfig;
27use jwks::JwksCache;
28use policy::Policies;
29
30enum KeySource {
32 Pem(Arc<DecodingKey>),
34 Jwks(JwksCache),
36}
37
38pub struct Auth {
40 keys: KeySource,
41 issuer: Option<String>,
42 audience: Option<String>,
43 claims_headers: HashMap<String, String>,
44 roles_claim: String,
45 policies: Policies,
46}
47
48impl Auth {
49 pub fn build(config: &AuthConfig) -> Result<Option<Arc<Self>>, String> {
55 if config.mode != "jwt" {
56 return Ok(None);
57 }
58 let jwt = config
59 .jwt
60 .as_ref()
61 .ok_or("auth.mode is \"jwt\" but auth.jwt is not set")?;
62
63 let keys = if let Some(uri) = &jwt.jwks_uri {
64 KeySource::Jwks(JwksCache::new(uri.clone()))
65 } else if let Some(pem_path) = &jwt.public_key_pem_file {
66 let pem = std::fs::read(pem_path)
67 .map_err(|e| format!("failed to read auth.jwt.public_key_pem_file: {e}"))?;
68 let key = DecodingKey::from_ed_pem(&pem)
69 .map_err(|e| format!("invalid Ed25519 public key PEM: {e}"))?;
70 KeySource::Pem(Arc::new(key))
71 } else {
72 return Err("auth.jwt requires either jwks_uri or public_key_pem_file".to_string());
73 };
74
75 let policies = match &config.forward_auth {
76 Some(fa) => Policies::compile(&fa.policies)?,
77 None => Policies::default(),
78 };
79
80 Ok(Some(Arc::new(Self {
81 keys,
82 issuer: jwt.issuer.clone(),
83 audience: jwt.audience.clone(),
84 claims_headers: jwt.claims_headers.clone(),
85 roles_claim: jwt.roles_claim.clone(),
86 policies,
87 })))
88 }
89
90 async fn verify(&self, token: &str) -> Option<Value> {
92 let header = decode_header(token).ok()?;
93 let (key, algorithm) = match &self.keys {
94 KeySource::Pem(k) => (k.clone(), Algorithm::EdDSA),
95 KeySource::Jwks(cache) => {
96 let kid = header.kid.as_deref()?;
97 let vk = cache.key_for(kid).await?;
98 (vk.key, vk.algorithm)
99 }
100 };
101 if header.alg != algorithm {
103 return None;
104 }
105
106 let mut validation = Validation::new(algorithm);
107 if let Some(iss) = &self.issuer {
108 validation.set_issuer(&[iss]);
109 }
110 match &self.audience {
111 Some(aud) => validation.set_audience(&[aud]),
112 None => validation.validate_aud = false,
113 }
114
115 decode::<Value>(token, &key, &validation)
116 .ok()
117 .map(|data| data.claims)
118 }
119}
120
121pub(crate) enum AuthDecision {
123 Allow(HeaderMap),
125 Unauthenticated(&'static str),
127 Forbidden(&'static str),
129}
130
131impl Auth {
132 pub(crate) async fn decide(
136 &self,
137 headers: &HeaderMap,
138 path: &str,
139 method: &str,
140 ) -> AuthDecision {
141 let claims = match bearer_token(headers) {
143 Some(token) => match self.verify(&token).await {
144 Some(c) => Some(c),
145 None => return AuthDecision::Unauthenticated("invalid or expired token"),
146 },
147 None => None,
148 };
149
150 if let Some(policy) = self.policies.match_rule(path, method) {
151 if policy.require_auth && claims.is_none() {
152 return AuthDecision::Unauthenticated("authentication required");
153 }
154 if !policy.required_roles.is_empty() {
155 let Some(claims) = claims.as_ref() else {
158 return AuthDecision::Unauthenticated("authentication required");
159 };
160 let roles = extract_roles(claims, &self.roles_claim);
161 if !policy.required_roles.iter().all(|r| roles.contains(r)) {
162 return AuthDecision::Forbidden("insufficient role");
163 }
164 }
165 }
166
167 let mut claim_headers = HeaderMap::new();
168 if let Some(claims) = &claims {
169 inject_claim_headers(&mut claim_headers, claims, &self.claims_headers);
170 }
171 AuthDecision::Allow(claim_headers)
172 }
173}
174
175pub async fn middleware(
177 State(auth): State<Arc<Auth>>,
178 mut request: axum::extract::Request,
179 next: Next,
180) -> Response {
181 let path = request.uri().path().to_string();
182 let method = request.method().as_str().to_ascii_uppercase();
183
184 strip_claim_headers(request.headers_mut(), &auth.claims_headers);
188
189 match auth.decide(request.headers(), &path, &method).await {
190 AuthDecision::Unauthenticated(msg) => unauthorized(msg),
191 AuthDecision::Forbidden(msg) => forbidden(msg),
192 AuthDecision::Allow(claim_headers) => {
193 let dst = request.headers_mut();
194 for (name, value) in &claim_headers {
195 dst.insert(name.clone(), value.clone());
196 }
197 next.run(request).await
198 }
199 }
200}
201
202fn bearer_token(headers: &HeaderMap) -> Option<String> {
204 let value = headers.get("authorization")?.to_str().ok()?;
205 let token = value
206 .strip_prefix("Bearer ")
207 .or_else(|| value.strip_prefix("bearer "))?;
208 let token = token.trim();
209 (!token.is_empty()).then(|| token.to_string())
210}
211
212fn claim_at<'a>(claims: &'a Value, path: &str) -> Option<&'a Value> {
214 let mut cur = claims;
215 for seg in path.split('.') {
216 cur = cur.get(seg)?;
217 }
218 Some(cur)
219}
220
221fn extract_roles(claims: &Value, roles_claim: &str) -> HashSet<String> {
223 claim_at(claims, roles_claim)
224 .and_then(Value::as_array)
225 .map(|arr| {
226 arr.iter()
227 .filter_map(|v| v.as_str().map(str::to_string))
228 .collect()
229 })
230 .unwrap_or_default()
231}
232
233fn strip_claim_headers(headers: &mut HeaderMap, mapping: &HashMap<String, String>) {
236 for header in mapping.values() {
237 if let Ok(name) = HeaderName::try_from(header.as_str()) {
238 while headers.remove(&name).is_some() {}
239 }
240 }
241}
242
243fn inject_claim_headers(
245 headers: &mut HeaderMap,
246 claims: &Value,
247 mapping: &HashMap<String, String>,
248) {
249 for (claim, header) in mapping {
250 let Some(value) = claim_at(claims, claim) else {
251 continue;
252 };
253 let rendered = match value {
254 Value::String(s) => s.clone(),
255 Value::Number(n) => n.to_string(),
256 Value::Bool(b) => b.to_string(),
257 _ => continue,
259 };
260 if let (Ok(name), Ok(val)) = (
261 HeaderName::try_from(header.as_str()),
262 HeaderValue::try_from(rendered),
263 ) {
264 headers.insert(name, val);
265 }
266 }
267}
268
269fn unauthorized(message: &str) -> Response {
270 (
271 StatusCode::UNAUTHORIZED,
272 Json(serde_json::json!({ "error": "UNAUTHENTICATED", "message": message })),
273 )
274 .into_response()
275}
276
277fn forbidden(message: &str) -> Response {
278 (
279 StatusCode::FORBIDDEN,
280 Json(serde_json::json!({ "error": "PERMISSION_DENIED", "message": message })),
281 )
282 .into_response()
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn bearer_token_parsing() {
291 let mut h = HeaderMap::new();
292 h.insert("authorization", "Bearer abc.def.ghi".parse().unwrap());
293 assert_eq!(bearer_token(&h).as_deref(), Some("abc.def.ghi"));
294
295 let mut h2 = HeaderMap::new();
296 h2.insert("authorization", "Basic xyz".parse().unwrap());
297 assert_eq!(bearer_token(&h2), None);
298 assert_eq!(bearer_token(&HeaderMap::new()), None);
299 }
300
301 #[test]
302 fn extract_roles_reads_array_and_dotted_path() {
303 let claims = serde_json::json!({
304 "roles": ["admin", "billing"],
305 "realm_access": { "roles": ["nested"] }
306 });
307 assert!(extract_roles(&claims, "roles").contains("admin"));
308 assert!(extract_roles(&claims, "realm_access.roles").contains("nested"));
309 assert!(extract_roles(&claims, "missing").is_empty());
310 }
311
312 #[test]
313 fn inject_claim_headers_renders_scalars() {
314 let claims = serde_json::json!({ "sub": "u-1", "n": 7, "obj": {"x": 1} });
315 let mapping = HashMap::from([
316 ("sub".to_string(), "x-user-id".to_string()),
317 ("n".to_string(), "x-n".to_string()),
318 ("obj".to_string(), "x-obj".to_string()),
319 ]);
320 let mut headers = HeaderMap::new();
321 inject_claim_headers(&mut headers, &claims, &mapping);
322 assert_eq!(headers["x-user-id"], "u-1");
323 assert_eq!(headers["x-n"], "7");
324 assert!(!headers.contains_key("x-obj"));
326 }
327
328 use crate::config::{AuthConfig, ForwardAuthConfig, JwtConfig, RoutePolicyConfig};
331 use axum::http::Request as HttpRequest;
332 use jsonwebtoken::{encode, EncodingKey, Header};
333 use std::sync::atomic::{AtomicU32, Ordering};
334 use tower::ServiceExt;
335
336 const TEST_PRIV_PEM: &str = "-----BEGIN PRIVATE KEY-----\n\
338 MC4CAQAwBQYDK2VwBCIEIEVVO7H+T5tERRn/dzukOc8i9iYEKKtPh//qcrES+dCt\n\
339 -----END PRIVATE KEY-----\n";
340 const TEST_PUB_PEM: &str = "-----BEGIN PUBLIC KEY-----\n\
341 MCowBQYDK2VwAyEARCMxEnaM2/dblLuPNgBZpTvSUXO5ir+XQ1nyzJm4CFw=\n\
342 -----END PUBLIC KEY-----\n";
343
344 fn temp_pub_pem() -> std::path::PathBuf {
345 static N: AtomicU32 = AtomicU32::new(0);
346 let path = std::env::temp_dir().join(format!(
347 "sp_auth_{}_{}.pem",
348 std::process::id(),
349 N.fetch_add(1, Ordering::Relaxed)
350 ));
351 std::fs::write(&path, TEST_PUB_PEM).unwrap();
352 path
353 }
354
355 fn sign(claims: serde_json::Value) -> String {
356 let key = EncodingKey::from_ed_pem(TEST_PRIV_PEM.as_bytes()).unwrap();
357 encode(&Header::new(Algorithm::EdDSA), &claims, &key).unwrap()
358 }
359
360 fn future_exp() -> i64 {
361 let now = std::time::SystemTime::now()
362 .duration_since(std::time::UNIX_EPOCH)
363 .unwrap()
364 .as_secs() as i64;
365 now + 3600
366 }
367
368 fn auth_with_policy(roles: &[&str]) -> Arc<Auth> {
369 let cfg = AuthConfig {
370 mode: "jwt".into(),
371 jwt: Some(JwtConfig {
372 jwks_uri: None,
373 issuer: Some("test-iss".into()),
374 audience: Some("test-aud".into()),
375 public_key_pem_file: Some(temp_pub_pem()),
376 claims_headers: HashMap::from([("sub".to_string(), "x-user".to_string())]),
377 roles_claim: "roles".into(),
378 }),
379 forward_auth: Some(ForwardAuthConfig {
380 enabled: true,
381 path: "/auth/verify".into(),
382 policies: vec![RoutePolicyConfig {
383 path: "/secure".into(),
384 methods: vec!["*".into()],
385 require_auth: true,
386 required_roles: roles.iter().map(|s| s.to_string()).collect(),
387 }],
388 login_url: None,
389 applications_path: None,
390 }),
391 authz: None,
392 };
393 Auth::build(&cfg).unwrap().unwrap()
394 }
395
396 fn app(auth: Arc<Auth>) -> axum::Router {
397 let echo = |headers: HeaderMap| async move {
399 headers
400 .get("x-user")
401 .and_then(|v| v.to_str().ok())
402 .unwrap_or("")
403 .to_string()
404 };
405 axum::Router::new()
406 .route("/secure", axum::routing::get(echo))
407 .route("/open", axum::routing::get(echo))
408 .layer(axum::middleware::from_fn_with_state(auth, middleware))
409 }
410
411 async fn body_string(resp: axum::response::Response) -> String {
412 let bytes = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
413 String::from_utf8(bytes.to_vec()).unwrap()
414 }
415
416 #[tokio::test]
417 async fn strips_client_supplied_claim_headers() {
418 let app = app(auth_with_policy(&[]));
421 let resp = app
422 .oneshot(
423 HttpRequest::get("/open")
424 .header("x-user", "forged-admin")
425 .body(axum::body::Body::empty())
426 .unwrap(),
427 )
428 .await
429 .unwrap();
430 assert_eq!(resp.status(), 200);
431 assert_eq!(body_string(resp).await, "");
432 }
433
434 #[tokio::test]
435 async fn unauthenticated_role_check_is_401_not_403() {
436 let cfg = AuthConfig {
439 mode: "jwt".into(),
440 jwt: Some(JwtConfig {
441 jwks_uri: None,
442 issuer: None,
443 audience: None,
444 public_key_pem_file: Some(temp_pub_pem()),
445 claims_headers: HashMap::new(),
446 roles_claim: "roles".into(),
447 }),
448 forward_auth: Some(ForwardAuthConfig {
449 enabled: true,
450 path: "/auth/verify".into(),
451 policies: vec![RoutePolicyConfig {
452 path: "/secure".into(),
453 methods: vec!["*".into()],
454 require_auth: false,
455 required_roles: vec!["admin".into()],
456 }],
457 login_url: None,
458 applications_path: None,
459 }),
460 authz: None,
461 };
462 let auth = Auth::build(&cfg).unwrap().unwrap();
463 let resp = app(auth)
464 .oneshot(
465 HttpRequest::get("/secure")
466 .body(axum::body::Body::empty())
467 .unwrap(),
468 )
469 .await
470 .unwrap();
471 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
472 }
473
474 #[tokio::test]
475 async fn rejects_missing_token_on_protected_route() {
476 let app = app(auth_with_policy(&[]));
477 let resp = app
478 .oneshot(
479 HttpRequest::get("/secure")
480 .body(axum::body::Body::empty())
481 .unwrap(),
482 )
483 .await
484 .unwrap();
485 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
486 }
487
488 #[tokio::test]
489 async fn accepts_valid_token_and_injects_claim_header() {
490 let app = app(auth_with_policy(&["admin"]));
491 let token = sign(serde_json::json!({
492 "iss": "test-iss", "aud": "test-aud", "exp": future_exp(),
493 "sub": "user-42", "roles": ["admin"]
494 }));
495 let resp = app
496 .oneshot(
497 HttpRequest::get("/secure")
498 .header("authorization", format!("Bearer {token}"))
499 .body(axum::body::Body::empty())
500 .unwrap(),
501 )
502 .await
503 .unwrap();
504 assert_eq!(resp.status(), 200);
505 let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
506 assert_eq!(&body[..], b"user-42");
508 }
509
510 #[tokio::test]
511 async fn forbids_when_required_role_missing() {
512 let app = app(auth_with_policy(&["admin"]));
513 let token = sign(serde_json::json!({
514 "iss": "test-iss", "aud": "test-aud", "exp": future_exp(),
515 "sub": "user-42", "roles": ["viewer"]
516 }));
517 let resp = app
518 .oneshot(
519 HttpRequest::get("/secure")
520 .header("authorization", format!("Bearer {token}"))
521 .body(axum::body::Body::empty())
522 .unwrap(),
523 )
524 .await
525 .unwrap();
526 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
527 }
528
529 #[tokio::test]
530 async fn rejects_expired_and_wrong_issuer() {
531 let app = app(auth_with_policy(&[]));
532 let expired = sign(serde_json::json!({
533 "iss": "test-iss", "aud": "test-aud", "exp": 1, "sub": "u", "roles": ["admin"]
534 }));
535 let resp = app
536 .clone()
537 .oneshot(
538 HttpRequest::get("/secure")
539 .header("authorization", format!("Bearer {expired}"))
540 .body(axum::body::Body::empty())
541 .unwrap(),
542 )
543 .await
544 .unwrap();
545 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
546
547 let wrong_iss = sign(serde_json::json!({
548 "iss": "evil", "aud": "test-aud", "exp": future_exp(), "sub": "u", "roles": ["admin"]
549 }));
550 let resp = app
551 .oneshot(
552 HttpRequest::get("/secure")
553 .header("authorization", format!("Bearer {wrong_iss}"))
554 .body(axum::body::Body::empty())
555 .unwrap(),
556 )
557 .await
558 .unwrap();
559 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
560 }
561}