pub mod authz;
pub mod forward;
pub mod jwks;
pub mod policy;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use axum::extract::State;
use axum::http::header::{HeaderName, HeaderValue};
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::Json;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use serde_json::Value;
use crate::config::AuthConfig;
use jwks::JwksCache;
use policy::Policies;
enum KeySource {
Pem(Arc<DecodingKey>),
Jwks(JwksCache),
}
pub struct Auth {
keys: KeySource,
issuer: Option<String>,
audience: Option<String>,
claims_headers: HashMap<String, String>,
roles_claim: String,
policies: Policies,
}
impl Auth {
pub fn build(config: &AuthConfig) -> Result<Option<Arc<Self>>, String> {
if config.mode != "jwt" {
return Ok(None);
}
let jwt = config
.jwt
.as_ref()
.ok_or("auth.mode is \"jwt\" but auth.jwt is not set")?;
let keys = if let Some(uri) = &jwt.jwks_uri {
KeySource::Jwks(JwksCache::new(uri.clone()))
} else if let Some(pem_path) = &jwt.public_key_pem_file {
let pem = std::fs::read(pem_path)
.map_err(|e| format!("failed to read auth.jwt.public_key_pem_file: {e}"))?;
let key = DecodingKey::from_ed_pem(&pem)
.map_err(|e| format!("invalid Ed25519 public key PEM: {e}"))?;
KeySource::Pem(Arc::new(key))
} else {
return Err("auth.jwt requires either jwks_uri or public_key_pem_file".to_string());
};
let policies = match &config.forward_auth {
Some(fa) => Policies::compile(&fa.policies)?,
None => Policies::default(),
};
Ok(Some(Arc::new(Self {
keys,
issuer: jwt.issuer.clone(),
audience: jwt.audience.clone(),
claims_headers: jwt.claims_headers.clone(),
roles_claim: jwt.roles_claim.clone(),
policies,
})))
}
async fn verify(&self, token: &str) -> Option<Value> {
let header = decode_header(token).ok()?;
let (key, algorithm) = match &self.keys {
KeySource::Pem(k) => (k.clone(), Algorithm::EdDSA),
KeySource::Jwks(cache) => {
let kid = header.kid.as_deref()?;
let vk = cache.key_for(kid).await?;
(vk.key, vk.algorithm)
}
};
if header.alg != algorithm {
return None;
}
let mut validation = Validation::new(algorithm);
if let Some(iss) = &self.issuer {
validation.set_issuer(&[iss]);
}
match &self.audience {
Some(aud) => validation.set_audience(&[aud]),
None => validation.validate_aud = false,
}
decode::<Value>(token, &key, &validation)
.ok()
.map(|data| data.claims)
}
}
pub(crate) enum AuthDecision {
Allow(HeaderMap),
Unauthenticated(&'static str),
Forbidden(&'static str),
}
impl Auth {
pub(crate) async fn decide(
&self,
headers: &HeaderMap,
path: &str,
method: &str,
) -> AuthDecision {
let claims = match bearer_token(headers) {
Some(token) => match self.verify(&token).await {
Some(c) => Some(c),
None => return AuthDecision::Unauthenticated("invalid or expired token"),
},
None => None,
};
if let Some(policy) = self.policies.match_rule(path, method) {
if policy.require_auth && claims.is_none() {
return AuthDecision::Unauthenticated("authentication required");
}
if !policy.required_roles.is_empty() {
let Some(claims) = claims.as_ref() else {
return AuthDecision::Unauthenticated("authentication required");
};
let roles = extract_roles(claims, &self.roles_claim);
if !policy.required_roles.iter().all(|r| roles.contains(r)) {
return AuthDecision::Forbidden("insufficient role");
}
}
}
let mut claim_headers = HeaderMap::new();
if let Some(claims) = &claims {
inject_claim_headers(&mut claim_headers, claims, &self.claims_headers);
}
AuthDecision::Allow(claim_headers)
}
}
pub async fn middleware(
State(auth): State<Arc<Auth>>,
mut request: axum::extract::Request,
next: Next,
) -> Response {
let path = request.uri().path().to_string();
let method = request.method().as_str().to_ascii_uppercase();
strip_claim_headers(request.headers_mut(), &auth.claims_headers);
match auth.decide(request.headers(), &path, &method).await {
AuthDecision::Unauthenticated(msg) => unauthorized(msg),
AuthDecision::Forbidden(msg) => forbidden(msg),
AuthDecision::Allow(claim_headers) => {
let dst = request.headers_mut();
for (name, value) in &claim_headers {
dst.insert(name.clone(), value.clone());
}
next.run(request).await
}
}
}
fn bearer_token(headers: &HeaderMap) -> Option<String> {
let value = headers.get("authorization")?.to_str().ok()?;
let token = value
.strip_prefix("Bearer ")
.or_else(|| value.strip_prefix("bearer "))?;
let token = token.trim();
(!token.is_empty()).then(|| token.to_string())
}
fn claim_at<'a>(claims: &'a Value, path: &str) -> Option<&'a Value> {
let mut cur = claims;
for seg in path.split('.') {
cur = cur.get(seg)?;
}
Some(cur)
}
fn extract_roles(claims: &Value, roles_claim: &str) -> HashSet<String> {
claim_at(claims, roles_claim)
.and_then(Value::as_array)
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_string))
.collect()
})
.unwrap_or_default()
}
fn strip_claim_headers(headers: &mut HeaderMap, mapping: &HashMap<String, String>) {
for header in mapping.values() {
if let Ok(name) = HeaderName::try_from(header.as_str()) {
while headers.remove(&name).is_some() {}
}
}
}
fn inject_claim_headers(
headers: &mut HeaderMap,
claims: &Value,
mapping: &HashMap<String, String>,
) {
for (claim, header) in mapping {
let Some(value) = claim_at(claims, claim) else {
continue;
};
let rendered = match value {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
_ => continue,
};
if let (Ok(name), Ok(val)) = (
HeaderName::try_from(header.as_str()),
HeaderValue::try_from(rendered),
) {
headers.insert(name, val);
}
}
}
fn unauthorized(message: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({ "error": "UNAUTHENTICATED", "message": message })),
)
.into_response()
}
fn forbidden(message: &str) -> Response {
(
StatusCode::FORBIDDEN,
Json(serde_json::json!({ "error": "PERMISSION_DENIED", "message": message })),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bearer_token_parsing() {
let mut h = HeaderMap::new();
h.insert("authorization", "Bearer abc.def.ghi".parse().unwrap());
assert_eq!(bearer_token(&h).as_deref(), Some("abc.def.ghi"));
let mut h2 = HeaderMap::new();
h2.insert("authorization", "Basic xyz".parse().unwrap());
assert_eq!(bearer_token(&h2), None);
assert_eq!(bearer_token(&HeaderMap::new()), None);
}
#[test]
fn extract_roles_reads_array_and_dotted_path() {
let claims = serde_json::json!({
"roles": ["admin", "billing"],
"realm_access": { "roles": ["nested"] }
});
assert!(extract_roles(&claims, "roles").contains("admin"));
assert!(extract_roles(&claims, "realm_access.roles").contains("nested"));
assert!(extract_roles(&claims, "missing").is_empty());
}
#[test]
fn inject_claim_headers_renders_scalars() {
let claims = serde_json::json!({ "sub": "u-1", "n": 7, "obj": {"x": 1} });
let mapping = HashMap::from([
("sub".to_string(), "x-user-id".to_string()),
("n".to_string(), "x-n".to_string()),
("obj".to_string(), "x-obj".to_string()),
]);
let mut headers = HeaderMap::new();
inject_claim_headers(&mut headers, &claims, &mapping);
assert_eq!(headers["x-user-id"], "u-1");
assert_eq!(headers["x-n"], "7");
assert!(!headers.contains_key("x-obj"));
}
use crate::config::{AuthConfig, ForwardAuthConfig, JwtConfig, RoutePolicyConfig};
use axum::http::Request as HttpRequest;
use jsonwebtoken::{encode, EncodingKey, Header};
use std::sync::atomic::{AtomicU32, Ordering};
use tower::ServiceExt;
const TEST_PRIV_PEM: &str = "-----BEGIN PRIVATE KEY-----\n\
MC4CAQAwBQYDK2VwBCIEIEVVO7H+T5tERRn/dzukOc8i9iYEKKtPh//qcrES+dCt\n\
-----END PRIVATE KEY-----\n";
const TEST_PUB_PEM: &str = "-----BEGIN PUBLIC KEY-----\n\
MCowBQYDK2VwAyEARCMxEnaM2/dblLuPNgBZpTvSUXO5ir+XQ1nyzJm4CFw=\n\
-----END PUBLIC KEY-----\n";
fn temp_pub_pem() -> std::path::PathBuf {
static N: AtomicU32 = AtomicU32::new(0);
let path = std::env::temp_dir().join(format!(
"sp_auth_{}_{}.pem",
std::process::id(),
N.fetch_add(1, Ordering::Relaxed)
));
std::fs::write(&path, TEST_PUB_PEM).unwrap();
path
}
fn sign(claims: serde_json::Value) -> String {
let key = EncodingKey::from_ed_pem(TEST_PRIV_PEM.as_bytes()).unwrap();
encode(&Header::new(Algorithm::EdDSA), &claims, &key).unwrap()
}
fn future_exp() -> i64 {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
now + 3600
}
fn auth_with_policy(roles: &[&str]) -> Arc<Auth> {
let cfg = AuthConfig {
mode: "jwt".into(),
jwt: Some(JwtConfig {
jwks_uri: None,
issuer: Some("test-iss".into()),
audience: Some("test-aud".into()),
public_key_pem_file: Some(temp_pub_pem()),
claims_headers: HashMap::from([("sub".to_string(), "x-user".to_string())]),
roles_claim: "roles".into(),
}),
forward_auth: Some(ForwardAuthConfig {
enabled: true,
path: "/auth/verify".into(),
policies: vec![RoutePolicyConfig {
path: "/secure".into(),
methods: vec!["*".into()],
require_auth: true,
required_roles: roles.iter().map(|s| s.to_string()).collect(),
}],
login_url: None,
applications_path: None,
}),
authz: None,
};
Auth::build(&cfg).unwrap().unwrap()
}
fn app(auth: Arc<Auth>) -> axum::Router {
let echo = |headers: HeaderMap| async move {
headers
.get("x-user")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string()
};
axum::Router::new()
.route("/secure", axum::routing::get(echo))
.route("/open", axum::routing::get(echo))
.layer(axum::middleware::from_fn_with_state(auth, middleware))
}
async fn body_string(resp: axum::response::Response) -> String {
let bytes = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[tokio::test]
async fn strips_client_supplied_claim_headers() {
let app = app(auth_with_policy(&[]));
let resp = app
.oneshot(
HttpRequest::get("/open")
.header("x-user", "forged-admin")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(body_string(resp).await, "");
}
#[tokio::test]
async fn unauthenticated_role_check_is_401_not_403() {
let cfg = AuthConfig {
mode: "jwt".into(),
jwt: Some(JwtConfig {
jwks_uri: None,
issuer: None,
audience: None,
public_key_pem_file: Some(temp_pub_pem()),
claims_headers: HashMap::new(),
roles_claim: "roles".into(),
}),
forward_auth: Some(ForwardAuthConfig {
enabled: true,
path: "/auth/verify".into(),
policies: vec![RoutePolicyConfig {
path: "/secure".into(),
methods: vec!["*".into()],
require_auth: false,
required_roles: vec!["admin".into()],
}],
login_url: None,
applications_path: None,
}),
authz: None,
};
let auth = Auth::build(&cfg).unwrap().unwrap();
let resp = app(auth)
.oneshot(
HttpRequest::get("/secure")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn rejects_missing_token_on_protected_route() {
let app = app(auth_with_policy(&[]));
let resp = app
.oneshot(
HttpRequest::get("/secure")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn accepts_valid_token_and_injects_claim_header() {
let app = app(auth_with_policy(&["admin"]));
let token = sign(serde_json::json!({
"iss": "test-iss", "aud": "test-aud", "exp": future_exp(),
"sub": "user-42", "roles": ["admin"]
}));
let resp = app
.oneshot(
HttpRequest::get("/secure")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
assert_eq!(&body[..], b"user-42");
}
#[tokio::test]
async fn forbids_when_required_role_missing() {
let app = app(auth_with_policy(&["admin"]));
let token = sign(serde_json::json!({
"iss": "test-iss", "aud": "test-aud", "exp": future_exp(),
"sub": "user-42", "roles": ["viewer"]
}));
let resp = app
.oneshot(
HttpRequest::get("/secure")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn rejects_expired_and_wrong_issuer() {
let app = app(auth_with_policy(&[]));
let expired = sign(serde_json::json!({
"iss": "test-iss", "aud": "test-aud", "exp": 1, "sub": "u", "roles": ["admin"]
}));
let resp = app
.clone()
.oneshot(
HttpRequest::get("/secure")
.header("authorization", format!("Bearer {expired}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let wrong_iss = sign(serde_json::json!({
"iss": "evil", "aud": "test-aud", "exp": future_exp(), "sub": "u", "roles": ["admin"]
}));
let resp = app
.oneshot(
HttpRequest::get("/secure")
.header("authorization", format!("Bearer {wrong_iss}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
}