use axum::http::{header, HeaderMap};
use crate::error::{AppError, AppErrorKind};
pub fn require_bearer(headers: &HeaderMap, expected: &str) -> Result<(), AppError> {
let auth = headers
.get(header::AUTHORIZATION)
.ok_or_else(|| AppError::new(AppErrorKind::Unauthorized, "missing Authorization header"))?
.to_str()
.map_err(|_| AppError::new(AppErrorKind::Unauthorized, "non-ascii Authorization header"))?;
let token = auth
.strip_prefix("Bearer ")
.ok_or_else(|| AppError::new(AppErrorKind::Unauthorized, "expected Bearer scheme"))?;
if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
Ok(())
} else {
Err(AppError::new(AppErrorKind::Unauthorized, "invalid token"))
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let len = a.len().max(b.len());
let mut diff: u8 = (a.len() != b.len()) as u8;
for i in 0..len {
let av = *a.get(i).unwrap_or(&0);
let bv = *b.get(i).unwrap_or(&0);
diff |= av ^ bv;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn missing_header_unauthorized() {
let h = HeaderMap::new();
assert!(require_bearer(&h, "s3cr3t").is_err());
}
#[test]
fn wrong_scheme_unauthorized() {
let mut h = HeaderMap::new();
h.insert(header::AUTHORIZATION, "Basic abc".parse().unwrap());
assert!(require_bearer(&h, "s3cr3t").is_err());
}
#[test]
fn wrong_token_unauthorized() {
let mut h = HeaderMap::new();
h.insert(header::AUTHORIZATION, "Bearer wrong".parse().unwrap());
assert!(require_bearer(&h, "s3cr3t").is_err());
}
#[test]
fn correct_token_ok() {
let mut h = HeaderMap::new();
h.insert(header::AUTHORIZATION, "Bearer s3cr3t".parse().unwrap());
assert!(require_bearer(&h, "s3cr3t").is_ok());
}
}