use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use base64::Engine;
use http::HeaderValue;
use http::StatusCode;
use http::header;
use crate::body::TakoBody;
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::responder::Responder;
use crate::types::BuildHasher;
use crate::types::Request;
use crate::types::Response;
pub struct BasicAuth {
users: Option<Arc<HashMap<String, String, BuildHasher>>>,
verify: Option<Arc<dyn Fn(&str, &str) -> bool + Send + Sync + 'static>>,
realm: &'static str,
}
impl BasicAuth {
pub fn single(user: impl Into<String>, pass: impl Into<String>) -> Self {
Self::multiple(std::iter::once((user, pass)))
}
pub fn multiple<I, T, P>(pairs: I) -> Self
where
I: IntoIterator<Item = (T, P)>,
T: Into<String>,
P: Into<String>,
{
Self {
users: Some(Arc::new(
pairs
.into_iter()
.map(|(u, p)| (u.into(), p.into()))
.collect(),
)),
verify: None,
realm: "Restricted",
}
}
pub fn with_verify<F>(cb: F) -> Self
where
F: Fn(&str, &str) -> bool + Send + Sync + 'static,
{
Self {
users: None,
verify: Some(Arc::new(cb)),
realm: "Restricted",
}
}
pub fn users_with_verify<I, S, F>(pairs: I, cb: F) -> Self
where
I: IntoIterator<Item = (S, S)>,
S: Into<String>,
F: Fn(&str, &str) -> bool + Send + Sync + 'static,
{
Self {
users: Some(Arc::new(
pairs
.into_iter()
.map(|(u, p)| (u.into(), p.into()))
.collect(),
)),
verify: Some(Arc::new(cb)),
realm: "Restricted",
}
}
pub fn realm(mut self, r: &'static str) -> Self {
self.realm = r;
self
}
}
impl IntoMiddleware for BasicAuth {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let users = self.users;
let verify = self.verify;
let realm = self.realm;
let www_authenticate =
HeaderValue::from_str(&format!("Basic realm=\"{realm}\"")).expect("valid realm header");
move |req: Request, next: Next| {
let users = users.clone();
let verify = verify.clone();
let www_authenticate = www_authenticate.clone();
Box::pin(async move {
let creds = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix("Basic "))
.and_then(|b64| base64::engine::general_purpose::STANDARD.decode(b64).ok());
match creds {
Some(raw) => {
let Some(decoded) = std::str::from_utf8(&raw).ok() else {
let mut res = Response::new(TakoBody::empty());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
.headers_mut()
.append(header::WWW_AUTHENTICATE, www_authenticate.clone());
return res;
};
let Some((u, p)) = decoded.split_once(':') else {
let mut res = Response::new(TakoBody::empty());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
.headers_mut()
.append(header::WWW_AUTHENTICATE, www_authenticate.clone());
return res;
};
if users
.as_ref()
.and_then(|map| map.get(u))
.map(|pw| pw == p)
.unwrap_or(false)
{
return next.run(req).await.into_response();
}
if let Some(cb) = &verify
&& cb(&u, &p)
{
return next.run(req).await.into_response();
}
}
None => {
return http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::WWW_AUTHENTICATE, www_authenticate.clone())
.body(TakoBody::from("Missing credentials"))
.unwrap()
.into_response();
}
}
let mut res = Response::new(TakoBody::empty());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res.headers_mut().append(
header::WWW_AUTHENTICATE,
www_authenticate,
);
res
})
}
}
}