use http::{HeaderValue, Request, Response};
use std::{
convert::TryFrom,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
#[derive(Debug, Clone)]
pub struct AddAuthorizationLayer {
value: HeaderValue,
}
impl AddAuthorizationLayer {
pub fn basic(username: &str, password: &str) -> Self {
let encoded = base64::encode(format!("{}:{}", username, password));
let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap();
Self { value }
}
pub fn bearer(token: &str) -> Self {
let value =
HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header");
Self { value }
}
#[allow(clippy::wrong_self_convention)]
pub fn as_sensitive(mut self, sensitive: bool) -> Self {
self.value.set_sensitive(sensitive);
self
}
}
impl<S> Layer<S> for AddAuthorizationLayer {
type Service = AddAuthorization<S>;
fn layer(&self, inner: S) -> Self::Service {
AddAuthorization {
inner,
value: self.value.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct AddAuthorization<S> {
inner: S,
value: HeaderValue,
}
impl<S> AddAuthorization<S> {
pub fn basic(inner: S, username: &str, password: &str) -> Self {
AddAuthorizationLayer::basic(username, password).layer(inner)
}
pub fn bearer(inner: S, token: &str) -> Self {
AddAuthorizationLayer::bearer(token).layer(inner)
}
define_inner_service_accessors!();
#[allow(clippy::wrong_self_convention)]
pub fn as_sensitive(mut self, sensitive: bool) -> Self {
self.value.set_sensitive(sensitive);
self
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AddAuthorization<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
req.headers_mut()
.insert(http::header::AUTHORIZATION, self.value.clone());
self.inner.call(req)
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::auth::RequireAuthorizationLayer;
use http::{Response, StatusCode};
use hyper::Body;
use tower::{BoxError, Service, ServiceBuilder, ServiceExt};
#[tokio::test]
async fn basic() {
let svc = ServiceBuilder::new()
.layer(RequireAuthorizationLayer::basic("foo", "bar"))
.service_fn(echo);
let mut client = AddAuthorization::basic(svc, "foo", "bar");
let res = client
.ready()
.await
.unwrap()
.call(Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn token() {
let svc = ServiceBuilder::new()
.layer(RequireAuthorizationLayer::bearer("foo"))
.service_fn(echo);
let mut client = AddAuthorization::bearer(svc, "foo");
let res = client
.ready()
.await
.unwrap()
.call(Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn making_header_sensitive() {
let svc = ServiceBuilder::new()
.layer(RequireAuthorizationLayer::bearer("foo"))
.service_fn(|request: Request<Body>| async move {
let auth = request.headers().get(http::header::AUTHORIZATION).unwrap();
assert!(auth.is_sensitive());
Ok::<_, hyper::Error>(Response::new(Body::empty()))
});
let mut client = AddAuthorization::bearer(svc, "foo").as_sensitive(true);
let res = client
.ready()
.await
.unwrap()
.call(Request::new(Body::empty()))
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}