google_cloud_auth/credentials/
service_account.rs1mod jws;
76
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{Credentials, Result};
79use crate::errors::{self, CredentialsError};
80use crate::headers_util::build_bearer_headers;
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use async_trait::async_trait;
84use http::{Extensions, HeaderMap};
85use jws::{CLOCK_SKEW_FUDGE, DEFAULT_TOKEN_TIMEOUT, JwsClaims, JwsHeader};
86use rustls::crypto::CryptoProvider;
87use rustls::sign::Signer;
88use rustls_pemfile::Item;
89use serde_json::Value;
90use std::sync::Arc;
91use time::OffsetDateTime;
92use tokio::time::Instant;
93
94const DEFAULT_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
95
96#[derive(Clone, Debug, PartialEq)]
105pub enum AccessSpecifier {
106 Audience(String),
113
114 Scopes(Vec<String>),
133}
134
135impl AccessSpecifier {
136 fn audience(&self) -> Option<&String> {
137 match self {
138 AccessSpecifier::Audience(aud) => Some(aud),
139 AccessSpecifier::Scopes(_) => None,
140 }
141 }
142
143 fn scopes(&self) -> Option<&[String]> {
144 match self {
145 AccessSpecifier::Scopes(scopes) => Some(scopes),
146 AccessSpecifier::Audience(_) => None,
147 }
148 }
149
150 pub fn from_scopes<I, S>(scopes: I) -> Self
164 where
165 I: IntoIterator<Item = S>,
166 S: Into<String>,
167 {
168 AccessSpecifier::Scopes(scopes.into_iter().map(|s| s.into()).collect())
169 }
170
171 pub fn from_audience<S: Into<String>>(audience: S) -> Self {
185 AccessSpecifier::Audience(audience.into())
186 }
187}
188
189pub struct Builder {
208 service_account_key: Value,
209 access_specifier: AccessSpecifier,
210 quota_project_id: Option<String>,
211}
212
213impl Builder {
214 pub fn new(service_account_key: Value) -> Self {
221 Self {
222 service_account_key,
223 access_specifier: AccessSpecifier::Scopes([DEFAULT_SCOPE].map(str::to_string).to_vec()),
224 quota_project_id: None,
225 }
226 }
227
228 pub fn with_access_specifier(mut self, access_specifier: AccessSpecifier) -> Self {
250 self.access_specifier = access_specifier;
251 self
252 }
253
254 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
263 self.quota_project_id = Some(quota_project_id.into());
264 self
265 }
266
267 fn build_token_provider(self) -> Result<ServiceAccountTokenProvider> {
268 let service_account_key =
269 serde_json::from_value::<ServiceAccountKey>(self.service_account_key)
270 .map_err(errors::non_retryable)?;
271
272 Ok(ServiceAccountTokenProvider {
273 service_account_key,
274 access_specifier: self.access_specifier,
275 })
276 }
277
278 pub fn build(self) -> Result<Credentials> {
291 Ok(Credentials {
292 inner: Arc::new(ServiceAccountCredentials {
293 quota_project_id: self.quota_project_id.clone(),
294 token_provider: TokenCache::new(self.build_token_provider()?),
295 }),
296 })
297 }
298}
299
300#[derive(serde::Deserialize, Default, Clone)]
304struct ServiceAccountKey {
305 client_email: String,
308 private_key_id: String,
310 private_key: String,
313 project_id: String,
315 universe_domain: Option<String>,
317}
318
319impl std::fmt::Debug for ServiceAccountKey {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 f.debug_struct("ServiceAccountKey")
322 .field("client_email", &self.client_email)
323 .field("private_key_id", &self.private_key_id)
324 .field("private_key", &"[censored]")
325 .field("project_id", &self.project_id)
326 .field("universe_domain", &self.universe_domain)
327 .finish()
328 }
329}
330
331#[derive(Debug)]
332struct ServiceAccountCredentials<T>
333where
334 T: CachedTokenProvider,
335{
336 token_provider: T,
337 quota_project_id: Option<String>,
338}
339
340#[derive(Debug)]
341struct ServiceAccountTokenProvider {
342 service_account_key: ServiceAccountKey,
343 access_specifier: AccessSpecifier,
344}
345
346fn token_issue_time(current_time: OffsetDateTime) -> OffsetDateTime {
347 current_time - CLOCK_SKEW_FUDGE
348}
349
350fn token_expiry_time(current_time: OffsetDateTime) -> OffsetDateTime {
351 current_time + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT
352}
353
354#[async_trait]
355impl TokenProvider for ServiceAccountTokenProvider {
356 async fn token(&self) -> Result<Token> {
357 let signer = self.signer(&self.service_account_key.private_key)?;
358
359 let expires_at = Instant::now() + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT;
360 let current_time = OffsetDateTime::now_utc();
364
365 let claims = JwsClaims {
366 iss: self.service_account_key.client_email.clone(),
367 scope: self
368 .access_specifier
369 .scopes()
370 .map(|scopes| scopes.join(" ")),
371 aud: self.access_specifier.audience().cloned(),
372 exp: token_expiry_time(current_time),
373 iat: token_issue_time(current_time),
374 typ: None,
375 sub: Some(self.service_account_key.client_email.clone()),
376 };
377
378 let header = JwsHeader {
379 alg: "RS256",
380 typ: "JWT",
381 kid: &self.service_account_key.private_key_id,
382 };
383 let encoded_header_claims = format!("{}.{}", header.encode()?, claims.encode()?);
384 let sig = signer
385 .sign(encoded_header_claims.as_bytes())
386 .map_err(errors::non_retryable)?;
387 use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine as _};
388 let token = format!(
389 "{}.{}",
390 encoded_header_claims,
391 &BASE64_URL_SAFE_NO_PAD.encode(sig)
392 );
393
394 let token = Token {
395 token,
396 token_type: "Bearer".to_string(),
397 expires_at: Some(expires_at),
398 metadata: None,
399 };
400 Ok(token)
401 }
402}
403
404impl ServiceAccountTokenProvider {
405 fn signer(&self, private_key: &String) -> Result<Box<dyn Signer>> {
407 let key_provider = CryptoProvider::get_default().map_or_else(
408 || rustls::crypto::ring::default_provider().key_provider,
409 |p| p.key_provider,
410 );
411
412 let private_key = rustls_pemfile::read_one(&mut private_key.as_bytes())
413 .map_err(errors::non_retryable)?
414 .ok_or_else(|| {
415 errors::non_retryable_from_str("missing PEM section in service account key")
416 })?;
417 let pk = match private_key {
418 Item::Pkcs8Key(item) => key_provider.load_private_key(item.into()),
419 other => {
420 return Err(Self::unexpected_private_key_error(other));
421 }
422 };
423 let sk = pk.map_err(errors::non_retryable)?;
424 sk.choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256])
425 .ok_or_else(|| errors::non_retryable_from_str("Unable to choose RSA_PKCS1_SHA256 signing scheme as it is not supported by current signer"))
426 }
427
428 fn unexpected_private_key_error(private_key_format: Item) -> CredentialsError {
429 errors::non_retryable_from_str(format!(
430 "expected key to be in form of PKCS8, found {:?}",
431 private_key_format
432 ))
433 }
434}
435
436#[async_trait::async_trait]
437impl<T> CredentialsProvider for ServiceAccountCredentials<T>
438where
439 T: CachedTokenProvider,
440{
441 async fn headers(&self, extensions: Extensions) -> Result<HeaderMap> {
442 let token = self.token_provider.token(extensions).await?;
443 build_bearer_headers(&token, &self.quota_project_id)
444 }
445}
446
447#[cfg(test)]
448mod test {
449 use super::*;
450 use crate::credentials::QUOTA_PROJECT_KEY;
451 use crate::credentials::test::{PKCS8_PK, b64_decode_to_json, get_token_from_headers};
452 use crate::token::test::MockTokenProvider;
453 use http::HeaderValue;
454 use http::header::AUTHORIZATION;
455 use rsa::RsaPrivateKey;
456 use rsa::pkcs1::EncodeRsaPrivateKey;
457 use rsa::pkcs8::LineEnding;
458 use rustls_pemfile::Item;
459 use serde_json::json;
460 use std::sync::LazyLock;
461 use std::time::Duration;
462
463 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
464
465 const SSJ_REGEX: &str = r"(?<header>[^\.]+)\.(?<claims>[^\.]+)\.(?<sig>[^\.]+)";
466
467 #[test]
468 fn debug_token_provider() {
469 let expected = ServiceAccountKey {
470 client_email: "test-client-email".to_string(),
471 private_key_id: "test-private-key-id".to_string(),
472 private_key: "super-duper-secret-private-key".to_string(),
473 project_id: "test-project-id".to_string(),
474 universe_domain: Some("test-universe-domain".to_string()),
475 };
476 let fmt = format!("{expected:?}");
477 assert!(fmt.contains("test-client-email"), "{fmt}");
478 assert!(fmt.contains("test-private-key-id"), "{fmt}");
479 assert!(!fmt.contains("super-duper-secret-private-key"), "{fmt}");
480 assert!(fmt.contains("test-project-id"), "{fmt}");
481 assert!(fmt.contains("test-universe-domain"), "{fmt}");
482 }
483
484 #[test]
485 fn validate_token_issue_time() {
486 let current_time = OffsetDateTime::now_utc();
487 let token_issue_time = token_issue_time(current_time);
488 assert!(token_issue_time == current_time - CLOCK_SKEW_FUDGE);
489 }
490
491 #[test]
492 fn validate_token_expiry_time() {
493 let current_time = OffsetDateTime::now_utc();
494 let token_issue_time = token_expiry_time(current_time);
495 assert!(token_issue_time == current_time + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT);
496 }
497
498 #[tokio::test]
499 async fn headers_success_without_quota_project() {
500 let token = Token {
501 token: "test-token".to_string(),
502 token_type: "Bearer".to_string(),
503 expires_at: None,
504 metadata: None,
505 };
506
507 let mut mock = MockTokenProvider::new();
508 mock.expect_token().times(1).return_once(|| Ok(token));
509
510 let sac = ServiceAccountCredentials {
511 token_provider: TokenCache::new(mock),
512 quota_project_id: None,
513 };
514
515 let headers = sac.headers(Extensions::new()).await.unwrap();
516 let token = headers.get(AUTHORIZATION).unwrap();
517
518 assert_eq!(headers.len(), 1, "{headers:?}");
519 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
520 assert!(token.is_sensitive());
521 }
522
523 #[tokio::test]
524 async fn headers_success_with_quota_project() {
525 let token = Token {
526 token: "test-token".to_string(),
527 token_type: "Bearer".to_string(),
528 expires_at: None,
529 metadata: None,
530 };
531
532 let quota_project = "test-quota-project";
533
534 let mut mock = MockTokenProvider::new();
535 mock.expect_token().times(1).return_once(|| Ok(token));
536
537 let sac = ServiceAccountCredentials {
538 token_provider: TokenCache::new(mock),
539 quota_project_id: Some(quota_project.to_string()),
540 };
541
542 let headers = sac.headers(Extensions::new()).await.unwrap();
543 let token = headers.get(AUTHORIZATION).unwrap();
544 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
545
546 assert_eq!(headers.len(), 2, "{headers:?}");
547 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
548 assert!(token.is_sensitive());
549 assert_eq!(
550 quota_project_header,
551 HeaderValue::from_static(quota_project)
552 );
553 assert!(!quota_project_header.is_sensitive());
554 }
555
556 #[tokio::test]
557 async fn headers_failure() {
558 let mut mock = MockTokenProvider::new();
559 mock.expect_token()
560 .times(1)
561 .return_once(|| Err(errors::non_retryable_from_str("fail")));
562
563 let sac = ServiceAccountCredentials {
564 token_provider: TokenCache::new(mock),
565 quota_project_id: None,
566 };
567 assert!(sac.headers(Extensions::new()).await.is_err());
568 }
569
570 fn get_mock_service_key() -> Value {
571 json!({
572 "client_email": "test-client-email",
573 "private_key_id": "test-private-key-id",
574 "private_key": "",
575 "project_id": "test-project-id",
576 })
577 }
578
579 static PKCS1_PK: LazyLock<String> = LazyLock::new(|| {
580 let mut rng = rand::thread_rng();
581 let bits = 2048;
582 let priv_key = RsaPrivateKey::new(&mut rng, bits).expect("failed to generate a key");
583 priv_key
584 .to_pkcs1_pem(LineEnding::LF)
585 .expect("Failed to encode key to PKCS#1 PEM")
586 .to_string()
587 });
588
589 #[tokio::test]
590 async fn get_service_account_headers_pkcs1_private_key_failure() -> TestResult {
591 let mut service_account_key = get_mock_service_key();
592 service_account_key["private_key"] = Value::from(PKCS1_PK.clone());
593 let cred = Builder::new(service_account_key).build()?;
594 let expected_error_message = "expected key to be in form of PKCS8, found Pkcs1Key";
595 assert!(
596 cred.headers(Extensions::new())
597 .await
598 .is_err_and(|e| e.to_string().contains(expected_error_message))
599 );
600 Ok(())
601 }
602
603 #[tokio::test]
604 async fn get_service_account_token_pkcs8_key_success() -> TestResult {
605 let mut service_account_key = get_mock_service_key();
606 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
607 let tp = Builder::new(service_account_key.clone()).build_token_provider()?;
608
609 let token = tp.token().await?;
610 let re = regex::Regex::new(SSJ_REGEX).unwrap();
611 let captures = re.captures(&token.token).ok_or_else(|| {
612 format!(
613 r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {}"#,
614 token.token
615 )
616 })?;
617 let header = b64_decode_to_json(captures["header"].to_string());
618 assert_eq!(header["alg"], "RS256");
619 assert_eq!(header["typ"], "JWT");
620 assert_eq!(header["kid"], service_account_key["private_key_id"]);
621
622 let claims = b64_decode_to_json(captures["claims"].to_string());
623 assert_eq!(claims["iss"], service_account_key["client_email"]);
624 assert_eq!(claims["scope"], DEFAULT_SCOPE);
625 assert!(claims["iat"].is_number());
626 assert!(claims["exp"].is_number());
627 assert_eq!(claims["sub"], service_account_key["client_email"]);
628
629 Ok(())
630 }
631
632 #[tokio::test]
633 async fn header_caching() -> TestResult {
634 let private_key = PKCS8_PK.clone();
635
636 let json_value = json!({
637 "client_email": "test-client-email",
638 "private_key_id": "test-private-key-id",
639 "private_key": private_key,
640 "project_id": "test-project-id",
641 "universe_domain": "test-universe-domain"
642 });
643
644 let credentials = Builder::new(json_value).build()?;
645
646 let headers = credentials.headers(Extensions::new()).await?;
647
648 let re = regex::Regex::new(SSJ_REGEX).unwrap();
649 let token = get_token_from_headers(&headers).unwrap();
650
651 let captures = re.captures(&token).unwrap();
652
653 let claims = b64_decode_to_json(captures["claims"].to_string());
654 let first_iat = claims["iat"].as_i64().unwrap();
655
656 std::thread::sleep(Duration::from_secs(1));
661
662 let token = get_token_from_headers(&credentials.headers(Extensions::new()).await?).unwrap();
664 let captures = re.captures(&token).unwrap();
665
666 let claims = b64_decode_to_json(captures["claims"].to_string());
667 let second_iat = claims["iat"].as_i64().unwrap();
668
669 assert_eq!(first_iat, second_iat);
672
673 Ok(())
674 }
675
676 #[tokio::test]
677 async fn get_service_account_headers_invalid_key_failure() -> TestResult {
678 let mut service_account_key = get_mock_service_key();
679 let pem_data = "-----BEGIN PRIVATE KEY-----\nMIGkAg==\n-----END PRIVATE KEY-----";
680 service_account_key["private_key"] = Value::from(pem_data);
681 let cred = Builder::new(service_account_key).build()?;
682
683 let token = cred.headers(Extensions::new()).await;
684 let expected_error_message = "failed to parse private key";
685 assert!(token.is_err_and(|e| e.to_string().contains(expected_error_message)));
686 Ok(())
687 }
688
689 #[tokio::test]
690 async fn get_service_account_invalid_json_failure() -> TestResult {
691 let service_account_key = Value::from(" ");
692 let e = Builder::new(service_account_key).build().err().unwrap();
693
694 assert!(!e.is_retryable());
695
696 Ok(())
697 }
698
699 #[test]
700 fn signer_failure() -> TestResult {
701 let tp = Builder::new(get_mock_service_key()).build_token_provider()?;
702
703 let signer = tp.signer(&tp.service_account_key.private_key);
704 let expected_error_message = "missing PEM section in service account key";
705 assert!(signer.is_err_and(|e| e.to_string().contains(expected_error_message)));
706 Ok(())
707 }
708
709 #[test]
710 fn unexpected_private_key_error_message() -> TestResult {
711 let expected_message = format!(
712 "expected key to be in form of PKCS8, found {:?}",
713 Item::Crl(Vec::new().into()) );
715
716 let error =
717 ServiceAccountTokenProvider::unexpected_private_key_error(Item::Crl(Vec::new().into()));
718 assert!(error.to_string().contains(&expected_message));
719 Ok(())
720 }
721
722 #[tokio::test]
723 async fn get_service_account_headers_with_audience() -> TestResult {
724 let mut service_account_key = get_mock_service_key();
725 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
726 let headers = Builder::new(service_account_key.clone())
727 .with_access_specifier(AccessSpecifier::from_audience("test-audience"))
728 .build()?
729 .headers(Extensions::new())
730 .await?;
731
732 let re = regex::Regex::new(SSJ_REGEX).unwrap();
733 let token = get_token_from_headers(&headers).unwrap();
734 let captures = re.captures(&token).ok_or_else(|| {
735 format!(
736 r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {}"#,
737 token
738 )
739 })?;
740 let token_header = b64_decode_to_json(captures["header"].to_string());
741 assert_eq!(token_header["alg"], "RS256");
742 assert_eq!(token_header["typ"], "JWT");
743 assert_eq!(token_header["kid"], service_account_key["private_key_id"]);
744
745 let claims = b64_decode_to_json(captures["claims"].to_string());
746 assert_eq!(claims["iss"], service_account_key["client_email"]);
747 assert_eq!(claims["scope"], Value::Null);
748 assert_eq!(claims["aud"], "test-audience");
749 assert!(claims["iat"].is_number());
750 assert!(claims["exp"].is_number());
751 assert_eq!(claims["sub"], service_account_key["client_email"]);
752 Ok(())
753 }
754
755 #[tokio::test(start_paused = true)]
756 async fn get_service_account_token_verify_expiry_time() -> TestResult {
757 let now = Instant::now();
758 let mut service_account_key = get_mock_service_key();
759 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
760 let token = Builder::new(service_account_key)
761 .build_token_provider()?
762 .token()
763 .await?;
764
765 let expected_expiry = now + CLOCK_SKEW_FUDGE + DEFAULT_TOKEN_TIMEOUT;
766
767 assert_eq!(token.expires_at.unwrap(), expected_expiry);
768 Ok(())
769 }
770
771 #[tokio::test]
772 async fn get_service_account_headers_with_custom_scopes() -> TestResult {
773 let mut service_account_key = get_mock_service_key();
774 let scopes = vec![
775 "https://www.googleapis.com/auth/pubsub, https://www.googleapis.com/auth/translate",
776 ];
777 service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
778 let headers = Builder::new(service_account_key.clone())
779 .with_access_specifier(AccessSpecifier::from_scopes(scopes.clone()))
780 .build()?
781 .headers(Extensions::new())
782 .await?;
783
784 let re = regex::Regex::new(SSJ_REGEX).unwrap();
785 let token = get_token_from_headers(&headers).unwrap();
786 let captures = re.captures(&token).ok_or_else(|| {
787 format!(
788 r#"Expected token in form: "<header>.<claims>.<sig>". Found token: {}"#,
789 token
790 )
791 })?;
792 let token_header = b64_decode_to_json(captures["header"].to_string());
793 assert_eq!(token_header["alg"], "RS256");
794 assert_eq!(token_header["typ"], "JWT");
795 assert_eq!(token_header["kid"], service_account_key["private_key_id"]);
796
797 let claims = b64_decode_to_json(captures["claims"].to_string());
798 assert_eq!(claims["iss"], service_account_key["client_email"]);
799 assert_eq!(claims["scope"], scopes.join(" "));
800 assert_eq!(claims["aud"], Value::Null);
801 assert!(claims["iat"].is_number());
802 assert!(claims["exp"].is_number());
803 assert_eq!(claims["sub"], service_account_key["client_email"]);
804 Ok(())
805 }
806}