pub mod error;
use axum::{
body::Body,
extract::State,
http::{HeaderMap, Request},
middleware::Next,
response::{IntoResponse, Response},
};
use idprova_core::dat::constraints::EvaluationContext;
use idprova_core::dat::Dat;
use std::net::IpAddr;
pub use error::DatMiddlewareError;
#[derive(Debug, Clone)]
pub struct VerifiedDat {
pub dat: Dat,
pub subject_did: String,
pub issuer_did: String,
pub scopes: Vec<String>,
pub jti: String,
}
#[derive(Debug, Clone)]
pub struct DatVerificationConfig {
pub public_key: [u8; 32],
pub required_scope: String,
}
fn build_eval_context(headers: &HeaderMap) -> EvaluationContext {
let request_ip: Option<IpAddr> = headers
.get("X-Forwarded-For")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(str::trim)
.and_then(|s| s.parse().ok())
.or_else(|| {
headers
.get("X-Real-IP")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.and_then(|s| s.parse().ok())
});
EvaluationContext {
request_ip,
current_timestamp: None,
..Default::default()
}
}
fn extract_bearer_token(headers: &HeaderMap) -> Result<&str, DatMiddlewareError> {
let auth = headers
.get("Authorization")
.ok_or_else(|| DatMiddlewareError::unauthorized("Authorization header required"))?;
let auth_str = auth
.to_str()
.map_err(|_| DatMiddlewareError::unauthorized("invalid Authorization header encoding"))?;
let token = auth_str.strip_prefix("Bearer ").unwrap_or("").trim();
if token.is_empty() {
return Err(DatMiddlewareError::unauthorized("Bearer token required"));
}
Ok(token)
}
pub async fn dat_verification_middleware(
State(config): State<DatVerificationConfig>,
mut request: Request<Body>,
next: Next,
) -> Response {
let headers = request.headers();
let token = match extract_bearer_token(headers) {
Ok(t) => t.to_string(),
Err(e) => return e.into_response(),
};
let ctx = build_eval_context(headers);
let dat = match idprova_verify::verify_dat(
&token,
&config.public_key,
&config.required_scope,
&ctx,
) {
Ok(dat) => dat,
Err(e) => {
let msg = e.to_string();
tracing::warn!("DAT verification failed: {msg}");
let error = if msg.contains("scope") {
DatMiddlewareError::forbidden(msg)
} else {
DatMiddlewareError::unauthorized(msg)
};
return error.into_response();
}
};
let verified = VerifiedDat {
subject_did: dat.claims.sub.clone(),
issuer_did: dat.claims.iss.clone(),
scopes: dat.claims.scope.clone(),
jti: dat.claims.jti.clone(),
dat,
};
request.extensions_mut().insert(verified);
next.run(request).await
}
pub fn make_dat_config(public_key: [u8; 32], required_scope: &str) -> DatVerificationConfig {
DatVerificationConfig {
public_key,
required_scope: required_scope.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
#[test]
fn test_build_eval_context_with_forwarded_for() {
let mut headers = HeaderMap::new();
headers.insert("X-Forwarded-For", "192.168.1.1, 10.0.0.1".parse().unwrap());
let ctx = build_eval_context(&headers);
assert_eq!(
ctx.request_ip,
Some("192.168.1.1".parse::<IpAddr>().unwrap())
);
}
#[test]
fn test_build_eval_context_with_real_ip() {
let mut headers = HeaderMap::new();
headers.insert("X-Real-IP", "10.0.0.5".parse().unwrap());
let ctx = build_eval_context(&headers);
assert_eq!(ctx.request_ip, Some("10.0.0.5".parse::<IpAddr>().unwrap()));
}
#[test]
fn test_build_eval_context_no_ip() {
let headers = HeaderMap::new();
let ctx = build_eval_context(&headers);
assert!(ctx.request_ip.is_none());
}
#[test]
fn test_extract_bearer_missing_header() {
let headers = HeaderMap::new();
assert!(extract_bearer_token(&headers).is_err());
}
#[test]
fn test_extract_bearer_empty_token() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Bearer ".parse().unwrap());
assert!(extract_bearer_token(&headers).is_err());
}
#[test]
fn test_extract_bearer_no_bearer_prefix() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Basic abc123".parse().unwrap());
assert!(extract_bearer_token(&headers).is_err());
}
#[test]
fn test_extract_bearer_valid() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Bearer my-token-here".parse().unwrap());
let token = extract_bearer_token(&headers).unwrap();
assert_eq!(token, "my-token-here");
}
#[test]
fn test_error_into_response_unauthorized() {
let err = DatMiddlewareError::unauthorized("bad token");
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_error_into_response_forbidden() {
let err = DatMiddlewareError::forbidden("scope denied");
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
}