sunbeam-g2v 0.4.0

Sunbeam Service Framework - A ConnectRPC-based framework for building microservices
//! Authentication and authorization middleware.
//!
//! Provides multitenancy-aware authn/authz for Sunbeam services:
//!
//! - `auth_middleware` resolves the caller to a [`TenantId`] using an API key,
//!   a Kratos session cookie/token, or a bare `X-Tenant-Id` header.
//! - [`PermissionLayer`] enforces ReBAC permission checks against the configured
//!   authorization backend.
//! - [`AuthorizationClient`] talks to that backend over HTTP.

#[cfg(feature = "auth")]
pub mod api_key;
#[cfg(feature = "auth")]
pub mod authorization;
#[cfg(feature = "auth")]
pub mod error;
#[cfg(feature = "auth")]
pub mod introspection;
#[cfg(feature = "auth")]
pub mod permission;
#[cfg(feature = "auth")]
pub mod session;

#[cfg(feature = "auth")]
pub use api_key::{ApiKeyRow, TenantApiKeyStore, hash_api_key};
#[cfg(feature = "auth")]
pub use authorization::{AuthorizationClient, AuthorizationConfig};
#[cfg(feature = "auth")]
pub use introspection::{IntrospectionConfig, IntrospectionSessionClient};
#[cfg(feature = "auth")]
pub use permission::{ObjectExtractor, PermissionLayer, PermissionService};
#[cfg(feature = "auth")]
pub use session::{IdentityMapping, IdentityMappingStore, KratosSessionClient, SessionClient};

use axum::{
    body::Body,
    extract::Request,
    http::StatusCode,
    middleware::Next,
    response::{IntoResponse, Response},
};
use std::sync::Arc;

pub use error::AuthError;

/// Header carrying the tenant id.
pub const TENANT_ID_HEADER: &str = "x-tenant-id";
/// Header carrying an API key.
pub const API_KEY_HEADER: &str = "x-api-key";
/// Header carrying a Kratos session token.
pub const SESSION_TOKEN_HEADER: &str = "x-session-token";

/// Resolved tenant for the request.
#[derive(Debug, Clone)]
pub struct TenantId(pub String);

/// Context set when the request authenticated with an API key.
#[derive(Debug, Clone)]
pub struct ApiKeyContext {
    /// Key identifier.
    pub key_id: String,
    /// Tenant that owns the key.
    pub tenant_id: String,
    /// Scopes granted to the key.
    pub scopes: Vec<String>,
}

/// Slim identity context for use in handlers.
#[derive(Debug, Clone, Default)]
pub struct AuthContext {
    /// Resolved tenant id.
    pub tenant_id: Option<String>,
    /// Authenticated subject (API key id or identity id).
    pub subject: Option<String>,
    /// Actor scopes, if authenticated via API key.
    pub scopes: Vec<String>,
}

impl AuthContext {
    /// Create an unauthenticated context.
    pub fn unauthenticated() -> Self {
        Self::default()
    }

    /// Create an authenticated context for the given tenant and subject.
    pub fn authenticated(tenant_id: impl Into<String>, subject: impl Into<String>) -> Self {
        Self {
            tenant_id: Some(tenant_id.into()),
            subject: Some(subject.into()),
            scopes: Vec::new(),
        }
    }

    /// Set scopes.
    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
        self.scopes = scopes;
        self
    }

    /// Returns true when a subject has been resolved.
    pub fn is_authenticated(&self) -> bool {
        self.subject.is_some()
    }
}

/// Shared state required by [`auth_middleware`].
#[derive(Clone)]
pub struct AuthMiddlewareState {
    /// Validates API keys.
    pub api_keys: Arc<dyn TenantApiKeyStore>,
    /// Validates Kratos sessions.
    pub sessions: Arc<dyn SessionClient>,
    /// Maps an identity to a tenant.
    pub identity_mappings: Arc<dyn IdentityMappingStore>,
}

impl std::fmt::Debug for AuthMiddlewareState {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("AuthMiddlewareState")
            .field("api_keys", &"<dyn TenantApiKeyStore>")
            .field("sessions", &"<dyn SessionClient>")
            .field("identity_mappings", &"<dyn IdentityMappingStore>")
            .finish()
    }
}

