use std::borrow::Cow;
use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::HeaderValue;
use http::StatusCode;
use http::header;
use crate::body::TakoBody;
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::responder::Responder;
use crate::types::BuildHasher;
use crate::types::Request;
use crate::types::Response;
#[derive(Clone)]
pub enum ApiKeyLocation {
Header(&'static str),
Query(&'static str),
HeaderOrQuery(&'static str, &'static str),
}
impl Default for ApiKeyLocation {
fn default() -> Self {
Self::Header("X-API-Key")
}
}
pub struct ApiKeyAuth {
keys: Option<HashSet<String, BuildHasher>>,
verify: Option<Arc<dyn Fn(&str) -> bool + Send + Sync + 'static>>,
location: ApiKeyLocation,
}
impl ApiKeyAuth {
pub fn new(key: impl Into<String>) -> Self {
let mut set: HashSet<String, BuildHasher> = HashSet::with_hasher(BuildHasher::default());
set.insert(key.into());
Self {
keys: Some(set),
verify: None,
location: ApiKeyLocation::default(),
}
}
pub fn from_keys<I>(keys: I) -> Self
where
I: IntoIterator,
I::Item: Into<String>,
{
Self {
keys: Some(keys.into_iter().map(Into::into).collect()),
verify: None,
location: ApiKeyLocation::default(),
}
}
pub fn with_verify<F>(f: F) -> Self
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
Self {
keys: None,
verify: Some(Arc::new(f)),
location: ApiKeyLocation::default(),
}
}
pub fn from_keys_with_verify<I, F>(keys: I, f: F) -> Self
where
I: IntoIterator,
I::Item: Into<String>,
F: Fn(&str) -> bool + Send + Sync + 'static,
{
Self {
keys: Some(keys.into_iter().map(Into::into).collect()),
verify: Some(Arc::new(f)),
location: ApiKeyLocation::default(),
}
}
pub fn location(mut self, location: ApiKeyLocation) -> Self {
self.location = location;
self
}
pub fn header_name(mut self, name: &'static str) -> Self {
self.location = ApiKeyLocation::Header(name);
self
}
pub fn query_param(mut self, name: &'static str) -> Self {
self.location = ApiKeyLocation::Query(name);
self
}
}
fn extract_api_key<'a>(req: &'a Request, location: &ApiKeyLocation) -> Option<Cow<'a, str>> {
match location {
ApiKeyLocation::Header(name) => req
.headers()
.get(*name)
.and_then(|v| v.to_str().ok())
.map(|s| Cow::Borrowed(s.trim())),
ApiKeyLocation::Query(name) => req
.uri()
.query()
.and_then(|q| {
url::form_urlencoded::parse(q.as_bytes())
.find(|(k, _)| k == *name)
.map(|(_, v)| v)
}),
ApiKeyLocation::HeaderOrQuery(header, query) => {
if let Some(key) = req
.headers()
.get(*header)
.and_then(|v| v.to_str().ok())
.map(|s| Cow::Borrowed(s.trim()))
{
return Some(key);
}
req.uri().query().and_then(|q| {
url::form_urlencoded::parse(q.as_bytes())
.find(|(k, _)| k == *query)
.map(|(_, v)| v)
})
}
}
}
impl IntoMiddleware for ApiKeyAuth {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let keys = self.keys.map(Arc::new);
let verify = self.verify;
let location = self.location;
let api_key_authenticate = HeaderValue::from_static("ApiKey");
move |req: Request, next: Next| {
let keys = keys.clone();
let verify = verify.clone();
let location = location.clone();
let api_key_authenticate = api_key_authenticate.clone();
Box::pin(async move {
let api_key = match extract_api_key(&req, &location) {
Some(key) => key,
None => {
return http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::WWW_AUTHENTICATE, api_key_authenticate.clone())
.body(TakoBody::from("API key is missing"))
.unwrap()
.into_response();
}
};
if let Some(set) = &keys {
if set.contains(api_key.as_ref()) {
return next.run(req).await.into_response();
}
}
if let Some(v) = verify.as_ref() {
if v(api_key.as_ref()) {
return next.run(req).await.into_response();
}
}
http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::WWW_AUTHENTICATE, api_key_authenticate)
.body(TakoBody::from("Invalid API key"))
.unwrap()
.into_response()
})
}
}
}