use std::sync::Arc;
use axum::extract::{self};
use axum::http::Request;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde::Deserialize;
use torrust_tracker_configuration::AccessTokens;
use crate::servers::apis::v1::responses::unhandled_rejection_response;
#[derive(Deserialize, Debug)]
pub struct QueryParams {
pub token: Option<String>,
}
#[derive(Clone, Debug)]
pub struct State {
pub access_tokens: Arc<AccessTokens>,
}
pub async fn auth(
extract::State(state): extract::State<State>,
extract::Query(params): extract::Query<QueryParams>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let Some(token) = params.token else {
return AuthError::Unauthorized.into_response();
};
if !authenticate(&token, &state.access_tokens) {
return AuthError::TokenNotValid.into_response();
}
next.run(request).await
}
enum AuthError {
Unauthorized,
TokenNotValid,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
match self {
AuthError::Unauthorized => unauthorized_response(),
AuthError::TokenNotValid => token_not_valid_response(),
}
}
}
fn authenticate(token: &str, tokens: &AccessTokens) -> bool {
tokens.values().any(|t| t == token)
}
#[must_use]
pub fn unauthorized_response() -> Response {
unhandled_rejection_response("unauthorized".to_string())
}
#[must_use]
pub fn token_not_valid_response() -> Response {
unhandled_rejection_response("token not valid".to_string())
}