1use crate::build_errors::Error as BuilderError;
98use crate::constants::OAUTH2_TOKEN_SERVER_URL;
99use crate::credentials::dynamic::CredentialsProvider;
100use crate::credentials::{CacheableResource, Credentials};
101use crate::errors::{self, CredentialsError};
102use crate::headers_util::build_cacheable_headers;
103use crate::retry::Builder as RetryTokenProviderBuilder;
104use crate::token::{CachedTokenProvider, Token, TokenProvider};
105use crate::token_cache::TokenCache;
106use crate::{BuildResult, Result};
107use gax::backoff_policy::BackoffPolicyArg;
108use gax::retry_policy::RetryPolicyArg;
109use gax::retry_throttler::RetryThrottlerArg;
110use http::header::CONTENT_TYPE;
111use http::{Extensions, HeaderMap, HeaderValue};
112use reqwest::{Client, Method};
113use serde_json::Value;
114use std::sync::Arc;
115use tokio::time::{Duration, Instant};
116
117pub struct Builder {
128 authorized_user: Value,
129 scopes: Option<Vec<String>>,
130 quota_project_id: Option<String>,
131 token_uri: Option<String>,
132 retry_builder: RetryTokenProviderBuilder,
133}
134
135impl Builder {
136 pub fn new(authorized_user: Value) -> Self {
143 Self {
144 authorized_user,
145 scopes: None,
146 quota_project_id: None,
147 token_uri: None,
148 retry_builder: RetryTokenProviderBuilder::default(),
149 }
150 }
151
152 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
166 self.token_uri = Some(token_uri.into());
167 self
168 }
169
170 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191 where
192 I: IntoIterator<Item = S>,
193 S: Into<String>,
194 {
195 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196 self
197 }
198
199 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
220 self.quota_project_id = Some(quota_project_id.into());
221 self
222 }
223
224 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
245 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
246 self
247 }
248
249 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
270 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
271 self
272 }
273
274 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
301 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
302 self
303 }
304
305 pub fn build(self) -> BuildResult<Credentials> {
318 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
319 .map_err(BuilderError::parsing)?;
320 let endpoint = self
321 .token_uri
322 .or(authorized_user.token_uri)
323 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
324 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
325
326 let token_provider = UserTokenProvider {
327 client_id: authorized_user.client_id,
328 client_secret: authorized_user.client_secret,
329 refresh_token: authorized_user.refresh_token,
330 endpoint,
331 scopes: self.scopes.map(|scopes| scopes.join(" ")),
332 source: UserTokenSource::AccessToken,
333 };
334
335 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337 Ok(Credentials {
338 inner: Arc::new(UserCredentials {
339 token_provider,
340 quota_project_id,
341 }),
342 })
343 }
344}
345
346#[derive(PartialEq)]
347struct UserTokenProvider {
348 client_id: String,
349 client_secret: String,
350 refresh_token: String,
351 endpoint: String,
352 scopes: Option<String>,
353 source: UserTokenSource,
354}
355
356#[allow(dead_code)]
357#[derive(PartialEq)]
358enum UserTokenSource {
359 IdToken,
360 AccessToken,
361}
362
363impl std::fmt::Debug for UserTokenProvider {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.debug_struct("UserCredentials")
366 .field("client_id", &self.client_id)
367 .field("client_secret", &"[censored]")
368 .field("refresh_token", &"[censored]")
369 .field("endpoint", &self.endpoint)
370 .field("scopes", &self.scopes)
371 .finish()
372 }
373}
374
375#[async_trait::async_trait]
376impl TokenProvider for UserTokenProvider {
377 async fn token(&self) -> Result<Token> {
378 let client = Client::new();
379
380 let req = Oauth2RefreshRequest {
382 grant_type: RefreshGrantType::RefreshToken,
383 client_id: self.client_id.clone(),
384 client_secret: self.client_secret.clone(),
385 refresh_token: self.refresh_token.clone(),
386 scopes: self.scopes.clone(),
387 };
388 let header = HeaderValue::from_static("application/json");
389 let builder = client
390 .request(Method::POST, self.endpoint.as_str())
391 .header(CONTENT_TYPE, header)
392 .json(&req);
393 let resp = builder
394 .send()
395 .await
396 .map_err(|e| errors::from_http_error(e, MSG))?;
397
398 if !resp.status().is_success() {
400 let err = errors::from_http_response(resp, MSG).await;
401 return Err(err);
402 }
403 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
404 let retryable = !e.is_decode();
405 CredentialsError::from_source(retryable, e)
406 })?;
407
408 let token = match self.source {
409 UserTokenSource::AccessToken => Ok(response.access_token),
410 UserTokenSource::IdToken => response
411 .id_token
412 .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
413 }?;
414 let token = Token {
415 token,
416 token_type: response.token_type,
417 expires_at: response
418 .expires_in
419 .map(|d| Instant::now() + Duration::from_secs(d)),
420 metadata: None,
421 };
422 Ok(token)
423 }
424}
425
426const MSG: &str = "failed to refresh user access token";
427const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
428gcloud running 'gcloud auth application-default login`";
429
430#[derive(Debug)]
434pub(crate) struct UserCredentials<T>
435where
436 T: CachedTokenProvider,
437{
438 token_provider: T,
439 quota_project_id: Option<String>,
440}
441
442#[async_trait::async_trait]
443impl<T> CredentialsProvider for UserCredentials<T>
444where
445 T: CachedTokenProvider,
446{
447 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
448 let token = self.token_provider.token(extensions).await?;
449 build_cacheable_headers(&token, &self.quota_project_id)
450 }
451}
452
453#[derive(Debug, PartialEq, serde::Deserialize)]
454pub(crate) struct AuthorizedUser {
455 #[serde(rename = "type")]
456 cred_type: String,
457 client_id: String,
458 client_secret: String,
459 refresh_token: String,
460 #[serde(skip_serializing_if = "Option::is_none")]
461 token_uri: Option<String>,
462 #[serde(skip_serializing_if = "Option::is_none")]
463 quota_project_id: Option<String>,
464}
465
466#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
467enum RefreshGrantType {
468 #[serde(rename = "refresh_token")]
469 RefreshToken,
470}
471
472#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
473struct Oauth2RefreshRequest {
474 grant_type: RefreshGrantType,
475 client_id: String,
476 client_secret: String,
477 refresh_token: String,
478 scopes: Option<String>,
479}
480
481#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
482struct Oauth2RefreshResponse {
483 access_token: String,
484 #[serde(skip_serializing_if = "Option::is_none")]
485 id_token: Option<String>,
486 #[serde(skip_serializing_if = "Option::is_none")]
487 scope: Option<String>,
488 #[serde(skip_serializing_if = "Option::is_none")]
489 expires_in: Option<u64>,
490 token_type: String,
491 #[serde(skip_serializing_if = "Option::is_none")]
492 refresh_token: Option<String>,
493}
494
495#[cfg(google_cloud_unstable_id_token)]
496pub mod idtoken {
497 use crate::build_errors::Error as BuilderError;
512 use crate::constants::OAUTH2_TOKEN_SERVER_URL;
513 use crate::{
514 BuildResult, Result,
515 credentials::{
516 idtoken::{IDTokenCredentials, dynamic::IDTokenCredentialsProvider},
517 user_account::AuthorizedUser,
518 },
519 token::TokenProvider,
520 };
521 use async_trait::async_trait;
522 use serde_json::Value;
523 use std::sync::Arc;
524
525 #[derive(Debug)]
526 struct UserAccountCredentials<T>
527 where
528 T: TokenProvider,
529 {
530 token_provider: T,
531 }
532
533 #[async_trait]
534 impl<T> IDTokenCredentialsProvider for UserAccountCredentials<T>
535 where
536 T: TokenProvider,
537 {
538 async fn id_token(&self) -> Result<String> {
539 self.token_provider.token().await.map(|token| token.token)
540 }
541 }
542
543 pub struct Builder {
545 authorized_user: Value,
546 token_uri: Option<String>,
547 }
548
549 impl Builder {
550 pub fn new(authorized_user: Value) -> Self {
558 Self {
559 authorized_user,
560 token_uri: None,
561 }
562 }
563
564 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
569 self.token_uri = Some(token_uri.into());
570 self
571 }
572
573 fn build_token_provider(self) -> BuildResult<super::UserTokenProvider> {
574 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
575 .map_err(BuilderError::parsing)?;
576 let endpoint = self
577 .token_uri
578 .or(authorized_user.token_uri)
579 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
580 Ok(super::UserTokenProvider {
581 client_id: authorized_user.client_id,
582 client_secret: authorized_user.client_secret,
583 refresh_token: authorized_user.refresh_token,
584 endpoint,
585 source: super::UserTokenSource::IdToken,
586 scopes: None,
587 })
588 }
589
590 pub fn build(self) -> BuildResult<IDTokenCredentials> {
604 let creds = UserAccountCredentials {
605 token_provider: self.build_token_provider()?,
606 };
607 Ok(IDTokenCredentials {
608 inner: Arc::new(creds),
609 })
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use crate::credentials::tests::{
618 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
619 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
620 get_token_type_from_headers,
621 };
622 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
623 use crate::errors::CredentialsError;
624 use crate::token::tests::MockTokenProvider;
625 use http::StatusCode;
626 use http::header::AUTHORIZATION;
627 use httptest::cycle;
628 use httptest::matchers::{all_of, json_decoded, request};
629 use httptest::responders::{json_encoded, status_code};
630 use httptest::{Expectation, Server};
631
632 type TestResult = anyhow::Result<()>;
633
634 pub(crate) fn authorized_user_json(token_uri: String) -> Value {
635 serde_json::json!({
636 "client_id": "test-client-id",
637 "client_secret": "test-client-secret",
638 "refresh_token": "test-refresh-token",
639 "type": "authorized_user",
640 "token_uri": token_uri,
641 })
642 }
643
644 #[tokio::test]
645 async fn test_user_account_retries_on_transient_failures() -> TestResult {
646 let mut server = Server::run();
647 server.expect(
648 Expectation::matching(request::path("/token"))
649 .times(3)
650 .respond_with(status_code(503)),
651 );
652
653 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
654 .with_retry_policy(get_mock_auth_retry_policy(3))
655 .with_backoff_policy(get_mock_backoff_policy())
656 .with_retry_throttler(get_mock_retry_throttler())
657 .build()?;
658
659 let err = credentials.headers(Extensions::new()).await.unwrap_err();
660 assert!(!err.is_transient());
661 server.verify_and_clear();
662 Ok(())
663 }
664
665 #[tokio::test]
666 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
667 let mut server = Server::run();
668 server.expect(
669 Expectation::matching(request::path("/token"))
670 .times(1)
671 .respond_with(status_code(401)),
672 );
673
674 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
675 .with_retry_policy(get_mock_auth_retry_policy(1))
676 .with_backoff_policy(get_mock_backoff_policy())
677 .with_retry_throttler(get_mock_retry_throttler())
678 .build()?;
679
680 let err = credentials.headers(Extensions::new()).await.unwrap_err();
681 assert!(!err.is_transient());
682 server.verify_and_clear();
683 Ok(())
684 }
685
686 #[tokio::test]
687 async fn test_user_account_retries_for_success() -> TestResult {
688 let mut server = Server::run();
689 let response = Oauth2RefreshResponse {
690 access_token: "test-access-token".to_string(),
691 id_token: None,
692 expires_in: Some(3600),
693 refresh_token: Some("test-refresh-token".to_string()),
694 scope: Some("scope1 scope2".to_string()),
695 token_type: "test-token-type".to_string(),
696 };
697
698 server.expect(
699 Expectation::matching(request::path("/token"))
700 .times(3)
701 .respond_with(cycle![
702 status_code(503).body("try-again"),
703 status_code(503).body("try-again"),
704 status_code(200)
705 .append_header("Content-Type", "application/json")
706 .body(serde_json::to_string(&response).unwrap()),
707 ]),
708 );
709
710 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
711 .with_retry_policy(get_mock_auth_retry_policy(3))
712 .with_backoff_policy(get_mock_backoff_policy())
713 .with_retry_throttler(get_mock_retry_throttler())
714 .build()?;
715
716 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
717 assert_eq!(token.unwrap(), "test-access-token");
718
719 server.verify_and_clear();
720 Ok(())
721 }
722
723 #[test]
724 fn debug_token_provider() {
725 let expected = UserTokenProvider {
726 client_id: "test-client-id".to_string(),
727 client_secret: "test-client-secret".to_string(),
728 refresh_token: "test-refresh-token".to_string(),
729 endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
730 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
731 source: UserTokenSource::AccessToken,
732 };
733 let fmt = format!("{expected:?}");
734 assert!(fmt.contains("test-client-id"), "{fmt}");
735 assert!(!fmt.contains("test-client-secret"), "{fmt}");
736 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
737 assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
738 assert!(
739 fmt.contains("https://www.googleapis.com/auth/pubsub"),
740 "{fmt}"
741 );
742 }
743
744 #[test]
745 fn authorized_user_full_from_json_success() {
746 let json = serde_json::json!({
747 "account": "",
748 "client_id": "test-client-id",
749 "client_secret": "test-client-secret",
750 "refresh_token": "test-refresh-token",
751 "type": "authorized_user",
752 "universe_domain": "googleapis.com",
753 "quota_project_id": "test-project",
754 "token_uri" : "test-token-uri",
755 });
756
757 let expected = AuthorizedUser {
758 cred_type: "authorized_user".to_string(),
759 client_id: "test-client-id".to_string(),
760 client_secret: "test-client-secret".to_string(),
761 refresh_token: "test-refresh-token".to_string(),
762 quota_project_id: Some("test-project".to_string()),
763 token_uri: Some("test-token-uri".to_string()),
764 };
765 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
766 assert_eq!(actual, expected);
767 }
768
769 #[test]
770 fn authorized_user_partial_from_json_success() {
771 let json = serde_json::json!({
772 "client_id": "test-client-id",
773 "client_secret": "test-client-secret",
774 "refresh_token": "test-refresh-token",
775 "type": "authorized_user",
776 });
777
778 let expected = AuthorizedUser {
779 cred_type: "authorized_user".to_string(),
780 client_id: "test-client-id".to_string(),
781 client_secret: "test-client-secret".to_string(),
782 refresh_token: "test-refresh-token".to_string(),
783 quota_project_id: None,
784 token_uri: None,
785 };
786 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
787 assert_eq!(actual, expected);
788 }
789
790 #[test]
791 fn authorized_user_from_json_parse_fail() {
792 let json_full = serde_json::json!({
793 "client_id": "test-client-id",
794 "client_secret": "test-client-secret",
795 "refresh_token": "test-refresh-token",
796 "type": "authorized_user",
797 "quota_project_id": "test-project"
798 });
799
800 for required_field in ["client_id", "client_secret", "refresh_token"] {
801 let mut json = json_full.clone();
802 json[required_field].take();
804 serde_json::from_value::<AuthorizedUser>(json)
805 .err()
806 .unwrap();
807 }
808 }
809
810 #[tokio::test]
811 async fn default_universe_domain_success() {
812 let mock = TokenCache::new(MockTokenProvider::new());
813
814 let uc = UserCredentials {
815 token_provider: mock,
816 quota_project_id: None,
817 };
818 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
819 }
820
821 #[tokio::test]
822 async fn headers_success() -> TestResult {
823 let token = Token {
824 token: "test-token".to_string(),
825 token_type: "Bearer".to_string(),
826 expires_at: None,
827 metadata: None,
828 };
829
830 let mut mock = MockTokenProvider::new();
831 mock.expect_token().times(1).return_once(|| Ok(token));
832
833 let uc = UserCredentials {
834 token_provider: TokenCache::new(mock),
835 quota_project_id: None,
836 };
837
838 let mut extensions = Extensions::new();
839 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
840 let (headers, entity_tag) = match cached_headers {
841 CacheableResource::New { entity_tag, data } => (data, entity_tag),
842 CacheableResource::NotModified => unreachable!("expecting new headers"),
843 };
844 let token = headers.get(AUTHORIZATION).unwrap();
845
846 assert_eq!(headers.len(), 1, "{headers:?}");
847 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
848 assert!(token.is_sensitive());
849
850 extensions.insert(entity_tag);
851
852 let cached_headers = uc.headers(extensions).await?;
853
854 match cached_headers {
855 CacheableResource::New { .. } => unreachable!("expecting new headers"),
856 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
857 };
858 Ok(())
859 }
860
861 #[tokio::test]
862 async fn headers_failure() {
863 let mut mock = MockTokenProvider::new();
864 mock.expect_token()
865 .times(1)
866 .return_once(|| Err(errors::non_retryable_from_str("fail")));
867
868 let uc = UserCredentials {
869 token_provider: TokenCache::new(mock),
870 quota_project_id: None,
871 };
872 assert!(uc.headers(Extensions::new()).await.is_err());
873 }
874
875 #[tokio::test]
876 async fn headers_with_quota_project_success() -> TestResult {
877 let token = Token {
878 token: "test-token".to_string(),
879 token_type: "Bearer".to_string(),
880 expires_at: None,
881 metadata: None,
882 };
883
884 let mut mock = MockTokenProvider::new();
885 mock.expect_token().times(1).return_once(|| Ok(token));
886
887 let uc = UserCredentials {
888 token_provider: TokenCache::new(mock),
889 quota_project_id: Some("test-project".to_string()),
890 };
891
892 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
893 let token = headers.get(AUTHORIZATION).unwrap();
894 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
895
896 assert_eq!(headers.len(), 2, "{headers:?}");
897 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
898 assert!(token.is_sensitive());
899 assert_eq!(
900 quota_project_header,
901 HeaderValue::from_static("test-project")
902 );
903 assert!(!quota_project_header.is_sensitive());
904 Ok(())
905 }
906
907 #[test]
908 fn oauth2_request_serde() {
909 let request = Oauth2RefreshRequest {
910 grant_type: RefreshGrantType::RefreshToken,
911 client_id: "test-client-id".to_string(),
912 client_secret: "test-client-secret".to_string(),
913 refresh_token: "test-refresh-token".to_string(),
914 scopes: Some("scope1 scope2".to_string()),
915 };
916
917 let json = serde_json::to_value(&request).unwrap();
918 let expected = serde_json::json!({
919 "grant_type": "refresh_token",
920 "client_id": "test-client-id",
921 "client_secret": "test-client-secret",
922 "refresh_token": "test-refresh-token",
923 "scopes": "scope1 scope2",
924 });
925 assert_eq!(json, expected);
926 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
927 assert_eq!(request, roundtrip);
928 }
929
930 #[test]
931 fn oauth2_response_serde_full() {
932 let response = Oauth2RefreshResponse {
933 access_token: "test-access-token".to_string(),
934 id_token: None,
935 scope: Some("scope1 scope2".to_string()),
936 expires_in: Some(3600),
937 token_type: "test-token-type".to_string(),
938 refresh_token: Some("test-refresh-token".to_string()),
939 };
940
941 let json = serde_json::to_value(&response).unwrap();
942 let expected = serde_json::json!({
943 "access_token": "test-access-token",
944 "scope": "scope1 scope2",
945 "expires_in": 3600,
946 "token_type": "test-token-type",
947 "refresh_token": "test-refresh-token"
948 });
949 assert_eq!(json, expected);
950 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
951 assert_eq!(response, roundtrip);
952 }
953
954 #[test]
955 fn oauth2_response_serde_partial() {
956 let response = Oauth2RefreshResponse {
957 access_token: "test-access-token".to_string(),
958 id_token: None,
959 scope: None,
960 expires_in: None,
961 token_type: "test-token-type".to_string(),
962 refresh_token: None,
963 };
964
965 let json = serde_json::to_value(&response).unwrap();
966 let expected = serde_json::json!({
967 "access_token": "test-access-token",
968 "token_type": "test-token-type",
969 });
970 assert_eq!(json, expected);
971 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
972 assert_eq!(response, roundtrip);
973 }
974
975 pub(crate) fn check_request(
976 request: &Oauth2RefreshRequest,
977 expected_scopes: Option<String>,
978 ) -> bool {
979 request.client_id == "test-client-id"
980 && request.client_secret == "test-client-secret"
981 && request.refresh_token == "test-refresh-token"
982 && request.grant_type == RefreshGrantType::RefreshToken
983 && request.scopes == expected_scopes
984 }
985
986 #[tokio::test(start_paused = true)]
987 async fn token_provider_full() -> TestResult {
988 let server = Server::run();
989 let response = Oauth2RefreshResponse {
990 access_token: "test-access-token".to_string(),
991 id_token: None,
992 expires_in: Some(3600),
993 refresh_token: Some("test-refresh-token".to_string()),
994 scope: Some("scope1 scope2".to_string()),
995 token_type: "test-token-type".to_string(),
996 };
997 server.expect(
998 Expectation::matching(all_of![
999 request::path("/token"),
1000 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1001 check_request(req, Some("scope1 scope2".to_string()))
1002 }))
1003 ])
1004 .respond_with(json_encoded(response)),
1005 );
1006
1007 let tp = UserTokenProvider {
1008 client_id: "test-client-id".to_string(),
1009 client_secret: "test-client-secret".to_string(),
1010 refresh_token: "test-refresh-token".to_string(),
1011 endpoint: server.url("/token").to_string(),
1012 scopes: Some("scope1 scope2".to_string()),
1013 source: UserTokenSource::AccessToken,
1014 };
1015 let now = Instant::now();
1016 let token = tp.token().await?;
1017 assert_eq!(token.token, "test-access-token");
1018 assert_eq!(token.token_type, "test-token-type");
1019 assert!(
1020 token
1021 .expires_at
1022 .is_some_and(|d| d == now + Duration::from_secs(3600)),
1023 "now: {:?}, expires_at: {:?}",
1024 now,
1025 token.expires_at
1026 );
1027
1028 Ok(())
1029 }
1030
1031 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1032 async fn credential_full_with_quota_project() -> TestResult {
1033 let server = Server::run();
1034 let response = Oauth2RefreshResponse {
1035 access_token: "test-access-token".to_string(),
1036 id_token: None,
1037 expires_in: Some(3600),
1038 refresh_token: Some("test-refresh-token".to_string()),
1039 scope: None,
1040 token_type: "test-token-type".to_string(),
1041 };
1042 server.expect(
1043 Expectation::matching(all_of![
1044 request::path("/token"),
1045 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1046 check_request(req, None)
1047 }))
1048 ])
1049 .respond_with(json_encoded(response)),
1050 );
1051
1052 let authorized_user = serde_json::json!({
1053 "client_id": "test-client-id",
1054 "client_secret": "test-client-secret",
1055 "refresh_token": "test-refresh-token",
1056 "type": "authorized_user",
1057 "token_uri": server.url("/token").to_string(),
1058 });
1059 let cred = Builder::new(authorized_user)
1060 .with_quota_project_id("test-project")
1061 .build()?;
1062
1063 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
1064 let token = headers.get(AUTHORIZATION).unwrap();
1065 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
1066
1067 assert_eq!(headers.len(), 2, "{headers:?}");
1068 assert_eq!(
1069 token,
1070 HeaderValue::from_static("test-token-type test-access-token")
1071 );
1072 assert!(token.is_sensitive());
1073 assert_eq!(
1074 quota_project_header,
1075 HeaderValue::from_static("test-project")
1076 );
1077 assert!(!quota_project_header.is_sensitive());
1078
1079 Ok(())
1080 }
1081
1082 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1083 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
1084 let mut server = Server::run();
1085 let response = Oauth2RefreshResponse {
1086 access_token: "test-access-token".to_string(),
1087 id_token: None,
1088 expires_in: Some(3600),
1089 refresh_token: Some("test-refresh-token".to_string()),
1090 scope: Some("scope1 scope2".to_string()),
1091 token_type: "test-token-type".to_string(),
1092 };
1093 server.expect(
1094 Expectation::matching(all_of![
1095 request::path("/token"),
1096 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1097 check_request(req, Some("scope1 scope2".to_string()))
1098 }))
1099 ])
1100 .times(1)
1101 .respond_with(json_encoded(response)),
1102 );
1103
1104 let json = serde_json::json!({
1105 "client_id": "test-client-id",
1106 "client_secret": "test-client-secret",
1107 "refresh_token": "test-refresh-token",
1108 "type": "authorized_user",
1109 "universe_domain": "googleapis.com",
1110 "quota_project_id": "test-project",
1111 "token_uri": server.url("/token").to_string(),
1112 });
1113
1114 let cred = Builder::new(json)
1115 .with_scopes(vec!["scope1", "scope2"])
1116 .build()?;
1117
1118 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1119 assert_eq!(token.unwrap(), "test-access-token");
1120
1121 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1122 assert_eq!(token.unwrap(), "test-access-token");
1123
1124 server.verify_and_clear();
1125
1126 Ok(())
1127 }
1128
1129 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1130 async fn credential_provider_partial() -> TestResult {
1131 let server = Server::run();
1132 let response = Oauth2RefreshResponse {
1133 access_token: "test-access-token".to_string(),
1134 id_token: None,
1135 expires_in: None,
1136 refresh_token: None,
1137 scope: None,
1138 token_type: "test-token-type".to_string(),
1139 };
1140 server.expect(
1141 Expectation::matching(all_of![
1142 request::path("/token"),
1143 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1144 check_request(req, None)
1145 }))
1146 ])
1147 .respond_with(json_encoded(response)),
1148 );
1149
1150 let authorized_user = serde_json::json!({
1151 "client_id": "test-client-id",
1152 "client_secret": "test-client-secret",
1153 "refresh_token": "test-refresh-token",
1154 "type": "authorized_user",
1155 "token_uri": server.url("/token").to_string()
1156 });
1157
1158 let uc = Builder::new(authorized_user).build()?;
1159 let headers = uc.headers(Extensions::new()).await?;
1160 assert_eq!(
1161 get_token_from_headers(headers.clone()).unwrap(),
1162 "test-access-token"
1163 );
1164 assert_eq!(
1165 get_token_type_from_headers(headers).unwrap(),
1166 "test-token-type"
1167 );
1168
1169 Ok(())
1170 }
1171
1172 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1173 async fn credential_provider_with_token_uri() -> TestResult {
1174 let server = Server::run();
1175 let response = Oauth2RefreshResponse {
1176 access_token: "test-access-token".to_string(),
1177 id_token: None,
1178 expires_in: None,
1179 refresh_token: None,
1180 scope: None,
1181 token_type: "test-token-type".to_string(),
1182 };
1183 server.expect(
1184 Expectation::matching(all_of![
1185 request::path("/token"),
1186 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1187 check_request(req, None)
1188 }))
1189 ])
1190 .respond_with(json_encoded(response)),
1191 );
1192
1193 let authorized_user = serde_json::json!({
1194 "client_id": "test-client-id",
1195 "client_secret": "test-client-secret",
1196 "refresh_token": "test-refresh-token",
1197 "type": "authorized_user",
1198 "token_uri": "test-endpoint"
1199 });
1200
1201 let uc = Builder::new(authorized_user)
1202 .with_token_uri(server.url("/token").to_string())
1203 .build()?;
1204 let headers = uc.headers(Extensions::new()).await?;
1205 assert_eq!(
1206 get_token_from_headers(headers.clone()).unwrap(),
1207 "test-access-token"
1208 );
1209 assert_eq!(
1210 get_token_type_from_headers(headers).unwrap(),
1211 "test-token-type"
1212 );
1213
1214 Ok(())
1215 }
1216
1217 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1218 async fn credential_provider_with_scopes() -> TestResult {
1219 let server = Server::run();
1220 let response = Oauth2RefreshResponse {
1221 access_token: "test-access-token".to_string(),
1222 id_token: None,
1223 expires_in: None,
1224 refresh_token: None,
1225 scope: Some("scope1 scope2".to_string()),
1226 token_type: "test-token-type".to_string(),
1227 };
1228 server.expect(
1229 Expectation::matching(all_of![
1230 request::path("/token"),
1231 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1232 check_request(req, Some("scope1 scope2".to_string()))
1233 }))
1234 ])
1235 .respond_with(json_encoded(response)),
1236 );
1237
1238 let authorized_user = serde_json::json!({
1239 "client_id": "test-client-id",
1240 "client_secret": "test-client-secret",
1241 "refresh_token": "test-refresh-token",
1242 "type": "authorized_user",
1243 "token_uri": "test-endpoint"
1244 });
1245
1246 let uc = Builder::new(authorized_user)
1247 .with_token_uri(server.url("/token").to_string())
1248 .with_scopes(vec!["scope1", "scope2"])
1249 .build()?;
1250 let headers = uc.headers(Extensions::new()).await?;
1251 assert_eq!(
1252 get_token_from_headers(headers.clone()).unwrap(),
1253 "test-access-token"
1254 );
1255 assert_eq!(
1256 get_token_type_from_headers(headers).unwrap(),
1257 "test-token-type"
1258 );
1259
1260 Ok(())
1261 }
1262
1263 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1264 async fn credential_provider_retryable_error() -> TestResult {
1265 let server = Server::run();
1266 server
1267 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1268
1269 let authorized_user = serde_json::json!({
1270 "client_id": "test-client-id",
1271 "client_secret": "test-client-secret",
1272 "refresh_token": "test-refresh-token",
1273 "type": "authorized_user",
1274 "token_uri": server.url("/token").to_string()
1275 });
1276
1277 let uc = Builder::new(authorized_user).build()?;
1278 let err = uc.headers(Extensions::new()).await.unwrap_err();
1279 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1280 assert!(original_err.is_transient());
1281
1282 let source = find_source_error::<reqwest::Error>(&err);
1283 assert!(
1284 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1285 "{err:?}"
1286 );
1287
1288 Ok(())
1289 }
1290
1291 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1292 async fn token_provider_nonretryable_error() -> TestResult {
1293 let server = Server::run();
1294 server
1295 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1296
1297 let authorized_user = serde_json::json!({
1298 "client_id": "test-client-id",
1299 "client_secret": "test-client-secret",
1300 "refresh_token": "test-refresh-token",
1301 "type": "authorized_user",
1302 "token_uri": server.url("/token").to_string()
1303 });
1304
1305 let uc = Builder::new(authorized_user).build()?;
1306 let err = uc.headers(Extensions::new()).await.unwrap_err();
1307 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1308 assert!(!original_err.is_transient());
1309
1310 let source = find_source_error::<reqwest::Error>(&err);
1311 assert!(
1312 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1313 "{err:?}"
1314 );
1315
1316 Ok(())
1317 }
1318
1319 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1320 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1321 let server = Server::run();
1322 server.expect(
1323 Expectation::matching(request::path("/token"))
1324 .respond_with(json_encoded("bad json".to_string())),
1325 );
1326
1327 let authorized_user = serde_json::json!({
1328 "client_id": "test-client-id",
1329 "client_secret": "test-client-secret",
1330 "refresh_token": "test-refresh-token",
1331 "type": "authorized_user",
1332 "token_uri": server.url("/token").to_string()
1333 });
1334
1335 let uc = Builder::new(authorized_user).build()?;
1336 let e = uc.headers(Extensions::new()).await.err().unwrap();
1337 assert!(!e.is_transient(), "{e}");
1338
1339 Ok(())
1340 }
1341
1342 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1343 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1344 let authorized_user = serde_json::json!({
1345 "client_secret": "test-client-secret",
1346 "refresh_token": "test-refresh-token",
1347 "type": "authorized_user",
1348 });
1349
1350 let e = Builder::new(authorized_user).build().unwrap_err();
1351 assert!(e.is_parsing(), "{e}");
1352
1353 Ok(())
1354 }
1355}
1356
1357#[cfg(all(test, google_cloud_unstable_id_token))]
1358mod unstable_tests {
1359 use super::tests::*;
1360 use super::*;
1361 use crate::credentials::tests::find_source_error;
1362 use http::StatusCode;
1363 use httptest::matchers::{all_of, json_decoded, request};
1364 use httptest::responders::{json_encoded, status_code};
1365 use httptest::{Expectation, Server};
1366
1367 type TestResult = anyhow::Result<()>;
1368
1369 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1370 async fn id_token_success() -> TestResult {
1371 let server = Server::run();
1372 let response = Oauth2RefreshResponse {
1373 access_token: "test-access-token".to_string(),
1374 id_token: Some("test-id-token".to_string()),
1375 expires_in: Some(3600),
1376 refresh_token: Some("test-refresh-token".to_string()),
1377 scope: None,
1378 token_type: "Bearer".to_string(),
1379 };
1380 server.expect(
1381 Expectation::matching(all_of![
1382 request::path("/token"),
1383 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1384 check_request(req, None)
1385 }))
1386 ])
1387 .respond_with(json_encoded(response)),
1388 );
1389
1390 let authorized_user = authorized_user_json(server.url("/token").to_string());
1391 let creds = super::idtoken::Builder::new(authorized_user).build()?;
1392 let id_token = creds.id_token().await?;
1393 assert_eq!(id_token, "test-id-token");
1394 Ok(())
1395 }
1396
1397 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1398 async fn id_token_missing_id_token_in_response() -> TestResult {
1399 let server = Server::run();
1400 let response = Oauth2RefreshResponse {
1401 access_token: "test-access-token".to_string(),
1402 id_token: None, expires_in: Some(3600),
1404 refresh_token: Some("test-refresh-token".to_string()),
1405 scope: None,
1406 token_type: "Bearer".to_string(),
1407 };
1408 server.expect(
1409 Expectation::matching(all_of![
1410 request::path("/token"),
1411 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1412 check_request(req, None)
1413 }))
1414 ])
1415 .respond_with(json_encoded(response)),
1416 );
1417
1418 let authorized_user = authorized_user_json(server.url("/token").to_string());
1419 let creds = super::idtoken::Builder::new(authorized_user).build()?;
1420 let err = creds.id_token().await.unwrap_err();
1421 assert!(!err.is_transient());
1422 assert!(err.to_string().contains(MISSING_ID_TOKEN_MSG));
1423 Ok(())
1424 }
1425
1426 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1427 async fn id_token_builder_malformed_authorized_json_nonretryable() -> TestResult {
1428 let authorized_user = serde_json::json!({
1429 "client_secret": "test-client-secret",
1430 "refresh_token": "test-refresh-token",
1431 "type": "authorized_user",
1432 });
1433
1434 let e = super::idtoken::Builder::new(authorized_user)
1435 .build()
1436 .unwrap_err();
1437 assert!(e.is_parsing(), "{e}");
1438 Ok(())
1439 }
1440
1441 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1442 async fn id_token_retryable_error() -> TestResult {
1443 let server = Server::run();
1444 server
1445 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1446
1447 let authorized_user = authorized_user_json(server.url("/token").to_string());
1448 let creds = super::idtoken::Builder::new(authorized_user).build()?;
1449 let err = creds.id_token().await.unwrap_err();
1450 assert!(err.is_transient());
1451
1452 let source = find_source_error::<reqwest::Error>(&err);
1453 assert!(
1454 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1455 "{err:?}"
1456 );
1457 Ok(())
1458 }
1459
1460 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1461 async fn id_token_nonretryable_error() -> TestResult {
1462 let server = Server::run();
1463 server
1464 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1465
1466 let authorized_user = authorized_user_json(server.url("/token").to_string());
1467 let creds = super::idtoken::Builder::new(authorized_user).build()?;
1468 let err = creds.id_token().await.unwrap_err();
1469 assert!(!err.is_transient());
1470
1471 let source = find_source_error::<reqwest::Error>(&err);
1472 assert!(
1473 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1474 "{err:?}"
1475 );
1476 Ok(())
1477 }
1478}