1use crate::build_errors::Error as BuilderError;
98use crate::constants::OAUTH2_TOKEN_SERVER_URL;
99use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
100use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
101use crate::errors::{self, CredentialsError};
102use crate::headers_util::AuthHeadersBuilder;
103use crate::retry::Builder as RetryTokenProviderBuilder;
104use crate::token::{CachedTokenProvider, Token, TokenProvider};
105use crate::token_cache::TokenCache;
106use crate::{BuildResult, Result};
107use google_cloud_gax::backoff_policy::BackoffPolicyArg;
108use google_cloud_gax::retry_policy::RetryPolicyArg;
109use google_cloud_gax::retry_throttler::RetryThrottlerArg;
110use http::header::CONTENT_TYPE;
111use http::{Extensions, HeaderMap, HeaderValue};
112use reqwest::{Client, Method};
113use serde_json::Value;
114use std::sync::Arc;
115use tokio::time::{Duration, Instant};
116
117pub struct Builder {
128 authorized_user: Value,
129 scopes: Option<Vec<String>>,
130 quota_project_id: Option<String>,
131 token_uri: Option<String>,
132 retry_builder: RetryTokenProviderBuilder,
133}
134
135impl Builder {
136 pub fn new(authorized_user: Value) -> Self {
143 Self {
144 authorized_user,
145 scopes: None,
146 quota_project_id: None,
147 token_uri: None,
148 retry_builder: RetryTokenProviderBuilder::default(),
149 }
150 }
151
152 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
166 self.token_uri = Some(token_uri.into());
167 self
168 }
169
170 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191 where
192 I: IntoIterator<Item = S>,
193 S: Into<String>,
194 {
195 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196 self
197 }
198
199 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
220 self.quota_project_id = Some(quota_project_id.into());
221 self
222 }
223
224 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
245 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
246 self
247 }
248
249 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
270 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
271 self
272 }
273
274 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
301 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
302 self
303 }
304
305 pub fn build(self) -> BuildResult<Credentials> {
319 Ok(self.build_access_token_credentials()?.into())
320 }
321
322 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
357 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
358 .map_err(BuilderError::parsing)?;
359 let endpoint = self
360 .token_uri
361 .or(authorized_user.token_uri)
362 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
363 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
364
365 let token_provider = UserTokenProvider {
366 client_id: authorized_user.client_id,
367 client_secret: authorized_user.client_secret,
368 refresh_token: authorized_user.refresh_token,
369 endpoint,
370 scopes: self.scopes.map(|scopes| scopes.join(" ")),
371 source: UserTokenSource::AccessToken,
372 };
373
374 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
375
376 Ok(AccessTokenCredentials {
377 inner: Arc::new(UserCredentials {
378 token_provider,
379 quota_project_id,
380 }),
381 })
382 }
383}
384
385#[derive(PartialEq)]
386pub(crate) struct UserTokenProvider {
387 client_id: String,
388 client_secret: String,
389 refresh_token: String,
390 endpoint: String,
391 scopes: Option<String>,
392 source: UserTokenSource,
393}
394
395#[allow(dead_code)]
396#[derive(PartialEq)]
397enum UserTokenSource {
398 IdToken,
399 AccessToken,
400}
401
402impl std::fmt::Debug for UserTokenProvider {
403 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 f.debug_struct("UserCredentials")
405 .field("client_id", &self.client_id)
406 .field("client_secret", &"[censored]")
407 .field("refresh_token", &"[censored]")
408 .field("endpoint", &self.endpoint)
409 .field("scopes", &self.scopes)
410 .finish()
411 }
412}
413
414impl UserTokenProvider {
415 #[cfg(feature = "idtoken")]
416 pub(crate) fn new_id_token_provider(
417 authorized_user: AuthorizedUser,
418 token_uri: Option<String>,
419 ) -> UserTokenProvider {
420 let endpoint = token_uri
421 .or(authorized_user.token_uri)
422 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
423 UserTokenProvider {
424 client_id: authorized_user.client_id,
425 client_secret: authorized_user.client_secret,
426 refresh_token: authorized_user.refresh_token,
427 endpoint,
428 source: UserTokenSource::IdToken,
429 scopes: None,
430 }
431 }
432}
433
434#[async_trait::async_trait]
435impl TokenProvider for UserTokenProvider {
436 async fn token(&self) -> Result<Token> {
437 let client = Client::new();
438
439 let req = Oauth2RefreshRequest {
441 grant_type: RefreshGrantType::RefreshToken,
442 client_id: self.client_id.clone(),
443 client_secret: self.client_secret.clone(),
444 refresh_token: self.refresh_token.clone(),
445 scopes: self.scopes.clone(),
446 };
447 let header = HeaderValue::from_static("application/json");
448 let builder = client
449 .request(Method::POST, self.endpoint.as_str())
450 .header(CONTENT_TYPE, header)
451 .json(&req);
452 let resp = builder
453 .send()
454 .await
455 .map_err(|e| errors::from_http_error(e, MSG))?;
456
457 if !resp.status().is_success() {
459 let err = errors::from_http_response(resp, MSG).await;
460 return Err(err);
461 }
462 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
463 let retryable = !e.is_decode();
464 CredentialsError::from_source(retryable, e)
465 })?;
466
467 let token = match self.source {
468 UserTokenSource::AccessToken => Ok(response.access_token),
469 UserTokenSource::IdToken => response
470 .id_token
471 .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
472 }?;
473 let token = Token {
474 token,
475 token_type: response.token_type,
476 expires_at: response
477 .expires_in
478 .map(|d| Instant::now() + Duration::from_secs(d)),
479 metadata: None,
480 };
481 Ok(token)
482 }
483}
484
485const MSG: &str = "failed to refresh user access token";
486const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
487gcloud running 'gcloud auth application-default login`";
488
489#[derive(Debug)]
493pub(crate) struct UserCredentials<T>
494where
495 T: CachedTokenProvider,
496{
497 token_provider: T,
498 quota_project_id: Option<String>,
499}
500
501#[async_trait::async_trait]
502impl<T> CredentialsProvider for UserCredentials<T>
503where
504 T: CachedTokenProvider,
505{
506 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
507 let token = self.token_provider.token(extensions).await?;
508
509 AuthHeadersBuilder::new(&token)
510 .maybe_quota_project_id(self.quota_project_id.as_deref())
511 .build()
512 }
513}
514
515#[async_trait::async_trait]
516impl<T> AccessTokenCredentialsProvider for UserCredentials<T>
517where
518 T: CachedTokenProvider,
519{
520 async fn access_token(&self) -> Result<AccessToken> {
521 let token = self.token_provider.token(Extensions::new()).await?;
522 token.into()
523 }
524}
525
526#[derive(Debug, PartialEq, serde::Deserialize)]
527pub(crate) struct AuthorizedUser {
528 #[serde(rename = "type")]
529 cred_type: String,
530 client_id: String,
531 client_secret: String,
532 refresh_token: String,
533 #[serde(skip_serializing_if = "Option::is_none")]
534 token_uri: Option<String>,
535 #[serde(skip_serializing_if = "Option::is_none")]
536 quota_project_id: Option<String>,
537}
538
539#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
540pub(crate) enum RefreshGrantType {
541 #[serde(rename = "refresh_token")]
542 RefreshToken,
543}
544
545#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
546pub(crate) struct Oauth2RefreshRequest {
547 pub(crate) grant_type: RefreshGrantType,
548 pub(crate) client_id: String,
549 pub(crate) client_secret: String,
550 pub(crate) refresh_token: String,
551 scopes: Option<String>,
552}
553
554#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
555pub(crate) struct Oauth2RefreshResponse {
556 pub(crate) access_token: String,
557 #[serde(skip_serializing_if = "Option::is_none")]
558 pub(crate) id_token: Option<String>,
559 #[serde(skip_serializing_if = "Option::is_none")]
560 pub(crate) scope: Option<String>,
561 #[serde(skip_serializing_if = "Option::is_none")]
562 pub(crate) expires_in: Option<u64>,
563 pub(crate) token_type: String,
564 #[serde(skip_serializing_if = "Option::is_none")]
565 pub(crate) refresh_token: Option<String>,
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use crate::credentials::tests::{
572 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
573 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
574 get_token_type_from_headers,
575 };
576 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
577 use crate::errors::CredentialsError;
578 use crate::token::tests::MockTokenProvider;
579 use http::StatusCode;
580 use http::header::AUTHORIZATION;
581 use httptest::cycle;
582 use httptest::matchers::{all_of, json_decoded, request};
583 use httptest::responders::{json_encoded, status_code};
584 use httptest::{Expectation, Server};
585
586 type TestResult = anyhow::Result<()>;
587
588 fn authorized_user_json(token_uri: String) -> Value {
589 serde_json::json!({
590 "client_id": "test-client-id",
591 "client_secret": "test-client-secret",
592 "refresh_token": "test-refresh-token",
593 "type": "authorized_user",
594 "token_uri": token_uri,
595 })
596 }
597
598 #[tokio::test]
599 async fn test_user_account_retries_on_transient_failures() -> TestResult {
600 let mut server = Server::run();
601 server.expect(
602 Expectation::matching(request::path("/token"))
603 .times(3)
604 .respond_with(status_code(503)),
605 );
606
607 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
608 .with_retry_policy(get_mock_auth_retry_policy(3))
609 .with_backoff_policy(get_mock_backoff_policy())
610 .with_retry_throttler(get_mock_retry_throttler())
611 .build()?;
612
613 let err = credentials.headers(Extensions::new()).await.unwrap_err();
614 assert!(err.is_transient(), "{err:?}");
615 server.verify_and_clear();
616 Ok(())
617 }
618
619 #[tokio::test]
620 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
621 let mut server = Server::run();
622 server.expect(
623 Expectation::matching(request::path("/token"))
624 .times(1)
625 .respond_with(status_code(401)),
626 );
627
628 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
629 .with_retry_policy(get_mock_auth_retry_policy(1))
630 .with_backoff_policy(get_mock_backoff_policy())
631 .with_retry_throttler(get_mock_retry_throttler())
632 .build()?;
633
634 let err = credentials.headers(Extensions::new()).await.unwrap_err();
635 assert!(!err.is_transient());
636 server.verify_and_clear();
637 Ok(())
638 }
639
640 #[tokio::test]
641 async fn test_user_account_retries_for_success() -> TestResult {
642 let mut server = Server::run();
643 let response = Oauth2RefreshResponse {
644 access_token: "test-access-token".to_string(),
645 id_token: None,
646 expires_in: Some(3600),
647 refresh_token: Some("test-refresh-token".to_string()),
648 scope: Some("scope1 scope2".to_string()),
649 token_type: "test-token-type".to_string(),
650 };
651
652 server.expect(
653 Expectation::matching(request::path("/token"))
654 .times(3)
655 .respond_with(cycle![
656 status_code(503).body("try-again"),
657 status_code(503).body("try-again"),
658 status_code(200)
659 .append_header("Content-Type", "application/json")
660 .body(serde_json::to_string(&response).unwrap()),
661 ]),
662 );
663
664 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
665 .with_retry_policy(get_mock_auth_retry_policy(3))
666 .with_backoff_policy(get_mock_backoff_policy())
667 .with_retry_throttler(get_mock_retry_throttler())
668 .build()?;
669
670 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
671 assert_eq!(token.unwrap(), "test-access-token");
672
673 server.verify_and_clear();
674 Ok(())
675 }
676
677 #[test]
678 fn debug_token_provider() {
679 let expected = UserTokenProvider {
680 client_id: "test-client-id".to_string(),
681 client_secret: "test-client-secret".to_string(),
682 refresh_token: "test-refresh-token".to_string(),
683 endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
684 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
685 source: UserTokenSource::AccessToken,
686 };
687 let fmt = format!("{expected:?}");
688 assert!(fmt.contains("test-client-id"), "{fmt}");
689 assert!(!fmt.contains("test-client-secret"), "{fmt}");
690 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
691 assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
692 assert!(
693 fmt.contains("https://www.googleapis.com/auth/pubsub"),
694 "{fmt}"
695 );
696 }
697
698 #[test]
699 fn authorized_user_full_from_json_success() {
700 let json = serde_json::json!({
701 "account": "",
702 "client_id": "test-client-id",
703 "client_secret": "test-client-secret",
704 "refresh_token": "test-refresh-token",
705 "type": "authorized_user",
706 "universe_domain": "googleapis.com",
707 "quota_project_id": "test-project",
708 "token_uri" : "test-token-uri",
709 });
710
711 let expected = AuthorizedUser {
712 cred_type: "authorized_user".to_string(),
713 client_id: "test-client-id".to_string(),
714 client_secret: "test-client-secret".to_string(),
715 refresh_token: "test-refresh-token".to_string(),
716 quota_project_id: Some("test-project".to_string()),
717 token_uri: Some("test-token-uri".to_string()),
718 };
719 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
720 assert_eq!(actual, expected);
721 }
722
723 #[test]
724 fn authorized_user_partial_from_json_success() {
725 let json = serde_json::json!({
726 "client_id": "test-client-id",
727 "client_secret": "test-client-secret",
728 "refresh_token": "test-refresh-token",
729 "type": "authorized_user",
730 });
731
732 let expected = AuthorizedUser {
733 cred_type: "authorized_user".to_string(),
734 client_id: "test-client-id".to_string(),
735 client_secret: "test-client-secret".to_string(),
736 refresh_token: "test-refresh-token".to_string(),
737 quota_project_id: None,
738 token_uri: None,
739 };
740 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
741 assert_eq!(actual, expected);
742 }
743
744 #[test]
745 fn authorized_user_from_json_parse_fail() {
746 let json_full = serde_json::json!({
747 "client_id": "test-client-id",
748 "client_secret": "test-client-secret",
749 "refresh_token": "test-refresh-token",
750 "type": "authorized_user",
751 "quota_project_id": "test-project"
752 });
753
754 for required_field in ["client_id", "client_secret", "refresh_token"] {
755 let mut json = json_full.clone();
756 json[required_field].take();
758 serde_json::from_value::<AuthorizedUser>(json)
759 .err()
760 .unwrap();
761 }
762 }
763
764 #[tokio::test]
765 async fn default_universe_domain_success() {
766 let mock = TokenCache::new(MockTokenProvider::new());
767
768 let uc = UserCredentials {
769 token_provider: mock,
770 quota_project_id: None,
771 };
772 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
773 }
774
775 #[tokio::test]
776 async fn headers_success() -> TestResult {
777 let token = Token {
778 token: "test-token".to_string(),
779 token_type: "Bearer".to_string(),
780 expires_at: None,
781 metadata: None,
782 };
783
784 let mut mock = MockTokenProvider::new();
785 mock.expect_token().times(1).return_once(|| Ok(token));
786
787 let uc = UserCredentials {
788 token_provider: TokenCache::new(mock),
789 quota_project_id: None,
790 };
791
792 let mut extensions = Extensions::new();
793 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
794 let (headers, entity_tag) = match cached_headers {
795 CacheableResource::New { entity_tag, data } => (data, entity_tag),
796 CacheableResource::NotModified => unreachable!("expecting new headers"),
797 };
798 let token = headers.get(AUTHORIZATION).unwrap();
799
800 assert_eq!(headers.len(), 1, "{headers:?}");
801 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
802 assert!(token.is_sensitive());
803
804 extensions.insert(entity_tag);
805
806 let cached_headers = uc.headers(extensions).await?;
807
808 match cached_headers {
809 CacheableResource::New { .. } => unreachable!("expecting new headers"),
810 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
811 };
812 Ok(())
813 }
814
815 #[tokio::test]
816 async fn headers_failure() {
817 let mut mock = MockTokenProvider::new();
818 mock.expect_token()
819 .times(1)
820 .return_once(|| Err(errors::non_retryable_from_str("fail")));
821
822 let uc = UserCredentials {
823 token_provider: TokenCache::new(mock),
824 quota_project_id: None,
825 };
826 let result = uc.headers(Extensions::new()).await;
827 assert!(result.is_err(), "{result:?}");
828 }
829
830 #[tokio::test]
831 async fn headers_with_quota_project_success() -> TestResult {
832 let token = Token {
833 token: "test-token".to_string(),
834 token_type: "Bearer".to_string(),
835 expires_at: None,
836 metadata: None,
837 };
838
839 let mut mock = MockTokenProvider::new();
840 mock.expect_token().times(1).return_once(|| Ok(token));
841
842 let uc = UserCredentials {
843 token_provider: TokenCache::new(mock),
844 quota_project_id: Some("test-project".to_string()),
845 };
846
847 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
848 let token = headers.get(AUTHORIZATION).unwrap();
849 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
850
851 assert_eq!(headers.len(), 2, "{headers:?}");
852 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
853 assert!(token.is_sensitive());
854 assert_eq!(
855 quota_project_header,
856 HeaderValue::from_static("test-project")
857 );
858 assert!(!quota_project_header.is_sensitive());
859 Ok(())
860 }
861
862 #[test]
863 fn oauth2_request_serde() {
864 let request = Oauth2RefreshRequest {
865 grant_type: RefreshGrantType::RefreshToken,
866 client_id: "test-client-id".to_string(),
867 client_secret: "test-client-secret".to_string(),
868 refresh_token: "test-refresh-token".to_string(),
869 scopes: Some("scope1 scope2".to_string()),
870 };
871
872 let json = serde_json::to_value(&request).unwrap();
873 let expected = serde_json::json!({
874 "grant_type": "refresh_token",
875 "client_id": "test-client-id",
876 "client_secret": "test-client-secret",
877 "refresh_token": "test-refresh-token",
878 "scopes": "scope1 scope2",
879 });
880 assert_eq!(json, expected);
881 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
882 assert_eq!(request, roundtrip);
883 }
884
885 #[test]
886 fn oauth2_response_serde_full() {
887 let response = Oauth2RefreshResponse {
888 access_token: "test-access-token".to_string(),
889 id_token: None,
890 scope: Some("scope1 scope2".to_string()),
891 expires_in: Some(3600),
892 token_type: "test-token-type".to_string(),
893 refresh_token: Some("test-refresh-token".to_string()),
894 };
895
896 let json = serde_json::to_value(&response).unwrap();
897 let expected = serde_json::json!({
898 "access_token": "test-access-token",
899 "scope": "scope1 scope2",
900 "expires_in": 3600,
901 "token_type": "test-token-type",
902 "refresh_token": "test-refresh-token"
903 });
904 assert_eq!(json, expected);
905 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
906 assert_eq!(response, roundtrip);
907 }
908
909 #[test]
910 fn oauth2_response_serde_partial() {
911 let response = Oauth2RefreshResponse {
912 access_token: "test-access-token".to_string(),
913 id_token: None,
914 scope: None,
915 expires_in: None,
916 token_type: "test-token-type".to_string(),
917 refresh_token: None,
918 };
919
920 let json = serde_json::to_value(&response).unwrap();
921 let expected = serde_json::json!({
922 "access_token": "test-access-token",
923 "token_type": "test-token-type",
924 });
925 assert_eq!(json, expected);
926 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
927 assert_eq!(response, roundtrip);
928 }
929
930 fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
931 request.client_id == "test-client-id"
932 && request.client_secret == "test-client-secret"
933 && request.refresh_token == "test-refresh-token"
934 && request.grant_type == RefreshGrantType::RefreshToken
935 && request.scopes == expected_scopes
936 }
937
938 #[tokio::test(start_paused = true)]
939 async fn token_provider_full() -> TestResult {
940 let server = Server::run();
941 let response = Oauth2RefreshResponse {
942 access_token: "test-access-token".to_string(),
943 id_token: None,
944 expires_in: Some(3600),
945 refresh_token: Some("test-refresh-token".to_string()),
946 scope: Some("scope1 scope2".to_string()),
947 token_type: "test-token-type".to_string(),
948 };
949 server.expect(
950 Expectation::matching(all_of![
951 request::path("/token"),
952 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
953 check_request(req, Some("scope1 scope2".to_string()))
954 }))
955 ])
956 .respond_with(json_encoded(response)),
957 );
958
959 let tp = UserTokenProvider {
960 client_id: "test-client-id".to_string(),
961 client_secret: "test-client-secret".to_string(),
962 refresh_token: "test-refresh-token".to_string(),
963 endpoint: server.url("/token").to_string(),
964 scopes: Some("scope1 scope2".to_string()),
965 source: UserTokenSource::AccessToken,
966 };
967 let now = Instant::now();
968 let token = tp.token().await?;
969 assert_eq!(token.token, "test-access-token");
970 assert_eq!(token.token_type, "test-token-type");
971 assert!(
972 token
973 .expires_at
974 .is_some_and(|d| d == now + Duration::from_secs(3600)),
975 "now: {:?}, expires_at: {:?}",
976 now,
977 token.expires_at
978 );
979
980 Ok(())
981 }
982
983 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
984 async fn credential_full_with_quota_project() -> TestResult {
985 let server = Server::run();
986 let response = Oauth2RefreshResponse {
987 access_token: "test-access-token".to_string(),
988 id_token: None,
989 expires_in: Some(3600),
990 refresh_token: Some("test-refresh-token".to_string()),
991 scope: None,
992 token_type: "test-token-type".to_string(),
993 };
994 server.expect(
995 Expectation::matching(all_of![
996 request::path("/token"),
997 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
998 check_request(req, None)
999 }))
1000 ])
1001 .respond_with(json_encoded(response)),
1002 );
1003
1004 let authorized_user = serde_json::json!({
1005 "client_id": "test-client-id",
1006 "client_secret": "test-client-secret",
1007 "refresh_token": "test-refresh-token",
1008 "type": "authorized_user",
1009 "token_uri": server.url("/token").to_string(),
1010 });
1011 let cred = Builder::new(authorized_user)
1012 .with_quota_project_id("test-project")
1013 .build()?;
1014
1015 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
1016 let token = headers.get(AUTHORIZATION).unwrap();
1017 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
1018
1019 assert_eq!(headers.len(), 2, "{headers:?}");
1020 assert_eq!(
1021 token,
1022 HeaderValue::from_static("test-token-type test-access-token")
1023 );
1024 assert!(token.is_sensitive());
1025 assert_eq!(
1026 quota_project_header,
1027 HeaderValue::from_static("test-project")
1028 );
1029 assert!(!quota_project_header.is_sensitive());
1030
1031 Ok(())
1032 }
1033
1034 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1035 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
1036 let mut server = Server::run();
1037 let response = Oauth2RefreshResponse {
1038 access_token: "test-access-token".to_string(),
1039 id_token: None,
1040 expires_in: Some(3600),
1041 refresh_token: Some("test-refresh-token".to_string()),
1042 scope: Some("scope1 scope2".to_string()),
1043 token_type: "test-token-type".to_string(),
1044 };
1045 server.expect(
1046 Expectation::matching(all_of![
1047 request::path("/token"),
1048 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1049 check_request(req, Some("scope1 scope2".to_string()))
1050 }))
1051 ])
1052 .times(1)
1053 .respond_with(json_encoded(response)),
1054 );
1055
1056 let json = serde_json::json!({
1057 "client_id": "test-client-id",
1058 "client_secret": "test-client-secret",
1059 "refresh_token": "test-refresh-token",
1060 "type": "authorized_user",
1061 "universe_domain": "googleapis.com",
1062 "quota_project_id": "test-project",
1063 "token_uri": server.url("/token").to_string(),
1064 });
1065
1066 let cred = Builder::new(json)
1067 .with_scopes(vec!["scope1", "scope2"])
1068 .build()?;
1069
1070 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1071 assert_eq!(token.unwrap(), "test-access-token");
1072
1073 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1074 assert_eq!(token.unwrap(), "test-access-token");
1075
1076 server.verify_and_clear();
1077
1078 Ok(())
1079 }
1080
1081 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1082 async fn credential_provider_partial() -> TestResult {
1083 let server = Server::run();
1084 let response = Oauth2RefreshResponse {
1085 access_token: "test-access-token".to_string(),
1086 id_token: None,
1087 expires_in: None,
1088 refresh_token: None,
1089 scope: None,
1090 token_type: "test-token-type".to_string(),
1091 };
1092 server.expect(
1093 Expectation::matching(all_of![
1094 request::path("/token"),
1095 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1096 check_request(req, None)
1097 }))
1098 ])
1099 .respond_with(json_encoded(response)),
1100 );
1101
1102 let authorized_user = serde_json::json!({
1103 "client_id": "test-client-id",
1104 "client_secret": "test-client-secret",
1105 "refresh_token": "test-refresh-token",
1106 "type": "authorized_user",
1107 "token_uri": server.url("/token").to_string()
1108 });
1109
1110 let uc = Builder::new(authorized_user).build()?;
1111 let headers = uc.headers(Extensions::new()).await?;
1112 assert_eq!(
1113 get_token_from_headers(headers.clone()).unwrap(),
1114 "test-access-token"
1115 );
1116 assert_eq!(
1117 get_token_type_from_headers(headers).unwrap(),
1118 "test-token-type"
1119 );
1120
1121 Ok(())
1122 }
1123
1124 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1125 async fn credential_provider_with_token_uri() -> TestResult {
1126 let server = Server::run();
1127 let response = Oauth2RefreshResponse {
1128 access_token: "test-access-token".to_string(),
1129 id_token: None,
1130 expires_in: None,
1131 refresh_token: None,
1132 scope: None,
1133 token_type: "test-token-type".to_string(),
1134 };
1135 server.expect(
1136 Expectation::matching(all_of![
1137 request::path("/token"),
1138 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1139 check_request(req, None)
1140 }))
1141 ])
1142 .respond_with(json_encoded(response)),
1143 );
1144
1145 let authorized_user = serde_json::json!({
1146 "client_id": "test-client-id",
1147 "client_secret": "test-client-secret",
1148 "refresh_token": "test-refresh-token",
1149 "type": "authorized_user",
1150 "token_uri": "test-endpoint"
1151 });
1152
1153 let uc = Builder::new(authorized_user)
1154 .with_token_uri(server.url("/token").to_string())
1155 .build()?;
1156 let headers = uc.headers(Extensions::new()).await?;
1157 assert_eq!(
1158 get_token_from_headers(headers.clone()).unwrap(),
1159 "test-access-token"
1160 );
1161 assert_eq!(
1162 get_token_type_from_headers(headers).unwrap(),
1163 "test-token-type"
1164 );
1165
1166 Ok(())
1167 }
1168
1169 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1170 async fn access_credential_provider_with_token_uri() -> TestResult {
1171 let server = Server::run();
1172 let response = Oauth2RefreshResponse {
1173 access_token: "test-access-token".to_string(),
1174 id_token: None,
1175 expires_in: None,
1176 refresh_token: None,
1177 scope: None,
1178 token_type: "test-token-type".to_string(),
1179 };
1180 server.expect(
1181 Expectation::matching(all_of![
1182 request::path("/token"),
1183 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1184 check_request(req, None)
1185 }))
1186 ])
1187 .respond_with(json_encoded(response)),
1188 );
1189
1190 let authorized_user = serde_json::json!({
1191 "client_id": "test-client-id",
1192 "client_secret": "test-client-secret",
1193 "refresh_token": "test-refresh-token",
1194 "type": "authorized_user",
1195 "token_uri": "test-endpoint"
1196 });
1197
1198 let uc = Builder::new(authorized_user)
1199 .with_token_uri(server.url("/token").to_string())
1200 .build_access_token_credentials()?;
1201 let access_token = uc.access_token().await?;
1202 assert_eq!(access_token.token, "test-access-token");
1203
1204 Ok(())
1205 }
1206
1207 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1208 async fn credential_provider_with_scopes() -> TestResult {
1209 let server = Server::run();
1210 let response = Oauth2RefreshResponse {
1211 access_token: "test-access-token".to_string(),
1212 id_token: None,
1213 expires_in: None,
1214 refresh_token: None,
1215 scope: Some("scope1 scope2".to_string()),
1216 token_type: "test-token-type".to_string(),
1217 };
1218 server.expect(
1219 Expectation::matching(all_of![
1220 request::path("/token"),
1221 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1222 check_request(req, Some("scope1 scope2".to_string()))
1223 }))
1224 ])
1225 .respond_with(json_encoded(response)),
1226 );
1227
1228 let authorized_user = serde_json::json!({
1229 "client_id": "test-client-id",
1230 "client_secret": "test-client-secret",
1231 "refresh_token": "test-refresh-token",
1232 "type": "authorized_user",
1233 "token_uri": "test-endpoint"
1234 });
1235
1236 let uc = Builder::new(authorized_user)
1237 .with_token_uri(server.url("/token").to_string())
1238 .with_scopes(vec!["scope1", "scope2"])
1239 .build()?;
1240 let headers = uc.headers(Extensions::new()).await?;
1241 assert_eq!(
1242 get_token_from_headers(headers.clone()).unwrap(),
1243 "test-access-token"
1244 );
1245 assert_eq!(
1246 get_token_type_from_headers(headers).unwrap(),
1247 "test-token-type"
1248 );
1249
1250 Ok(())
1251 }
1252
1253 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1254 async fn credential_provider_retryable_error() -> TestResult {
1255 let server = Server::run();
1256 server
1257 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1258
1259 let authorized_user = serde_json::json!({
1260 "client_id": "test-client-id",
1261 "client_secret": "test-client-secret",
1262 "refresh_token": "test-refresh-token",
1263 "type": "authorized_user",
1264 "token_uri": server.url("/token").to_string()
1265 });
1266
1267 let uc = Builder::new(authorized_user).build()?;
1268 let err = uc.headers(Extensions::new()).await.unwrap_err();
1269 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1270 assert!(original_err.is_transient());
1271
1272 let source = find_source_error::<reqwest::Error>(&err);
1273 assert!(
1274 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1275 "{err:?}"
1276 );
1277
1278 Ok(())
1279 }
1280
1281 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1282 async fn token_provider_nonretryable_error() -> TestResult {
1283 let server = Server::run();
1284 server
1285 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1286
1287 let authorized_user = serde_json::json!({
1288 "client_id": "test-client-id",
1289 "client_secret": "test-client-secret",
1290 "refresh_token": "test-refresh-token",
1291 "type": "authorized_user",
1292 "token_uri": server.url("/token").to_string()
1293 });
1294
1295 let uc = Builder::new(authorized_user).build()?;
1296 let err = uc.headers(Extensions::new()).await.unwrap_err();
1297 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1298 assert!(!original_err.is_transient());
1299
1300 let source = find_source_error::<reqwest::Error>(&err);
1301 assert!(
1302 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1303 "{err:?}"
1304 );
1305
1306 Ok(())
1307 }
1308
1309 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1310 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1311 let server = Server::run();
1312 server.expect(
1313 Expectation::matching(request::path("/token"))
1314 .respond_with(json_encoded("bad json".to_string())),
1315 );
1316
1317 let authorized_user = serde_json::json!({
1318 "client_id": "test-client-id",
1319 "client_secret": "test-client-secret",
1320 "refresh_token": "test-refresh-token",
1321 "type": "authorized_user",
1322 "token_uri": server.url("/token").to_string()
1323 });
1324
1325 let uc = Builder::new(authorized_user).build()?;
1326 let e = uc.headers(Extensions::new()).await.err().unwrap();
1327 assert!(!e.is_transient(), "{e}");
1328
1329 Ok(())
1330 }
1331
1332 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1333 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1334 let authorized_user = serde_json::json!({
1335 "client_secret": "test-client-secret",
1336 "refresh_token": "test-refresh-token",
1337 "type": "authorized_user",
1338 });
1339
1340 let e = Builder::new(authorized_user).build().unwrap_err();
1341 assert!(e.is_parsing(), "{e}");
1342
1343 Ok(())
1344 }
1345}