use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use subtle::ConstantTimeEq;
#[derive(Clone)]
pub struct AuthState {
pub token: String,
}
pub async fn auth_middleware(
State(auth): State<AuthState>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Response {
if let Some(auth_header) = headers.get("authorization")
&& let Ok(value) = auth_header.to_str()
&& let Some(token) = value.strip_prefix("Bearer ")
&& bool::from(token.as_bytes().ct_eq(auth.token.as_bytes()))
{
return next.run(request).await;
}
if let Some(query) = request.uri().query() {
for pair in query.split('&') {
if let Some(token) = pair.strip_prefix("token=")
&& bool::from(token.as_bytes().ct_eq(auth.token.as_bytes()))
{
return next.run(request).await;
}
}
}
(StatusCode::UNAUTHORIZED, "Invalid or missing auth token").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_state_clone() {
let state = AuthState {
token: "test-token".to_string(),
};
let cloned = state.clone();
assert_eq!(cloned.token, "test-token");
}
}