1use axum::http::{header, HeaderMap};
9
10use crate::error::{AppError, AppErrorKind};
11
12pub fn require_bearer(headers: &HeaderMap, expected: &str) -> Result<(), AppError> {
15 let auth = headers
16 .get(header::AUTHORIZATION)
17 .ok_or_else(|| AppError::new(AppErrorKind::Unauthorized, "missing Authorization header"))?
18 .to_str()
19 .map_err(|_| AppError::new(AppErrorKind::Unauthorized, "non-ascii Authorization header"))?;
20
21 let token = auth
22 .strip_prefix("Bearer ")
23 .ok_or_else(|| AppError::new(AppErrorKind::Unauthorized, "expected Bearer scheme"))?;
24
25 if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
26 Ok(())
27 } else {
28 Err(AppError::new(AppErrorKind::Unauthorized, "invalid token"))
29 }
30}
31
32fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
36 let len = a.len().max(b.len());
39 let mut diff: u8 = (a.len() != b.len()) as u8;
40 for i in 0..len {
41 let av = *a.get(i).unwrap_or(&0);
42 let bv = *b.get(i).unwrap_or(&0);
43 diff |= av ^ bv;
44 }
45 diff == 0
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51
52 #[test]
53 fn missing_header_unauthorized() {
54 let h = HeaderMap::new();
55 assert!(require_bearer(&h, "s3cr3t").is_err());
56 }
57
58 #[test]
59 fn wrong_scheme_unauthorized() {
60 let mut h = HeaderMap::new();
61 h.insert(header::AUTHORIZATION, "Basic abc".parse().unwrap());
62 assert!(require_bearer(&h, "s3cr3t").is_err());
63 }
64
65 #[test]
66 fn wrong_token_unauthorized() {
67 let mut h = HeaderMap::new();
68 h.insert(header::AUTHORIZATION, "Bearer wrong".parse().unwrap());
69 assert!(require_bearer(&h, "s3cr3t").is_err());
70 }
71
72 #[test]
73 fn correct_token_ok() {
74 let mut h = HeaderMap::new();
75 h.insert(header::AUTHORIZATION, "Bearer s3cr3t".parse().unwrap());
76 assert!(require_bearer(&h, "s3cr3t").is_ok());
77 }
78}