use crate::context::Push;
use futures::future::FutureExt;
use headers::authorization::{Basic, Bearer, Credentials};
use headers::Authorization as Header;
use http::header::AUTHORIZATION;
use http::{HeaderMap, Request};
use hyper::service::Service;
use std::collections::BTreeSet;
use std::marker::PhantomData;
use std::string::ToString;
use zeroize::ZeroizeOnDrop;
#[derive(Clone, Debug, PartialEq)]
pub enum Scopes {
Some(BTreeSet<String>),
All,
}
#[derive(Clone, Debug, PartialEq)]
pub struct Authorization {
pub subject: String,
pub scopes: Scopes,
pub issuer: Option<String>,
}
#[derive(Clone, Debug, PartialEq, ZeroizeOnDrop)]
pub enum AuthData {
Basic(String, String),
Bearer(String),
ApiKey(String),
}
impl AuthData {
pub fn basic(username: &str, password: &str) -> Self {
AuthData::Basic(username.to_owned(), password.to_owned())
}
pub fn bearer(token: &str) -> Option<Self> {
Some(AuthData::Bearer(
Header::bearer(token).ok()?.token().to_owned(),
))
}
pub fn apikey(apikey: &str) -> Self {
AuthData::ApiKey(apikey.to_owned())
}
}
pub trait RcBound: Push<Option<Authorization>> + Send + 'static {}
impl<T> RcBound for T where T: Push<Option<Authorization>> + Send + 'static {}
#[derive(Debug)]
pub struct MakeAllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
inner: T,
subject: String,
marker: PhantomData<RC>,
}
impl<T, RC> Clone for MakeAllowAllAuthenticator<T, RC>
where
T: Clone,
RC: RcBound,
RC::Result: Send + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
subject: self.subject.clone(),
marker: PhantomData,
}
}
}
impl<T, RC> MakeAllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
pub fn new<U: Into<String>>(inner: T, subject: U) -> Self {
MakeAllowAllAuthenticator {
inner,
subject: subject.into(),
marker: PhantomData,
}
}
}
impl<Inner, RC, Target> Service<Target> for MakeAllowAllAuthenticator<Inner, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
Inner: Service<Target>,
Inner::Future: Send + 'static,
{
type Error = Inner::Error;
type Response = AllowAllAuthenticator<Inner::Response, RC>;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&self, target: Target) -> Self::Future {
let subject = self.subject.clone();
Box::pin(
self.inner
.call(target)
.map(|s| Ok(AllowAllAuthenticator::new(s?, subject))),
)
}
}
#[derive(Debug)]
pub struct AllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
inner: T,
subject: String,
marker: PhantomData<RC>,
}
impl<T, RC> AllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
{
pub fn new<U: Into<String>>(inner: T, subject: U) -> Self {
AllowAllAuthenticator {
inner,
subject: subject.into(),
marker: PhantomData,
}
}
}
impl<T, RC> Clone for AllowAllAuthenticator<T, RC>
where
T: Clone,
RC: RcBound,
RC::Result: Send + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
subject: self.subject.clone(),
marker: PhantomData,
}
}
}
impl<T, B, RC> Service<(Request<B>, RC)> for AllowAllAuthenticator<T, RC>
where
RC: RcBound,
RC::Result: Send + 'static,
T: Service<(Request<B>, RC::Result)>,
{
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;
fn call(&self, req: (Request<B>, RC)) -> Self::Future {
let (request, context) = req;
let context = context.push(Some(Authorization {
subject: self.subject.clone(),
scopes: Scopes::All,
issuer: None,
}));
self.inner.call((request, context))
}
}
pub fn from_headers(headers: &HeaderMap) -> Option<AuthData> {
headers.get(AUTHORIZATION).and_then(|value| {
if let Ok(value_str) = value.to_str() {
if value_str.to_lowercase().starts_with("basic ") {
Basic::decode(value).map(|basic| {
AuthData::Basic(basic.username().to_string(), basic.password().to_string())
})
} else if value_str.to_lowercase().starts_with("bearer ") {
Bearer::decode(value).map(|bearer| AuthData::Bearer(bearer.token().to_string()))
} else {
None
}
} else {
None
}
})
}
pub fn api_key_from_header(headers: &HeaderMap, header: &str) -> Option<String> {
headers
.get(header)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{ContextBuilder, Has};
use crate::EmptyContext;
use bytes::Bytes;
use http::Response;
use http_body_util::Full;
use hyper::service::Service;
struct MakeTestService;
type ReqWithAuth = (
Request<Full<Bytes>>,
ContextBuilder<Option<Authorization>, EmptyContext>,
);
impl<Target> Service<Target> for MakeTestService {
type Response = TestService;
type Error = ();
type Future = futures::future::Ready<Result<Self::Response, Self::Error>>;
fn call(&self, _target: Target) -> Self::Future {
futures::future::ok(TestService)
}
}
struct TestService;
impl Service<ReqWithAuth> for TestService {
type Response = Response<Full<Bytes>>;
type Error = String;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&self, req: ReqWithAuth) -> Self::Future {
Box::pin(async move {
let auth: &Option<Authorization> = req.1.get();
let expected = Some(Authorization {
subject: "foo".to_string(),
scopes: Scopes::All,
issuer: None,
});
if *auth == expected {
Ok(Response::new(Full::default()))
} else {
Err(format!("{:?} != {:?}", auth, expected))
}
})
}
}
#[tokio::test]
async fn test_make_service() {
let make_svc = MakeTestService;
let a: MakeAllowAllAuthenticator<_, EmptyContext> =
MakeAllowAllAuthenticator::new(make_svc, "foo");
let service = a.call(&()).await.unwrap();
let response = service
.call((
Request::get("http://localhost")
.body(Full::default())
.unwrap(),
EmptyContext,
))
.await;
response.unwrap();
}
#[test]
fn test_from_headers_basic() {
let mut headers = HeaderMap::new();
headers.append(
AUTHORIZATION,
headers::HeaderValue::from_static("Basic Zm9vOmJhcg=="),
);
assert_eq!(
from_headers(&headers),
Some(AuthData::Basic("foo".to_string(), "bar".to_string()))
)
}
#[test]
fn test_from_headers_bearer() {
let mut headers = HeaderMap::new();
headers.append(
AUTHORIZATION,
headers::HeaderValue::from_static("Bearer foo"),
);
assert_eq!(
from_headers(&headers),
Some(AuthData::Bearer("foo".to_string()))
)
}
}