1mod jws;
74
75use crate::build_errors::Error as BuilderError;
76use crate::constants::DEFAULT_SCOPE;
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::{self, CredentialsError};
80use crate::headers_util::build_cacheable_headers;
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use crate::{BuildResult, Result};
84use async_trait::async_trait;
85use http::{Extensions, HeaderMap};
86use jws::{CLOCK_SKEW_FUDGE, DEFAULT_TOKEN_TIMEOUT, JwsClaims, JwsHeader};
87use rustls::crypto::CryptoProvider;
88use rustls::sign::Signer;
89use rustls_pemfile::Item;
90use serde_json::Value;
91use std::sync::Arc;
92use time::OffsetDateTime;
93use tokio::time::Instant;
94
95#[derive(Clone, Debug, PartialEq)]
104pub enum AccessSpecifier {
105 Audience(String),
112
113 Scopes(Vec<String>),
132}
133
134impl AccessSpecifier {
135 fn audience(&self) -> Option<&String> {
136 match self {
137 AccessSpecifier::Audience(aud) => Some(aud),
138 AccessSpecifier::Scopes(_) => None,
139 }
140 }
141
142 fn scopes(&self) -> Option<&[String]> {
143 match self {
144 AccessSpecifier::Scopes(scopes) => Some(scopes),
145 AccessSpecifier::Audience(_) => None,
146 }
147 }
148
149 pub fn from_scopes<I, S>(scopes: I) -> Self
163 where
164 I: IntoIterator<Item = S>,
165 S: Into<String>,
166 {
167 AccessSpecifier::Scopes(scopes.into_iter().map(|s| s.into()).collect())
168 }
169
170 pub fn from_audience<S: Into<String>>(audience: S) -> Self {
184 AccessSpecifier::Audience(audience.into())
185 }
186}
187
188pub struct Builder {
207 service_account_key: Value,
208 access_specifier: AccessSpecifier,
209 quota_project_id: Option<String>,
210}
211
212impl Builder {
213 pub fn new(service_account_key: Value) -> Self {
220 Self {
221 service_account_key,
222 access_specifier: AccessSpecifier::Scopes([DEFAULT_SCOPE].map(str::to_string).to_vec()),
223 quota_project_id: None,
224 }
225 }
226
227 pub fn with_access_specifier(mut self, access_specifier: AccessSpecifier) -> Self {
249 self.access_specifier = access_specifier;
250 self
251 }
252
253 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
262 self.quota_project_id = Some(quota_project_id.into());
263 self
264 }
265
266 fn build_token_provider(self) -> BuildResult<ServiceAccountTokenProvider> {
267 let service_account_key =
268 serde_json::from_value::<ServiceAccountKey>(self.service_account_key)
269 .map_err(BuilderError::parsing)?;
270
271 Ok(ServiceAccountTokenProvider {
272 service_account_key,
273 access_specifier: self.access_specifier,
274 })
275 }
276
277 pub fn build(self) -> BuildResult<Credentials> {
290 Ok(Credentials {
291 inner: Arc::new(ServiceAccountCredentials {
292 quota_project_id: self.quota_project_id.clone(),
293 token_provider: TokenCache::new(self.build_token_provider()?),
294 }),
295 })
296 }
297}
298
299#[derive(serde::Deserialize, Default, Clone)]
303struct ServiceAccountKey {
304 client_email: String,
307 private_key_id: String,
309 private_key: String,
312 project_id: String,
314 universe_domain: Option<String>,
316}
317
318impl std::fmt::Debug for ServiceAccountKey {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 f.debug_struct("ServiceAccountKey")
321 .field("client_email", &self.client_email)
322 .field("private_key_id", &self.private_key_id)
323 .field("private_key", &"[censored]")
324 .field("project_id", &self.project_id)
325 .field("universe_domain", &self.universe_domain)
326 .finish()
327 }
328}
329
330#[derive(Debug)]
331struct ServiceAccountCredentials<T>
332where
333 T: CachedTokenProvider,
334{
335 token_provider: T,
336 quota_project_id: Option<String>,
337}
338
339#[derive(Debug)]
340struct ServiceAccountTokenProvider {
341 service_account_key: ServiceAccountKey,
342 access_specifier: AccessSpecifier,
343}
344
345fn token_issue_time(current_time: OffsetDateTime) -> OffsetDateTime {
346 current_time - CLOCK_SKEW_FUDGE
347}
348
349fn token_expiry_time(current_time: OffsetDateTime) -> OffsetDateTime {
350 current_time + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT
351}
352
353#[async_trait]
354impl TokenProvider for ServiceAccountTokenProvider {
355 async fn token(&self) -> Result<Token> {
356 let expires_at = Instant::now() + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT;
357 let tg = ServiceAccountTokenGenerator {
358 audience: self.access_specifier.audience().cloned(),
359 scopes: self
360 .access_specifier
361 .scopes()
362 .map(|scopes| scopes.join(" ")),
363 service_account_key: self.service_account_key.clone(),
364 target_audience: None,
365 };
366
367 let token = tg.generate()?;
368
369 let token = Token {
370 token,
371 token_type: "Bearer".to_string(),
372 expires_at: Some(expires_at),
373 metadata: None,
374 };
375 Ok(token)
376 }
377}
378
379#[derive(Default, Clone)]
380struct ServiceAccountTokenGenerator {
381 service_account_key: ServiceAccountKey,
382 audience: Option<String>,
383 scopes: Option<String>,
384 target_audience: Option<String>,
385}
386
387impl ServiceAccountTokenGenerator {
388 fn generate(&self) -> Result<String> {
389 let signer = self.signer(&self.service_account_key.private_key)?;
390
391 let current_time = OffsetDateTime::now_utc();
395
396 let claims = JwsClaims {
397 iss: self.service_account_key.client_email.clone(),
398 scope: self.scopes.clone(),
399 target_audience: self.target_audience.clone(),
400 aud: self.audience.clone(),
401 exp: token_expiry_time(current_time),
402 iat: token_issue_time(current_time),
403 typ: None,
404 sub: Some(self.service_account_key.client_email.clone()),
405 };
406
407 let header = JwsHeader {
408 alg: "RS256",
409 typ: "JWT",
410 kid: &self.service_account_key.private_key_id,
411 };
412 let encoded_header_claims = format!("{}.{}", header.encode()?, claims.encode()?);
413 let sig = signer
414 .sign(encoded_header_claims.as_bytes())
415 .map_err(errors::non_retryable)?;
416 use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine as _};
417 let token = format!(
418 "{}.{}",
419 encoded_header_claims,
420 &BASE64_URL_SAFE_NO_PAD.encode(sig)
421 );
422
423 Ok(token)
424 }
425
426 fn signer(&self, private_key: &String) -> Result<Box<dyn Signer>> {
428 let key_provider = CryptoProvider::get_default().map_or_else(
429 || rustls::crypto::ring::default_provider().key_provider,
430 |p| p.key_provider,
431 );
432
433 let private_key = rustls_pemfile::read_one(&mut private_key.as_bytes())
434 .map_err(errors::non_retryable)?
435 .ok_or_else(|| {
436 errors::non_retryable_from_str("missing PEM section in service account key")
437 })?;
438 let pk = match private_key {
439 Item::Pkcs8Key(item) => key_provider.load_private_key(item.into()),
440 other => {
441 return Err(Self::unexpected_private_key_error(other));
442 }
443 };
444 let sk = pk.map_err(errors::non_retryable)?;
445 sk.choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256])
446 .ok_or_else(|| errors::non_retryable_from_str("Unable to choose RSA_PKCS1_SHA256 signing scheme as it is not supported by current signer"))
447 }
448
449 fn unexpected_private_key_error(private_key_format: Item) -> CredentialsError {
450 errors::non_retryable_from_str(format!(
451 "expected key to be in form of PKCS8, found {private_key_format:?}",
452 ))
453 }
454}
455
456#[async_trait::async_trait]
457impl<T> CredentialsProvider for ServiceAccountCredentials<T>
458where
459 T: CachedTokenProvider,
460{
461 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
462 let token = self.token_provider.token(extensions).await?;
463 build_cacheable_headers(&token, &self.quota_project_id)
464 }
465}
466
467#[cfg(google_cloud_unstable_id_token)]
468pub mod idtoken {
469 use crate::Result;
470 use crate::build_errors::Error as BuilderError;
471 use crate::constants::{JWT_BEARER_GRANT_TYPE, OAUTH2_TOKEN_SERVER_URL};
472 use crate::credentials::CacheableResource;
473 use crate::credentials::idtoken::dynamic::IDTokenCredentialsProvider;
474 use crate::credentials::service_account::{ServiceAccountKey, ServiceAccountTokenGenerator};
475 use crate::token::{CachedTokenProvider, Token, TokenProvider};
476 use crate::token_cache::TokenCache;
477 use crate::{BuildResult, credentials::idtoken::IDTokenCredentials};
478 use async_trait::async_trait;
479 use gax::error::CredentialsError;
480 use http::Extensions;
481 use reqwest::Client;
482 use serde_json::Value;
483 use std::sync::Arc;
484
485 #[derive(Debug)]
486 struct ServiceAccountCredentials<T>
487 where
488 T: CachedTokenProvider,
489 {
490 token_provider: T,
491 }
492
493 #[async_trait]
494 impl<T> IDTokenCredentialsProvider for ServiceAccountCredentials<T>
495 where
496 T: CachedTokenProvider,
497 {
498 async fn id_token(&self) -> Result<String> {
499 let cached_token = self.token_provider.token(Extensions::new()).await?;
500 match cached_token {
501 CacheableResource::New { data, .. } => Ok(data.token),
502 CacheableResource::NotModified => {
503 Err(CredentialsError::from_msg(false, "failed to fetch token"))
504 }
505 }
506 }
507 }
508
509 #[derive(Debug)]
510 struct ServiceAccountTokenProvider {
511 service_account_key: ServiceAccountKey,
512 audience: String,
513 target_audience: String,
514 token_server_url: String,
515 }
516
517 #[async_trait]
518 impl TokenProvider for ServiceAccountTokenProvider {
519 async fn token(&self) -> Result<Token> {
520 let audience = self.audience.clone();
521 let target_audience = Some(self.target_audience.clone());
522 let service_account_key = self.service_account_key.clone();
523 let tg = ServiceAccountTokenGenerator {
524 audience: Some(audience),
525 service_account_key,
526 target_audience,
527 scopes: None,
528 };
529 let assertion = tg.generate()?;
530
531 let client = Client::new();
532 let request = client.post(&self.token_server_url).form(&[
533 ("grant_type", JWT_BEARER_GRANT_TYPE.to_string()),
534 ("assertion", assertion),
535 ]);
536
537 let response = request
538 .send()
539 .await
540 .map_err(|e| crate::errors::from_http_error(e, "failed to exchange id token"))?;
541
542 if !response.status().is_success() {
543 let err =
544 crate::errors::from_http_response(response, "failed to fetch id token").await;
545 return Err(err);
546 }
547
548 let token = response
549 .text()
550 .await
551 .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?;
552
553 Ok(Token {
554 token,
555 token_type: "Bearer".to_string(),
556 expires_at: None,
557 metadata: None,
558 })
559 }
560 }
561
562 pub struct Builder {
563 service_account_key: Value,
564 target_audience: String,
565 token_server_url: String,
566 }
567
568 impl Builder {
571 pub fn new<S: Into<String>>(target_audience: S, service_account_key: Value) -> Self {
575 Self {
576 service_account_key,
577 target_audience: target_audience.into(),
578 token_server_url: OAUTH2_TOKEN_SERVER_URL.to_string(),
579 }
580 }
581
582 #[cfg(test)]
583 pub(crate) fn with_token_server_url<S: Into<String>>(mut self, url: S) -> Self {
584 self.token_server_url = url.into();
585 self
586 }
587
588 fn build_token_provider(
589 self,
590 target_audience: String,
591 ) -> BuildResult<ServiceAccountTokenProvider> {
592 let service_account_key =
593 serde_json::from_value::<ServiceAccountKey>(self.service_account_key)
594 .map_err(BuilderError::parsing)?;
595 Ok(ServiceAccountTokenProvider {
596 service_account_key,
597 audience: OAUTH2_TOKEN_SERVER_URL.to_string(),
598 target_audience,
599 token_server_url: self.token_server_url,
600 })
601 }
602
603 pub fn build(self) -> BuildResult<IDTokenCredentials> {
606 let target_audience = self.target_audience.clone();
607 let creds = ServiceAccountCredentials {
608 token_provider: TokenCache::new(self.build_token_provider(target_audience)?),
609 };
610 Ok(IDTokenCredentials {
611 inner: Arc::new(creds),
612 })
613 }
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620 use crate::credentials::QUOTA_PROJECT_KEY;
621 use crate::credentials::tests::{
622 PKCS8_PK, b64_decode_to_json, get_headers_from_cache, get_token_from_headers,
623 };
624 use crate::token::tests::MockTokenProvider;
625 use http::HeaderValue;
626 use http::header::AUTHORIZATION;
627 use rsa::pkcs1::EncodeRsaPrivateKey;
628 use rsa::pkcs8::LineEnding;
629 use rustls_pemfile::Item;
630 use serde_json::Value;
631 use serde_json::json;
632 use std::error::Error as _;
633 use std::time::Duration;
634
635 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
636
637 const SSJ_REGEX: &str = r"(?<header>[^\.]+)\.(?<claims>[^\.]+)\.(?<sig>[^\.]+)";
638
639 #[test]
640 fn debug_token_provider() {
641 let expected = ServiceAccountKey {
642 client_email: "test-client-email".to_string(),
643 private_key_id: "test-private-key-id".to_string(),
644 private_key: "super-duper-secret-private-key".to_string(),
645 project_id: "test-project-id".to_string(),
646 universe_domain: Some("test-universe-domain".to_string()),
647 };
648 let fmt = format!("{expected:?}");
649 assert!(fmt.contains("test-client-email"), "{fmt}");
650 assert!(fmt.contains("test-private-key-id"), "{fmt}");
651 assert!(!fmt.contains("super-duper-secret-private-key"), "{fmt}");
652 assert!(fmt.contains("test-project-id"), "{fmt}");
653 assert!(fmt.contains("test-universe-domain"), "{fmt}");
654 }
655
656 #[test]
657 fn validate_token_issue_time() {
658 let current_time = OffsetDateTime::now_utc();
659 let token_issue_time = token_issue_time(current_time);
660 assert!(token_issue_time == current_time - CLOCK_SKEW_FUDGE);
661 }
662
663 #[test]
664 fn validate_token_expiry_time() {
665 let current_time = OffsetDateTime::now_utc();
666 let token_issue_time = token_expiry_time(current_time);
667 assert!(token_issue_time == current_time + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT);
668 }
669
670 #[tokio::test]
671 async fn headers_success_without_quota_project() -> TestResult {
672 let token = Token {
673 token: "test-token".to_string(),
674 token_type: "Bearer".to_string(),
675 expires_at: None,
676 metadata: None,
677 };
678
679 let mut mock = MockTokenProvider::new();
680 mock.expect_token().times(1).return_once(|| Ok(token));
681
682 let sac = ServiceAccountCredentials {
683 token_provider: TokenCache::new(mock),
684 quota_project_id: None,
685 };
686
687 let mut extensions = Extensions::new();
688 let cached_headers = sac.headers(extensions.clone()).await.unwrap();
689 let (headers, entity_tag) = match cached_headers {
690 CacheableResource::New { entity_tag, data } => (data, entity_tag),
691 CacheableResource::NotModified => unreachable!("expecting new headers"),
692 };
693 let token = headers.get(AUTHORIZATION).unwrap();
694
695 assert_eq!(headers.len(), 1, "{headers:?}");
696 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
697 assert!(token.is_sensitive());
698
699 extensions.insert(entity_tag);
700
701 let cached_headers = sac.headers(extensions).await?;
702
703 match cached_headers {
704 CacheableResource::New { .. } => unreachable!("expecting new headers"),
705 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
706 };
707 Ok(())
708 }
709
710 #[tokio::test]
711 async fn headers_success_with_quota_project() -> TestResult {
712 let token = Token {
713 token: "test-token".to_string(),
714 token_type: "Bearer".to_string(),
715 expires_at: None,
716 metadata: None,
717 };
718
719 let quota_project = "test-quota-project";
720
721 let mut mock = MockTokenProvider::new();
722 mock.expect_token().times(1).return_once(|| Ok(token));
723
724 let sac = ServiceAccountCredentials {
725 token_provider: TokenCache::new(mock),
726 quota_project_id: Some(quota_project.to_string()),
727 };
728
729 let headers = get_headers_from_cache(sac.headers(Extensions::new()).await.unwrap())?;
730 let token = headers.get(AUTHORIZATION).unwrap();
731 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
732
733 assert_eq!(headers.len(), 2, "{headers:?}");
734 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
735 assert!(token.is_sensitive());
736 assert_eq!(
737 quota_project_header,
738 HeaderValue::from_static(quota_project)
739 );
740 assert!(!quota_project_header.is_sensitive());
741 Ok(())
742 }
743
744 #[tokio::test]
745 async fn headers_failure() {
746 let mut mock = MockTokenProvider::new();
747 mock.expect_token()
748 .times(1)
749 .return_once(|| Err(errors::non_retryable_from_str("fail")));
750
751 let sac = ServiceAccountCredentials {
752 token_provider: TokenCache::new(mock),
753 quota_project_id: None,
754 };
755 assert!(sac.headers(Extensions::new()).await.is_err());
756 }
757
758 pub(crate) fn get_mock_service_key() -> Value {
759 json!({
760 "client_email": "test-client-email",
761 "private_key_id": "test-private-key-id",
762 "private_key": "",
763 "project_id": "test-project-id",
764 })
765 }
766
767 #[tokio::test]
768 async fn get_service_account_headers_pkcs1_private_key_failure() -> TestResult {
769 let mut service_account_key = get_mock_service_key();
770
771 let key = crate::credentials::tests::RSA_PRIVATE_KEY
772 .to_pkcs1_pem(LineEnding::LF)
773 .expect("Failed to encode key to PKCS#1 PEM")
774 .to_string();
775
776 service_account_key["private_key"] = Value::from(key);
777 let cred = Builder::new(service_account_key).build()?;
778 let expected_error_message = "expected key to be in form of PKCS8, found Pkcs1Key";
779 assert!(
780 cred.headers(Extensions::new())
781 .await
782 .is_err_and(|e| e.to_string().contains(expected_error_message))
783 );
784 Ok(())
785 }
786
787 #[tokio::test]
788 async fn get_service_account_token_pkcs8_key_success() -> TestResult {
789 let mut service_account_key = get_mock_service_key();
790 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
791 let tp = Builder::new(service_account_key.clone()).build_token_provider()?;
792
793 let token = tp.token().await?;
794 let re = regex::Regex::new(SSJ_REGEX).unwrap();
795 let captures = re.captures(&token.token).ok_or_else(|| {
796 format!(
797 r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {}"#,
798 token.token
799 )
800 })?;
801 let header = b64_decode_to_json(captures["header"].to_string());
802 assert_eq!(header["alg"], "RS256");
803 assert_eq!(header["typ"], "JWT");
804 assert_eq!(header["kid"], service_account_key["private_key_id"]);
805
806 let claims = b64_decode_to_json(captures["claims"].to_string());
807 assert_eq!(claims["iss"], service_account_key["client_email"]);
808 assert_eq!(claims["scope"], DEFAULT_SCOPE);
809 assert!(claims["iat"].is_number());
810 assert!(claims["exp"].is_number());
811 assert_eq!(claims["sub"], service_account_key["client_email"]);
812
813 Ok(())
814 }
815
816 #[tokio::test]
817 async fn header_caching() -> TestResult {
818 let private_key = PKCS8_PK.clone();
819
820 let json_value = json!({
821 "client_email": "test-client-email",
822 "private_key_id": "test-private-key-id",
823 "private_key": private_key,
824 "project_id": "test-project-id",
825 "universe_domain": "test-universe-domain"
826 });
827
828 let credentials = Builder::new(json_value).build()?;
829
830 let headers = credentials.headers(Extensions::new()).await?;
831
832 let re = regex::Regex::new(SSJ_REGEX).unwrap();
833 let token = get_token_from_headers(headers).unwrap();
834
835 let captures = re.captures(&token).unwrap();
836
837 let claims = b64_decode_to_json(captures["claims"].to_string());
838 let first_iat = claims["iat"].as_i64().unwrap();
839
840 std::thread::sleep(Duration::from_secs(1));
845
846 let token = get_token_from_headers(credentials.headers(Extensions::new()).await?).unwrap();
848 let captures = re.captures(&token).unwrap();
849
850 let claims = b64_decode_to_json(captures["claims"].to_string());
851 let second_iat = claims["iat"].as_i64().unwrap();
852
853 assert_eq!(first_iat, second_iat);
856
857 Ok(())
858 }
859
860 #[tokio::test]
861 async fn get_service_account_headers_invalid_key_failure() -> TestResult {
862 let mut service_account_key = get_mock_service_key();
863 let pem_data = "-----BEGIN PRIVATE KEY-----\nMIGkAg==\n-----END PRIVATE KEY-----";
864 service_account_key["private_key"] = Value::from(pem_data);
865 let cred = Builder::new(service_account_key).build()?;
866
867 let token = cred.headers(Extensions::new()).await;
868 let err = token.unwrap_err();
869 assert!(!err.is_transient(), "{err:?}");
870 let source = err.source().and_then(|e| e.downcast_ref::<rustls::Error>());
871 assert!(matches!(source, Some(rustls::Error::General(_))), "{err:?}");
872 Ok(())
873 }
874
875 #[tokio::test]
876 async fn get_service_account_invalid_json_failure() -> TestResult {
877 let service_account_key = Value::from(" ");
878 let e = Builder::new(service_account_key).build().unwrap_err();
879 assert!(e.is_parsing(), "{e:?}");
880
881 Ok(())
882 }
883
884 #[test]
885 fn signer_failure() -> TestResult {
886 let tp = Builder::new(get_mock_service_key()).build_token_provider()?;
887 let tg = ServiceAccountTokenGenerator {
888 service_account_key: tp.service_account_key.clone(),
889 ..Default::default()
890 };
891
892 let signer = tg.signer(&tg.service_account_key.private_key);
893 let expected_error_message = "missing PEM section in service account key";
894 assert!(signer.is_err_and(|e| e.to_string().contains(expected_error_message)));
895 Ok(())
896 }
897
898 #[test]
899 fn unexpected_private_key_error_message() -> TestResult {
900 let expected_message = format!(
901 "expected key to be in form of PKCS8, found {:?}",
902 Item::Crl(Vec::new().into()) );
904
905 let error = ServiceAccountTokenGenerator::unexpected_private_key_error(Item::Crl(
906 Vec::new().into(),
907 ));
908 assert!(error.to_string().contains(&expected_message));
909 Ok(())
910 }
911
912 #[tokio::test]
913 async fn get_service_account_headers_with_audience() -> TestResult {
914 let mut service_account_key = get_mock_service_key();
915 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
916 let headers = Builder::new(service_account_key.clone())
917 .with_access_specifier(AccessSpecifier::from_audience("test-audience"))
918 .build()?
919 .headers(Extensions::new())
920 .await?;
921
922 let re = regex::Regex::new(SSJ_REGEX).unwrap();
923 let token = get_token_from_headers(headers).unwrap();
924 let captures = re.captures(&token).ok_or_else(|| {
925 format!(r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {token}"#)
926 })?;
927 let token_header = b64_decode_to_json(captures["header"].to_string());
928 assert_eq!(token_header["alg"], "RS256");
929 assert_eq!(token_header["typ"], "JWT");
930 assert_eq!(token_header["kid"], service_account_key["private_key_id"]);
931
932 let claims = b64_decode_to_json(captures["claims"].to_string());
933 assert_eq!(claims["iss"], service_account_key["client_email"]);
934 assert_eq!(claims["scope"], Value::Null);
935 assert_eq!(claims["aud"], "test-audience");
936 assert!(claims["iat"].is_number());
937 assert!(claims["exp"].is_number());
938 assert_eq!(claims["sub"], service_account_key["client_email"]);
939 Ok(())
940 }
941
942 #[tokio::test(start_paused = true)]
943 async fn get_service_account_token_verify_expiry_time() -> TestResult {
944 let now = Instant::now();
945 let mut service_account_key = get_mock_service_key();
946 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
947 let token = Builder::new(service_account_key)
948 .build_token_provider()?
949 .token()
950 .await?;
951
952 let expected_expiry = now + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT;
953
954 assert_eq!(token.expires_at.unwrap(), expected_expiry);
955 Ok(())
956 }
957
958 #[tokio::test]
959 async fn get_service_account_headers_with_custom_scopes() -> TestResult {
960 let mut service_account_key = get_mock_service_key();
961 let scopes = vec![
962 "https://www.googleapis.com/auth/pubsub, https://www.googleapis.com/auth/translate",
963 ];
964 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
965 let headers = Builder::new(service_account_key.clone())
966 .with_access_specifier(AccessSpecifier::from_scopes(scopes.clone()))
967 .build()?
968 .headers(Extensions::new())
969 .await?;
970
971 let re = regex::Regex::new(SSJ_REGEX).unwrap();
972 let token = get_token_from_headers(headers).unwrap();
973 let captures = re.captures(&token).ok_or_else(|| {
974 format!(r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {token}"#)
975 })?;
976 let token_header = b64_decode_to_json(captures["header"].to_string());
977 assert_eq!(token_header["alg"], "RS256");
978 assert_eq!(token_header["typ"], "JWT");
979 assert_eq!(token_header["kid"], service_account_key["private_key_id"]);
980
981 let claims = b64_decode_to_json(captures["claims"].to_string());
982 assert_eq!(claims["iss"], service_account_key["client_email"]);
983 assert_eq!(claims["scope"], scopes.join(" "));
984 assert_eq!(claims["aud"], Value::Null);
985 assert!(claims["iat"].is_number());
986 assert!(claims["exp"].is_number());
987 assert_eq!(claims["sub"], service_account_key["client_email"]);
988 Ok(())
989 }
990}
991
992#[cfg(all(test, google_cloud_unstable_id_token))]
993mod unstable_tests {
994 use super::tests::*;
995 use super::*;
996 use crate::constants::JWT_BEARER_GRANT_TYPE;
997 use crate::credentials::tests::PKCS8_PK;
998 use httptest::{
999 Expectation, Server,
1000 matchers::{all_of, any, contains, request, url_decoded},
1001 responders::*,
1002 };
1003 use serde_json::Value;
1004
1005 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
1006
1007 #[tokio::test]
1008 async fn idtoken_success() -> TestResult {
1009 let server = Server::run();
1010 server.expect(
1011 Expectation::matching(all_of![
1012 request::method("POST"),
1013 request::path("/"),
1014 request::body(url_decoded(contains(("grant_type", JWT_BEARER_GRANT_TYPE)))),
1015 request::body(url_decoded(contains(("assertion", any())))),
1016 ])
1017 .respond_with(status_code(200).body("test-id-token")),
1018 );
1019
1020 let mut service_account_key = get_mock_service_key();
1021 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
1022
1023 let creds = idtoken::Builder::new("test-audience", service_account_key)
1024 .with_token_server_url(server.url("/").to_string())
1025 .build()?;
1026
1027 let token = creds.id_token().await?;
1028 assert_eq!(token, "test-id-token");
1029 Ok(())
1030 }
1031
1032 #[tokio::test]
1033 async fn idtoken_http_error() -> TestResult {
1034 let server = Server::run();
1035 server.expect(
1036 Expectation::matching(all_of![request::method("POST"), request::path("/"),])
1037 .respond_with(status_code(501)),
1038 );
1039
1040 let mut service_account_key = get_mock_service_key();
1041 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
1042
1043 let creds = idtoken::Builder::new("test-audience", service_account_key)
1044 .with_token_server_url(server.url("/").to_string())
1045 .build()?;
1046
1047 let err = creds.id_token().await.unwrap_err();
1048 assert!(!err.is_transient());
1049 Ok(())
1050 }
1051
1052 #[tokio::test]
1053 async fn idtoken_caching() -> TestResult {
1054 let server = Server::run();
1055 server.expect(
1056 Expectation::matching(all_of![
1057 request::method("POST"),
1058 request::path("/"),
1059 request::body(url_decoded(contains(("grant_type", JWT_BEARER_GRANT_TYPE)))),
1060 request::body(url_decoded(contains(("assertion", any())))),
1061 ])
1062 .times(1)
1063 .respond_with(status_code(200).body("test-id-token")),
1064 );
1065
1066 let mut service_account_key = get_mock_service_key();
1067 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
1068
1069 let creds = idtoken::Builder::new("test-audience", service_account_key)
1070 .with_token_server_url(format!("http://{}", server.addr()))
1071 .build()?;
1072
1073 let id_token = creds.id_token().await?;
1074 assert_eq!(id_token, "test-id-token");
1075
1076 let id_token = creds.id_token().await?;
1077 assert_eq!(id_token, "test-id-token");
1078
1079 Ok(())
1080 }
1081}