Skip to main content

tonin_client/
auth.rs

1//! Auth types shared between server and client.
2//!
3//! These types are the contract between:
4//! - the inbound `AuthLayer` (in `tonin-core::auth::layer`) which
5//!   produces an `AuthCtx` from a verified token, and
6//! - the outbound client SDK which copies the bearer token onto
7//!   downstream requests via [`AuthCtx::propagate`].
8//!
9//! Custom verifiers (Okta, Cognito, API keys, cookies) are implemented
10//! on the server side via the `TokenVerifier` trait — but they all
11//! return the same `AuthCtx` defined here.
12
13use std::collections::HashMap;
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use serde::{Deserialize, Serialize};
17use tonic::{Request, Status};
18
19/// A token as extracted from a request, before verification.
20///
21/// The framework doesn't assume JWT — the value could be a session ID,
22/// API key, or anything else a server-side `TokenVerifier` knows how to
23/// handle.
24#[derive(Clone, Debug)]
25pub struct RawToken {
26    pub value: String,
27    /// Hint for verifiers that handle multiple token formats.
28    /// Conventions: `"bearer-jwt"`, `"api-key"`, `"session-cookie"`,
29    /// `"basic-auth"`, etc.
30    pub kind: &'static str,
31}
32
33/// Identity + claims for the current request. The single concrete type
34/// that flows through the framework — outbound clients accept this, the
35/// server-side auth layer produces it. Custom claims live in
36/// [`Self::extra`].
37#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct AuthCtx {
39    /// `sub` claim. User ID for users, service ID for services.
40    pub subject: String,
41    pub issuer: String,
42    pub audience: String,
43    pub scopes: Vec<String>,
44    pub kind: PrincipalKind,
45    /// The verbatim token. Used by [`AuthCtx::propagate`] for outbound calls.
46    pub raw_token: String,
47    /// Unix-seconds expiry. f64 to stay JSON-compatible with the
48    /// Python and TS sides (which use `number`). `0.0` means "no
49    /// expiry recorded" (e.g. an anonymous context). Use
50    /// [`AuthCtx::expires_at_systime`] / [`AuthCtx::set_expires_at_systime`]
51    /// when interop with `std::time::SystemTime` is convenient.
52    pub expires_at: f64,
53    /// Claims not mapped to typed fields. Verifiers populate this with
54    /// anything custom (e.g. tenant_id, role, agent_on_behalf_of).
55    #[serde(default)]
56    pub extra: HashMap<String, serde_json::Value>,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
60#[serde(rename_all = "lowercase")]
61pub enum PrincipalKind {
62    User,
63    Service,
64    Agent,
65    /// No auth was attempted (service is in opt-out mode).
66    Anonymous,
67}
68
69impl AuthCtx {
70    /// Returns an empty `AuthCtx` for opt-out / no-auth flows.
71    pub fn anonymous() -> Self {
72        Self {
73            subject: String::new(),
74            issuer: String::new(),
75            audience: String::new(),
76            scopes: Vec::new(),
77            kind: PrincipalKind::Anonymous,
78            raw_token: String::new(),
79            expires_at: 0.0,
80            extra: HashMap::new(),
81        }
82    }
83
84    /// Wrap a bearer token without verification. For client-side code
85    /// that already has a token (e.g., from a login flow) and wants to
86    /// hand it to the framework's outbound propagation.
87    pub fn from_bearer(token: impl Into<String>) -> Self {
88        let token = token.into();
89        Self {
90            raw_token: token,
91            kind: PrincipalKind::User,
92            ..Self::anonymous()
93        }
94    }
95
96    /// Pull `AuthCtx` from a tonic request's extensions, populated by
97    /// the inbound auth layer. Returns [`Self::anonymous`] if no layer
98    /// ran.
99    pub fn from<T>(req: &Request<T>) -> Self {
100        req.extensions()
101            .get::<AuthCtx>()
102            .cloned()
103            .unwrap_or_else(Self::anonymous)
104    }
105
106    /// Copy the bearer token onto an outbound request so the caller's
107    /// identity rides along to the next service.
108    pub fn propagate<T>(&self, req: &mut Request<T>) {
109        if self.raw_token.is_empty() {
110            return;
111        }
112        if let Ok(value) = format!("Bearer {}", self.raw_token).parse() {
113            req.metadata_mut().insert("authorization", value);
114        }
115    }
116
117    /// Authorize a single scope. Returns `PermissionDenied` if missing.
118    /// Convenient for `Status` returns from handlers.
119    #[allow(clippy::result_large_err)] // tonic::Status is the canonical error type for gRPC handlers
120    pub fn require_scope(&self, scope: &str) -> Result<(), Status> {
121        if self.scopes.iter().any(|s| s == scope) {
122            Ok(())
123        } else {
124            Err(AuthError::InsufficientScope {
125                required: scope.into(),
126            }
127            .into())
128        }
129    }
130
131    pub fn is_anonymous(&self) -> bool {
132        matches!(self.kind, PrincipalKind::Anonymous)
133    }
134
135    /// Convert `expires_at` (unix seconds) into a `SystemTime`. Returns
136    /// `UNIX_EPOCH` for an anonymous / unset context (`expires_at == 0.0`).
137    pub fn expires_at_systime(&self) -> SystemTime {
138        if self.expires_at <= 0.0 {
139            UNIX_EPOCH
140        } else {
141            UNIX_EPOCH + std::time::Duration::from_secs_f64(self.expires_at)
142        }
143    }
144
145    /// Set `expires_at` from a `SystemTime`. Convenience for verifiers
146    /// that already hold a `SystemTime` (e.g. JWT `iat + max_age`).
147    pub fn set_expires_at_systime(&mut self, t: SystemTime) {
148        self.expires_at = t
149            .duration_since(UNIX_EPOCH)
150            .map(|d| d.as_secs_f64())
151            .unwrap_or(0.0);
152    }
153}
154
155#[derive(Debug, thiserror::Error)]
156pub enum AuthError {
157    #[error("no token in request")]
158    MissingToken,
159    #[error("token signature invalid")]
160    Signature,
161    #[error("token expired")]
162    Expired,
163    #[error("audience mismatch: expected {expected}, got {got}")]
164    Audience { expected: String, got: String },
165    #[error("issuer mismatch: expected {expected}, got {got}")]
166    Issuer { expected: String, got: String },
167    #[error("token verification failed: {0}")]
168    Verification(String),
169    #[error("insufficient scope: required {required}")]
170    InsufficientScope { required: String },
171    #[error("configuration error: {0}")]
172    Config(String),
173    #[error("transport error contacting auth backend: {0}")]
174    Transport(String),
175}
176
177impl From<AuthError> for Status {
178    fn from(e: AuthError) -> Status {
179        match e {
180            AuthError::InsufficientScope { .. } => Status::permission_denied(e.to_string()),
181            AuthError::Config(_) | AuthError::Transport(_) => Status::internal(e.to_string()),
182            _ => Status::unauthenticated(e.to_string()),
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn anonymous_authctx_is_anonymous() {
193        let a = AuthCtx::anonymous();
194        assert!(a.is_anonymous());
195        assert_eq!(a.kind, PrincipalKind::Anonymous);
196    }
197
198    #[test]
199    fn from_bearer_carries_token() {
200        let a = AuthCtx::from_bearer("abc.def.ghi");
201        assert_eq!(a.raw_token, "abc.def.ghi");
202        assert_eq!(a.kind, PrincipalKind::User);
203    }
204
205    #[test]
206    fn propagate_writes_authorization_header() {
207        let a = AuthCtx::from_bearer("abc.def.ghi");
208        let mut req = Request::new(());
209        a.propagate(&mut req);
210        let v = req.metadata().get("authorization").unwrap();
211        assert_eq!(v.to_str().unwrap(), "Bearer abc.def.ghi");
212    }
213
214    #[test]
215    fn propagate_anonymous_is_noop() {
216        let a = AuthCtx::anonymous();
217        let mut req = Request::new(());
218        a.propagate(&mut req);
219        assert!(req.metadata().get("authorization").is_none());
220    }
221
222    #[test]
223    fn require_scope_ok_when_present() {
224        let mut a = AuthCtx::anonymous();
225        a.scopes = vec!["read:billing".into()];
226        assert!(a.require_scope("read:billing").is_ok());
227    }
228
229    #[test]
230    fn require_scope_err_when_missing() {
231        let a = AuthCtx::anonymous();
232        let err = a.require_scope("admin").unwrap_err();
233        assert_eq!(err.code(), tonic::Code::PermissionDenied);
234    }
235
236    #[test]
237    fn auth_error_maps_to_correct_status() {
238        let s: Status = AuthError::Signature.into();
239        assert_eq!(s.code(), tonic::Code::Unauthenticated);
240
241        let s: Status = AuthError::InsufficientScope {
242            required: "admin".into(),
243        }
244        .into();
245        assert_eq!(s.code(), tonic::Code::PermissionDenied);
246
247        let s: Status = AuthError::Config("missing env".into()).into();
248        assert_eq!(s.code(), tonic::Code::Internal);
249    }
250
251    /// Lock in the JSON wire shape across language boundaries.
252    ///
253    /// If this test breaks, the Python/TS sides need an update too —
254    /// run `cargo run --bin gen-shared-types` and re-check
255    /// `python/tonin-client/tests/test_wire_compat.py`.
256    #[test]
257    fn authctx_json_shape_is_stable_for_polyglot_consumers() {
258        let mut ctx = AuthCtx::anonymous();
259        ctx.subject = "alice".into();
260        ctx.issuer = "https://issuer.example".into();
261        ctx.audience = "my-svc".into();
262        ctx.scopes = vec!["read:billing".into(), "write:billing".into()];
263        ctx.kind = PrincipalKind::User;
264        ctx.raw_token = "abc.def.ghi".into();
265        ctx.expires_at = 1_735_689_600.0;
266        ctx.extra
267            .insert("tenant_id".into(), serde_json::json!("acme"));
268
269        let v = serde_json::to_value(&ctx).unwrap();
270        // Field-name presence (snake_case, no rename).
271        for f in [
272            "subject",
273            "issuer",
274            "audience",
275            "scopes",
276            "kind",
277            "raw_token",
278            "expires_at",
279            "extra",
280        ] {
281            assert!(
282                v.get(f).is_some(),
283                "missing field `{f}` in serialized AuthCtx JSON shape"
284            );
285        }
286        // Critical contract: expires_at is a JSON number (not a struct).
287        assert!(v["expires_at"].is_number());
288        // Critical contract: kind serializes as lowercase string.
289        assert_eq!(v["kind"], serde_json::json!("user"));
290    }
291}