/// Axum middleware that resolves authentication and inserts a [`TenantId`] and
/// [`AuthContext`] into request extensions.
///
/// Public OAuth2/OIDC discovery paths are not validated here; handlers for those
/// routes perform their own tenant validation.
pub async fn auth_middleware(
    axum::extract::State(state): axum::extract::State<AuthMiddlewareState>,
    mut request: Request,
    next: Next,
) -> Response {
    let path = request.uri().path();
    if is_public_path(path) {
        return next.run(request).await;
    }

    // 1. API key authentication.
    if let Some(key) = header_value(&request, API_KEY_HEADER) {
        match authenticate_api_key(state.api_keys.as_ref(), &key).await {
            Ok(ctx) => {
                let tenant_id = ctx.tenant_id.clone();
                let auth_ctx = AuthContext::authenticated(&tenant_id, &ctx.key_id)
                    .with_scopes(ctx.scopes.clone());
                request.extensions_mut().insert(TenantId(tenant_id));
                request.extensions_mut().insert(ctx);
                request.extensions_mut().insert(auth_ctx);
            }
            Err(resp) => return resp,
        }
        return next.run(request).await;
    }

    // 2. Session token authentication (API/Bearer flows).
    if let Some(token) = session_token_from_headers(&request) {
        match authenticate_session(
            state.sessions.as_ref(),
            state.identity_mappings.as_ref(),
            None,
            Some(&token),
        )
        .await
        {
            Ok((tenant_id, identity_id)) => {
                let auth_ctx = AuthContext::authenticated(&tenant_id, &identity_id);
                request.extensions_mut().insert(TenantId(tenant_id));
                request.extensions_mut().insert(auth_ctx);
            }
            Err(resp) => return resp,
        }
        return next.run(request).await;
    }

    // 3. Session cookie authentication.
    if let Some(cookie) = header_value(&request, "cookie") {
        match authenticate_session(
            state.sessions.as_ref(),
            state.identity_mappings.as_ref(),
            Some(&cookie),
            None,
        )
        .await
        {
            Ok((tenant_id, identity_id)) => {
                let auth_ctx = AuthContext::authenticated(&tenant_id, &identity_id);
                request.extensions_mut().insert(TenantId(tenant_id));
                request.extensions_mut().insert(auth_ctx);
            }
            Err(resp) => return resp,
        }
        return next.run(request).await;
    }

    // 3. Bare tenant header (bootstrap / service-to-service).
    match header_value(&request, TENANT_ID_HEADER) {
        Some(value) => match parse_tenant_id(&value) {
            Ok(tenant_id) => {
                request.extensions_mut().insert(TenantId(tenant_id.clone()));
                request
                    .extensions_mut()
                    .insert(AuthContext::authenticated(&tenant_id, &tenant_id));
            }
            Err(resp) => return resp,
        },
        None => {
            return auth_error(
                StatusCode::UNAUTHORIZED,
                "missing x-tenant-id or x-api-key header",
            );
        }
    }

    next.run(request).await
}

fn is_public_path(path: &str) -> bool {
    path.starts_with("/.well-known/")
        || path.starts_with("/oauth2/")
        || path.starts_with("/health/")
}

fn header_value(request: &Request, name: &str) -> Option<String> {
    request
        .headers()
        .get(name)
        .and_then(|v| v.to_str().ok())
        .map(str::to_string)
}

#[allow(clippy::result_large_err)]
async fn authenticate_api_key(
    store: &dyn TenantApiKeyStore,
    key: &str,
) -> Result<ApiKeyContext, Response> {
    let hash = hash_api_key(key);
    match store.get_by_hash(&hash).await {
        Ok(row) => Ok(ApiKeyContext {
            key_id: row.id,
            tenant_id: row.tenant_id,
            scopes: row.scopes,
        }),
        Err(AuthError::InvalidApiKey) => Err(auth_error(
            StatusCode::UNAUTHORIZED,
            "invalid or expired api key",
        )),
        Err(_) => Err(auth_error(
            StatusCode::INTERNAL_SERVER_ERROR,
            "failed to authenticate api key",
        )),
    }
}

