use axum::http::StatusCode;
use std::collections::BTreeMap;
use crate::{
api_tokens::{ApiToken, ParseApiTokenError},
paseto_tokens::{AccessTokenClaims, PasetoError},
types::ServiceId,
};
#[derive(Debug, Clone)]
pub enum AuthToken {
Legacy(ApiToken),
ApiKey(String),
AccessToken(AccessTokenClaims),
}
#[derive(Debug, thiserror::Error)]
pub enum AuthTokenError {
#[error("Invalid token format")]
InvalidFormat,
#[error("Legacy token error: {0}")]
LegacyToken(#[from] ParseApiTokenError),
#[error("Paseto token error: {0}")]
PasetoToken(#[from] PasetoError),
#[error("Malformed API key")]
MalformedApiKey,
}
impl AuthToken {
pub fn parse(token_str: &str) -> Result<Self, AuthTokenError> {
if token_str.starts_with("v4.local.") {
return Err(AuthTokenError::InvalidFormat); } else if token_str.contains('|') {
let legacy_token = ApiToken::from_str(token_str)?;
return Ok(AuthToken::Legacy(legacy_token));
} else if token_str.contains('.') {
return Ok(AuthToken::ApiKey(token_str.to_string()));
}
Err(AuthTokenError::InvalidFormat)
}
pub fn service_id(&self) -> Option<ServiceId> {
match self {
AuthToken::Legacy(_) => None, AuthToken::ApiKey(_) => None, AuthToken::AccessToken(claims) => Some(claims.service_id),
}
}
pub fn additional_headers(&self) -> BTreeMap<String, String> {
match self {
AuthToken::AccessToken(claims) => claims.additional_headers.clone(),
_ => BTreeMap::new(),
}
}
pub fn is_expired(&self) -> bool {
match self {
AuthToken::AccessToken(claims) => claims.is_expired(),
_ => false, }
}
}
#[derive(Debug)]
pub enum TokenExtractionResult {
Legacy(ApiToken),
ApiKey(String),
ValidatedAccessToken(AccessTokenClaims),
}
impl<S> axum::extract::FromRequestParts<S> for AuthToken
where
S: Send + Sync,
{
type Rejection = axum::response::Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
use axum::response::IntoResponse;
let header = match parts.headers.get(crate::types::headers::AUTHORIZATION) {
Some(header) => header,
None => {
return Err(
(StatusCode::UNAUTHORIZED, "Missing Authorization header").into_response()
);
}
};
let header_str = match header.to_str() {
Ok(header_str) if header_str.starts_with("Bearer ") => &header_str[7..],
Ok(_) => {
return Err((
StatusCode::BAD_REQUEST,
"Invalid Authorization header; expected Bearer token",
)
.into_response());
}
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
"Invalid Authorization header; not valid UTF-8",
)
.into_response());
}
};
match AuthToken::parse(header_str) {
Ok(token) => Ok(token),
Err(AuthTokenError::InvalidFormat) => {
if header_str.starts_with("v4.local.") {
Err(
(StatusCode::BAD_REQUEST, "Paseto token validation required")
.into_response(),
)
} else {
Err((StatusCode::BAD_REQUEST, "Invalid token format").into_response())
}
}
Err(e) => Err((StatusCode::BAD_REQUEST, format!("Invalid token: {e}")).into_response()),
}
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct TokenExchangeRequest {
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub additional_headers: BTreeMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl_seconds: Option<u64>,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct TokenExchangeResponse {
pub access_token: String,
pub token_type: String,
pub expires_at: u64,
pub expires_in: u64,
}
impl TokenExchangeResponse {
pub fn new(access_token: String, expires_at: u64) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
access_token,
token_type: "Bearer".to_string(),
expires_at,
expires_in: expires_at.saturating_sub(now),
}
}
}