1mod jws;
74
75use crate::build_errors::Error as BuilderError;
76use crate::constants::DEFAULT_SCOPE;
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::{self, CredentialsError};
80use crate::headers_util::build_cacheable_headers;
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use crate::{BuildResult, Result};
84use async_trait::async_trait;
85use http::{Extensions, HeaderMap};
86use jws::{CLOCK_SKEW_FUDGE, DEFAULT_TOKEN_TIMEOUT, JwsClaims, JwsHeader};
87use rustls::crypto::CryptoProvider;
88use rustls::sign::Signer;
89use rustls_pemfile::Item;
90use serde_json::Value;
91use std::sync::Arc;
92use time::OffsetDateTime;
93use tokio::time::Instant;
94
95#[derive(Clone, Debug, PartialEq)]
104pub enum AccessSpecifier {
105 Audience(String),
112
113 Scopes(Vec<String>),
132}
133
134impl AccessSpecifier {
135 fn audience(&self) -> Option<&String> {
136 match self {
137 AccessSpecifier::Audience(aud) => Some(aud),
138 AccessSpecifier::Scopes(_) => None,
139 }
140 }
141
142 fn scopes(&self) -> Option<&[String]> {
143 match self {
144 AccessSpecifier::Scopes(scopes) => Some(scopes),
145 AccessSpecifier::Audience(_) => None,
146 }
147 }
148
149 pub fn from_scopes<I, S>(scopes: I) -> Self
163 where
164 I: IntoIterator<Item = S>,
165 S: Into<String>,
166 {
167 AccessSpecifier::Scopes(scopes.into_iter().map(|s| s.into()).collect())
168 }
169
170 pub fn from_audience<S: Into<String>>(audience: S) -> Self {
184 AccessSpecifier::Audience(audience.into())
185 }
186}
187
188pub struct Builder {
207 service_account_key: Value,
208 access_specifier: AccessSpecifier,
209 quota_project_id: Option<String>,
210}
211
212impl Builder {
213 pub fn new(service_account_key: Value) -> Self {
220 Self {
221 service_account_key,
222 access_specifier: AccessSpecifier::Scopes([DEFAULT_SCOPE].map(str::to_string).to_vec()),
223 quota_project_id: None,
224 }
225 }
226
227 pub fn with_access_specifier(mut self, access_specifier: AccessSpecifier) -> Self {
249 self.access_specifier = access_specifier;
250 self
251 }
252
253 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
262 self.quota_project_id = Some(quota_project_id.into());
263 self
264 }
265
266 fn build_token_provider(self) -> BuildResult<ServiceAccountTokenProvider> {
267 let service_account_key =
268 serde_json::from_value::<ServiceAccountKey>(self.service_account_key)
269 .map_err(BuilderError::parsing)?;
270
271 Ok(ServiceAccountTokenProvider {
272 service_account_key,
273 access_specifier: self.access_specifier,
274 })
275 }
276
277 pub fn build(self) -> BuildResult<Credentials> {
290 Ok(Credentials {
291 inner: Arc::new(ServiceAccountCredentials {
292 quota_project_id: self.quota_project_id.clone(),
293 token_provider: TokenCache::new(self.build_token_provider()?),
294 }),
295 })
296 }
297}
298
299#[derive(serde::Deserialize, Default, Clone)]
303pub(crate) struct ServiceAccountKey {
304 client_email: String,
307 private_key_id: String,
309 private_key: String,
312 project_id: String,
314 universe_domain: Option<String>,
316}
317
318impl std::fmt::Debug for ServiceAccountKey {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 f.debug_struct("ServiceAccountKey")
321 .field("client_email", &self.client_email)
322 .field("private_key_id", &self.private_key_id)
323 .field("private_key", &"[censored]")
324 .field("project_id", &self.project_id)
325 .field("universe_domain", &self.universe_domain)
326 .finish()
327 }
328}
329
330#[derive(Debug)]
331struct ServiceAccountCredentials<T>
332where
333 T: CachedTokenProvider,
334{
335 token_provider: T,
336 quota_project_id: Option<String>,
337}
338
339#[derive(Debug)]
340struct ServiceAccountTokenProvider {
341 service_account_key: ServiceAccountKey,
342 access_specifier: AccessSpecifier,
343}
344
345fn token_issue_time(current_time: OffsetDateTime) -> OffsetDateTime {
346 current_time - CLOCK_SKEW_FUDGE
347}
348
349fn token_expiry_time(current_time: OffsetDateTime) -> OffsetDateTime {
350 current_time + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT
351}
352
353#[async_trait]
354impl TokenProvider for ServiceAccountTokenProvider {
355 async fn token(&self) -> Result<Token> {
356 let expires_at = Instant::now() + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT;
357 let tg = ServiceAccountTokenGenerator {
358 audience: self.access_specifier.audience().cloned(),
359 scopes: self
360 .access_specifier
361 .scopes()
362 .map(|scopes| scopes.join(" ")),
363 service_account_key: self.service_account_key.clone(),
364 target_audience: None,
365 };
366
367 let token = tg.generate()?;
368
369 let token = Token {
370 token,
371 token_type: "Bearer".to_string(),
372 expires_at: Some(expires_at),
373 metadata: None,
374 };
375 Ok(token)
376 }
377}
378
379#[derive(Default, Clone)]
380pub(crate) struct ServiceAccountTokenGenerator {
381 service_account_key: ServiceAccountKey,
382 audience: Option<String>,
383 scopes: Option<String>,
384 target_audience: Option<String>,
385}
386
387impl ServiceAccountTokenGenerator {
388 #[cfg(google_cloud_unstable_id_token)]
389 pub(crate) fn new_id_token_generator(
390 target_audience: String,
391 audience: String,
392 service_account_key: ServiceAccountKey,
393 ) -> Self {
394 Self {
395 service_account_key,
396 target_audience: Some(target_audience),
397 audience: Some(audience),
398 scopes: None,
399 }
400 }
401
402 pub(crate) fn generate(&self) -> Result<String> {
403 let signer = self.signer(&self.service_account_key.private_key)?;
404
405 let current_time = OffsetDateTime::now_utc();
409
410 let claims = JwsClaims {
411 iss: self.service_account_key.client_email.clone(),
412 scope: self.scopes.clone(),
413 target_audience: self.target_audience.clone(),
414 aud: self.audience.clone(),
415 exp: token_expiry_time(current_time),
416 iat: token_issue_time(current_time),
417 typ: None,
418 sub: Some(self.service_account_key.client_email.clone()),
419 };
420
421 let header = JwsHeader {
422 alg: "RS256",
423 typ: "JWT",
424 kid: &self.service_account_key.private_key_id,
425 };
426 let encoded_header_claims = format!("{}.{}", header.encode()?, claims.encode()?);
427 let sig = signer
428 .sign(encoded_header_claims.as_bytes())
429 .map_err(errors::non_retryable)?;
430 use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine as _};
431 let token = format!(
432 "{}.{}",
433 encoded_header_claims,
434 &BASE64_URL_SAFE_NO_PAD.encode(sig)
435 );
436
437 Ok(token)
438 }
439
440 fn signer(&self, private_key: &String) -> Result<Box<dyn Signer>> {
442 let key_provider = CryptoProvider::get_default().map_or_else(
443 || rustls::crypto::ring::default_provider().key_provider,
444 |p| p.key_provider,
445 );
446
447 let private_key = rustls_pemfile::read_one(&mut private_key.as_bytes())
448 .map_err(errors::non_retryable)?
449 .ok_or_else(|| {
450 errors::non_retryable_from_str("missing PEM section in service account key")
451 })?;
452 let pk = match private_key {
453 Item::Pkcs8Key(item) => key_provider.load_private_key(item.into()),
454 other => {
455 return Err(Self::unexpected_private_key_error(other));
456 }
457 };
458 let sk = pk.map_err(errors::non_retryable)?;
459 sk.choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256])
460 .ok_or_else(|| errors::non_retryable_from_str("Unable to choose RSA_PKCS1_SHA256 signing scheme as it is not supported by current signer"))
461 }
462
463 fn unexpected_private_key_error(private_key_format: Item) -> CredentialsError {
464 errors::non_retryable_from_str(format!(
465 "expected key to be in form of PKCS8, found {private_key_format:?}",
466 ))
467 }
468}
469
470#[async_trait::async_trait]
471impl<T> CredentialsProvider for ServiceAccountCredentials<T>
472where
473 T: CachedTokenProvider,
474{
475 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
476 let token = self.token_provider.token(extensions).await?;
477 build_cacheable_headers(&token, &self.quota_project_id)
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::credentials::QUOTA_PROJECT_KEY;
485 use crate::credentials::tests::{
486 PKCS8_PK, b64_decode_to_json, get_headers_from_cache, get_token_from_headers,
487 };
488 use crate::token::tests::MockTokenProvider;
489 use http::HeaderValue;
490 use http::header::AUTHORIZATION;
491 use rsa::pkcs1::EncodeRsaPrivateKey;
492 use rsa::pkcs8::LineEnding;
493 use rustls_pemfile::Item;
494 use serde_json::Value;
495 use serde_json::json;
496 use std::error::Error as _;
497 use std::time::Duration;
498
499 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
500
501 const SSJ_REGEX: &str = r"(?<header>[^\.]+)\.(?<claims>[^\.]+)\.(?<sig>[^\.]+)";
502
503 #[test]
504 fn debug_token_provider() {
505 let expected = ServiceAccountKey {
506 client_email: "test-client-email".to_string(),
507 private_key_id: "test-private-key-id".to_string(),
508 private_key: "super-duper-secret-private-key".to_string(),
509 project_id: "test-project-id".to_string(),
510 universe_domain: Some("test-universe-domain".to_string()),
511 };
512 let fmt = format!("{expected:?}");
513 assert!(fmt.contains("test-client-email"), "{fmt}");
514 assert!(fmt.contains("test-private-key-id"), "{fmt}");
515 assert!(!fmt.contains("super-duper-secret-private-key"), "{fmt}");
516 assert!(fmt.contains("test-project-id"), "{fmt}");
517 assert!(fmt.contains("test-universe-domain"), "{fmt}");
518 }
519
520 #[test]
521 fn validate_token_issue_time() {
522 let current_time = OffsetDateTime::now_utc();
523 let token_issue_time = token_issue_time(current_time);
524 assert!(token_issue_time == current_time - CLOCK_SKEW_FUDGE);
525 }
526
527 #[test]
528 fn validate_token_expiry_time() {
529 let current_time = OffsetDateTime::now_utc();
530 let token_issue_time = token_expiry_time(current_time);
531 assert!(token_issue_time == current_time + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT);
532 }
533
534 #[tokio::test]
535 async fn headers_success_without_quota_project() -> TestResult {
536 let token = Token {
537 token: "test-token".to_string(),
538 token_type: "Bearer".to_string(),
539 expires_at: None,
540 metadata: None,
541 };
542
543 let mut mock = MockTokenProvider::new();
544 mock.expect_token().times(1).return_once(|| Ok(token));
545
546 let sac = ServiceAccountCredentials {
547 token_provider: TokenCache::new(mock),
548 quota_project_id: None,
549 };
550
551 let mut extensions = Extensions::new();
552 let cached_headers = sac.headers(extensions.clone()).await.unwrap();
553 let (headers, entity_tag) = match cached_headers {
554 CacheableResource::New { entity_tag, data } => (data, entity_tag),
555 CacheableResource::NotModified => unreachable!("expecting new headers"),
556 };
557 let token = headers.get(AUTHORIZATION).unwrap();
558
559 assert_eq!(headers.len(), 1, "{headers:?}");
560 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
561 assert!(token.is_sensitive());
562
563 extensions.insert(entity_tag);
564
565 let cached_headers = sac.headers(extensions).await?;
566
567 match cached_headers {
568 CacheableResource::New { .. } => unreachable!("expecting new headers"),
569 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
570 };
571 Ok(())
572 }
573
574 #[tokio::test]
575 async fn headers_success_with_quota_project() -> TestResult {
576 let token = Token {
577 token: "test-token".to_string(),
578 token_type: "Bearer".to_string(),
579 expires_at: None,
580 metadata: None,
581 };
582
583 let quota_project = "test-quota-project";
584
585 let mut mock = MockTokenProvider::new();
586 mock.expect_token().times(1).return_once(|| Ok(token));
587
588 let sac = ServiceAccountCredentials {
589 token_provider: TokenCache::new(mock),
590 quota_project_id: Some(quota_project.to_string()),
591 };
592
593 let headers = get_headers_from_cache(sac.headers(Extensions::new()).await.unwrap())?;
594 let token = headers.get(AUTHORIZATION).unwrap();
595 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
596
597 assert_eq!(headers.len(), 2, "{headers:?}");
598 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
599 assert!(token.is_sensitive());
600 assert_eq!(
601 quota_project_header,
602 HeaderValue::from_static(quota_project)
603 );
604 assert!(!quota_project_header.is_sensitive());
605 Ok(())
606 }
607
608 #[tokio::test]
609 async fn headers_failure() {
610 let mut mock = MockTokenProvider::new();
611 mock.expect_token()
612 .times(1)
613 .return_once(|| Err(errors::non_retryable_from_str("fail")));
614
615 let sac = ServiceAccountCredentials {
616 token_provider: TokenCache::new(mock),
617 quota_project_id: None,
618 };
619 assert!(sac.headers(Extensions::new()).await.is_err());
620 }
621
622 fn get_mock_service_key() -> Value {
623 json!({
624 "client_email": "test-client-email",
625 "private_key_id": "test-private-key-id",
626 "private_key": "",
627 "project_id": "test-project-id",
628 })
629 }
630
631 #[tokio::test]
632 async fn get_service_account_headers_pkcs1_private_key_failure() -> TestResult {
633 let mut service_account_key = get_mock_service_key();
634
635 let key = crate::credentials::tests::RSA_PRIVATE_KEY
636 .to_pkcs1_pem(LineEnding::LF)
637 .expect("Failed to encode key to PKCS#1 PEM")
638 .to_string();
639
640 service_account_key["private_key"] = Value::from(key);
641 let cred = Builder::new(service_account_key).build()?;
642 let expected_error_message = "expected key to be in form of PKCS8, found Pkcs1Key";
643 assert!(
644 cred.headers(Extensions::new())
645 .await
646 .is_err_and(|e| e.to_string().contains(expected_error_message))
647 );
648 Ok(())
649 }
650
651 #[tokio::test]
652 async fn get_service_account_token_pkcs8_key_success() -> TestResult {
653 let mut service_account_key = get_mock_service_key();
654 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
655 let tp = Builder::new(service_account_key.clone()).build_token_provider()?;
656
657 let token = tp.token().await?;
658 let re = regex::Regex::new(SSJ_REGEX).unwrap();
659 let captures = re.captures(&token.token).ok_or_else(|| {
660 format!(
661 r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {}"#,
662 token.token
663 )
664 })?;
665 let header = b64_decode_to_json(captures["header"].to_string());
666 assert_eq!(header["alg"], "RS256");
667 assert_eq!(header["typ"], "JWT");
668 assert_eq!(header["kid"], service_account_key["private_key_id"]);
669
670 let claims = b64_decode_to_json(captures["claims"].to_string());
671 assert_eq!(claims["iss"], service_account_key["client_email"]);
672 assert_eq!(claims["scope"], DEFAULT_SCOPE);
673 assert!(claims["iat"].is_number());
674 assert!(claims["exp"].is_number());
675 assert_eq!(claims["sub"], service_account_key["client_email"]);
676
677 Ok(())
678 }
679
680 #[tokio::test]
681 async fn header_caching() -> TestResult {
682 let private_key = PKCS8_PK.clone();
683
684 let json_value = json!({
685 "client_email": "test-client-email",
686 "private_key_id": "test-private-key-id",
687 "private_key": private_key,
688 "project_id": "test-project-id",
689 "universe_domain": "test-universe-domain"
690 });
691
692 let credentials = Builder::new(json_value).build()?;
693
694 let headers = credentials.headers(Extensions::new()).await?;
695
696 let re = regex::Regex::new(SSJ_REGEX).unwrap();
697 let token = get_token_from_headers(headers).unwrap();
698
699 let captures = re.captures(&token).unwrap();
700
701 let claims = b64_decode_to_json(captures["claims"].to_string());
702 let first_iat = claims["iat"].as_i64().unwrap();
703
704 std::thread::sleep(Duration::from_secs(1));
709
710 let token = get_token_from_headers(credentials.headers(Extensions::new()).await?).unwrap();
712 let captures = re.captures(&token).unwrap();
713
714 let claims = b64_decode_to_json(captures["claims"].to_string());
715 let second_iat = claims["iat"].as_i64().unwrap();
716
717 assert_eq!(first_iat, second_iat);
720
721 Ok(())
722 }
723
724 #[tokio::test]
725 async fn get_service_account_headers_invalid_key_failure() -> TestResult {
726 let mut service_account_key = get_mock_service_key();
727 let pem_data = "-----BEGIN PRIVATE KEY-----\nMIGkAg==\n-----END PRIVATE KEY-----";
728 service_account_key["private_key"] = Value::from(pem_data);
729 let cred = Builder::new(service_account_key).build()?;
730
731 let token = cred.headers(Extensions::new()).await;
732 let err = token.unwrap_err();
733 assert!(!err.is_transient(), "{err:?}");
734 let source = err.source().and_then(|e| e.downcast_ref::<rustls::Error>());
735 assert!(matches!(source, Some(rustls::Error::General(_))), "{err:?}");
736 Ok(())
737 }
738
739 #[tokio::test]
740 async fn get_service_account_invalid_json_failure() -> TestResult {
741 let service_account_key = Value::from(" ");
742 let e = Builder::new(service_account_key).build().unwrap_err();
743 assert!(e.is_parsing(), "{e:?}");
744
745 Ok(())
746 }
747
748 #[test]
749 fn signer_failure() -> TestResult {
750 let tp = Builder::new(get_mock_service_key()).build_token_provider()?;
751 let tg = ServiceAccountTokenGenerator {
752 service_account_key: tp.service_account_key.clone(),
753 ..Default::default()
754 };
755
756 let signer = tg.signer(&tg.service_account_key.private_key);
757 let expected_error_message = "missing PEM section in service account key";
758 assert!(signer.is_err_and(|e| e.to_string().contains(expected_error_message)));
759 Ok(())
760 }
761
762 #[test]
763 fn unexpected_private_key_error_message() -> TestResult {
764 let expected_message = format!(
765 "expected key to be in form of PKCS8, found {:?}",
766 Item::Crl(Vec::new().into()) );
768
769 let error = ServiceAccountTokenGenerator::unexpected_private_key_error(Item::Crl(
770 Vec::new().into(),
771 ));
772 assert!(error.to_string().contains(&expected_message));
773 Ok(())
774 }
775
776 #[tokio::test]
777 async fn get_service_account_headers_with_audience() -> TestResult {
778 let mut service_account_key = get_mock_service_key();
779 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
780 let headers = Builder::new(service_account_key.clone())
781 .with_access_specifier(AccessSpecifier::from_audience("test-audience"))
782 .build()?
783 .headers(Extensions::new())
784 .await?;
785
786 let re = regex::Regex::new(SSJ_REGEX).unwrap();
787 let token = get_token_from_headers(headers).unwrap();
788 let captures = re.captures(&token).ok_or_else(|| {
789 format!(r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {token}"#)
790 })?;
791 let token_header = b64_decode_to_json(captures["header"].to_string());
792 assert_eq!(token_header["alg"], "RS256");
793 assert_eq!(token_header["typ"], "JWT");
794 assert_eq!(token_header["kid"], service_account_key["private_key_id"]);
795
796 let claims = b64_decode_to_json(captures["claims"].to_string());
797 assert_eq!(claims["iss"], service_account_key["client_email"]);
798 assert_eq!(claims["scope"], Value::Null);
799 assert_eq!(claims["aud"], "test-audience");
800 assert!(claims["iat"].is_number());
801 assert!(claims["exp"].is_number());
802 assert_eq!(claims["sub"], service_account_key["client_email"]);
803 Ok(())
804 }
805
806 #[tokio::test(start_paused = true)]
807 async fn get_service_account_token_verify_expiry_time() -> TestResult {
808 let now = Instant::now();
809 let mut service_account_key = get_mock_service_key();
810 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
811 let token = Builder::new(service_account_key)
812 .build_token_provider()?
813 .token()
814 .await?;
815
816 let expected_expiry = now + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT;
817
818 assert_eq!(token.expires_at.unwrap(), expected_expiry);
819 Ok(())
820 }
821
822 #[tokio::test]
823 async fn get_service_account_headers_with_custom_scopes() -> TestResult {
824 let mut service_account_key = get_mock_service_key();
825 let scopes = vec![
826 "https://www.googleapis.com/auth/pubsub, https://www.googleapis.com/auth/translate",
827 ];
828 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
829 let headers = Builder::new(service_account_key.clone())
830 .with_access_specifier(AccessSpecifier::from_scopes(scopes.clone()))
831 .build()?
832 .headers(Extensions::new())
833 .await?;
834
835 let re = regex::Regex::new(SSJ_REGEX).unwrap();
836 let token = get_token_from_headers(headers).unwrap();
837 let captures = re.captures(&token).ok_or_else(|| {
838 format!(r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {token}"#)
839 })?;
840 let token_header = b64_decode_to_json(captures["header"].to_string());
841 assert_eq!(token_header["alg"], "RS256");
842 assert_eq!(token_header["typ"], "JWT");
843 assert_eq!(token_header["kid"], service_account_key["private_key_id"]);
844
845 let claims = b64_decode_to_json(captures["claims"].to_string());
846 assert_eq!(claims["iss"], service_account_key["client_email"]);
847 assert_eq!(claims["scope"], scopes.join(" "));
848 assert_eq!(claims["aud"], Value::Null);
849 assert!(claims["iat"].is_number());
850 assert!(claims["exp"].is_number());
851 assert_eq!(claims["sub"], service_account_key["client_email"]);
852 Ok(())
853 }
854}