use axum::extract::Request;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use systemprompt_models::RequestContext;
use systemprompt_models::api::ApiError;
use systemprompt_models::auth::UserType;
#[derive(Clone, Copy, Debug)]
pub struct AuthzPolicy {
allowed: &'static [UserType],
}
impl AuthzPolicy {
#[must_use]
pub const fn public() -> Self {
Self {
allowed: &[
UserType::Anon,
UserType::User,
UserType::Admin,
UserType::A2a,
UserType::Mcp,
UserType::Service,
],
}
}
#[must_use]
pub const fn authenticated() -> Self {
Self {
allowed: &[
UserType::User,
UserType::Admin,
UserType::A2a,
UserType::Mcp,
UserType::Service,
],
}
}
#[must_use]
pub const fn user() -> Self {
Self {
allowed: &[UserType::User, UserType::Admin],
}
}
#[must_use]
pub const fn admin() -> Self {
Self {
allowed: &[UserType::Admin],
}
}
#[must_use]
pub const fn restricted_to(allowed: &'static [UserType]) -> Self {
Self { allowed }
}
fn permits(self, user_type: UserType) -> bool {
self.allowed.contains(&user_type)
}
}
pub async fn authz_gate(policy: AuthzPolicy, request: Request, next: Next) -> Response {
let user_type = request
.extensions()
.get::<RequestContext>()
.map_or(UserType::Anon, RequestContext::user_type);
if policy.permits(user_type) {
next.run(request).await
} else {
ApiError::forbidden(format!(
"caller type '{}' is not authorized for this route",
user_type.as_str()
))
.into_response()
}
}