use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use base64::Engine;
use serde::{Deserialize, Serialize};
use crate::web::{HeaderMap, header};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Identity {
pub user_id: String,
pub is_staff: bool,
#[serde(default)]
pub is_superuser: bool,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub extras: std::collections::HashMap<String, serde_json::Value>,
}
impl Identity {
pub fn user(user_id: impl ToString) -> Self {
Self {
user_id: user_id.to_string(),
is_staff: false,
is_superuser: false,
extras: Default::default(),
}
}
pub fn staff(mut self) -> Self {
self.is_staff = true;
self
}
pub fn with_staff(mut self, is_staff: bool) -> Self {
self.is_staff = is_staff;
self
}
pub fn with_superuser(mut self, is_superuser: bool) -> Self {
self.is_superuser = is_superuser;
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.extras.insert(key.into(), value);
self
}
}
#[async_trait]
pub trait Authentication: Send + Sync + 'static {
async fn authenticate(&self, headers: &HeaderMap) -> Option<Identity>;
fn security_scheme(&self) -> Option<(String, serde_json::Value)> {
None
}
fn security_schemes_all(&self) -> Vec<(String, serde_json::Value)> {
self.security_scheme().into_iter().collect()
}
fn is_anonymous(&self) -> bool {
false
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoAuthentication;
#[async_trait]
impl Authentication for NoAuthentication {
async fn authenticate(&self, _headers: &HeaderMap) -> Option<Identity> {
None
}
fn is_anonymous(&self) -> bool {
true
}
}
#[derive(Clone)]
pub struct FnAuthentication {
f: Arc<
dyn Fn(HeaderMap) -> Pin<Box<dyn std::future::Future<Output = Option<Identity>> + Send>>
+ Send
+ Sync,
>,
}
impl std::fmt::Debug for FnAuthentication {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FnAuthentication").finish_non_exhaustive()
}
}
impl FnAuthentication {
pub fn new<F, Fut>(f: F) -> Self
where
F: Fn(HeaderMap) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Option<Identity>> + Send + 'static,
{
Self {
f: Arc::new(move |headers| Box::pin(f(headers))),
}
}
}
#[async_trait]
impl Authentication for FnAuthentication {
async fn authenticate(&self, headers: &HeaderMap) -> Option<Identity> {
(self.f)(headers.clone()).await
}
}
#[derive(Clone)]
pub struct ChainAuthentication {
backends: Vec<Arc<dyn Authentication>>,
}
impl std::fmt::Debug for ChainAuthentication {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChainAuthentication")
.field("backends_count", &self.backends.len())
.finish()
}
}
impl ChainAuthentication {
pub fn new(backends: Vec<Arc<dyn Authentication>>) -> Self {
Self { backends }
}
}
#[async_trait]
impl Authentication for ChainAuthentication {
async fn authenticate(&self, headers: &HeaderMap) -> Option<Identity> {
for backend in &self.backends {
if let Some(id) = backend.authenticate(headers).await {
return Some(id);
}
}
None
}
fn security_scheme(&self) -> Option<(String, serde_json::Value)> {
self.backends.iter().find_map(|b| b.security_scheme())
}
fn security_schemes_all(&self) -> Vec<(String, serde_json::Value)> {
self.backends
.iter()
.flat_map(|b| b.security_schemes_all())
.collect()
}
}
pub fn parse_basic_credentials(headers: &HeaderMap) -> Option<(String, String)> {
let header = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let encoded = header.strip_prefix("Basic ")?;
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded)
.ok()?;
let decoded = String::from_utf8(decoded).ok()?;
let (user, pass) = decoded.split_once(':')?;
Some((user.to_string(), pass.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::web::header::AUTHORIZATION;
fn headers_with(name: &str, value: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(
crate::web::header::HeaderName::from_bytes(name.as_bytes()).unwrap(),
value.parse().unwrap(),
);
h
}
#[tokio::test]
async fn no_authentication_always_returns_none() {
let headers = HeaderMap::new();
assert!(NoAuthentication.authenticate(&headers).await.is_none());
}
#[tokio::test]
async fn fn_authentication_invokes_closure() {
let auth = FnAuthentication::new(|_headers| async move { Some(Identity::user(42)) });
let id = auth.authenticate(&HeaderMap::new()).await.unwrap();
assert_eq!(id.user_id, "42");
assert!(!id.is_staff);
}
#[tokio::test]
async fn chain_authentication_first_success_wins() {
let first = FnAuthentication::new(|_| async move { None });
let second = FnAuthentication::new(|_| async move { Some(Identity::user(7).staff()) });
let third = FnAuthentication::new(|_| async move { Some(Identity::user(99)) });
let chain = ChainAuthentication::new(vec![
Arc::new(first) as Arc<dyn Authentication>,
Arc::new(second) as Arc<dyn Authentication>,
Arc::new(third) as Arc<dyn Authentication>,
]);
let id = chain.authenticate(&HeaderMap::new()).await.unwrap();
assert_eq!(id.user_id, "7");
assert!(id.is_staff);
}
#[tokio::test]
async fn chain_authentication_returns_none_when_all_fail() {
let chain = ChainAuthentication::new(vec![
Arc::new(NoAuthentication) as Arc<dyn Authentication>,
Arc::new(NoAuthentication) as Arc<dyn Authentication>,
]);
assert!(chain.authenticate(&HeaderMap::new()).await.is_none());
}
#[test]
fn parse_basic_credentials_extracts_user_and_pass() {
let headers = headers_with(AUTHORIZATION.as_str(), "Basic YWxpY2U6c2VjcmV0");
let (user, pass) = parse_basic_credentials(&headers).unwrap();
assert_eq!(user, "alice");
assert_eq!(pass, "secret");
}
#[test]
fn parse_basic_credentials_returns_none_for_missing_header() {
assert!(parse_basic_credentials(&HeaderMap::new()).is_none());
}
#[test]
fn parse_basic_credentials_returns_none_for_wrong_scheme() {
let headers = headers_with(AUTHORIZATION.as_str(), "Bearer abc");
assert!(parse_basic_credentials(&headers).is_none());
}
#[test]
fn parse_basic_credentials_returns_none_for_invalid_base64() {
let headers = headers_with(AUTHORIZATION.as_str(), "Basic !!!notbase64");
assert!(parse_basic_credentials(&headers).is_none());
}
}