1use crate::build_errors::Error as BuilderError;
96use crate::constants::OAUTH2_TOKEN_SERVER_URL;
97use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
98use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
99use crate::errors::{self, CredentialsError};
100use crate::headers_util::AuthHeadersBuilder;
101use crate::retry::Builder as RetryTokenProviderBuilder;
102use crate::token::{CachedTokenProvider, Token, TokenProvider};
103use crate::token_cache::TokenCache;
104use crate::{BuildResult, Result};
105use google_cloud_gax::backoff_policy::BackoffPolicyArg;
106use google_cloud_gax::retry_policy::RetryPolicyArg;
107use google_cloud_gax::retry_throttler::RetryThrottlerArg;
108use http::header::CONTENT_TYPE;
109use http::{Extensions, HeaderMap, HeaderValue};
110use reqwest::{Client, Method};
111use serde_json::Value;
112use std::sync::Arc;
113use tokio::time::{Duration, Instant};
114
115pub struct Builder {
126 authorized_user: Value,
127 scopes: Option<Vec<String>>,
128 quota_project_id: Option<String>,
129 token_uri: Option<String>,
130 retry_builder: RetryTokenProviderBuilder,
131}
132
133impl Builder {
134 pub fn new(authorized_user: Value) -> Self {
141 Self {
142 authorized_user,
143 scopes: None,
144 quota_project_id: None,
145 token_uri: None,
146 retry_builder: RetryTokenProviderBuilder::default(),
147 }
148 }
149
150 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
164 self.token_uri = Some(token_uri.into());
165 self
166 }
167
168 pub(crate) fn with_universe_domain<S: Into<String>>(self, _universe_domain: S) -> Self {
170 self
171 }
172
173 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
194 where
195 I: IntoIterator<Item = S>,
196 S: Into<String>,
197 {
198 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
199 self
200 }
201
202 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
223 self.quota_project_id = Some(quota_project_id.into());
224 self
225 }
226
227 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
248 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
249 self
250 }
251
252 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
273 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
274 self
275 }
276
277 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
304 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
305 self
306 }
307
308 pub fn build(self) -> BuildResult<Credentials> {
322 Ok(Credentials {
323 inner: Arc::new(self.build_credentials()?),
324 })
325 }
326
327 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
361 Ok(AccessTokenCredentials {
362 inner: Arc::new(self.build_credentials()?),
363 })
364 }
365
366 fn build_credentials(self) -> BuildResult<UserCredentials<TokenCache>> {
367 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
368 .map_err(BuilderError::parsing)?;
369
370 let universe_domain = authorized_user.universe_domain.as_deref();
371 if !crate::universe_domain::is_default_universe_domain(universe_domain) {
372 return Err(BuilderError::not_supported(
373 "User Account Credentials are not supported in universes other than googleapis.com",
374 ));
375 }
376
377 let endpoint = self
378 .token_uri
379 .or(authorized_user.token_uri)
380 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
381 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
382
383 let token_provider = UserTokenProvider {
384 client_id: authorized_user.client_id,
385 client_secret: authorized_user.client_secret,
386 refresh_token: authorized_user.refresh_token,
387 endpoint,
388 scopes: self.scopes.map(|scopes| scopes.join(" ")),
389 source: UserTokenSource::AccessToken,
390 };
391
392 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
393
394 Ok(UserCredentials {
395 token_provider,
396 quota_project_id,
397 })
398 }
399}
400
401#[derive(PartialEq)]
402pub(crate) struct UserTokenProvider {
403 client_id: String,
404 client_secret: String,
405 refresh_token: String,
406 endpoint: String,
407 scopes: Option<String>,
408 source: UserTokenSource,
409}
410
411#[derive(PartialEq)]
412#[allow(dead_code)]
413enum UserTokenSource {
414 IdToken,
415 AccessToken,
416}
417
418impl std::fmt::Debug for UserTokenProvider {
419 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420 f.debug_struct("UserCredentials")
421 .field("client_id", &self.client_id)
422 .field("client_secret", &"[censored]")
423 .field("refresh_token", &"[censored]")
424 .field("endpoint", &self.endpoint)
425 .field("scopes", &self.scopes)
426 .finish()
427 }
428}
429
430impl UserTokenProvider {
431 #[cfg(feature = "idtoken")]
432 pub(crate) fn new_id_token_provider(
433 authorized_user: AuthorizedUser,
434 token_uri: Option<String>,
435 ) -> UserTokenProvider {
436 let endpoint = token_uri
437 .or(authorized_user.token_uri)
438 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
439 UserTokenProvider {
440 client_id: authorized_user.client_id,
441 client_secret: authorized_user.client_secret,
442 refresh_token: authorized_user.refresh_token,
443 endpoint,
444 source: UserTokenSource::IdToken,
445 scopes: None,
446 }
447 }
448}
449
450#[async_trait::async_trait]
451impl TokenProvider for UserTokenProvider {
452 async fn token(&self) -> Result<Token> {
453 let client = Client::new();
454
455 let req = Oauth2RefreshRequest {
457 grant_type: RefreshGrantType::RefreshToken,
458 client_id: self.client_id.clone(),
459 client_secret: self.client_secret.clone(),
460 refresh_token: self.refresh_token.clone(),
461 scopes: self.scopes.clone(),
462 };
463 let header = HeaderValue::from_static("application/json");
464 let builder = client
465 .request(Method::POST, self.endpoint.as_str())
466 .header(CONTENT_TYPE, header)
467 .json(&req);
468 let resp = builder
469 .send()
470 .await
471 .map_err(|e| errors::from_http_error(e, MSG))?;
472
473 if !resp.status().is_success() {
475 let err = errors::from_http_response(resp, MSG).await;
476 return Err(err);
477 }
478 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
479 let retryable = !e.is_decode();
480 CredentialsError::from_source(retryable, e)
481 })?;
482
483 let token = match self.source {
484 UserTokenSource::AccessToken => Ok(response.access_token),
485 UserTokenSource::IdToken => response
486 .id_token
487 .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
488 }?;
489 let token = Token {
490 token,
491 token_type: response.token_type,
492 expires_at: response
493 .expires_in
494 .map(|d| Instant::now() + Duration::from_secs(d)),
495 metadata: None,
496 };
497 Ok(token)
498 }
499}
500
501const MSG: &str = "failed to refresh user access token";
502const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
503gcloud running 'gcloud auth application-default login`";
504
505#[derive(Debug)]
509pub(crate) struct UserCredentials<T>
510where
511 T: CachedTokenProvider,
512{
513 token_provider: T,
514 quota_project_id: Option<String>,
515}
516
517#[async_trait::async_trait]
518impl<T> CredentialsProvider for UserCredentials<T>
519where
520 T: CachedTokenProvider,
521{
522 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
523 let token = self.token_provider.token(extensions).await?;
524
525 AuthHeadersBuilder::new(&token)
526 .maybe_quota_project_id(self.quota_project_id.as_deref())
527 .build()
528 }
529}
530
531#[async_trait::async_trait]
532impl<T> AccessTokenCredentialsProvider for UserCredentials<T>
533where
534 T: CachedTokenProvider,
535{
536 async fn access_token(&self) -> Result<AccessToken> {
537 let token = self.token_provider.token(Extensions::new()).await?;
538 token.into()
539 }
540}
541
542#[derive(Debug, PartialEq, serde::Deserialize)]
543pub(crate) struct AuthorizedUser {
544 #[serde(rename = "type")]
545 cred_type: String,
546 client_id: String,
547 client_secret: String,
548 refresh_token: String,
549 #[serde(skip_serializing_if = "Option::is_none")]
550 token_uri: Option<String>,
551 #[serde(skip_serializing_if = "Option::is_none")]
552 quota_project_id: Option<String>,
553 #[serde(skip_serializing_if = "Option::is_none")]
554 universe_domain: Option<String>,
555}
556
557#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
558pub(crate) enum RefreshGrantType {
559 #[serde(rename = "refresh_token")]
560 RefreshToken,
561}
562
563#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
564pub(crate) struct Oauth2RefreshRequest {
565 pub(crate) grant_type: RefreshGrantType,
566 pub(crate) client_id: String,
567 pub(crate) client_secret: String,
568 pub(crate) refresh_token: String,
569 scopes: Option<String>,
570}
571
572#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
573pub(crate) struct Oauth2RefreshResponse {
574 pub(crate) access_token: String,
575 #[serde(skip_serializing_if = "Option::is_none")]
576 pub(crate) id_token: Option<String>,
577 #[serde(skip_serializing_if = "Option::is_none")]
578 pub(crate) scope: Option<String>,
579 #[serde(skip_serializing_if = "Option::is_none")]
580 pub(crate) expires_in: Option<u64>,
581 pub(crate) token_type: String,
582 #[serde(skip_serializing_if = "Option::is_none")]
583 pub(crate) refresh_token: Option<String>,
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use crate::constants::DEFAULT_UNIVERSE_DOMAIN;
590 use crate::credentials::QUOTA_PROJECT_KEY;
591 use crate::credentials::tests::{
592 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
593 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
594 get_token_type_from_headers,
595 };
596 use crate::errors::CredentialsError;
597 use crate::token::tests::MockTokenProvider;
598 use http::StatusCode;
599 use http::header::AUTHORIZATION;
600 use httptest::cycle;
601 use httptest::matchers::{all_of, json_decoded, request};
602 use httptest::responders::{json_encoded, status_code};
603 use httptest::{Expectation, Server};
604
605 type TestResult = anyhow::Result<()>;
606
607 fn authorized_user_json(token_uri: String) -> Value {
608 serde_json::json!({
609 "client_id": "test-client-id",
610 "client_secret": "test-client-secret",
611 "refresh_token": "test-refresh-token",
612 "type": "authorized_user",
613 "token_uri": token_uri,
614 })
615 }
616
617 #[tokio::test]
618 async fn test_user_account_retries_on_transient_failures() -> TestResult {
619 let mut server = Server::run();
620 server.expect(
621 Expectation::matching(request::path("/token"))
622 .times(3)
623 .respond_with(status_code(503)),
624 );
625
626 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
627 .with_retry_policy(get_mock_auth_retry_policy(3))
628 .with_backoff_policy(get_mock_backoff_policy())
629 .with_retry_throttler(get_mock_retry_throttler())
630 .build()?;
631
632 let err = credentials.headers(Extensions::new()).await.unwrap_err();
633 assert!(err.is_transient(), "{err:?}");
634 server.verify_and_clear();
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
640 let mut server = Server::run();
641 server.expect(
642 Expectation::matching(request::path("/token"))
643 .times(1)
644 .respond_with(status_code(401)),
645 );
646
647 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
648 .with_retry_policy(get_mock_auth_retry_policy(1))
649 .with_backoff_policy(get_mock_backoff_policy())
650 .with_retry_throttler(get_mock_retry_throttler())
651 .build()?;
652
653 let err = credentials.headers(Extensions::new()).await.unwrap_err();
654 assert!(!err.is_transient());
655 server.verify_and_clear();
656 Ok(())
657 }
658
659 #[tokio::test]
660 async fn test_user_account_retries_for_success() -> TestResult {
661 let mut server = Server::run();
662 let response = Oauth2RefreshResponse {
663 access_token: "test-access-token".to_string(),
664 id_token: None,
665 expires_in: Some(3600),
666 refresh_token: Some("test-refresh-token".to_string()),
667 scope: Some("scope1 scope2".to_string()),
668 token_type: "test-token-type".to_string(),
669 };
670
671 server.expect(
672 Expectation::matching(request::path("/token"))
673 .times(3)
674 .respond_with(cycle![
675 status_code(503).body("try-again"),
676 status_code(503).body("try-again"),
677 status_code(200)
678 .append_header("Content-Type", "application/json")
679 .body(serde_json::to_string(&response).unwrap()),
680 ]),
681 );
682
683 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
684 .with_retry_policy(get_mock_auth_retry_policy(3))
685 .with_backoff_policy(get_mock_backoff_policy())
686 .with_retry_throttler(get_mock_retry_throttler())
687 .build()?;
688
689 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
690 assert_eq!(token.unwrap(), "test-access-token");
691
692 server.verify_and_clear();
693 Ok(())
694 }
695
696 #[test]
697 fn debug_token_provider() {
698 let expected = UserTokenProvider {
699 client_id: "test-client-id".to_string(),
700 client_secret: "test-client-secret".to_string(),
701 refresh_token: "test-refresh-token".to_string(),
702 endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
703 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
704 source: UserTokenSource::AccessToken,
705 };
706 let fmt = format!("{expected:?}");
707 assert!(fmt.contains("test-client-id"), "{fmt}");
708 assert!(!fmt.contains("test-client-secret"), "{fmt}");
709 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
710 assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
711 assert!(
712 fmt.contains("https://www.googleapis.com/auth/pubsub"),
713 "{fmt}"
714 );
715 }
716
717 #[test]
718 fn authorized_user_full_from_json_success() {
719 let json = serde_json::json!({
720 "account": "",
721 "client_id": "test-client-id",
722 "client_secret": "test-client-secret",
723 "refresh_token": "test-refresh-token",
724 "type": "authorized_user",
725 "universe_domain": "googleapis.com",
726 "quota_project_id": "test-project",
727 "token_uri" : "test-token-uri",
728 });
729
730 let expected = AuthorizedUser {
731 cred_type: "authorized_user".to_string(),
732 client_id: "test-client-id".to_string(),
733 client_secret: "test-client-secret".to_string(),
734 refresh_token: "test-refresh-token".to_string(),
735 quota_project_id: Some("test-project".to_string()),
736 token_uri: Some("test-token-uri".to_string()),
737 universe_domain: Some("googleapis.com".to_string()),
738 };
739 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
740 assert_eq!(actual, expected);
741 }
742
743 #[test]
744 fn authorized_user_partial_from_json_success() {
745 let json = serde_json::json!({
746 "client_id": "test-client-id",
747 "client_secret": "test-client-secret",
748 "refresh_token": "test-refresh-token",
749 "type": "authorized_user",
750 });
751
752 let expected = AuthorizedUser {
753 cred_type: "authorized_user".to_string(),
754 client_id: "test-client-id".to_string(),
755 client_secret: "test-client-secret".to_string(),
756 refresh_token: "test-refresh-token".to_string(),
757 quota_project_id: None,
758 token_uri: None,
759 universe_domain: None,
760 };
761 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
762 assert_eq!(actual, expected);
763 }
764
765 #[test]
766 fn authorized_user_from_json_parse_fail() {
767 let json_full = serde_json::json!({
768 "client_id": "test-client-id",
769 "client_secret": "test-client-secret",
770 "refresh_token": "test-refresh-token",
771 "type": "authorized_user",
772 "quota_project_id": "test-project"
773 });
774
775 for required_field in ["client_id", "client_secret", "refresh_token"] {
776 let mut json = json_full.clone();
777 json[required_field].take();
779 serde_json::from_value::<AuthorizedUser>(json)
780 .err()
781 .unwrap();
782 }
783 }
784
785 #[tokio::test]
786 async fn default_universe_domain_success() {
787 let mock = TokenCache::new(MockTokenProvider::new());
788
789 let uc = UserCredentials {
790 token_provider: mock,
791 quota_project_id: None,
792 };
793 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
794 }
795
796 #[test]
797 fn builder_rejects_non_default_universe() {
798 let json = serde_json::json!({
799 "client_id": "test-client-id",
800 "client_secret": "test-client-secret",
801 "refresh_token": "test-refresh-token",
802 "type": "authorized_user",
803 "universe_domain": "non-default-universe.com",
804 });
805
806 let err = Builder::new(json).build().unwrap_err();
807 assert!(err.is_not_supported(), "{err:?}");
808 }
809
810 #[tokio::test]
811 async fn headers_success() -> TestResult {
812 let token = Token {
813 token: "test-token".to_string(),
814 token_type: "Bearer".to_string(),
815 expires_at: None,
816 metadata: None,
817 };
818
819 let mut mock = MockTokenProvider::new();
820 mock.expect_token().times(1).return_once(|| Ok(token));
821
822 let uc = UserCredentials {
823 token_provider: TokenCache::new(mock),
824 quota_project_id: None,
825 };
826
827 let mut extensions = Extensions::new();
828 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
829 let (headers, entity_tag) = match cached_headers {
830 CacheableResource::New { entity_tag, data } => (data, entity_tag),
831 CacheableResource::NotModified => unreachable!("expecting new headers"),
832 };
833 let token = headers.get(AUTHORIZATION).unwrap();
834
835 assert_eq!(headers.len(), 1, "{headers:?}");
836 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
837 assert!(token.is_sensitive());
838
839 extensions.insert(entity_tag);
840
841 let cached_headers = uc.headers(extensions).await?;
842
843 match cached_headers {
844 CacheableResource::New { .. } => unreachable!("expecting new headers"),
845 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
846 };
847 Ok(())
848 }
849
850 #[tokio::test]
851 async fn headers_failure() {
852 let mut mock = MockTokenProvider::new();
853 mock.expect_token()
854 .times(1)
855 .return_once(|| Err(errors::non_retryable_from_str("fail")));
856
857 let uc = UserCredentials {
858 token_provider: TokenCache::new(mock),
859 quota_project_id: None,
860 };
861 let result = uc.headers(Extensions::new()).await;
862 assert!(result.is_err(), "{result:?}");
863 }
864
865 #[tokio::test]
866 async fn headers_with_quota_project_success() -> TestResult {
867 let token = Token {
868 token: "test-token".to_string(),
869 token_type: "Bearer".to_string(),
870 expires_at: None,
871 metadata: None,
872 };
873
874 let mut mock = MockTokenProvider::new();
875 mock.expect_token().times(1).return_once(|| Ok(token));
876
877 let uc = UserCredentials {
878 token_provider: TokenCache::new(mock),
879 quota_project_id: Some("test-project".to_string()),
880 };
881
882 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
883 let token = headers.get(AUTHORIZATION).unwrap();
884 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
885
886 assert_eq!(headers.len(), 2, "{headers:?}");
887 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
888 assert!(token.is_sensitive());
889 assert_eq!(
890 quota_project_header,
891 HeaderValue::from_static("test-project")
892 );
893 assert!(!quota_project_header.is_sensitive());
894 Ok(())
895 }
896
897 #[test]
898 fn oauth2_request_serde() {
899 let request = Oauth2RefreshRequest {
900 grant_type: RefreshGrantType::RefreshToken,
901 client_id: "test-client-id".to_string(),
902 client_secret: "test-client-secret".to_string(),
903 refresh_token: "test-refresh-token".to_string(),
904 scopes: Some("scope1 scope2".to_string()),
905 };
906
907 let json = serde_json::to_value(&request).unwrap();
908 let expected = serde_json::json!({
909 "grant_type": "refresh_token",
910 "client_id": "test-client-id",
911 "client_secret": "test-client-secret",
912 "refresh_token": "test-refresh-token",
913 "scopes": "scope1 scope2",
914 });
915 assert_eq!(json, expected);
916 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
917 assert_eq!(request, roundtrip);
918 }
919
920 #[test]
921 fn oauth2_response_serde_full() {
922 let response = Oauth2RefreshResponse {
923 access_token: "test-access-token".to_string(),
924 id_token: None,
925 scope: Some("scope1 scope2".to_string()),
926 expires_in: Some(3600),
927 token_type: "test-token-type".to_string(),
928 refresh_token: Some("test-refresh-token".to_string()),
929 };
930
931 let json = serde_json::to_value(&response).unwrap();
932 let expected = serde_json::json!({
933 "access_token": "test-access-token",
934 "scope": "scope1 scope2",
935 "expires_in": 3600,
936 "token_type": "test-token-type",
937 "refresh_token": "test-refresh-token"
938 });
939 assert_eq!(json, expected);
940 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
941 assert_eq!(response, roundtrip);
942 }
943
944 #[test]
945 fn oauth2_response_serde_partial() {
946 let response = Oauth2RefreshResponse {
947 access_token: "test-access-token".to_string(),
948 id_token: None,
949 scope: None,
950 expires_in: None,
951 token_type: "test-token-type".to_string(),
952 refresh_token: None,
953 };
954
955 let json = serde_json::to_value(&response).unwrap();
956 let expected = serde_json::json!({
957 "access_token": "test-access-token",
958 "token_type": "test-token-type",
959 });
960 assert_eq!(json, expected);
961 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
962 assert_eq!(response, roundtrip);
963 }
964
965 fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
966 request.client_id == "test-client-id"
967 && request.client_secret == "test-client-secret"
968 && request.refresh_token == "test-refresh-token"
969 && request.grant_type == RefreshGrantType::RefreshToken
970 && request.scopes == expected_scopes
971 }
972
973 #[tokio::test(start_paused = true)]
974 async fn token_provider_full() -> TestResult {
975 let server = Server::run();
976 let response = Oauth2RefreshResponse {
977 access_token: "test-access-token".to_string(),
978 id_token: None,
979 expires_in: Some(3600),
980 refresh_token: Some("test-refresh-token".to_string()),
981 scope: Some("scope1 scope2".to_string()),
982 token_type: "test-token-type".to_string(),
983 };
984 server.expect(
985 Expectation::matching(all_of![
986 request::path("/token"),
987 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
988 check_request(req, Some("scope1 scope2".to_string()))
989 }))
990 ])
991 .respond_with(json_encoded(response)),
992 );
993
994 let tp = UserTokenProvider {
995 client_id: "test-client-id".to_string(),
996 client_secret: "test-client-secret".to_string(),
997 refresh_token: "test-refresh-token".to_string(),
998 endpoint: server.url("/token").to_string(),
999 scopes: Some("scope1 scope2".to_string()),
1000 source: UserTokenSource::AccessToken,
1001 };
1002 let now = Instant::now();
1003 let token = tp.token().await?;
1004 assert_eq!(token.token, "test-access-token");
1005 assert_eq!(token.token_type, "test-token-type");
1006 assert!(
1007 token
1008 .expires_at
1009 .is_some_and(|d| d == now + Duration::from_secs(3600)),
1010 "now: {:?}, expires_at: {:?}",
1011 now,
1012 token.expires_at
1013 );
1014
1015 Ok(())
1016 }
1017
1018 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1019 async fn credential_full_with_quota_project() -> TestResult {
1020 let server = Server::run();
1021 let response = Oauth2RefreshResponse {
1022 access_token: "test-access-token".to_string(),
1023 id_token: None,
1024 expires_in: Some(3600),
1025 refresh_token: Some("test-refresh-token".to_string()),
1026 scope: None,
1027 token_type: "test-token-type".to_string(),
1028 };
1029 server.expect(
1030 Expectation::matching(all_of![
1031 request::path("/token"),
1032 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1033 check_request(req, None)
1034 }))
1035 ])
1036 .respond_with(json_encoded(response)),
1037 );
1038
1039 let authorized_user = serde_json::json!({
1040 "client_id": "test-client-id",
1041 "client_secret": "test-client-secret",
1042 "refresh_token": "test-refresh-token",
1043 "type": "authorized_user",
1044 "token_uri": server.url("/token").to_string(),
1045 });
1046 let cred = Builder::new(authorized_user)
1047 .with_quota_project_id("test-project")
1048 .build()?;
1049
1050 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
1051 let token = headers.get(AUTHORIZATION).unwrap();
1052 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
1053
1054 assert_eq!(headers.len(), 2, "{headers:?}");
1055 assert_eq!(
1056 token,
1057 HeaderValue::from_static("test-token-type test-access-token")
1058 );
1059 assert!(token.is_sensitive());
1060 assert_eq!(
1061 quota_project_header,
1062 HeaderValue::from_static("test-project")
1063 );
1064 assert!(!quota_project_header.is_sensitive());
1065
1066 Ok(())
1067 }
1068
1069 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1070 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
1071 let mut server = Server::run();
1072 let response = Oauth2RefreshResponse {
1073 access_token: "test-access-token".to_string(),
1074 id_token: None,
1075 expires_in: Some(3600),
1076 refresh_token: Some("test-refresh-token".to_string()),
1077 scope: Some("scope1 scope2".to_string()),
1078 token_type: "test-token-type".to_string(),
1079 };
1080 server.expect(
1081 Expectation::matching(all_of![
1082 request::path("/token"),
1083 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1084 check_request(req, Some("scope1 scope2".to_string()))
1085 }))
1086 ])
1087 .times(1)
1088 .respond_with(json_encoded(response)),
1089 );
1090
1091 let json = serde_json::json!({
1092 "client_id": "test-client-id",
1093 "client_secret": "test-client-secret",
1094 "refresh_token": "test-refresh-token",
1095 "type": "authorized_user",
1096 "universe_domain": "googleapis.com",
1097 "quota_project_id": "test-project",
1098 "token_uri": server.url("/token").to_string(),
1099 });
1100
1101 let cred = Builder::new(json)
1102 .with_scopes(vec!["scope1", "scope2"])
1103 .build()?;
1104
1105 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1106 assert_eq!(token.unwrap(), "test-access-token");
1107
1108 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1109 assert_eq!(token.unwrap(), "test-access-token");
1110
1111 server.verify_and_clear();
1112
1113 Ok(())
1114 }
1115
1116 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1117 async fn credential_provider_partial() -> TestResult {
1118 let server = Server::run();
1119 let response = Oauth2RefreshResponse {
1120 access_token: "test-access-token".to_string(),
1121 id_token: None,
1122 expires_in: None,
1123 refresh_token: None,
1124 scope: None,
1125 token_type: "test-token-type".to_string(),
1126 };
1127 server.expect(
1128 Expectation::matching(all_of![
1129 request::path("/token"),
1130 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1131 check_request(req, None)
1132 }))
1133 ])
1134 .respond_with(json_encoded(response)),
1135 );
1136
1137 let authorized_user = serde_json::json!({
1138 "client_id": "test-client-id",
1139 "client_secret": "test-client-secret",
1140 "refresh_token": "test-refresh-token",
1141 "type": "authorized_user",
1142 "token_uri": server.url("/token").to_string()
1143 });
1144
1145 let uc = Builder::new(authorized_user).build()?;
1146 let headers = uc.headers(Extensions::new()).await?;
1147 assert_eq!(
1148 get_token_from_headers(headers.clone()).unwrap(),
1149 "test-access-token"
1150 );
1151 assert_eq!(
1152 get_token_type_from_headers(headers).unwrap(),
1153 "test-token-type"
1154 );
1155
1156 Ok(())
1157 }
1158
1159 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1160 async fn credential_provider_with_token_uri() -> TestResult {
1161 let server = Server::run();
1162 let response = Oauth2RefreshResponse {
1163 access_token: "test-access-token".to_string(),
1164 id_token: None,
1165 expires_in: None,
1166 refresh_token: None,
1167 scope: None,
1168 token_type: "test-token-type".to_string(),
1169 };
1170 server.expect(
1171 Expectation::matching(all_of![
1172 request::path("/token"),
1173 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1174 check_request(req, None)
1175 }))
1176 ])
1177 .respond_with(json_encoded(response)),
1178 );
1179
1180 let authorized_user = serde_json::json!({
1181 "client_id": "test-client-id",
1182 "client_secret": "test-client-secret",
1183 "refresh_token": "test-refresh-token",
1184 "type": "authorized_user",
1185 "token_uri": "test-endpoint"
1186 });
1187
1188 let uc = Builder::new(authorized_user)
1189 .with_token_uri(server.url("/token").to_string())
1190 .build()?;
1191 let headers = uc.headers(Extensions::new()).await?;
1192 assert_eq!(
1193 get_token_from_headers(headers.clone()).unwrap(),
1194 "test-access-token"
1195 );
1196 assert_eq!(
1197 get_token_type_from_headers(headers).unwrap(),
1198 "test-token-type"
1199 );
1200
1201 Ok(())
1202 }
1203
1204 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1205 async fn access_credential_provider_with_token_uri() -> TestResult {
1206 let server = Server::run();
1207 let response = Oauth2RefreshResponse {
1208 access_token: "test-access-token".to_string(),
1209 id_token: None,
1210 expires_in: None,
1211 refresh_token: None,
1212 scope: None,
1213 token_type: "test-token-type".to_string(),
1214 };
1215 server.expect(
1216 Expectation::matching(all_of![
1217 request::path("/token"),
1218 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1219 check_request(req, None)
1220 }))
1221 ])
1222 .respond_with(json_encoded(response)),
1223 );
1224
1225 let authorized_user = serde_json::json!({
1226 "client_id": "test-client-id",
1227 "client_secret": "test-client-secret",
1228 "refresh_token": "test-refresh-token",
1229 "type": "authorized_user",
1230 "token_uri": "test-endpoint"
1231 });
1232
1233 let uc = Builder::new(authorized_user)
1234 .with_token_uri(server.url("/token").to_string())
1235 .build_access_token_credentials()?;
1236 let access_token = uc.access_token().await?;
1237 assert_eq!(access_token.token, "test-access-token");
1238
1239 Ok(())
1240 }
1241
1242 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1243 async fn credential_provider_with_scopes() -> TestResult {
1244 let server = Server::run();
1245 let response = Oauth2RefreshResponse {
1246 access_token: "test-access-token".to_string(),
1247 id_token: None,
1248 expires_in: None,
1249 refresh_token: None,
1250 scope: Some("scope1 scope2".to_string()),
1251 token_type: "test-token-type".to_string(),
1252 };
1253 server.expect(
1254 Expectation::matching(all_of![
1255 request::path("/token"),
1256 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1257 check_request(req, Some("scope1 scope2".to_string()))
1258 }))
1259 ])
1260 .respond_with(json_encoded(response)),
1261 );
1262
1263 let authorized_user = serde_json::json!({
1264 "client_id": "test-client-id",
1265 "client_secret": "test-client-secret",
1266 "refresh_token": "test-refresh-token",
1267 "type": "authorized_user",
1268 "token_uri": "test-endpoint"
1269 });
1270
1271 let uc = Builder::new(authorized_user)
1272 .with_token_uri(server.url("/token").to_string())
1273 .with_scopes(vec!["scope1", "scope2"])
1274 .build()?;
1275 let headers = uc.headers(Extensions::new()).await?;
1276 assert_eq!(
1277 get_token_from_headers(headers.clone()).unwrap(),
1278 "test-access-token"
1279 );
1280 assert_eq!(
1281 get_token_type_from_headers(headers).unwrap(),
1282 "test-token-type"
1283 );
1284
1285 Ok(())
1286 }
1287
1288 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1289 async fn credential_provider_retryable_error() -> TestResult {
1290 let server = Server::run();
1291 server
1292 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1293
1294 let authorized_user = serde_json::json!({
1295 "client_id": "test-client-id",
1296 "client_secret": "test-client-secret",
1297 "refresh_token": "test-refresh-token",
1298 "type": "authorized_user",
1299 "token_uri": server.url("/token").to_string()
1300 });
1301
1302 let uc = Builder::new(authorized_user).build()?;
1303 let err = uc.headers(Extensions::new()).await.unwrap_err();
1304 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1305 assert!(original_err.is_transient());
1306
1307 let source = find_source_error::<reqwest::Error>(&err);
1308 assert!(
1309 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1310 "{err:?}"
1311 );
1312
1313 Ok(())
1314 }
1315
1316 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1317 async fn token_provider_nonretryable_error() -> TestResult {
1318 let server = Server::run();
1319 server
1320 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1321
1322 let authorized_user = serde_json::json!({
1323 "client_id": "test-client-id",
1324 "client_secret": "test-client-secret",
1325 "refresh_token": "test-refresh-token",
1326 "type": "authorized_user",
1327 "token_uri": server.url("/token").to_string()
1328 });
1329
1330 let uc = Builder::new(authorized_user).build()?;
1331 let err = uc.headers(Extensions::new()).await.unwrap_err();
1332 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1333 assert!(!original_err.is_transient());
1334
1335 let source = find_source_error::<reqwest::Error>(&err);
1336 assert!(
1337 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1338 "{err:?}"
1339 );
1340
1341 Ok(())
1342 }
1343
1344 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1345 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1346 let server = Server::run();
1347 server.expect(
1348 Expectation::matching(request::path("/token"))
1349 .respond_with(json_encoded("bad json".to_string())),
1350 );
1351
1352 let authorized_user = serde_json::json!({
1353 "client_id": "test-client-id",
1354 "client_secret": "test-client-secret",
1355 "refresh_token": "test-refresh-token",
1356 "type": "authorized_user",
1357 "token_uri": server.url("/token").to_string()
1358 });
1359
1360 let uc = Builder::new(authorized_user).build()?;
1361 let e = uc.headers(Extensions::new()).await.err().unwrap();
1362 assert!(!e.is_transient(), "{e}");
1363
1364 Ok(())
1365 }
1366
1367 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1368 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1369 let authorized_user = serde_json::json!({
1370 "client_secret": "test-client-secret",
1371 "refresh_token": "test-refresh-token",
1372 "type": "authorized_user",
1373 });
1374
1375 let e = Builder::new(authorized_user).build().unwrap_err();
1376 assert!(e.is_parsing(), "{e}");
1377
1378 Ok(())
1379 }
1380}