#![forbid(unsafe_code)]
#![deny(missing_docs)]
use axum::{
async_trait,
extract::{FromRequestParts, Request},
http::{request::Parts, StatusCode},
response::Response,
BoxError,
};
use cookie::{
time::{Duration, OffsetDateTime},
Key,
};
use futures_util::future::BoxFuture;
use std::task::{Context, Poll};
use tower_layer::Layer;
use tower_service::Service;
#[derive(thiserror::Error, Debug)]
#[error("one or more validation errors")]
pub struct ValidationErrors(Vec<String>);
impl ValidationErrors {
pub fn errors(&self) -> impl Iterator<Item = &str> {
self.0.iter().map(String::as_str)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct SessionKey(pub uuid::Uuid);
impl Default for SessionKey {
fn default() -> Self {
SessionKey(uuid::Uuid::new_v4())
}
}
#[derive(Clone, Debug)]
pub struct TokenConfig {
pub name: String,
pub value: String,
}
impl TokenConfig {
pub fn new_token(name: &str) -> Self {
Self {
name: name.into(),
value: format!("{}", uuid::Uuid::new_v4()),
}
}
fn parse_token_from_uri_query(&self, req: &Request) -> bool {
use std::borrow::Cow;
let query = req.uri().query();
let query_pairs = url::form_urlencoded::parse(query.unwrap_or("").as_bytes());
for (key, value) in query_pairs {
if key == Cow::Borrowed(&self.name) && self.value.as_str() == value {
return true;
}
}
false
}
}
#[derive(Clone, Debug)]
pub struct AuthConfig<'a> {
pub cookie_name: &'a str,
pub persistent_secret: Key,
pub token_config: Option<TokenConfig>,
pub cookie_expires: Option<std::time::Duration>,
}
impl<'a> Default for AuthConfig<'a> {
fn default() -> Self {
Self {
cookie_name: env!["CARGO_PKG_NAME"],
persistent_secret: Key::generate(),
token_config: None,
cookie_expires: None,
}
}
}
impl<'a> AuthConfig<'a> {
pub fn into_layer(self) -> AuthLayer {
let access_info = AccessInfo::new(self);
AuthLayer { access_info }
}
}
#[derive(Clone)]
struct AccessInfo {
cookie_name: String,
token_config: Option<TokenConfig>,
cookie_expires: Option<std::time::Duration>,
key: tower_cookies::Key,
}
impl AccessInfo {
fn new(cfg: AuthConfig<'_>) -> Self {
let AuthConfig {
cookie_name,
persistent_secret,
token_config,
cookie_expires,
} = cfg;
let key = persistent_secret;
Self {
cookie_name: cookie_name.into(),
token_config,
key,
cookie_expires,
}
}
fn check_token_and_cookie(
&self,
req: &Request,
valid_session_key: Option<SessionKey>,
) -> Result<(bool, SessionKey), ValidationErrors> {
let mut errors = Vec::new();
let has_valid_token = self
.token_config
.as_ref()
.map(|i| i.parse_token_from_uri_query(req))
.unwrap_or(true);
match (has_valid_token, valid_session_key) {
(false, None) => {
errors.push("No (valid) token in uri and no (valid) session.".into());
Err(ValidationErrors(errors))
}
(true, None) => Ok((true, SessionKey::default())),
(_has_valid_token, Some(session_key)) => Ok((false, session_key)),
}
}
}
#[async_trait]
impl<S> FromRequestParts<S> for SessionKey
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(session_key) = parts.extensions.remove::<SessionKey>() {
Ok(session_key.clone())
} else {
Err((StatusCode::UNAUTHORIZED, "(valid) session key is missing"))
}
}
}
#[derive(Clone)]
pub struct AuthLayer {
access_info: AccessInfo,
}
impl<S> Layer<S> for AuthLayer {
type Service = tower_cookies::CookieManager<AuthMiddleware<S>>;
fn layer(&self, inner: S) -> Self::Service {
let auth_middleware = AuthMiddleware {
inner,
access_info: self.access_info.clone(),
};
tower_cookies::CookieManager::new(auth_middleware)
}
}
#[derive(Clone)]
pub struct AuthMiddleware<S> {
inner: S,
access_info: AccessInfo,
}
impl<S> Service<Request> for AuthMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Error: Into<BoxError>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
}
}
fn call(&mut self, mut request: Request) -> Self::Future {
let Some(cookies) = request
.extensions()
.get::<tower_cookies::Cookies>()
.cloned()
else {
tracing::error!("missing cookies request extension");
return Box::pin(std::future::ready(Err(Box::new(ValidationErrors(vec![
"missing cookies request extension".into(),
])) as BoxError)));
};
let signed = cookies.signed(&self.access_info.key);
let err_info = {
let opt_session_key: Option<SessionKey> = signed
.get(&self.access_info.cookie_name)
.map(|received_cookie| {
SessionKey(uuid::Uuid::parse_str(received_cookie.value()).unwrap())
});
match self
.access_info
.check_token_and_cookie(&request, opt_session_key)
{
Ok((new_cookie_value, session_key)) => {
let expires = self.access_info.cookie_expires.as_ref().map(|exp| {
OffsetDateTime::now_utc()
.checked_add(Duration::try_from(*exp).unwrap())
.unwrap()
});
request.extensions_mut().insert(session_key.clone());
if new_cookie_value {
let value = format!("{}", session_key.0.as_hyphenated());
let mut set_cookie =
tower_cookies::Cookie::new(self.access_info.cookie_name.clone(), value);
if let Some(expires) = expires {
set_cookie.set_expires(expires);
}
signed.add(set_cookie);
}
None
}
Err(val_err) => Some(val_err),
}
};
let fut = match err_info {
None => self.inner.call(request),
Some(val_err) => {
return Box::pin(std::future::ready(Err(val_err.into())));
}
};
Box::pin(async move {
let response: Response = fut.await.map_err(|e| e.into())?;
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use axum::body::Body;
use cookie::Cookie;
use http::{Request, StatusCode};
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};
async fn handler(_: Request<Body>) -> std::result::Result<Response<Body>, Infallible> {
Ok(Response::new(Body::empty()))
}
fn get_cfg() -> AuthConfig<'static> {
let token_config = Some(TokenConfig {
name: "token".into(),
value: "token_value".into(),
});
AuthConfig {
cookie_name: "auth",
persistent_secret: Key::generate(),
token_config,
cookie_expires: None,
}
}
#[tokio::test]
async fn fail_without_token_or_cookie() -> Result<()> {
let auth_layer = get_cfg().into_layer();
let svc = ServiceBuilder::new().layer(auth_layer).service_fn(handler);
let req = Request::builder().body(Body::empty())?;
let res = svc.oneshot(req).await;
assert!(!res
.err()
.unwrap()
.downcast::<ValidationErrors>()
.unwrap()
.errors()
.collect::<Vec<_>>()
.is_empty());
Ok(())
}
async fn get_second_response(
cfg: AuthConfig<'_>,
req: Request<Body>,
) -> Result<Response<Body>> {
let auth_layer = cfg.into_layer();
let svc = ServiceBuilder::new().layer(auth_layer).service_fn(handler);
let res = svc.clone().oneshot(req).await.unwrap();
let cookie = {
let set_cookie: Vec<_> = res
.headers()
.get_all(http::header::SET_COOKIE)
.iter()
.collect();
assert_eq!(set_cookie.len(), 1);
Cookie::parse(set_cookie[0].to_str()?.to_string())?
};
let req2 = Request::builder()
.header(http::header::COOKIE, cookie.stripped().to_string())
.body(Body::empty())
.unwrap();
let res2 = svc.oneshot(req2).await.unwrap();
Ok(res2)
}
#[tokio::test]
async fn set_cookie_with_trusted_socket() -> Result<()> {
let mut cfg = get_cfg();
cfg.token_config = None;
let uri = "http://example.com/path";
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
let res2 = get_second_response(cfg, req).await?;
assert_eq!(res2.status(), StatusCode::OK);
Ok(())
}
#[tokio::test]
async fn set_cookie_with_valid_token() -> Result<()> {
let cfg = get_cfg();
let uri = {
let x = cfg.token_config.as_ref().unwrap();
format!("http://example.com/path?{}={}", x.name, x.value)
};
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
let res2 = get_second_response(cfg, req).await?;
assert_eq!(res2.status(), StatusCode::OK);
Ok(())
}
}