Skip to main content

stack_auth/
token.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2
3use cts_common::claims::Claims;
4use cts_common::{Crn, Region, WorkspaceId};
5use url::Url;
6
7use crate::{http_client, AuthError, SecretToken};
8
9impl stack_profile::ProfileData for Token {
10    const FILENAME: &'static str = "auth.json";
11    const MODE: Option<u32> = Some(0o600);
12}
13
14/// How many seconds before expiry [`Token::is_expired`] returns `true`.
15///
16/// This leeway triggers preemptive refresh well before the token becomes
17/// unusable, giving the HTTP refresh call time to complete while concurrent
18/// callers can still use the current token.
19const EXPIRY_LEEWAY_SECS: u64 = 90;
20
21/// An access token returned by a successful authentication flow.
22///
23/// The token contains a [`SecretToken`] (the bearer credential), a token type
24/// (typically `"Bearer"`), and an absolute expiry timestamp.
25#[derive(Debug, serde::Serialize, serde::Deserialize)]
26pub struct Token {
27    pub(crate) access_token: SecretToken,
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub(crate) refresh_token: Option<SecretToken>,
30    pub(crate) token_type: String,
31    pub(crate) expires_at: u64,
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub(crate) region: Option<String>,
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub(crate) client_id: Option<String>,
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub(crate) device_instance_id: Option<String>,
38}
39
40impl Token {
41    /// Returns a reference to the access token credential.
42    ///
43    /// The returned [`SecretToken`] is opaque — its [`Debug`] output is masked.
44    /// Pass it to API clients that need the raw bearer token.
45    pub fn access_token(&self) -> &SecretToken {
46        &self.access_token
47    }
48
49    /// The token type (e.g. `"Bearer"`).
50    pub fn token_type(&self) -> &str {
51        &self.token_type
52    }
53
54    /// The absolute epoch timestamp when the token expires.
55    pub fn expires_at(&self) -> u64 {
56        self.expires_at
57    }
58
59    /// How many seconds until the token expires (computed from the current time).
60    pub fn expires_in(&self) -> u64 {
61        let now = SystemTime::now()
62            .duration_since(UNIX_EPOCH)
63            .unwrap_or_default()
64            .as_secs();
65        self.expires_at.saturating_sub(now)
66    }
67
68    /// Returns `true` if the token has expired (with 90 seconds of leeway).
69    ///
70    /// The 90-second leeway triggers preemptive refresh well before the token
71    /// becomes unusable, giving the HTTP refresh call plenty of time to complete
72    /// while the current token is still valid for concurrent callers.
73    ///
74    /// For checking whether the token is still usable as a bearer credential,
75    /// use [`is_usable`](Self::is_usable) instead.
76    pub fn is_expired(&self) -> bool {
77        let now = SystemTime::now()
78            .duration_since(UNIX_EPOCH)
79            .unwrap_or_default()
80            .as_secs();
81        now + EXPIRY_LEEWAY_SECS >= self.expires_at
82    }
83
84    /// Returns `true` if the token is still usable (before the actual expiry timestamp).
85    ///
86    /// Unlike [`is_expired`](Self::is_expired) which includes 90s leeway for preemptive
87    /// refresh, this only returns `false` when the token has genuinely expired.
88    pub fn is_usable(&self) -> bool {
89        let now = SystemTime::now()
90            .duration_since(UNIX_EPOCH)
91            .unwrap_or_default()
92            .as_secs();
93        now < self.expires_at
94    }
95
96    /// Returns a reference to the refresh token, if one was provided.
97    pub fn refresh_token(&self) -> Option<&SecretToken> {
98        self.refresh_token.as_ref()
99    }
100
101    /// Takes the refresh token out, leaving `None` in its place.
102    pub fn take_refresh_token(&mut self) -> Option<SecretToken> {
103        self.refresh_token.take()
104    }
105
106    /// Returns the stored region identifier, if any.
107    pub fn region(&self) -> Option<&str> {
108        self.region.as_deref()
109    }
110
111    /// Returns the stored client ID, if any.
112    pub fn client_id(&self) -> Option<&str> {
113        self.client_id.as_deref()
114    }
115
116    /// Set the region identifier on this token.
117    pub(crate) fn set_region(&mut self, region: impl Into<String>) {
118        self.region = Some(region.into());
119    }
120
121    /// Set the client ID on this token.
122    pub(crate) fn set_client_id(&mut self, client_id: impl Into<String>) {
123        self.client_id = Some(client_id.into());
124    }
125
126    /// Returns the stored device instance ID, if any.
127    pub fn device_instance_id(&self) -> Option<&str> {
128        self.device_instance_id.as_deref()
129    }
130
131    /// Set the device instance ID on this token.
132    pub(crate) fn set_device_instance_id(&mut self, id: impl Into<String>) {
133        self.device_instance_id = Some(id.into());
134    }
135
136    /// Returns the workspace ID from the JWT claims.
137    ///
138    /// The access token is decoded (without signature verification) to extract
139    /// the `workspace` claim.
140    pub fn workspace_id(&self) -> Result<WorkspaceId, AuthError> {
141        self.decode_claims().map(|c| c.workspace)
142    }
143
144    /// Returns the workspace CRN derived from the token's region and workspace ID.
145    ///
146    /// The region is set during the device code flow, and the workspace ID is
147    /// extracted from the JWT `workspace` claim.
148    pub fn workspace_crn(&self) -> Result<Crn, AuthError> {
149        let workspace_id = self.workspace_id()?;
150        let region: Region = self
151            .region()
152            .ok_or(AuthError::NotAuthenticated)?
153            .parse()
154            .map_err(|e: cts_common::RegionError| AuthError::Server(e.to_string()))?;
155        Ok(Crn::new(region, workspace_id))
156    }
157
158    /// Returns the issuer URL from the JWT claims.
159    ///
160    /// The `iss` claim in CipherStash tokens is the CTS host URL for the
161    /// workspace, so this can be used directly as the CTS base URL.
162    pub fn issuer(&self) -> Result<Url, AuthError> {
163        let claims = self.decode_claims()?;
164        claims.iss.parse().map_err(AuthError::from)
165    }
166
167    /// Decode the JWT payload into [`Claims`] without verifying the signature.
168    ///
169    /// This is safe because we already possess the token — we just need to read
170    /// the claims it contains.
171    fn decode_claims(&self) -> Result<Claims, AuthError> {
172        use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
173        use std::collections::HashSet;
174
175        let token_str = self.access_token.as_str();
176        let header = decode_header(token_str)
177            .map_err(|e| AuthError::InvalidToken(format!("invalid JWT header: {e}")))?;
178
179        let dummy_key = DecodingKey::from_secret(&[]);
180        let mut validation = Validation::new(header.alg);
181        validation.validate_exp = false;
182        validation.validate_aud = false;
183        validation.required_spec_claims = HashSet::new();
184        validation.insecure_disable_signature_validation();
185
186        decode(token_str, &dummy_key, &validation)
187            .map(|data| data.claims)
188            .map_err(|e| AuthError::InvalidToken(format!("failed to decode JWT claims: {e}")))
189    }
190
191    /// Exchange a refresh token for a new [`Token`] via the `/oauth/token`
192    /// endpoint.
193    ///
194    /// This is a static constructor — it takes a bare [`SecretToken`] (the
195    /// refresh token) rather than operating on an existing `Token`. This
196    /// allows callers to manage the refresh token lifecycle independently
197    /// (e.g. taking it out of a cached token for cascade prevention and
198    /// restoring it on failure).
199    ///
200    /// # Errors
201    ///
202    /// - [`AuthError::InvalidGrant`] — the refresh token was revoked or expired.
203    /// - [`AuthError::InvalidClient`] — the client ID is not recognized.
204    /// - [`AuthError::Request`] — a network error occurred.
205    pub async fn refresh(
206        refresh_token: &SecretToken,
207        base_url: &Url,
208        client_id: &str,
209        device_instance_id: Option<&str>,
210    ) -> Result<Token, AuthError> {
211        let token_url = base_url.join("oauth/token")?;
212
213        tracing::debug!(url = %token_url, "refreshing token");
214
215        let resp = http_client()
216            .post(token_url)
217            .form(&RefreshRequest {
218                grant_type: "refresh_token",
219                client_id,
220                refresh_token: refresh_token.as_str(),
221                device_instance_id,
222            })
223            .send()
224            .await?;
225
226        if !resp.status().is_success() {
227            let err: RefreshErrorResponse = resp.json().await?;
228            tracing::debug!(error = %err.error, "token refresh failed");
229            return Err(match err.error.as_str() {
230                "invalid_grant" => AuthError::InvalidGrant,
231                "invalid_client" => AuthError::InvalidClient,
232                "access_denied" => AuthError::AccessDenied,
233                _ => AuthError::Server(err.error_description),
234            });
235        }
236
237        let token_resp: RefreshResponse = resp.json().await?;
238        let now = SystemTime::now()
239            .duration_since(UNIX_EPOCH)
240            .unwrap_or_default()
241            .as_secs();
242
243        Ok(Token {
244            access_token: token_resp.access_token,
245            token_type: token_resp.token_type,
246            expires_at: now + token_resp.expires_in,
247            refresh_token: token_resp.refresh_token,
248            region: None,
249            client_id: None,
250            // TODO(CIP-2793): The server should include device_instance_id in the
251            // refresh response. Until then, callers (e.g. OAuthRefresher) must
252            // re-attach it manually after refresh.
253            device_instance_id: None,
254        })
255    }
256}
257
258#[derive(serde::Serialize)]
259struct RefreshRequest<'a> {
260    grant_type: &'a str,
261    client_id: &'a str,
262    refresh_token: &'a str,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    device_instance_id: Option<&'a str>,
265}
266
267#[derive(serde::Deserialize)]
268struct RefreshResponse {
269    access_token: SecretToken,
270    token_type: String,
271    expires_in: u64,
272    #[serde(default)]
273    refresh_token: Option<SecretToken>,
274}
275
276#[derive(serde::Deserialize)]
277struct RefreshErrorResponse {
278    error: String,
279    #[serde(default)]
280    error_description: String,
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::AuthError;
287    use mocktail::prelude::*;
288
289    fn make_token(expires_in: u64, refresh: bool) -> Token {
290        let now = SystemTime::now()
291            .duration_since(UNIX_EPOCH)
292            .unwrap()
293            .as_secs();
294
295        Token {
296            access_token: SecretToken::new("test-access-token"),
297            token_type: "Bearer".to_string(),
298            expires_at: now + expires_in,
299            refresh_token: if refresh {
300                Some(SecretToken::new("test-refresh-token"))
301            } else {
302                None
303            },
304            region: None,
305            client_id: None,
306            device_instance_id: None,
307        }
308    }
309
310    fn refresh_response_json() -> serde_json::Value {
311        serde_json::json!({
312            "access_token": "new-access-token",
313            "token_type": "Bearer",
314            "expires_in": 3600,
315            "refresh_token": "new-refresh-token"
316        })
317    }
318
319    fn error_json(error: &str) -> serde_json::Value {
320        serde_json::json!({
321            "error": error,
322            "error_description": format!("{error} occurred")
323        })
324    }
325
326    async fn start_server(mocks: MockSet) -> MockServer {
327        let server = MockServer::new_http("token-refresh-test").with_mocks(mocks);
328        server.start().await.unwrap();
329        server
330    }
331
332    #[test]
333    fn test_secret_token_debug_does_not_leak() {
334        let token = SecretToken("super_secret_value".to_string());
335        let debug = format!("{:?}", token);
336        assert!(
337            !debug.contains("super_secret_value"),
338            "SecretToken Debug should not contain the secret, got: {debug}"
339        );
340    }
341
342    // ---- refresh() tests ----
343
344    #[tokio::test]
345    async fn test_refresh_success() {
346        let mut mocks = MockSet::new();
347        mocks.mock(|when, then| {
348            when.post().path("/oauth/token");
349            then.json(refresh_response_json());
350        });
351        let server = start_server(mocks).await;
352        let base_url = server.url("");
353
354        let refresh_token = SecretToken::new("test-refresh-token");
355        let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
356            .await
357            .unwrap();
358
359        assert_eq!(refreshed.access_token().as_str(), "new-access-token");
360        assert_eq!(refreshed.token_type(), "Bearer");
361        assert_eq!(
362            refreshed.refresh_token().unwrap().as_str(),
363            "new-refresh-token"
364        );
365        assert!(!refreshed.is_expired());
366        assert!((3598..=3600).contains(&refreshed.expires_in()));
367    }
368
369    #[tokio::test]
370    async fn test_refresh_invalid_grant() {
371        let mut mocks = MockSet::new();
372        mocks.mock(|when, then| {
373            when.post().path("/oauth/token");
374            then.bad_request().json(error_json("invalid_grant"));
375        });
376        let server = start_server(mocks).await;
377        let base_url = server.url("");
378
379        let refresh_token = SecretToken::new("test-refresh-token");
380        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
381            .await
382            .unwrap_err();
383
384        assert!(matches!(err, AuthError::InvalidGrant));
385    }
386
387    #[tokio::test]
388    async fn test_refresh_invalid_client() {
389        let mut mocks = MockSet::new();
390        mocks.mock(|when, then| {
391            when.post().path("/oauth/token");
392            then.bad_request().json(error_json("invalid_client"));
393        });
394        let server = start_server(mocks).await;
395        let base_url = server.url("");
396
397        let refresh_token = SecretToken::new("test-refresh-token");
398        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
399            .await
400            .unwrap_err();
401
402        assert!(matches!(err, AuthError::InvalidClient));
403    }
404
405    #[tokio::test]
406    async fn test_refresh_access_denied() {
407        let mut mocks = MockSet::new();
408        mocks.mock(|when, then| {
409            when.post().path("/oauth/token");
410            then.bad_request().json(error_json("access_denied"));
411        });
412        let server = start_server(mocks).await;
413        let base_url = server.url("");
414
415        let refresh_token = SecretToken::new("test-refresh-token");
416        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
417            .await
418            .unwrap_err();
419
420        assert!(matches!(err, AuthError::AccessDenied));
421    }
422
423    #[tokio::test]
424    async fn test_refresh_unknown_error() {
425        let mut mocks = MockSet::new();
426        mocks.mock(|when, then| {
427            when.post().path("/oauth/token");
428            then.bad_request().json(error_json("something_unexpected"));
429        });
430        let server = start_server(mocks).await;
431        let base_url = server.url("");
432
433        let refresh_token = SecretToken::new("test-refresh-token");
434        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
435            .await
436            .unwrap_err();
437
438        assert!(matches!(&err, AuthError::Server(desc) if desc == "something_unexpected occurred"));
439    }
440
441    #[tokio::test]
442    async fn test_refresh_response_without_new_refresh_token() {
443        let mut mocks = MockSet::new();
444        mocks.mock(|when, then| {
445            when.post().path("/oauth/token");
446            then.json(serde_json::json!({
447                "access_token": "new-access-token",
448                "token_type": "Bearer",
449                "expires_in": 3600
450            }));
451        });
452        let server = start_server(mocks).await;
453        let base_url = server.url("");
454
455        let refresh_token = SecretToken::new("test-refresh-token");
456        let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
457            .await
458            .unwrap();
459
460        assert_eq!(refreshed.access_token().as_str(), "new-access-token");
461        assert!(refreshed.refresh_token().is_none());
462    }
463
464    #[tokio::test]
465    async fn test_refresh_debug_does_not_leak_tokens() {
466        let token = make_token(3600, true);
467        let debug = format!("{:?}", token);
468        assert!(
469            !debug.contains("test-access-token"),
470            "Debug output should not contain access token, got: {debug}"
471        );
472        assert!(
473            !debug.contains("test-refresh-token"),
474            "Debug output should not contain refresh token, got: {debug}"
475        );
476    }
477
478    // ---- decode_claims / workspace_id / issuer tests ----
479
480    /// Build a Token whose access_token is a real (unsigned) JWT containing the
481    /// given claims JSON.
482    fn make_jwt_token(claims_json: serde_json::Value) -> Token {
483        use jsonwebtoken::{encode, EncodingKey, Header};
484        let jwt = encode(
485            &Header::default(),
486            &claims_json,
487            &EncodingKey::from_secret(b"test-secret"),
488        )
489        .expect("failed to encode JWT");
490
491        let now = SystemTime::now()
492            .duration_since(UNIX_EPOCH)
493            .unwrap()
494            .as_secs();
495
496        Token {
497            access_token: SecretToken::new(jwt),
498            token_type: "Bearer".to_string(),
499            expires_at: now + 3600,
500            refresh_token: None,
501            region: None,
502            client_id: None,
503            device_instance_id: None,
504        }
505    }
506
507    fn valid_claims_json() -> serde_json::Value {
508        serde_json::json!({
509            "workspace": "7366ITCXSAPCH5TN",
510            "iss": "https://cts.example.com",
511            "sub": "user-123",
512            "aud": "https://cts.example.com",
513            "iat": 1700000000u64,
514            "exp": 1700003600u64,
515            "scope": "dataset:create"
516        })
517    }
518
519    #[test]
520    fn test_workspace_id_extracts_from_jwt() {
521        let token = make_jwt_token(valid_claims_json());
522        let ws = token.workspace_id().expect("should extract workspace ID");
523        assert_eq!(ws.to_string(), "7366ITCXSAPCH5TN");
524    }
525
526    #[test]
527    fn test_issuer_extracts_url_from_jwt() {
528        let token = make_jwt_token(valid_claims_json());
529        let issuer = token.issuer().expect("should extract issuer");
530        assert_eq!(issuer.as_str(), "https://cts.example.com/");
531    }
532
533    #[test]
534    fn test_workspace_id_fails_on_invalid_jwt() {
535        let token = Token {
536            access_token: SecretToken::new("not-a-jwt"),
537            token_type: "Bearer".to_string(),
538            expires_at: 0,
539            refresh_token: None,
540            region: None,
541            client_id: None,
542            device_instance_id: None,
543        };
544        let err = token.workspace_id().unwrap_err();
545        assert!(matches!(err, AuthError::InvalidToken(_)));
546    }
547
548    #[test]
549    fn test_issuer_fails_on_missing_claims() {
550        let token = make_jwt_token(serde_json::json!({"sub": "user-123"}));
551        let err = token.issuer().unwrap_err();
552        assert!(matches!(err, AuthError::InvalidToken(_)));
553    }
554
555    #[test]
556    fn test_workspace_crn_derives_from_region_and_workspace() {
557        let mut token = make_jwt_token(valid_claims_json());
558        token.set_region("ap-southeast-2.aws");
559        let crn = token.workspace_crn().expect("should derive workspace CRN");
560        assert_eq!(crn.to_string(), "crn:ap-southeast-2.aws:7366ITCXSAPCH5TN");
561    }
562
563    #[test]
564    fn test_workspace_crn_fails_without_region() {
565        let token = make_jwt_token(valid_claims_json());
566        let err = token.workspace_crn().unwrap_err();
567        assert!(matches!(err, AuthError::NotAuthenticated));
568    }
569
570    #[test]
571    fn test_workspace_crn_fails_with_invalid_region() {
572        let mut token = make_jwt_token(valid_claims_json());
573        token.set_region("invalid-region");
574        let err = token.workspace_crn().unwrap_err();
575        assert!(matches!(err, AuthError::Server(_)));
576    }
577}