1use std::time::{SystemTime, UNIX_EPOCH};
2
3use cts_common::claims::Claims;
4use cts_common::{Crn, Region, WorkspaceId};
5use url::Url;
6
7use crate::{http_client, AuthError, SecretToken};
8
9impl stack_profile::ProfileData for Token {
10 const FILENAME: &'static str = "auth.json";
11 const MODE: Option<u32> = Some(0o600);
12}
13
14const EXPIRY_LEEWAY_SECS: u64 = 90;
20
21#[derive(Debug, serde::Serialize, serde::Deserialize)]
26pub struct Token {
27 pub(crate) access_token: SecretToken,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub(crate) refresh_token: Option<SecretToken>,
30 pub(crate) token_type: String,
31 pub(crate) expires_at: u64,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub(crate) region: Option<String>,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub(crate) client_id: Option<String>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
37 pub(crate) device_instance_id: Option<String>,
38}
39
40impl Token {
41 pub fn access_token(&self) -> &SecretToken {
46 &self.access_token
47 }
48
49 pub fn token_type(&self) -> &str {
51 &self.token_type
52 }
53
54 pub fn expires_at(&self) -> u64 {
56 self.expires_at
57 }
58
59 pub fn expires_in(&self) -> u64 {
61 let now = SystemTime::now()
62 .duration_since(UNIX_EPOCH)
63 .unwrap_or_default()
64 .as_secs();
65 self.expires_at.saturating_sub(now)
66 }
67
68 pub fn is_expired(&self) -> bool {
77 let now = SystemTime::now()
78 .duration_since(UNIX_EPOCH)
79 .unwrap_or_default()
80 .as_secs();
81 now + EXPIRY_LEEWAY_SECS >= self.expires_at
82 }
83
84 pub fn is_usable(&self) -> bool {
89 let now = SystemTime::now()
90 .duration_since(UNIX_EPOCH)
91 .unwrap_or_default()
92 .as_secs();
93 now < self.expires_at
94 }
95
96 pub fn refresh_token(&self) -> Option<&SecretToken> {
98 self.refresh_token.as_ref()
99 }
100
101 pub fn take_refresh_token(&mut self) -> Option<SecretToken> {
103 self.refresh_token.take()
104 }
105
106 pub fn region(&self) -> Option<&str> {
108 self.region.as_deref()
109 }
110
111 pub fn client_id(&self) -> Option<&str> {
113 self.client_id.as_deref()
114 }
115
116 pub(crate) fn set_region(&mut self, region: impl Into<String>) {
118 self.region = Some(region.into());
119 }
120
121 pub(crate) fn set_client_id(&mut self, client_id: impl Into<String>) {
123 self.client_id = Some(client_id.into());
124 }
125
126 pub fn device_instance_id(&self) -> Option<&str> {
128 self.device_instance_id.as_deref()
129 }
130
131 pub(crate) fn set_device_instance_id(&mut self, id: impl Into<String>) {
133 self.device_instance_id = Some(id.into());
134 }
135
136 pub fn workspace_id(&self) -> Result<WorkspaceId, AuthError> {
141 self.decode_claims().map(|c| c.workspace)
142 }
143
144 pub fn workspace_crn(&self) -> Result<Crn, AuthError> {
149 let workspace_id = self.workspace_id()?;
150 let region: Region = self
151 .region()
152 .ok_or(AuthError::NotAuthenticated)?
153 .parse()
154 .map_err(|e: cts_common::RegionError| AuthError::Server(e.to_string()))?;
155 Ok(Crn::new(region, workspace_id))
156 }
157
158 pub fn issuer(&self) -> Result<Url, AuthError> {
163 let claims = self.decode_claims()?;
164 claims.iss.parse().map_err(AuthError::from)
165 }
166
167 fn decode_claims(&self) -> Result<Claims, AuthError> {
172 use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
173 use std::collections::HashSet;
174
175 let token_str = self.access_token.as_str();
176 let header = decode_header(token_str)
177 .map_err(|e| AuthError::InvalidToken(format!("invalid JWT header: {e}")))?;
178
179 let dummy_key = DecodingKey::from_secret(&[]);
180 let mut validation = Validation::new(header.alg);
181 validation.validate_exp = false;
182 validation.validate_aud = false;
183 validation.required_spec_claims = HashSet::new();
184 validation.insecure_disable_signature_validation();
185
186 decode(token_str, &dummy_key, &validation)
187 .map(|data| data.claims)
188 .map_err(|e| AuthError::InvalidToken(format!("failed to decode JWT claims: {e}")))
189 }
190
191 pub async fn refresh(
206 refresh_token: &SecretToken,
207 base_url: &Url,
208 client_id: &str,
209 device_instance_id: Option<&str>,
210 ) -> Result<Token, AuthError> {
211 let token_url = base_url.join("oauth/token")?;
212
213 tracing::debug!(url = %token_url, "refreshing token");
214
215 let resp = http_client()
216 .post(token_url)
217 .form(&RefreshRequest {
218 grant_type: "refresh_token",
219 client_id,
220 refresh_token: refresh_token.as_str(),
221 device_instance_id,
222 })
223 .send()
224 .await?;
225
226 if !resp.status().is_success() {
227 let err: RefreshErrorResponse = resp.json().await?;
228 tracing::debug!(error = %err.error, "token refresh failed");
229 return Err(match err.error.as_str() {
230 "invalid_grant" => AuthError::InvalidGrant,
231 "invalid_client" => AuthError::InvalidClient,
232 "access_denied" => AuthError::AccessDenied,
233 _ => AuthError::Server(err.error_description),
234 });
235 }
236
237 let token_resp: RefreshResponse = resp.json().await?;
238 let now = SystemTime::now()
239 .duration_since(UNIX_EPOCH)
240 .unwrap_or_default()
241 .as_secs();
242
243 Ok(Token {
244 access_token: token_resp.access_token,
245 token_type: token_resp.token_type,
246 expires_at: now + token_resp.expires_in,
247 refresh_token: token_resp.refresh_token,
248 region: None,
249 client_id: None,
250 device_instance_id: None,
254 })
255 }
256}
257
258#[derive(serde::Serialize)]
259struct RefreshRequest<'a> {
260 grant_type: &'a str,
261 client_id: &'a str,
262 refresh_token: &'a str,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 device_instance_id: Option<&'a str>,
265}
266
267#[derive(serde::Deserialize)]
268struct RefreshResponse {
269 access_token: SecretToken,
270 token_type: String,
271 expires_in: u64,
272 #[serde(default)]
273 refresh_token: Option<SecretToken>,
274}
275
276#[derive(serde::Deserialize)]
277struct RefreshErrorResponse {
278 error: String,
279 #[serde(default)]
280 error_description: String,
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::AuthError;
287 use mocktail::prelude::*;
288
289 fn make_token(expires_in: u64, refresh: bool) -> Token {
290 let now = SystemTime::now()
291 .duration_since(UNIX_EPOCH)
292 .unwrap()
293 .as_secs();
294
295 Token {
296 access_token: SecretToken::new("test-access-token"),
297 token_type: "Bearer".to_string(),
298 expires_at: now + expires_in,
299 refresh_token: if refresh {
300 Some(SecretToken::new("test-refresh-token"))
301 } else {
302 None
303 },
304 region: None,
305 client_id: None,
306 device_instance_id: None,
307 }
308 }
309
310 fn refresh_response_json() -> serde_json::Value {
311 serde_json::json!({
312 "access_token": "new-access-token",
313 "token_type": "Bearer",
314 "expires_in": 3600,
315 "refresh_token": "new-refresh-token"
316 })
317 }
318
319 fn error_json(error: &str) -> serde_json::Value {
320 serde_json::json!({
321 "error": error,
322 "error_description": format!("{error} occurred")
323 })
324 }
325
326 async fn start_server(mocks: MockSet) -> MockServer {
327 let server = MockServer::new_http("token-refresh-test").with_mocks(mocks);
328 server.start().await.unwrap();
329 server
330 }
331
332 #[test]
333 fn test_secret_token_debug_does_not_leak() {
334 let token = SecretToken("super_secret_value".to_string());
335 let debug = format!("{:?}", token);
336 assert!(
337 !debug.contains("super_secret_value"),
338 "SecretToken Debug should not contain the secret, got: {debug}"
339 );
340 }
341
342 #[tokio::test]
345 async fn test_refresh_success() {
346 let mut mocks = MockSet::new();
347 mocks.mock(|when, then| {
348 when.post().path("/oauth/token");
349 then.json(refresh_response_json());
350 });
351 let server = start_server(mocks).await;
352 let base_url = server.url("");
353
354 let refresh_token = SecretToken::new("test-refresh-token");
355 let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
356 .await
357 .unwrap();
358
359 assert_eq!(refreshed.access_token().as_str(), "new-access-token");
360 assert_eq!(refreshed.token_type(), "Bearer");
361 assert_eq!(
362 refreshed.refresh_token().unwrap().as_str(),
363 "new-refresh-token"
364 );
365 assert!(!refreshed.is_expired());
366 assert!((3598..=3600).contains(&refreshed.expires_in()));
367 }
368
369 #[tokio::test]
370 async fn test_refresh_invalid_grant() {
371 let mut mocks = MockSet::new();
372 mocks.mock(|when, then| {
373 when.post().path("/oauth/token");
374 then.bad_request().json(error_json("invalid_grant"));
375 });
376 let server = start_server(mocks).await;
377 let base_url = server.url("");
378
379 let refresh_token = SecretToken::new("test-refresh-token");
380 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
381 .await
382 .unwrap_err();
383
384 assert!(matches!(err, AuthError::InvalidGrant));
385 }
386
387 #[tokio::test]
388 async fn test_refresh_invalid_client() {
389 let mut mocks = MockSet::new();
390 mocks.mock(|when, then| {
391 when.post().path("/oauth/token");
392 then.bad_request().json(error_json("invalid_client"));
393 });
394 let server = start_server(mocks).await;
395 let base_url = server.url("");
396
397 let refresh_token = SecretToken::new("test-refresh-token");
398 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
399 .await
400 .unwrap_err();
401
402 assert!(matches!(err, AuthError::InvalidClient));
403 }
404
405 #[tokio::test]
406 async fn test_refresh_access_denied() {
407 let mut mocks = MockSet::new();
408 mocks.mock(|when, then| {
409 when.post().path("/oauth/token");
410 then.bad_request().json(error_json("access_denied"));
411 });
412 let server = start_server(mocks).await;
413 let base_url = server.url("");
414
415 let refresh_token = SecretToken::new("test-refresh-token");
416 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
417 .await
418 .unwrap_err();
419
420 assert!(matches!(err, AuthError::AccessDenied));
421 }
422
423 #[tokio::test]
424 async fn test_refresh_unknown_error() {
425 let mut mocks = MockSet::new();
426 mocks.mock(|when, then| {
427 when.post().path("/oauth/token");
428 then.bad_request().json(error_json("something_unexpected"));
429 });
430 let server = start_server(mocks).await;
431 let base_url = server.url("");
432
433 let refresh_token = SecretToken::new("test-refresh-token");
434 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
435 .await
436 .unwrap_err();
437
438 assert!(matches!(&err, AuthError::Server(desc) if desc == "something_unexpected occurred"));
439 }
440
441 #[tokio::test]
442 async fn test_refresh_response_without_new_refresh_token() {
443 let mut mocks = MockSet::new();
444 mocks.mock(|when, then| {
445 when.post().path("/oauth/token");
446 then.json(serde_json::json!({
447 "access_token": "new-access-token",
448 "token_type": "Bearer",
449 "expires_in": 3600
450 }));
451 });
452 let server = start_server(mocks).await;
453 let base_url = server.url("");
454
455 let refresh_token = SecretToken::new("test-refresh-token");
456 let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
457 .await
458 .unwrap();
459
460 assert_eq!(refreshed.access_token().as_str(), "new-access-token");
461 assert!(refreshed.refresh_token().is_none());
462 }
463
464 #[tokio::test]
465 async fn test_refresh_debug_does_not_leak_tokens() {
466 let token = make_token(3600, true);
467 let debug = format!("{:?}", token);
468 assert!(
469 !debug.contains("test-access-token"),
470 "Debug output should not contain access token, got: {debug}"
471 );
472 assert!(
473 !debug.contains("test-refresh-token"),
474 "Debug output should not contain refresh token, got: {debug}"
475 );
476 }
477
478 fn make_jwt_token(claims_json: serde_json::Value) -> Token {
483 use jsonwebtoken::{encode, EncodingKey, Header};
484 let jwt = encode(
485 &Header::default(),
486 &claims_json,
487 &EncodingKey::from_secret(b"test-secret"),
488 )
489 .expect("failed to encode JWT");
490
491 let now = SystemTime::now()
492 .duration_since(UNIX_EPOCH)
493 .unwrap()
494 .as_secs();
495
496 Token {
497 access_token: SecretToken::new(jwt),
498 token_type: "Bearer".to_string(),
499 expires_at: now + 3600,
500 refresh_token: None,
501 region: None,
502 client_id: None,
503 device_instance_id: None,
504 }
505 }
506
507 fn valid_claims_json() -> serde_json::Value {
508 serde_json::json!({
509 "workspace": "7366ITCXSAPCH5TN",
510 "iss": "https://cts.example.com",
511 "sub": "user-123",
512 "aud": "https://cts.example.com",
513 "iat": 1700000000u64,
514 "exp": 1700003600u64,
515 "scope": "dataset:create"
516 })
517 }
518
519 #[test]
520 fn test_workspace_id_extracts_from_jwt() {
521 let token = make_jwt_token(valid_claims_json());
522 let ws = token.workspace_id().expect("should extract workspace ID");
523 assert_eq!(ws.to_string(), "7366ITCXSAPCH5TN");
524 }
525
526 #[test]
527 fn test_issuer_extracts_url_from_jwt() {
528 let token = make_jwt_token(valid_claims_json());
529 let issuer = token.issuer().expect("should extract issuer");
530 assert_eq!(issuer.as_str(), "https://cts.example.com/");
531 }
532
533 #[test]
534 fn test_workspace_id_fails_on_invalid_jwt() {
535 let token = Token {
536 access_token: SecretToken::new("not-a-jwt"),
537 token_type: "Bearer".to_string(),
538 expires_at: 0,
539 refresh_token: None,
540 region: None,
541 client_id: None,
542 device_instance_id: None,
543 };
544 let err = token.workspace_id().unwrap_err();
545 assert!(matches!(err, AuthError::InvalidToken(_)));
546 }
547
548 #[test]
549 fn test_issuer_fails_on_missing_claims() {
550 let token = make_jwt_token(serde_json::json!({"sub": "user-123"}));
551 let err = token.issuer().unwrap_err();
552 assert!(matches!(err, AuthError::InvalidToken(_)));
553 }
554
555 #[test]
556 fn test_workspace_crn_derives_from_region_and_workspace() {
557 let mut token = make_jwt_token(valid_claims_json());
558 token.set_region("ap-southeast-2.aws");
559 let crn = token.workspace_crn().expect("should derive workspace CRN");
560 assert_eq!(crn.to_string(), "crn:ap-southeast-2.aws:7366ITCXSAPCH5TN");
561 }
562
563 #[test]
564 fn test_workspace_crn_fails_without_region() {
565 let token = make_jwt_token(valid_claims_json());
566 let err = token.workspace_crn().unwrap_err();
567 assert!(matches!(err, AuthError::NotAuthenticated));
568 }
569
570 #[test]
571 fn test_workspace_crn_fails_with_invalid_region() {
572 let mut token = make_jwt_token(valid_claims_json());
573 token.set_region("invalid-region");
574 let err = token.workspace_crn().unwrap_err();
575 assert!(matches!(err, AuthError::Server(_)));
576 }
577}