pub mod default;
mod layer;
use std::sync::Arc;
use async_trait::async_trait;
pub use layer::AuthLayer;
pub use tonin_client::auth::{AuthCtx, AuthError, PrincipalKind, RawToken};
pub trait TokenExtractor: Send + Sync + 'static {
fn extract(&self, metadata: &tonic::metadata::MetadataMap) -> Result<RawToken, AuthError>;
}
#[async_trait]
pub trait TokenVerifier: Send + Sync + 'static {
async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError>;
}
#[async_trait]
pub trait ServiceTokenMinter: Send + Sync + 'static {
async fn mint(&self) -> Result<AuthCtx, AuthError>;
}
pub struct ChainVerifier {
inner: Vec<Arc<dyn TokenVerifier>>,
}
impl ChainVerifier {
pub fn new() -> Self {
Self { inner: Vec::new() }
}
#[allow(clippy::should_implement_trait)] pub fn add<V: TokenVerifier>(mut self, v: V) -> Self {
self.inner.push(Arc::new(v));
self
}
}
impl Default for ChainVerifier {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TokenVerifier for ChainVerifier {
async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError> {
let mut last_err = AuthError::MissingToken;
for v in &self.inner {
match v.verify(token).await {
Ok(ctx) => return Ok(ctx),
Err(e) => last_err = e,
}
}
Err(last_err)
}
}
pub(crate) struct AnonymousVerifier;
#[async_trait]
impl TokenVerifier for AnonymousVerifier {
async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
Ok(AuthCtx::anonymous())
}
}
tokio::task_local! {
pub static CURRENT_AUTH: AuthCtx;
}
pub fn current() -> AuthCtx {
CURRENT_AUTH
.try_with(|a| a.clone())
.unwrap_or_else(|_| AuthCtx::anonymous())
}
pub async fn service_token() -> Result<AuthCtx, AuthError> {
static MINTER: tokio::sync::OnceCell<Arc<dyn ServiceTokenMinter>> =
tokio::sync::OnceCell::const_new();
let minter = MINTER
.get_or_try_init(|| async {
let m = default::HttpServiceTokenMinter::from_env()?;
Ok::<Arc<dyn ServiceTokenMinter>, AuthError>(Arc::new(m))
})
.await?;
minter.mint().await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn chain_verifier_first_success_wins() {
struct AlwaysOk(AuthCtx);
#[async_trait]
impl TokenVerifier for AlwaysOk {
async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
Ok(self.0.clone())
}
}
struct AlwaysErr;
#[async_trait]
impl TokenVerifier for AlwaysErr {
async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
Err(AuthError::Signature)
}
}
let mut ok = AuthCtx::anonymous();
ok.subject = "alice".into();
let chain = ChainVerifier::new().add(AlwaysErr).add(AlwaysOk(ok));
let token = RawToken {
value: "x".into(),
kind: "bearer-jwt",
};
let out = chain.verify(&token).await.unwrap();
assert_eq!(out.subject, "alice");
}
#[tokio::test]
async fn chain_verifier_returns_last_err_when_all_fail() {
struct ErrA;
struct ErrB;
#[async_trait]
impl TokenVerifier for ErrA {
async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
Err(AuthError::Signature)
}
}
#[async_trait]
impl TokenVerifier for ErrB {
async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
Err(AuthError::Expired)
}
}
let chain = ChainVerifier::new().add(ErrA).add(ErrB);
let token = RawToken {
value: "x".into(),
kind: "bearer-jwt",
};
let err = chain.verify(&token).await.unwrap_err();
matches!(err, AuthError::Expired);
}
}