use std::marker::PhantomData;
use std::ops::Deref;
use axum::extract::FromRequestParts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use rusty_gasket::auth::context::AuthContext;
use rusty_gasket::auth::identity::Identity;
#[derive(Debug, Clone)]
pub struct Authenticated(pub Identity);
#[derive(Debug, Clone)]
pub struct CurrentUser(pub Identity);
#[derive(Debug, Clone)]
pub struct OptionalIdentity(pub Option<Identity>);
#[derive(Debug)]
pub struct AuthRequired;
#[derive(Debug, Clone)]
pub struct AuthorizationRequired {
code: &'static str,
message: String,
}
pub trait RequiredScope: Send + Sync + 'static {
const SCOPE: &'static str;
}
#[derive(Debug, Clone)]
pub struct RequireScope<Scope> {
identity: Identity,
scope: PhantomData<Scope>,
}
impl<Scope> RequireScope<Scope> {
#[must_use]
pub const fn identity(&self) -> &Identity {
&self.identity
}
#[must_use]
pub fn into_identity(self) -> Identity {
self.identity
}
}
impl<Scope> Deref for RequireScope<Scope> {
type Target = Identity;
fn deref(&self) -> &Self::Target {
&self.identity
}
}
#[derive(Debug, Clone)]
pub struct ServiceAccount(pub Identity);
#[derive(Debug, Clone)]
pub struct SuperuserOnly(pub Identity);
impl IntoResponse for AuthRequired {
fn into_response(self) -> Response {
rusty_gasket::error::quick_error_response(
StatusCode::UNAUTHORIZED,
"AUTHENTICATION_REQUIRED",
"Missing or invalid credentials. Authentication is required for this endpoint.",
)
}
}
impl IntoResponse for AuthorizationRequired {
fn into_response(self) -> Response {
rusty_gasket::error::quick_error_response(StatusCode::FORBIDDEN, self.code, &self.message)
}
}
impl<S> FromRequestParts<S> for Authenticated
where
S: Send + Sync,
{
type Rejection = AuthRequired;
async fn from_request_parts(
parts: &mut http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let ctx = parts.extensions.get::<AuthContext>().ok_or(AuthRequired)?;
ctx.identity()
.cloned()
.map(Authenticated)
.ok_or(AuthRequired)
}
}
impl<S> FromRequestParts<S> for CurrentUser
where
S: Send + Sync,
{
type Rejection = AuthRequired;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
Authenticated::from_request_parts(parts, state)
.await
.map(|Authenticated(identity)| Self(identity))
}
}
impl<S> FromRequestParts<S> for OptionalIdentity
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let identity = parts
.extensions
.get::<AuthContext>()
.and_then(|ctx| ctx.identity().cloned());
Ok(Self(identity))
}
}
impl<S, Scope> FromRequestParts<S> for RequireScope<Scope>
where
S: Send + Sync,
Scope: RequiredScope,
{
type Rejection = Response;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Authenticated(identity) = Authenticated::from_request_parts(parts, state)
.await
.map_err(IntoResponse::into_response)?;
if identity.has_scope(Scope::SCOPE) {
Ok(Self {
identity,
scope: PhantomData,
})
} else {
Err(AuthorizationRequired {
code: "MISSING_SCOPE",
message: format!("Missing required scope '{}'.", Scope::SCOPE),
}
.into_response())
}
}
}
impl<S> FromRequestParts<S> for ServiceAccount
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Authenticated(identity) = Authenticated::from_request_parts(parts, state)
.await
.map_err(IntoResponse::into_response)?;
if identity.is_service_account() {
Ok(Self(identity))
} else {
Err(AuthorizationRequired {
code: "SERVICE_ACCOUNT_REQUIRED",
message: "This endpoint requires a service account.".to_owned(),
}
.into_response())
}
}
}
impl<S> FromRequestParts<S> for SuperuserOnly
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Authenticated(identity) = Authenticated::from_request_parts(parts, state)
.await
.map_err(IntoResponse::into_response)?;
if identity.is_privileged() {
Ok(Self(identity))
} else {
Err(AuthorizationRequired {
code: "SUPERUSER_REQUIRED",
message: "This endpoint requires a privileged identity.".to_owned(),
}
.into_response())
}
}
}