use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use tonic::{Request, Status};
#[derive(Clone, Debug)]
pub struct RawToken {
pub value: String,
pub kind: &'static str,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthCtx {
pub subject: String,
pub issuer: String,
pub audience: String,
pub scopes: Vec<String>,
pub kind: PrincipalKind,
pub raw_token: String,
pub expires_at: f64,
#[serde(default)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PrincipalKind {
User,
Service,
Agent,
Anonymous,
}
impl AuthCtx {
pub fn anonymous() -> Self {
Self {
subject: String::new(),
issuer: String::new(),
audience: String::new(),
scopes: Vec::new(),
kind: PrincipalKind::Anonymous,
raw_token: String::new(),
expires_at: 0.0,
extra: HashMap::new(),
}
}
pub fn from_bearer(token: impl Into<String>) -> Self {
let token = token.into();
Self {
raw_token: token,
kind: PrincipalKind::User,
..Self::anonymous()
}
}
pub fn from<T>(req: &Request<T>) -> Self {
req.extensions()
.get::<AuthCtx>()
.cloned()
.unwrap_or_else(Self::anonymous)
}
pub fn propagate<T>(&self, req: &mut Request<T>) {
if self.raw_token.is_empty() {
return;
}
if let Ok(value) = format!("Bearer {}", self.raw_token).parse() {
req.metadata_mut().insert("authorization", value);
}
}
#[allow(clippy::result_large_err)] pub fn require_scope(&self, scope: &str) -> Result<(), Status> {
if self.scopes.iter().any(|s| s == scope) {
Ok(())
} else {
Err(AuthError::InsufficientScope {
required: scope.into(),
}
.into())
}
}
pub fn is_anonymous(&self) -> bool {
matches!(self.kind, PrincipalKind::Anonymous)
}
pub fn expires_at_systime(&self) -> SystemTime {
if self.expires_at <= 0.0 {
UNIX_EPOCH
} else {
UNIX_EPOCH + std::time::Duration::from_secs_f64(self.expires_at)
}
}
pub fn set_expires_at_systime(&mut self, t: SystemTime) {
self.expires_at = t
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0);
}
}
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("no token in request")]
MissingToken,
#[error("token signature invalid")]
Signature,
#[error("token expired")]
Expired,
#[error("audience mismatch: expected {expected}, got {got}")]
Audience { expected: String, got: String },
#[error("issuer mismatch: expected {expected}, got {got}")]
Issuer { expected: String, got: String },
#[error("token verification failed: {0}")]
Verification(String),
#[error("insufficient scope: required {required}")]
InsufficientScope { required: String },
#[error("configuration error: {0}")]
Config(String),
#[error("transport error contacting auth backend: {0}")]
Transport(String),
}
impl From<AuthError> for Status {
fn from(e: AuthError) -> Status {
match e {
AuthError::InsufficientScope { .. } => Status::permission_denied(e.to_string()),
AuthError::Config(_) | AuthError::Transport(_) => Status::internal(e.to_string()),
_ => Status::unauthenticated(e.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn anonymous_authctx_is_anonymous() {
let a = AuthCtx::anonymous();
assert!(a.is_anonymous());
assert_eq!(a.kind, PrincipalKind::Anonymous);
}
#[test]
fn from_bearer_carries_token() {
let a = AuthCtx::from_bearer("abc.def.ghi");
assert_eq!(a.raw_token, "abc.def.ghi");
assert_eq!(a.kind, PrincipalKind::User);
}
#[test]
fn propagate_writes_authorization_header() {
let a = AuthCtx::from_bearer("abc.def.ghi");
let mut req = Request::new(());
a.propagate(&mut req);
let v = req.metadata().get("authorization").unwrap();
assert_eq!(v.to_str().unwrap(), "Bearer abc.def.ghi");
}
#[test]
fn propagate_anonymous_is_noop() {
let a = AuthCtx::anonymous();
let mut req = Request::new(());
a.propagate(&mut req);
assert!(req.metadata().get("authorization").is_none());
}
#[test]
fn require_scope_ok_when_present() {
let mut a = AuthCtx::anonymous();
a.scopes = vec!["read:billing".into()];
assert!(a.require_scope("read:billing").is_ok());
}
#[test]
fn require_scope_err_when_missing() {
let a = AuthCtx::anonymous();
let err = a.require_scope("admin").unwrap_err();
assert_eq!(err.code(), tonic::Code::PermissionDenied);
}
#[test]
fn auth_error_maps_to_correct_status() {
let s: Status = AuthError::Signature.into();
assert_eq!(s.code(), tonic::Code::Unauthenticated);
let s: Status = AuthError::InsufficientScope {
required: "admin".into(),
}
.into();
assert_eq!(s.code(), tonic::Code::PermissionDenied);
let s: Status = AuthError::Config("missing env".into()).into();
assert_eq!(s.code(), tonic::Code::Internal);
}
#[test]
fn authctx_json_shape_is_stable_for_polyglot_consumers() {
let mut ctx = AuthCtx::anonymous();
ctx.subject = "alice".into();
ctx.issuer = "https://issuer.example".into();
ctx.audience = "my-svc".into();
ctx.scopes = vec!["read:billing".into(), "write:billing".into()];
ctx.kind = PrincipalKind::User;
ctx.raw_token = "abc.def.ghi".into();
ctx.expires_at = 1_735_689_600.0;
ctx.extra
.insert("tenant_id".into(), serde_json::json!("acme"));
let v = serde_json::to_value(&ctx).unwrap();
for f in [
"subject",
"issuer",
"audience",
"scopes",
"kind",
"raw_token",
"expires_at",
"extra",
] {
assert!(
v.get(f).is_some(),
"missing field `{f}` in serialized AuthCtx JSON shape"
);
}
assert!(v["expires_at"].is_number());
assert_eq!(v["kind"], serde_json::json!("user"));
}
}