pub mod bearer;
pub mod mtls;
pub mod oauth;
#[cfg(feature = "jwt")]
pub mod jwt;
#[cfg(feature = "oauth-pkce")]
pub mod pkce;
#[cfg(feature = "oauth-pkce-server")]
pub mod pkce_server;
use std::collections::BTreeMap;
use std::sync::Arc;
use crate::errors::RpcError;
#[derive(Clone, Debug, Default)]
pub struct AuthContext {
pub domain: String,
pub authenticated: bool,
pub principal: String,
pub claims: BTreeMap<String, String>,
}
impl AuthContext {
pub fn anonymous() -> Self {
Self::default()
}
pub fn for_principal(domain: impl Into<String>, principal: impl Into<String>) -> Self {
Self {
domain: domain.into(),
authenticated: true,
principal: principal.into(),
claims: BTreeMap::new(),
}
}
pub fn require_authenticated(&self) -> crate::errors::Result<()> {
if self.authenticated {
Ok(())
} else {
Err(RpcError::permission_error("authentication required"))
}
}
pub fn with_claim(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.claims.insert(key.into(), value.into());
self
}
}
#[derive(Debug)]
pub struct AuthRequest<'a> {
pub method: &'a str,
pub headers: &'a [(String, String)],
pub peer_addr: Option<&'a str>,
}
impl<'a> AuthRequest<'a> {
pub fn anonymous_pipe(method: &'a str) -> Self {
Self {
method,
headers: &[],
peer_addr: None,
}
}
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
}
pub type AuthResult = std::result::Result<AuthContext, RpcError>;
pub type Authenticate = Arc<dyn Fn(&AuthRequest<'_>) -> AuthResult + Send + Sync>;
pub fn chain_authenticate(a: Authenticate, b: Authenticate) -> Authenticate {
Arc::new(move |req| {
let first = (a)(req)?;
if first.authenticated {
return Ok(first);
}
(b)(req)
})
}
pub(crate) fn extract_bearer<'a>(req: &'a AuthRequest<'a>) -> Option<&'a str> {
let h = req.header("authorization")?;
let prefix = "Bearer ";
if h.len() > prefix.len() && h[..prefix.len()].eq_ignore_ascii_case(prefix) {
let tok = h[prefix.len()..].trim();
(!tok.is_empty()).then_some(tok)
} else {
None
}
}
pub fn chain_all<I: IntoIterator<Item = Authenticate>>(cbs: I) -> Option<Authenticate> {
let mut it = cbs.into_iter();
let mut acc = it.next()?;
for next in it {
acc = chain_authenticate(acc, next);
}
Some(acc)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn require_authenticated_rejects_anonymous() {
let anon = AuthContext::anonymous();
assert!(anon.require_authenticated().is_err());
let authd = AuthContext::for_principal("bearer", "alice");
assert!(authd.require_authenticated().is_ok());
}
#[test]
fn chain_tries_second_when_first_anonymous() {
let a: Authenticate = Arc::new(|_| Ok(AuthContext::anonymous()));
let b: Authenticate = Arc::new(|_| Ok(AuthContext::for_principal("bearer", "alice")));
let chain = chain_authenticate(a, b);
let req = AuthRequest::anonymous_pipe("echo");
let ctx = chain(&req).unwrap();
assert_eq!(ctx.principal, "alice");
}
#[test]
fn chain_uses_first_when_authenticated() {
let a: Authenticate = Arc::new(|_| Ok(AuthContext::for_principal("mtls", "bob")));
let b: Authenticate = Arc::new(|_| Ok(AuthContext::for_principal("bearer", "alice")));
let chain = chain_authenticate(a, b);
let req = AuthRequest::anonymous_pipe("echo");
assert_eq!(chain(&req).unwrap().principal, "bob");
}
}