use crate::marker::{MaybeSend, MaybeSync};
use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use core::fmt;
use core::future::Future;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Credential {
Bearer(String),
ApiKey(String),
Basic {
username: String,
password: String,
},
Custom {
scheme: String,
value: String,
},
}
impl Credential {
pub fn bearer(token: impl Into<String>) -> Self {
Self::Bearer(token.into())
}
pub fn api_key(key: impl Into<String>) -> Self {
Self::ApiKey(key.into())
}
pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
Self::Basic {
username: username.into(),
password: password.into(),
}
}
pub fn custom(scheme: impl Into<String>, value: impl Into<String>) -> Self {
Self::Custom {
scheme: scheme.into(),
value: value.into(),
}
}
pub fn is_bearer(&self) -> bool {
matches!(self, Self::Bearer(_))
}
pub fn as_bearer(&self) -> Option<&str> {
match self {
Self::Bearer(token) => Some(token),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Principal {
pub subject: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audience: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub roles: Vec<String>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub claims: BTreeMap<String, serde_json::Value>,
}
impl Principal {
pub fn new(subject: impl Into<String>) -> Self {
Self {
subject: subject.into(),
issuer: None,
audience: None,
expires_at: None,
email: None,
name: None,
roles: Vec::new(),
claims: BTreeMap::new(),
}
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn with_expires_at(mut self, expires_at: u64) -> Self {
self.expires_at = Some(expires_at);
self
}
pub fn with_email(mut self, email: impl Into<String>) -> Self {
self.email = Some(email.into());
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_role(mut self, role: impl Into<String>) -> Self {
self.roles.push(role.into());
self
}
pub fn with_roles(mut self, roles: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.roles.extend(roles.into_iter().map(Into::into));
self
}
pub fn with_claim(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.claims.insert(key.into(), value);
self
}
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
pub fn has_any_role(&self, roles: &[&str]) -> bool {
roles.iter().any(|r| self.has_role(r))
}
pub fn is_expired(&self) -> bool {
#[cfg(feature = "std")]
{
if let Some(exp) = self.expires_at {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now > exp
} else {
false
}
}
#[cfg(not(feature = "std"))]
{
let _ = self.expires_at;
false
}
}
pub fn get_claim(&self, key: &str) -> Option<&serde_json::Value> {
self.claims.get(key)
}
}
impl fmt::Display for Principal {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Principal({})", self.subject)
}
}
#[derive(Debug, Clone)]
pub enum AuthError {
MissingCredentials,
InvalidCredentialFormat(String),
UnsupportedCredentialType,
TokenExpired,
InvalidSignature,
InvalidClaims(String),
InvalidIssuer {
expected: String,
actual: String,
},
InvalidAudience {
expected: String,
actual: String,
},
KeyNotFound(String),
KeyFetchError(String),
Internal(String),
}
impl fmt::Display for AuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MissingCredentials => write!(f, "No credentials provided"),
Self::InvalidCredentialFormat(msg) => write!(f, "Invalid credential format: {}", msg),
Self::UnsupportedCredentialType => write!(f, "Unsupported credential type"),
Self::TokenExpired => write!(f, "Token has expired"),
Self::InvalidSignature => write!(f, "Invalid token signature"),
Self::InvalidClaims(msg) => write!(f, "Invalid claims: {}", msg),
Self::InvalidIssuer { expected, actual } => {
write!(
f,
"Invalid issuer: expected '{}', got '{}'",
expected, actual
)
}
Self::InvalidAudience { expected, actual } => {
write!(
f,
"Invalid audience: expected '{}', got '{}'",
expected, actual
)
}
Self::KeyNotFound(kid) => write!(f, "Key not found: {}", kid),
Self::KeyFetchError(msg) => write!(f, "Failed to fetch keys: {}", msg),
Self::Internal(msg) => write!(f, "Internal auth error: {}", msg),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for AuthError {}
pub trait Authenticator: MaybeSend + MaybeSync + Clone {
type Error: fmt::Debug + fmt::Display + MaybeSend;
fn authenticate(
&self,
credential: &Credential,
) -> impl Future<Output = Result<Principal, Self::Error>> + MaybeSend;
}
pub trait CredentialExtractor: MaybeSend + MaybeSync {
fn extract<F>(&self, get_header: F) -> Option<Credential>
where
F: Fn(&str) -> Option<String>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct HeaderExtractor;
impl CredentialExtractor for HeaderExtractor {
fn extract<F>(&self, get_header: F) -> Option<Credential>
where
F: Fn(&str) -> Option<String>,
{
if let Some(auth) = get_header("authorization") {
let auth = auth.trim();
if let Some(token) = auth
.strip_prefix("Bearer ")
.or_else(|| auth.strip_prefix("bearer "))
{
return Some(Credential::Bearer(token.trim().to_string()));
}
#[cfg(feature = "std")]
if let Some(encoded) = auth
.strip_prefix("Basic ")
.or_else(|| auth.strip_prefix("basic "))
{
use base64::Engine;
if let Ok(decoded) =
base64::engine::general_purpose::STANDARD.decode(encoded.trim())
&& let Ok(decoded_str) = String::from_utf8(decoded)
&& let Some((username, password)) = decoded_str.split_once(':')
{
return Some(Credential::Basic {
username: username.to_string(),
password: password.to_string(),
});
}
}
if let Some(key) = auth
.strip_prefix("ApiKey ")
.or_else(|| auth.strip_prefix("apikey "))
{
return Some(Credential::ApiKey(key.trim().to_string()));
}
}
if let Some(key) = get_header("x-api-key") {
return Some(Credential::ApiKey(key.trim().to_string()));
}
None
}
}
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub issuer: Option<String>,
pub audience: Option<String>,
pub algorithms: Vec<JwtAlgorithm>,
pub leeway_seconds: u64,
pub validate_exp: bool,
pub validate_nbf: bool,
}
impl JwtConfig {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
issuer: None,
audience: None,
algorithms: vec![JwtAlgorithm::RS256, JwtAlgorithm::ES256],
leeway_seconds: 60,
validate_exp: true,
validate_nbf: true,
}
}
pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn algorithms(mut self, algorithms: Vec<JwtAlgorithm>) -> Self {
self.algorithms = algorithms;
self
}
pub fn leeway_seconds(mut self, seconds: u64) -> Self {
self.leeway_seconds = seconds;
self
}
pub fn skip_exp_validation(mut self) -> Self {
self.validate_exp = false;
self
}
pub fn skip_nbf_validation(mut self) -> Self {
self.validate_nbf = false;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum JwtAlgorithm {
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
}
impl JwtAlgorithm {
pub fn as_str(&self) -> &'static str {
match self {
Self::HS256 => "HS256",
Self::HS384 => "HS384",
Self::HS512 => "HS512",
Self::RS256 => "RS256",
Self::RS384 => "RS384",
Self::RS512 => "RS512",
Self::ES256 => "ES256",
Self::ES384 => "ES384",
}
}
pub fn parse(s: &str) -> Option<Self> {
s.parse().ok()
}
pub fn is_asymmetric(&self) -> bool {
matches!(
self,
Self::RS256 | Self::RS384 | Self::RS512 | Self::ES256 | Self::ES384
)
}
pub fn is_symmetric(&self) -> bool {
matches!(self, Self::HS256 | Self::HS384 | Self::HS512)
}
}
impl fmt::Display for JwtAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl core::str::FromStr for JwtAlgorithm {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"HS256" => Ok(Self::HS256),
"HS384" => Ok(Self::HS384),
"HS512" => Ok(Self::HS512),
"RS256" => Ok(Self::RS256),
"RS384" => Ok(Self::RS384),
"RS512" => Ok(Self::RS512),
"ES256" => Ok(Self::ES256),
"ES384" => Ok(Self::ES384),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StandardClaims {
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<Audience>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Audience {
Single(String),
Multiple(Vec<String>),
}
impl Audience {
pub fn contains(&self, expected: &str) -> bool {
match self {
Self::Single(s) => s == expected,
Self::Multiple(v) => v.iter().any(|s| s == expected),
}
}
pub fn to_vec(&self) -> Vec<String> {
match self {
Self::Single(s) => vec![s.clone()],
Self::Multiple(v) => v.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_credential_constructors() {
let bearer = Credential::bearer("token123");
assert!(bearer.is_bearer());
assert_eq!(bearer.as_bearer(), Some("token123"));
let api_key = Credential::api_key("key456");
assert!(!api_key.is_bearer());
assert_eq!(api_key.as_bearer(), None);
let basic = Credential::basic("user", "pass");
assert!(!basic.is_bearer());
}
#[test]
fn test_principal_builder() {
let principal = Principal::new("user123")
.with_issuer("https://auth.example.com")
.with_audience("my-api")
.with_email("user@example.com")
.with_role("admin")
.with_role("user");
assert_eq!(principal.subject, "user123");
assert_eq!(
principal.issuer,
Some("https://auth.example.com".to_string())
);
assert!(principal.has_role("admin"));
assert!(principal.has_role("user"));
assert!(!principal.has_role("guest"));
assert!(principal.has_any_role(&["admin", "guest"]));
}
#[test]
fn test_header_extractor_bearer() {
let extractor = HeaderExtractor;
let cred = extractor.extract(|name| {
if name == "authorization" {
Some("Bearer my-token".to_string())
} else {
None
}
});
assert_eq!(cred, Some(Credential::Bearer("my-token".to_string())));
}
#[test]
fn test_header_extractor_api_key() {
let extractor = HeaderExtractor;
let cred = extractor.extract(|name| {
if name == "x-api-key" {
Some("my-api-key".to_string())
} else {
None
}
});
assert_eq!(cred, Some(Credential::ApiKey("my-api-key".to_string())));
let cred2 = extractor.extract(|name| {
if name == "authorization" {
Some("ApiKey another-key".to_string())
} else {
None
}
});
assert_eq!(cred2, Some(Credential::ApiKey("another-key".to_string())));
}
#[test]
fn test_jwt_algorithm() {
assert_eq!(JwtAlgorithm::RS256.as_str(), "RS256");
assert!(JwtAlgorithm::RS256.is_asymmetric());
assert!(!JwtAlgorithm::RS256.is_symmetric());
assert!(JwtAlgorithm::HS256.is_symmetric());
assert!(!JwtAlgorithm::HS256.is_asymmetric());
assert_eq!(JwtAlgorithm::parse("es256"), Some(JwtAlgorithm::ES256));
assert_eq!(JwtAlgorithm::parse("unknown"), None);
}
#[test]
fn test_audience() {
let single = Audience::Single("my-api".to_string());
assert!(single.contains("my-api"));
assert!(!single.contains("other"));
let multiple = Audience::Multiple(vec!["api1".to_string(), "api2".to_string()]);
assert!(multiple.contains("api1"));
assert!(multiple.contains("api2"));
assert!(!multiple.contains("api3"));
}
#[test]
fn test_jwt_config_builder() {
let config = JwtConfig::new()
.issuer("https://auth.example.com")
.audience("my-api")
.algorithms(vec![JwtAlgorithm::RS256])
.leeway_seconds(120);
assert_eq!(config.issuer, Some("https://auth.example.com".to_string()));
assert_eq!(config.audience, Some("my-api".to_string()));
assert_eq!(config.algorithms, vec![JwtAlgorithm::RS256]);
assert_eq!(config.leeway_seconds, 120);
}
}