use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::HeaderValue;
use http::StatusCode;
use http::header;
use crate::body::TakoBody;
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::responder::Responder;
use crate::types::BuildHasher;
use crate::types::Request;
use crate::types::Response;
pub struct BearerAuth {
tokens: Option<HashSet<String, BuildHasher>>,
verify: Option<Box<dyn Fn(&str) -> bool + Send + Sync + 'static>>,
}
impl BearerAuth {
pub fn static_token(token: impl Into<String>) -> Self {
let mut set: HashSet<String, BuildHasher> = HashSet::with_hasher(BuildHasher::default());
set.insert(token.into());
Self {
tokens: Some(set),
verify: None,
}
}
pub fn static_tokens<I>(tokens: I) -> Self
where
I: IntoIterator,
I::Item: Into<String>,
{
Self {
tokens: Some(tokens.into_iter().map(Into::into).collect()),
verify: None,
}
}
pub fn with_verify<F>(f: F) -> Self
where
F: Fn(&str) -> bool + Clone + Send + Sync + 'static,
{
Self {
tokens: None,
verify: Some(Box::new(f)),
}
}
pub fn static_tokens_with_verify<I, F>(tokens: I, f: F) -> Self
where
I: IntoIterator,
I::Item: Into<String>,
F: Fn(&str) -> bool + Clone + Send + Sync + 'static,
{
Self {
tokens: Some(tokens.into_iter().map(Into::into).collect()),
verify: Some(Box::new(f)),
}
}
}
impl IntoMiddleware for BearerAuth {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let tokens = self.tokens.map(Arc::new);
let verify = self.verify.map(Arc::new);
let bearer_authenticate = HeaderValue::from_static("Bearer");
move |req: Request, next: Next| {
let tokens = tokens.clone();
let verify = verify.clone();
let bearer_authenticate = bearer_authenticate.clone();
Box::pin(async move {
let tok = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
.map(str::trim);
match tok {
None => {
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(header::WWW_AUTHENTICATE, bearer_authenticate.clone())
.body(TakoBody::from("Token is missing"))
.unwrap()
.into_response();
}
Some(t) => {
if let Some(set) = &tokens
&& set.contains(t)
{
return next.run(req).await.into_response();
}
if let Some(v) = verify.as_ref()
&& v(t)
{
return next.run(req).await.into_response();
}
}
}
http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::WWW_AUTHENTICATE, bearer_authenticate)
.body(TakoBody::empty())
.unwrap()
.into_response()
})
}
}
}