use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode},
};
use uuid::Uuid;
use crate::{
error::ApiError,
models::{ApiToken, TokenScope},
};
#[derive(Debug, Clone)]
pub struct ScopedAuth {
pub token: Option<ApiToken>,
pub is_api_token: bool,
pub user_id: Option<Uuid>,
}
impl<S> FromRequestParts<S> for ScopedAuth
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let api_token = parts.extensions.get::<ApiToken>().cloned();
let is_api_token = parts
.extensions
.get::<AuthType>()
.map(|t| matches!(t, AuthType::ApiToken))
.unwrap_or(false);
let user_id = parts.extensions.get::<String>().and_then(|s| Uuid::parse_str(s).ok());
Ok(ScopedAuth {
token: api_token,
is_api_token,
user_id,
})
}
}
#[derive(Debug, Clone)]
pub enum AuthType {
Jwt,
ApiToken,
}
impl ScopedAuth {
pub fn has_scope(&self, scope: &TokenScope) -> bool {
match &self.token {
Some(token) => token.has_scope(scope),
None => true, }
}
pub fn require_scope(&self, scope: TokenScope) -> Result<(), ApiError> {
if !self.is_api_token {
return Ok(());
}
match &self.token {
Some(token) => {
if token.has_scope(&scope) {
Ok(())
} else {
Err(ApiError::InsufficientScope {
required: scope.to_string(),
scopes: token.scopes.clone(),
})
}
}
None => {
Err(ApiError::InsufficientScope {
required: scope.to_string(),
scopes: vec![],
})
}
}
}
pub fn user_id(&self) -> Option<Uuid> {
if let Some(token) = &self.token {
token.user_id
} else {
self.user_id
}
}
}
pub fn check_scope(
request_extensions: &axum::http::Extensions,
required_scope: TokenScope,
) -> Result<Option<ApiToken>, StatusCode> {
let token = request_extensions.get::<ApiToken>().cloned();
match token {
Some(t) => {
if t.has_scope(&required_scope) {
Ok(Some(t))
} else {
Err(StatusCode::FORBIDDEN)
}
}
None => {
Ok(None)
}
}
}
pub fn get_org_id_from_extensions(request_extensions: &axum::http::Extensions) -> Option<Uuid> {
request_extensions.get::<ApiToken>().map(|token| token.org_id)
}