#[allow(clippy::result_large_err)]
async fn authenticate_session(
    client: &dyn SessionClient,
    mappings: &dyn IdentityMappingStore,
    cookie: Option<&str>,
    token: Option<&str>,
) -> Result<(String, String), Response> {
    let session = client
        .to_session(cookie, token)
        .await
        .map_err(|_| auth_error(StatusCode::UNAUTHORIZED, "invalid or expired session"))?;

    let upstream_identity_id = session["identity"]["id"].as_str().unwrap_or("").to_string();

    if upstream_identity_id.is_empty() {
        return Err(auth_error(
            StatusCode::UNAUTHORIZED,
            "session missing identity",
        ));
    }

    let mapping = mappings
        .get_identity_mapping("kratos", &upstream_identity_id)
        .await
        .map_err(|_| {
            auth_error(
                StatusCode::INTERNAL_SERVER_ERROR,
                "failed to resolve tenant",
            )
        })?
        .ok_or_else(|| auth_error(StatusCode::UNAUTHORIZED, "identity not registered"))?;

    Ok((mapping.tenant_id, mapping.identity_id))
}

fn session_token_from_headers(request: &Request) -> Option<String> {
    if let Some(token) = header_value(request, SESSION_TOKEN_HEADER) {
        return Some(token);
    }
    header_value(request, "authorization").and_then(|auth| {
        auth.strip_prefix("Bearer ")
            .or_else(|| auth.strip_prefix("bearer "))
            .map(str::to_string)
    })
}

#[allow(clippy::result_large_err)]
fn parse_tenant_id(value: &str) -> Result<String, Response> {
    if value.is_empty() {
        return Err(auth_error(StatusCode::BAD_REQUEST, "missing tenant id"));
    }
    if ulid::Ulid::from_string(value).is_err() {
        return Err(auth_error(StatusCode::BAD_REQUEST, "invalid tenant id"));
    }
    Ok(value.to_string())
}

fn auth_error(status: StatusCode, message: &'static str) -> Response {
    let body = Body::from(format!("{{\"error\":\"{message}\"}}"));
    (
        status,
        [(axum::http::header::CONTENT_TYPE, "application/json")],
        body,
    )
        .into_response()
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Mutex;

    #[test]
    fn is_public_path_matches_public_prefixes() {
        assert!(is_public_path("/.well-known/openid-configuration"));
        assert!(is_public_path("/oauth2/auth"));
        assert!(is_public_path("/health/live"));
        assert!(!is_public_path("/api/v1/things"));
    }

    #[test]
    fn parse_tenant_id_accepts_valid_ulid() {
        let valid = ulid::Ulid::new().to_string();
        assert_eq!(parse_tenant_id(&valid).unwrap(), valid);
    }

    #[test]
    fn parse_tenant_id_rejects_empty() {
        let resp = parse_tenant_id("").unwrap_err();
        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
    }

    #[test]
    fn parse_tenant_id_rejects_invalid_ulid() {
        let resp = parse_tenant_id("not-a-ulid").unwrap_err();
        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
    }

    struct StubApiKeyStore(Mutex<Option<Result<ApiKeyRow, AuthError>>>);

    #[async_trait::async_trait]
    impl TenantApiKeyStore for StubApiKeyStore {
        async fn get_by_hash(&self, _hash: &str) -> Result<ApiKeyRow, AuthError> {
            self.0
                .lock()
                .unwrap()
                .take()
                .unwrap_or(Err(AuthError::InvalidApiKey))
        }
    }

    fn dummy_api_key_row() -> ApiKeyRow {
        ApiKeyRow {
            id: "key-1".into(),
            tenant_id: "tenant-1".into(),
            key_hash: hash_api_key("secret"),
            name: "test".into(),
            scopes: vec!["tenant:read".into()],
        }
    }

    #[tokio::test]
    async fn authenticate_api_key_returns_context_for_valid_key() {
        let store = StubApiKeyStore(Mutex::new(Some(Ok(dummy_api_key_row()))));
        let ctx = authenticate_api_key(&store, "secret").await.unwrap();
        assert_eq!(ctx.key_id, "key-1");
        assert_eq!(ctx.tenant_id, "tenant-1");
        assert_eq!(ctx.scopes, vec!["tenant:read".to_string()]);
    }

    #[tokio::test]
    async fn authenticate_api_key_returns_unauthorized_for_unknown_key() {
        let store = StubApiKeyStore(Mutex::new(Some(Err(AuthError::InvalidApiKey))));
        let err = authenticate_api_key(&store, "secret").await.unwrap_err();
        assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
    }
}