use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use base64::Engine;
use http::HeaderValue;
use http::StatusCode;
use http::header;
use subtle::ConstantTimeEq;
use tako_rs_core::body::TakoBody;
use tako_rs_core::middleware::IntoMiddleware;
use tako_rs_core::middleware::Next;
use tako_rs_core::responder::Responder;
use tako_rs_core::types::BuildHasher;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
pub type BasicAuthVerifyFn = Arc<dyn Fn(&str, &str) -> bool + Send + Sync + 'static>;
pub struct BasicAuth {
users: Option<Arc<HashMap<String, String, BuildHasher>>>,
verify: Option<BasicAuthVerifyFn>,
realm: &'static str,
}
impl BasicAuth {
pub fn single(user: impl Into<String>, pass: impl Into<String>) -> Self {
Self::multiple(std::iter::once((user, pass)))
}
pub fn multiple<I, T, P>(pairs: I) -> Self
where
I: IntoIterator<Item = (T, P)>,
T: Into<String>,
P: Into<String>,
{
Self {
users: Some(Arc::new(
pairs
.into_iter()
.map(|(u, p)| (u.into(), p.into()))
.collect(),
)),
verify: None,
realm: "Restricted",
}
}
pub fn with_verify<F>(cb: F) -> Self
where
F: Fn(&str, &str) -> bool + Send + Sync + 'static,
{
Self {
users: None,
verify: Some(Arc::new(cb)),
realm: "Restricted",
}
}
pub fn users_with_verify<I, S, F>(pairs: I, cb: F) -> Self
where
I: IntoIterator<Item = (S, S)>,
S: Into<String>,
F: Fn(&str, &str) -> bool + Send + Sync + 'static,
{
Self {
users: Some(Arc::new(
pairs
.into_iter()
.map(|(u, p)| (u.into(), p.into()))
.collect(),
)),
verify: Some(Arc::new(cb)),
realm: "Restricted",
}
}
pub fn realm(mut self, r: &'static str) -> Self {
self.realm = r;
self
}
}
impl IntoMiddleware for BasicAuth {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let users = self.users;
let verify = self.verify;
let realm = self.realm;
let www_authenticate = HeaderValue::from_str(&format!("Basic realm=\"{realm}\""))
.unwrap_or_else(|_| HeaderValue::from_static("Basic"));
move |req: Request, next: Next| {
let users = users.clone();
let verify = verify.clone();
let www_authenticate = www_authenticate.clone();
Box::pin(async move {
let creds = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| {
let (scheme, rest) = h.trim_start().split_once(' ')?;
scheme.eq_ignore_ascii_case("Basic").then(|| rest.trim())
})
.and_then(|b64| base64::engine::general_purpose::STANDARD.decode(b64).ok());
match creds {
Some(raw) => {
let Some(decoded) = std::str::from_utf8(&raw).ok() else {
let mut res = Response::new(TakoBody::empty());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
.headers_mut()
.append(header::WWW_AUTHENTICATE, www_authenticate.clone());
return res;
};
let Some((u, p)) = decoded.split_once(':') else {
let mut res = Response::new(TakoBody::empty());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
.headers_mut()
.append(header::WWW_AUTHENTICATE, www_authenticate.clone());
return res;
};
let mut authed = false;
if let Some(map) = users.as_ref() {
for (known_user, known_pw) in map.iter() {
let user_match = constant_time_eq(known_user.as_bytes(), u.as_bytes());
let pw_match = constant_time_eq(known_pw.as_bytes(), p.as_bytes());
authed |= user_match & pw_match;
}
}
if authed {
return next.run(req).await.into_response();
}
if let Some(cb) = &verify
&& cb(u, p)
{
return next.run(req).await.into_response();
}
}
None => {
return http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::WWW_AUTHENTICATE, www_authenticate.clone())
.body(TakoBody::from("Missing credentials"))
.unwrap()
.into_response();
}
}
let mut res = Response::new(TakoBody::empty());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
.headers_mut()
.append(header::WWW_AUTHENTICATE, www_authenticate);
res
})
}
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}