#[cfg(feature = "auth")]
pub mod api_key;
#[cfg(feature = "auth")]
pub mod authorization;
#[cfg(feature = "auth")]
pub mod error;
#[cfg(feature = "auth")]
pub mod introspection;
#[cfg(feature = "auth")]
pub mod permission;
#[cfg(feature = "auth")]
pub mod session;
#[cfg(feature = "auth")]
pub use api_key::{ApiKeyRow, TenantApiKeyStore, hash_api_key};
#[cfg(feature = "auth")]
pub use authorization::{AuthorizationClient, AuthorizationConfig};
#[cfg(feature = "auth")]
pub use introspection::{IntrospectionConfig, IntrospectionSessionClient};
#[cfg(feature = "auth")]
pub use permission::{ObjectExtractor, PermissionLayer, PermissionService};
#[cfg(feature = "auth")]
pub use session::{IdentityMapping, IdentityMappingStore, KratosSessionClient, SessionClient};
use axum::{
body::Body,
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
pub use error::AuthError;
pub const TENANT_ID_HEADER: &str = "x-tenant-id";
pub const API_KEY_HEADER: &str = "x-api-key";
pub const SESSION_TOKEN_HEADER: &str = "x-session-token";
#[derive(Debug, Clone)]
pub struct TenantId(pub String);
#[derive(Debug, Clone)]
pub struct ApiKeyContext {
pub key_id: String,
pub tenant_id: String,
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct AuthContext {
pub tenant_id: Option<String>,
pub subject: Option<String>,
pub scopes: Vec<String>,
}
impl AuthContext {
pub fn unauthenticated() -> Self {
Self::default()
}
pub fn authenticated(tenant_id: impl Into<String>, subject: impl Into<String>) -> Self {
Self {
tenant_id: Some(tenant_id.into()),
subject: Some(subject.into()),
scopes: Vec::new(),
}
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn is_authenticated(&self) -> bool {
self.subject.is_some()
}
}
#[derive(Clone)]
pub struct AuthMiddlewareState {
pub api_keys: Arc<dyn TenantApiKeyStore>,
pub sessions: Arc<dyn SessionClient>,
pub identity_mappings: Arc<dyn IdentityMappingStore>,
}
impl std::fmt::Debug for AuthMiddlewareState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthMiddlewareState")
.field("api_keys", &"<dyn TenantApiKeyStore>")
.field("sessions", &"<dyn SessionClient>")
.field("identity_mappings", &"<dyn IdentityMappingStore>")
.finish()
}
}
pub async fn auth_middleware(
axum::extract::State(state): axum::extract::State<AuthMiddlewareState>,
mut request: Request,
next: Next,
) -> Response {
let path = request.uri().path();
if is_public_path(path) {
return next.run(request).await;
}
if let Some(key) = header_value(&request, API_KEY_HEADER) {
match authenticate_api_key(state.api_keys.as_ref(), &key).await {
Ok(ctx) => {
let tenant_id = ctx.tenant_id.clone();
let auth_ctx = AuthContext::authenticated(&tenant_id, &ctx.key_id)
.with_scopes(ctx.scopes.clone());
request.extensions_mut().insert(TenantId(tenant_id));
request.extensions_mut().insert(ctx);
request.extensions_mut().insert(auth_ctx);
}
Err(resp) => return resp,
}
return next.run(request).await;
}
if let Some(token) = session_token_from_headers(&request) {
match authenticate_session(
state.sessions.as_ref(),
state.identity_mappings.as_ref(),
None,
Some(&token),
)
.await
{
Ok((tenant_id, identity_id)) => {
let auth_ctx = AuthContext::authenticated(&tenant_id, &identity_id);
request.extensions_mut().insert(TenantId(tenant_id));
request.extensions_mut().insert(auth_ctx);
}
Err(resp) => return resp,
}
return next.run(request).await;
}
if let Some(cookie) = header_value(&request, "cookie") {
match authenticate_session(
state.sessions.as_ref(),
state.identity_mappings.as_ref(),
Some(&cookie),
None,
)
.await
{
Ok((tenant_id, identity_id)) => {
let auth_ctx = AuthContext::authenticated(&tenant_id, &identity_id);
request.extensions_mut().insert(TenantId(tenant_id));
request.extensions_mut().insert(auth_ctx);
}
Err(resp) => return resp,
}
return next.run(request).await;
}
match header_value(&request, TENANT_ID_HEADER) {
Some(value) => match parse_tenant_id(&value) {
Ok(tenant_id) => {
request.extensions_mut().insert(TenantId(tenant_id.clone()));
request
.extensions_mut()
.insert(AuthContext::authenticated(&tenant_id, &tenant_id));
}
Err(resp) => return resp,
},
None => {
return auth_error(
StatusCode::UNAUTHORIZED,
"missing x-tenant-id or x-api-key header",
);
}
}
next.run(request).await
}
fn is_public_path(path: &str) -> bool {
path.starts_with("/.well-known/")
|| path.starts_with("/oauth2/")
|| path.starts_with("/health/")
}
fn header_value(request: &Request, name: &str) -> Option<String> {
request
.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(str::to_string)
}
#[allow(clippy::result_large_err)]
async fn authenticate_api_key(
store: &dyn TenantApiKeyStore,
key: &str,
) -> Result<ApiKeyContext, Response> {
let hash = hash_api_key(key);
match store.get_by_hash(&hash).await {
Ok(row) => Ok(ApiKeyContext {
key_id: row.id,
tenant_id: row.tenant_id,
scopes: row.scopes,
}),
Err(AuthError::InvalidApiKey) => Err(auth_error(
StatusCode::UNAUTHORIZED,
"invalid or expired api key",
)),
Err(_) => Err(auth_error(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to authenticate api key",
)),
}
}
#[allow(clippy::result_large_err)]
async fn authenticate_session(
client: &dyn SessionClient,
mappings: &dyn IdentityMappingStore,
cookie: Option<&str>,
token: Option<&str>,
) -> Result<(String, String), Response> {
let session = client
.to_session(cookie, token)
.await
.map_err(|_| auth_error(StatusCode::UNAUTHORIZED, "invalid or expired session"))?;
let upstream_identity_id = session["identity"]["id"].as_str().unwrap_or("").to_string();
if upstream_identity_id.is_empty() {
return Err(auth_error(
StatusCode::UNAUTHORIZED,
"session missing identity",
));
}
let mapping = mappings
.get_identity_mapping("kratos", &upstream_identity_id)
.await
.map_err(|_| {
auth_error(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to resolve tenant",
)
})?
.ok_or_else(|| auth_error(StatusCode::UNAUTHORIZED, "identity not registered"))?;
Ok((mapping.tenant_id, mapping.identity_id))
}
fn session_token_from_headers(request: &Request) -> Option<String> {
if let Some(token) = header_value(request, SESSION_TOKEN_HEADER) {
return Some(token);
}
header_value(request, "authorization").and_then(|auth| {
auth.strip_prefix("Bearer ")
.or_else(|| auth.strip_prefix("bearer "))
.map(str::to_string)
})
}
#[allow(clippy::result_large_err)]
fn parse_tenant_id(value: &str) -> Result<String, Response> {
if value.is_empty() {
return Err(auth_error(StatusCode::BAD_REQUEST, "missing tenant id"));
}
if ulid::Ulid::from_string(value).is_err() {
return Err(auth_error(StatusCode::BAD_REQUEST, "invalid tenant id"));
}
Ok(value.to_string())
}
fn auth_error(status: StatusCode, message: &'static str) -> Response {
let body = Body::from(format!("{{\"error\":\"{message}\"}}"));
(
status,
[(axum::http::header::CONTENT_TYPE, "application/json")],
body,
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn is_public_path_matches_public_prefixes() {
assert!(is_public_path("/.well-known/openid-configuration"));
assert!(is_public_path("/oauth2/auth"));
assert!(is_public_path("/health/live"));
assert!(!is_public_path("/api/v1/things"));
}
#[test]
fn parse_tenant_id_accepts_valid_ulid() {
let valid = ulid::Ulid::new().to_string();
assert_eq!(parse_tenant_id(&valid).unwrap(), valid);
}
#[test]
fn parse_tenant_id_rejects_empty() {
let resp = parse_tenant_id("").unwrap_err();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn parse_tenant_id_rejects_invalid_ulid() {
let resp = parse_tenant_id("not-a-ulid").unwrap_err();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
struct StubApiKeyStore(Mutex<Option<Result<ApiKeyRow, AuthError>>>);
#[async_trait::async_trait]
impl TenantApiKeyStore for StubApiKeyStore {
async fn get_by_hash(&self, _hash: &str) -> Result<ApiKeyRow, AuthError> {
self.0
.lock()
.unwrap()
.take()
.unwrap_or(Err(AuthError::InvalidApiKey))
}
}
fn dummy_api_key_row() -> ApiKeyRow {
ApiKeyRow {
id: "key-1".into(),
tenant_id: "tenant-1".into(),
key_hash: hash_api_key("secret"),
name: "test".into(),
scopes: vec!["tenant:read".into()],
}
}
#[tokio::test]
async fn authenticate_api_key_returns_context_for_valid_key() {
let store = StubApiKeyStore(Mutex::new(Some(Ok(dummy_api_key_row()))));
let ctx = authenticate_api_key(&store, "secret").await.unwrap();
assert_eq!(ctx.key_id, "key-1");
assert_eq!(ctx.tenant_id, "tenant-1");
assert_eq!(ctx.scopes, vec!["tenant:read".to_string()]);
}
#[tokio::test]
async fn authenticate_api_key_returns_unauthorized_for_unknown_key() {
let store = StubApiKeyStore(Mutex::new(Some(Err(AuthError::InvalidApiKey))));
let err = authenticate_api_key(&store, "secret").await.unwrap_err();
assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
}
}