use axum::{
body::Body,
extract::{Query, State},
http::{header, Request, StatusCode},
middleware::Next,
response::Response,
};
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug, Deserialize)]
pub struct TokenQuery {
pub token: Option<String>,
}
pub struct AuthState {
pub token: String,
}
pub fn generate_token() -> String {
uuid::Uuid::new_v4().to_string()
}
fn extract_bearer_token(request: &Request<Body>) -> Option<&str> {
request
.headers()
.get(header::AUTHORIZATION)?
.to_str()
.ok()?
.strip_prefix("Bearer ")
}
pub async fn auth_middleware(
State(auth): State<Arc<AuthState>>,
Query(query): Query<TokenQuery>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
if let Some(bearer_token) = extract_bearer_token(&request) {
return if bearer_token == auth.token {
Ok(next.run(request).await)
} else {
Err(StatusCode::UNAUTHORIZED)
};
}
match query.token {
Some(token) if token == auth.token => Ok(next.run(request).await),
_ => Err(StatusCode::UNAUTHORIZED),
}
}