use std::fmt;
#[cfg(feature = "auth-jwt")]
pub mod jwt;
#[cfg(feature = "auth-axum")]
pub mod axum;
#[cfg(feature = "auth-tonic")]
pub mod tonic;
#[cfg(feature = "auth-jwt")]
pub use jwt::{JwtAlgorithm, JwtConfig, JwtValidator};
#[cfg(feature = "auth-axum")]
pub use self::axum::{AuthLayer, AuthenticatedUser};
#[cfg(feature = "auth-tonic")]
pub use self::tonic::AuthInterceptor;
#[derive(Debug, Clone)]
pub enum AuthError {
MissingToken,
InvalidToken(String),
TokenExpired,
InvalidSignature,
InvalidIssuer,
InvalidAudience,
ValidationFailed(String),
Internal(String),
}
impl fmt::Display for AuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AuthError::MissingToken => write!(f, "missing authentication token"),
AuthError::InvalidToken(msg) => write!(f, "invalid token: {}", msg),
AuthError::TokenExpired => write!(f, "token has expired"),
AuthError::InvalidSignature => write!(f, "invalid token signature"),
AuthError::InvalidIssuer => write!(f, "invalid token issuer"),
AuthError::InvalidAudience => write!(f, "invalid token audience"),
AuthError::ValidationFailed(msg) => write!(f, "validation failed: {}", msg),
AuthError::Internal(msg) => write!(f, "internal auth error: {}", msg),
}
}
}
impl std::error::Error for AuthError {}
impl AuthError {
pub fn is_missing(&self) -> bool {
matches!(self, AuthError::MissingToken)
}
pub fn is_expired(&self) -> bool {
matches!(self, AuthError::TokenExpired)
}
pub fn status_code(&self) -> u16 {
match self {
AuthError::MissingToken => 401,
AuthError::InvalidToken(_) => 401,
AuthError::TokenExpired => 401,
AuthError::InvalidSignature => 401,
AuthError::InvalidIssuer => 401,
AuthError::InvalidAudience => 401,
AuthError::ValidationFailed(_) => 403,
AuthError::Internal(_) => 500,
}
}
}
#[async_trait::async_trait]
pub trait Authenticator: Send + Sync {
type Claims: Clone + Send + Sync + 'static;
async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError>;
}
#[derive(Clone, Debug)]
pub struct AuthContext<C> {
pub claims: C,
pub token: String,
}
impl<C: Clone> AuthContext<C> {
pub fn new(claims: C, token: impl Into<String>) -> Self {
Self {
claims,
token: token.into(),
}
}
pub fn claims(&self) -> &C {
&self.claims
}
pub fn token(&self) -> &str {
&self.token
}
pub fn get<T>(&self, f: impl FnOnce(&C) -> T) -> T {
f(&self.claims)
}
}
pub fn extract_bearer_token(header_value: &str) -> Option<&str> {
let header = header_value.trim();
if header.len() > 7 && header[..7].eq_ignore_ascii_case("bearer ") {
Some(header[7..].trim())
} else {
None
}
}
pub trait HasSubject {
fn subject(&self) -> &str;
}
pub trait HasExpiration {
fn expiration(&self) -> Option<i64>;
fn is_expired(&self) -> bool {
if let Some(exp) = self.expiration() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
exp < now
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_bearer_token() {
assert_eq!(extract_bearer_token("Bearer abc123"), Some("abc123"));
assert_eq!(extract_bearer_token("bearer ABC"), Some("ABC"));
assert_eq!(extract_bearer_token("BEARER token"), Some("token"));
assert_eq!(extract_bearer_token("Bearer spaced"), Some("spaced"));
assert_eq!(extract_bearer_token("Basic xyz"), None);
assert_eq!(extract_bearer_token("abc123"), None);
assert_eq!(extract_bearer_token(""), None);
assert_eq!(extract_bearer_token("Bearer"), None);
assert_eq!(extract_bearer_token("Bearer "), None);
}
#[test]
fn test_auth_error_display() {
assert_eq!(
AuthError::MissingToken.to_string(),
"missing authentication token"
);
assert_eq!(AuthError::TokenExpired.to_string(), "token has expired");
assert_eq!(
AuthError::InvalidToken("bad".into()).to_string(),
"invalid token: bad"
);
}
#[test]
fn test_auth_error_status_codes() {
assert_eq!(AuthError::MissingToken.status_code(), 401);
assert_eq!(AuthError::TokenExpired.status_code(), 401);
assert_eq!(AuthError::ValidationFailed("".into()).status_code(), 403);
assert_eq!(AuthError::Internal("".into()).status_code(), 500);
}
#[test]
fn test_auth_context() {
#[derive(Clone, Debug)]
struct TestClaims {
sub: String,
role: String,
}
let ctx = AuthContext::new(
TestClaims {
sub: "user123".into(),
role: "admin".into(),
},
"token123",
);
assert_eq!(ctx.claims().sub, "user123");
assert_eq!(ctx.token(), "token123");
assert_eq!(ctx.get(|c| c.role.clone()), "admin");
}
#[test]
fn test_auth_error_predicates() {
assert!(AuthError::MissingToken.is_missing());
assert!(!AuthError::TokenExpired.is_missing());
assert!(AuthError::TokenExpired.is_expired());
assert!(!AuthError::MissingToken.is_expired());
}
#[derive(Clone)]
struct MockClaims {
exp: Option<i64>,
}
impl HasExpiration for MockClaims {
fn expiration(&self) -> Option<i64> {
self.exp
}
}
#[test]
fn test_has_expiration() {
let past = MockClaims { exp: Some(0) };
assert!(past.is_expired());
let future = MockClaims {
exp: Some(i64::MAX),
};
assert!(!future.is_expired());
let no_exp = MockClaims { exp: None };
assert!(!no_exp.is_expired());
}
}