1use crate::build_errors::Error as BuilderError;
98use crate::constants::OAUTH2_TOKEN_SERVER_URL;
99use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
100use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
101use crate::errors::{self, CredentialsError};
102use crate::headers_util::build_cacheable_headers;
103use crate::retry::Builder as RetryTokenProviderBuilder;
104use crate::token::{CachedTokenProvider, Token, TokenProvider};
105use crate::token_cache::TokenCache;
106use crate::{BuildResult, Result};
107use gax::backoff_policy::BackoffPolicyArg;
108use gax::retry_policy::RetryPolicyArg;
109use gax::retry_throttler::RetryThrottlerArg;
110use http::header::CONTENT_TYPE;
111use http::{Extensions, HeaderMap, HeaderValue};
112use reqwest::{Client, Method};
113use serde_json::Value;
114use std::sync::Arc;
115use tokio::time::{Duration, Instant};
116
117pub struct Builder {
128 authorized_user: Value,
129 scopes: Option<Vec<String>>,
130 quota_project_id: Option<String>,
131 token_uri: Option<String>,
132 retry_builder: RetryTokenProviderBuilder,
133}
134
135impl Builder {
136 pub fn new(authorized_user: Value) -> Self {
143 Self {
144 authorized_user,
145 scopes: None,
146 quota_project_id: None,
147 token_uri: None,
148 retry_builder: RetryTokenProviderBuilder::default(),
149 }
150 }
151
152 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
166 self.token_uri = Some(token_uri.into());
167 self
168 }
169
170 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191 where
192 I: IntoIterator<Item = S>,
193 S: Into<String>,
194 {
195 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196 self
197 }
198
199 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
220 self.quota_project_id = Some(quota_project_id.into());
221 self
222 }
223
224 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
245 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
246 self
247 }
248
249 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
270 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
271 self
272 }
273
274 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
301 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
302 self
303 }
304
305 pub fn build(self) -> BuildResult<Credentials> {
319 Ok(self.build_access_token_credentials()?.into())
320 }
321
322 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
357 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
358 .map_err(BuilderError::parsing)?;
359 let endpoint = self
360 .token_uri
361 .or(authorized_user.token_uri)
362 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
363 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
364
365 let token_provider = UserTokenProvider {
366 client_id: authorized_user.client_id,
367 client_secret: authorized_user.client_secret,
368 refresh_token: authorized_user.refresh_token,
369 endpoint,
370 scopes: self.scopes.map(|scopes| scopes.join(" ")),
371 source: UserTokenSource::AccessToken,
372 };
373
374 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
375
376 Ok(AccessTokenCredentials {
377 inner: Arc::new(UserCredentials {
378 token_provider,
379 quota_project_id,
380 }),
381 })
382 }
383}
384
385#[derive(PartialEq)]
386pub(crate) struct UserTokenProvider {
387 client_id: String,
388 client_secret: String,
389 refresh_token: String,
390 endpoint: String,
391 scopes: Option<String>,
392 source: UserTokenSource,
393}
394
395#[allow(dead_code)]
396#[derive(PartialEq)]
397enum UserTokenSource {
398 IdToken,
399 AccessToken,
400}
401
402impl std::fmt::Debug for UserTokenProvider {
403 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 f.debug_struct("UserCredentials")
405 .field("client_id", &self.client_id)
406 .field("client_secret", &"[censored]")
407 .field("refresh_token", &"[censored]")
408 .field("endpoint", &self.endpoint)
409 .field("scopes", &self.scopes)
410 .finish()
411 }
412}
413
414impl UserTokenProvider {
415 #[cfg(feature = "idtoken")]
416 pub(crate) fn new_id_token_provider(
417 authorized_user: AuthorizedUser,
418 token_uri: Option<String>,
419 ) -> UserTokenProvider {
420 let endpoint = token_uri
421 .or(authorized_user.token_uri)
422 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
423 UserTokenProvider {
424 client_id: authorized_user.client_id,
425 client_secret: authorized_user.client_secret,
426 refresh_token: authorized_user.refresh_token,
427 endpoint,
428 source: UserTokenSource::IdToken,
429 scopes: None,
430 }
431 }
432}
433
434#[async_trait::async_trait]
435impl TokenProvider for UserTokenProvider {
436 async fn token(&self) -> Result<Token> {
437 let client = Client::new();
438
439 let req = Oauth2RefreshRequest {
441 grant_type: RefreshGrantType::RefreshToken,
442 client_id: self.client_id.clone(),
443 client_secret: self.client_secret.clone(),
444 refresh_token: self.refresh_token.clone(),
445 scopes: self.scopes.clone(),
446 };
447 let header = HeaderValue::from_static("application/json");
448 let builder = client
449 .request(Method::POST, self.endpoint.as_str())
450 .header(CONTENT_TYPE, header)
451 .json(&req);
452 let resp = builder
453 .send()
454 .await
455 .map_err(|e| errors::from_http_error(e, MSG))?;
456
457 if !resp.status().is_success() {
459 let err = errors::from_http_response(resp, MSG).await;
460 return Err(err);
461 }
462 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
463 let retryable = !e.is_decode();
464 CredentialsError::from_source(retryable, e)
465 })?;
466
467 let token = match self.source {
468 UserTokenSource::AccessToken => Ok(response.access_token),
469 UserTokenSource::IdToken => response
470 .id_token
471 .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
472 }?;
473 let token = Token {
474 token,
475 token_type: response.token_type,
476 expires_at: response
477 .expires_in
478 .map(|d| Instant::now() + Duration::from_secs(d)),
479 metadata: None,
480 };
481 Ok(token)
482 }
483}
484
485const MSG: &str = "failed to refresh user access token";
486const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
487gcloud running 'gcloud auth application-default login`";
488
489#[derive(Debug)]
493pub(crate) struct UserCredentials<T>
494where
495 T: CachedTokenProvider,
496{
497 token_provider: T,
498 quota_project_id: Option<String>,
499}
500
501#[async_trait::async_trait]
502impl<T> CredentialsProvider for UserCredentials<T>
503where
504 T: CachedTokenProvider,
505{
506 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
507 let token = self.token_provider.token(extensions).await?;
508 build_cacheable_headers(&token, &self.quota_project_id)
509 }
510}
511
512#[async_trait::async_trait]
513impl<T> AccessTokenCredentialsProvider for UserCredentials<T>
514where
515 T: CachedTokenProvider,
516{
517 async fn access_token(&self) -> Result<AccessToken> {
518 let token = self.token_provider.token(Extensions::new()).await?;
519 token.into()
520 }
521}
522
523#[derive(Debug, PartialEq, serde::Deserialize)]
524pub(crate) struct AuthorizedUser {
525 #[serde(rename = "type")]
526 cred_type: String,
527 client_id: String,
528 client_secret: String,
529 refresh_token: String,
530 #[serde(skip_serializing_if = "Option::is_none")]
531 token_uri: Option<String>,
532 #[serde(skip_serializing_if = "Option::is_none")]
533 quota_project_id: Option<String>,
534}
535
536#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
537pub(crate) enum RefreshGrantType {
538 #[serde(rename = "refresh_token")]
539 RefreshToken,
540}
541
542#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
543pub(crate) struct Oauth2RefreshRequest {
544 pub(crate) grant_type: RefreshGrantType,
545 pub(crate) client_id: String,
546 pub(crate) client_secret: String,
547 pub(crate) refresh_token: String,
548 scopes: Option<String>,
549}
550
551#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
552pub(crate) struct Oauth2RefreshResponse {
553 pub(crate) access_token: String,
554 #[serde(skip_serializing_if = "Option::is_none")]
555 pub(crate) id_token: Option<String>,
556 #[serde(skip_serializing_if = "Option::is_none")]
557 pub(crate) scope: Option<String>,
558 #[serde(skip_serializing_if = "Option::is_none")]
559 pub(crate) expires_in: Option<u64>,
560 pub(crate) token_type: String,
561 #[serde(skip_serializing_if = "Option::is_none")]
562 pub(crate) refresh_token: Option<String>,
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use crate::credentials::tests::{
569 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
570 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
571 get_token_type_from_headers,
572 };
573 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
574 use crate::errors::CredentialsError;
575 use crate::token::tests::MockTokenProvider;
576 use http::StatusCode;
577 use http::header::AUTHORIZATION;
578 use httptest::cycle;
579 use httptest::matchers::{all_of, json_decoded, request};
580 use httptest::responders::{json_encoded, status_code};
581 use httptest::{Expectation, Server};
582
583 type TestResult = anyhow::Result<()>;
584
585 fn authorized_user_json(token_uri: String) -> Value {
586 serde_json::json!({
587 "client_id": "test-client-id",
588 "client_secret": "test-client-secret",
589 "refresh_token": "test-refresh-token",
590 "type": "authorized_user",
591 "token_uri": token_uri,
592 })
593 }
594
595 #[tokio::test]
596 async fn test_user_account_retries_on_transient_failures() -> TestResult {
597 let mut server = Server::run();
598 server.expect(
599 Expectation::matching(request::path("/token"))
600 .times(3)
601 .respond_with(status_code(503)),
602 );
603
604 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
605 .with_retry_policy(get_mock_auth_retry_policy(3))
606 .with_backoff_policy(get_mock_backoff_policy())
607 .with_retry_throttler(get_mock_retry_throttler())
608 .build()?;
609
610 let err = credentials.headers(Extensions::new()).await.unwrap_err();
611 assert!(!err.is_transient());
612 server.verify_and_clear();
613 Ok(())
614 }
615
616 #[tokio::test]
617 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
618 let mut server = Server::run();
619 server.expect(
620 Expectation::matching(request::path("/token"))
621 .times(1)
622 .respond_with(status_code(401)),
623 );
624
625 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
626 .with_retry_policy(get_mock_auth_retry_policy(1))
627 .with_backoff_policy(get_mock_backoff_policy())
628 .with_retry_throttler(get_mock_retry_throttler())
629 .build()?;
630
631 let err = credentials.headers(Extensions::new()).await.unwrap_err();
632 assert!(!err.is_transient());
633 server.verify_and_clear();
634 Ok(())
635 }
636
637 #[tokio::test]
638 async fn test_user_account_retries_for_success() -> TestResult {
639 let mut server = Server::run();
640 let response = Oauth2RefreshResponse {
641 access_token: "test-access-token".to_string(),
642 id_token: None,
643 expires_in: Some(3600),
644 refresh_token: Some("test-refresh-token".to_string()),
645 scope: Some("scope1 scope2".to_string()),
646 token_type: "test-token-type".to_string(),
647 };
648
649 server.expect(
650 Expectation::matching(request::path("/token"))
651 .times(3)
652 .respond_with(cycle![
653 status_code(503).body("try-again"),
654 status_code(503).body("try-again"),
655 status_code(200)
656 .append_header("Content-Type", "application/json")
657 .body(serde_json::to_string(&response).unwrap()),
658 ]),
659 );
660
661 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
662 .with_retry_policy(get_mock_auth_retry_policy(3))
663 .with_backoff_policy(get_mock_backoff_policy())
664 .with_retry_throttler(get_mock_retry_throttler())
665 .build()?;
666
667 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
668 assert_eq!(token.unwrap(), "test-access-token");
669
670 server.verify_and_clear();
671 Ok(())
672 }
673
674 #[test]
675 fn debug_token_provider() {
676 let expected = UserTokenProvider {
677 client_id: "test-client-id".to_string(),
678 client_secret: "test-client-secret".to_string(),
679 refresh_token: "test-refresh-token".to_string(),
680 endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
681 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
682 source: UserTokenSource::AccessToken,
683 };
684 let fmt = format!("{expected:?}");
685 assert!(fmt.contains("test-client-id"), "{fmt}");
686 assert!(!fmt.contains("test-client-secret"), "{fmt}");
687 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
688 assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
689 assert!(
690 fmt.contains("https://www.googleapis.com/auth/pubsub"),
691 "{fmt}"
692 );
693 }
694
695 #[test]
696 fn authorized_user_full_from_json_success() {
697 let json = serde_json::json!({
698 "account": "",
699 "client_id": "test-client-id",
700 "client_secret": "test-client-secret",
701 "refresh_token": "test-refresh-token",
702 "type": "authorized_user",
703 "universe_domain": "googleapis.com",
704 "quota_project_id": "test-project",
705 "token_uri" : "test-token-uri",
706 });
707
708 let expected = AuthorizedUser {
709 cred_type: "authorized_user".to_string(),
710 client_id: "test-client-id".to_string(),
711 client_secret: "test-client-secret".to_string(),
712 refresh_token: "test-refresh-token".to_string(),
713 quota_project_id: Some("test-project".to_string()),
714 token_uri: Some("test-token-uri".to_string()),
715 };
716 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
717 assert_eq!(actual, expected);
718 }
719
720 #[test]
721 fn authorized_user_partial_from_json_success() {
722 let json = serde_json::json!({
723 "client_id": "test-client-id",
724 "client_secret": "test-client-secret",
725 "refresh_token": "test-refresh-token",
726 "type": "authorized_user",
727 });
728
729 let expected = AuthorizedUser {
730 cred_type: "authorized_user".to_string(),
731 client_id: "test-client-id".to_string(),
732 client_secret: "test-client-secret".to_string(),
733 refresh_token: "test-refresh-token".to_string(),
734 quota_project_id: None,
735 token_uri: None,
736 };
737 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
738 assert_eq!(actual, expected);
739 }
740
741 #[test]
742 fn authorized_user_from_json_parse_fail() {
743 let json_full = serde_json::json!({
744 "client_id": "test-client-id",
745 "client_secret": "test-client-secret",
746 "refresh_token": "test-refresh-token",
747 "type": "authorized_user",
748 "quota_project_id": "test-project"
749 });
750
751 for required_field in ["client_id", "client_secret", "refresh_token"] {
752 let mut json = json_full.clone();
753 json[required_field].take();
755 serde_json::from_value::<AuthorizedUser>(json)
756 .err()
757 .unwrap();
758 }
759 }
760
761 #[tokio::test]
762 async fn default_universe_domain_success() {
763 let mock = TokenCache::new(MockTokenProvider::new());
764
765 let uc = UserCredentials {
766 token_provider: mock,
767 quota_project_id: None,
768 };
769 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
770 }
771
772 #[tokio::test]
773 async fn headers_success() -> TestResult {
774 let token = Token {
775 token: "test-token".to_string(),
776 token_type: "Bearer".to_string(),
777 expires_at: None,
778 metadata: None,
779 };
780
781 let mut mock = MockTokenProvider::new();
782 mock.expect_token().times(1).return_once(|| Ok(token));
783
784 let uc = UserCredentials {
785 token_provider: TokenCache::new(mock),
786 quota_project_id: None,
787 };
788
789 let mut extensions = Extensions::new();
790 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
791 let (headers, entity_tag) = match cached_headers {
792 CacheableResource::New { entity_tag, data } => (data, entity_tag),
793 CacheableResource::NotModified => unreachable!("expecting new headers"),
794 };
795 let token = headers.get(AUTHORIZATION).unwrap();
796
797 assert_eq!(headers.len(), 1, "{headers:?}");
798 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
799 assert!(token.is_sensitive());
800
801 extensions.insert(entity_tag);
802
803 let cached_headers = uc.headers(extensions).await?;
804
805 match cached_headers {
806 CacheableResource::New { .. } => unreachable!("expecting new headers"),
807 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
808 };
809 Ok(())
810 }
811
812 #[tokio::test]
813 async fn headers_failure() {
814 let mut mock = MockTokenProvider::new();
815 mock.expect_token()
816 .times(1)
817 .return_once(|| Err(errors::non_retryable_from_str("fail")));
818
819 let uc = UserCredentials {
820 token_provider: TokenCache::new(mock),
821 quota_project_id: None,
822 };
823 assert!(uc.headers(Extensions::new()).await.is_err());
824 }
825
826 #[tokio::test]
827 async fn headers_with_quota_project_success() -> TestResult {
828 let token = Token {
829 token: "test-token".to_string(),
830 token_type: "Bearer".to_string(),
831 expires_at: None,
832 metadata: None,
833 };
834
835 let mut mock = MockTokenProvider::new();
836 mock.expect_token().times(1).return_once(|| Ok(token));
837
838 let uc = UserCredentials {
839 token_provider: TokenCache::new(mock),
840 quota_project_id: Some("test-project".to_string()),
841 };
842
843 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
844 let token = headers.get(AUTHORIZATION).unwrap();
845 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
846
847 assert_eq!(headers.len(), 2, "{headers:?}");
848 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
849 assert!(token.is_sensitive());
850 assert_eq!(
851 quota_project_header,
852 HeaderValue::from_static("test-project")
853 );
854 assert!(!quota_project_header.is_sensitive());
855 Ok(())
856 }
857
858 #[test]
859 fn oauth2_request_serde() {
860 let request = Oauth2RefreshRequest {
861 grant_type: RefreshGrantType::RefreshToken,
862 client_id: "test-client-id".to_string(),
863 client_secret: "test-client-secret".to_string(),
864 refresh_token: "test-refresh-token".to_string(),
865 scopes: Some("scope1 scope2".to_string()),
866 };
867
868 let json = serde_json::to_value(&request).unwrap();
869 let expected = serde_json::json!({
870 "grant_type": "refresh_token",
871 "client_id": "test-client-id",
872 "client_secret": "test-client-secret",
873 "refresh_token": "test-refresh-token",
874 "scopes": "scope1 scope2",
875 });
876 assert_eq!(json, expected);
877 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
878 assert_eq!(request, roundtrip);
879 }
880
881 #[test]
882 fn oauth2_response_serde_full() {
883 let response = Oauth2RefreshResponse {
884 access_token: "test-access-token".to_string(),
885 id_token: None,
886 scope: Some("scope1 scope2".to_string()),
887 expires_in: Some(3600),
888 token_type: "test-token-type".to_string(),
889 refresh_token: Some("test-refresh-token".to_string()),
890 };
891
892 let json = serde_json::to_value(&response).unwrap();
893 let expected = serde_json::json!({
894 "access_token": "test-access-token",
895 "scope": "scope1 scope2",
896 "expires_in": 3600,
897 "token_type": "test-token-type",
898 "refresh_token": "test-refresh-token"
899 });
900 assert_eq!(json, expected);
901 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
902 assert_eq!(response, roundtrip);
903 }
904
905 #[test]
906 fn oauth2_response_serde_partial() {
907 let response = Oauth2RefreshResponse {
908 access_token: "test-access-token".to_string(),
909 id_token: None,
910 scope: None,
911 expires_in: None,
912 token_type: "test-token-type".to_string(),
913 refresh_token: None,
914 };
915
916 let json = serde_json::to_value(&response).unwrap();
917 let expected = serde_json::json!({
918 "access_token": "test-access-token",
919 "token_type": "test-token-type",
920 });
921 assert_eq!(json, expected);
922 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
923 assert_eq!(response, roundtrip);
924 }
925
926 fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
927 request.client_id == "test-client-id"
928 && request.client_secret == "test-client-secret"
929 && request.refresh_token == "test-refresh-token"
930 && request.grant_type == RefreshGrantType::RefreshToken
931 && request.scopes == expected_scopes
932 }
933
934 #[tokio::test(start_paused = true)]
935 async fn token_provider_full() -> TestResult {
936 let server = Server::run();
937 let response = Oauth2RefreshResponse {
938 access_token: "test-access-token".to_string(),
939 id_token: None,
940 expires_in: Some(3600),
941 refresh_token: Some("test-refresh-token".to_string()),
942 scope: Some("scope1 scope2".to_string()),
943 token_type: "test-token-type".to_string(),
944 };
945 server.expect(
946 Expectation::matching(all_of![
947 request::path("/token"),
948 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
949 check_request(req, Some("scope1 scope2".to_string()))
950 }))
951 ])
952 .respond_with(json_encoded(response)),
953 );
954
955 let tp = UserTokenProvider {
956 client_id: "test-client-id".to_string(),
957 client_secret: "test-client-secret".to_string(),
958 refresh_token: "test-refresh-token".to_string(),
959 endpoint: server.url("/token").to_string(),
960 scopes: Some("scope1 scope2".to_string()),
961 source: UserTokenSource::AccessToken,
962 };
963 let now = Instant::now();
964 let token = tp.token().await?;
965 assert_eq!(token.token, "test-access-token");
966 assert_eq!(token.token_type, "test-token-type");
967 assert!(
968 token
969 .expires_at
970 .is_some_and(|d| d == now + Duration::from_secs(3600)),
971 "now: {:?}, expires_at: {:?}",
972 now,
973 token.expires_at
974 );
975
976 Ok(())
977 }
978
979 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
980 async fn credential_full_with_quota_project() -> TestResult {
981 let server = Server::run();
982 let response = Oauth2RefreshResponse {
983 access_token: "test-access-token".to_string(),
984 id_token: None,
985 expires_in: Some(3600),
986 refresh_token: Some("test-refresh-token".to_string()),
987 scope: None,
988 token_type: "test-token-type".to_string(),
989 };
990 server.expect(
991 Expectation::matching(all_of![
992 request::path("/token"),
993 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
994 check_request(req, None)
995 }))
996 ])
997 .respond_with(json_encoded(response)),
998 );
999
1000 let authorized_user = serde_json::json!({
1001 "client_id": "test-client-id",
1002 "client_secret": "test-client-secret",
1003 "refresh_token": "test-refresh-token",
1004 "type": "authorized_user",
1005 "token_uri": server.url("/token").to_string(),
1006 });
1007 let cred = Builder::new(authorized_user)
1008 .with_quota_project_id("test-project")
1009 .build()?;
1010
1011 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
1012 let token = headers.get(AUTHORIZATION).unwrap();
1013 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
1014
1015 assert_eq!(headers.len(), 2, "{headers:?}");
1016 assert_eq!(
1017 token,
1018 HeaderValue::from_static("test-token-type test-access-token")
1019 );
1020 assert!(token.is_sensitive());
1021 assert_eq!(
1022 quota_project_header,
1023 HeaderValue::from_static("test-project")
1024 );
1025 assert!(!quota_project_header.is_sensitive());
1026
1027 Ok(())
1028 }
1029
1030 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1031 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
1032 let mut server = Server::run();
1033 let response = Oauth2RefreshResponse {
1034 access_token: "test-access-token".to_string(),
1035 id_token: None,
1036 expires_in: Some(3600),
1037 refresh_token: Some("test-refresh-token".to_string()),
1038 scope: Some("scope1 scope2".to_string()),
1039 token_type: "test-token-type".to_string(),
1040 };
1041 server.expect(
1042 Expectation::matching(all_of![
1043 request::path("/token"),
1044 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1045 check_request(req, Some("scope1 scope2".to_string()))
1046 }))
1047 ])
1048 .times(1)
1049 .respond_with(json_encoded(response)),
1050 );
1051
1052 let json = serde_json::json!({
1053 "client_id": "test-client-id",
1054 "client_secret": "test-client-secret",
1055 "refresh_token": "test-refresh-token",
1056 "type": "authorized_user",
1057 "universe_domain": "googleapis.com",
1058 "quota_project_id": "test-project",
1059 "token_uri": server.url("/token").to_string(),
1060 });
1061
1062 let cred = Builder::new(json)
1063 .with_scopes(vec!["scope1", "scope2"])
1064 .build()?;
1065
1066 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1067 assert_eq!(token.unwrap(), "test-access-token");
1068
1069 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1070 assert_eq!(token.unwrap(), "test-access-token");
1071
1072 server.verify_and_clear();
1073
1074 Ok(())
1075 }
1076
1077 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1078 async fn credential_provider_partial() -> TestResult {
1079 let server = Server::run();
1080 let response = Oauth2RefreshResponse {
1081 access_token: "test-access-token".to_string(),
1082 id_token: None,
1083 expires_in: None,
1084 refresh_token: None,
1085 scope: None,
1086 token_type: "test-token-type".to_string(),
1087 };
1088 server.expect(
1089 Expectation::matching(all_of![
1090 request::path("/token"),
1091 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1092 check_request(req, None)
1093 }))
1094 ])
1095 .respond_with(json_encoded(response)),
1096 );
1097
1098 let authorized_user = serde_json::json!({
1099 "client_id": "test-client-id",
1100 "client_secret": "test-client-secret",
1101 "refresh_token": "test-refresh-token",
1102 "type": "authorized_user",
1103 "token_uri": server.url("/token").to_string()
1104 });
1105
1106 let uc = Builder::new(authorized_user).build()?;
1107 let headers = uc.headers(Extensions::new()).await?;
1108 assert_eq!(
1109 get_token_from_headers(headers.clone()).unwrap(),
1110 "test-access-token"
1111 );
1112 assert_eq!(
1113 get_token_type_from_headers(headers).unwrap(),
1114 "test-token-type"
1115 );
1116
1117 Ok(())
1118 }
1119
1120 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1121 async fn credential_provider_with_token_uri() -> TestResult {
1122 let server = Server::run();
1123 let response = Oauth2RefreshResponse {
1124 access_token: "test-access-token".to_string(),
1125 id_token: None,
1126 expires_in: None,
1127 refresh_token: None,
1128 scope: None,
1129 token_type: "test-token-type".to_string(),
1130 };
1131 server.expect(
1132 Expectation::matching(all_of![
1133 request::path("/token"),
1134 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1135 check_request(req, None)
1136 }))
1137 ])
1138 .respond_with(json_encoded(response)),
1139 );
1140
1141 let authorized_user = serde_json::json!({
1142 "client_id": "test-client-id",
1143 "client_secret": "test-client-secret",
1144 "refresh_token": "test-refresh-token",
1145 "type": "authorized_user",
1146 "token_uri": "test-endpoint"
1147 });
1148
1149 let uc = Builder::new(authorized_user)
1150 .with_token_uri(server.url("/token").to_string())
1151 .build()?;
1152 let headers = uc.headers(Extensions::new()).await?;
1153 assert_eq!(
1154 get_token_from_headers(headers.clone()).unwrap(),
1155 "test-access-token"
1156 );
1157 assert_eq!(
1158 get_token_type_from_headers(headers).unwrap(),
1159 "test-token-type"
1160 );
1161
1162 Ok(())
1163 }
1164
1165 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1166 async fn access_credential_provider_with_token_uri() -> TestResult {
1167 let server = Server::run();
1168 let response = Oauth2RefreshResponse {
1169 access_token: "test-access-token".to_string(),
1170 id_token: None,
1171 expires_in: None,
1172 refresh_token: None,
1173 scope: None,
1174 token_type: "test-token-type".to_string(),
1175 };
1176 server.expect(
1177 Expectation::matching(all_of![
1178 request::path("/token"),
1179 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1180 check_request(req, None)
1181 }))
1182 ])
1183 .respond_with(json_encoded(response)),
1184 );
1185
1186 let authorized_user = serde_json::json!({
1187 "client_id": "test-client-id",
1188 "client_secret": "test-client-secret",
1189 "refresh_token": "test-refresh-token",
1190 "type": "authorized_user",
1191 "token_uri": "test-endpoint"
1192 });
1193
1194 let uc = Builder::new(authorized_user)
1195 .with_token_uri(server.url("/token").to_string())
1196 .build_access_token_credentials()?;
1197 let access_token = uc.access_token().await?;
1198 assert_eq!(access_token.token, "test-access-token");
1199
1200 Ok(())
1201 }
1202
1203 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1204 async fn credential_provider_with_scopes() -> TestResult {
1205 let server = Server::run();
1206 let response = Oauth2RefreshResponse {
1207 access_token: "test-access-token".to_string(),
1208 id_token: None,
1209 expires_in: None,
1210 refresh_token: None,
1211 scope: Some("scope1 scope2".to_string()),
1212 token_type: "test-token-type".to_string(),
1213 };
1214 server.expect(
1215 Expectation::matching(all_of![
1216 request::path("/token"),
1217 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1218 check_request(req, Some("scope1 scope2".to_string()))
1219 }))
1220 ])
1221 .respond_with(json_encoded(response)),
1222 );
1223
1224 let authorized_user = serde_json::json!({
1225 "client_id": "test-client-id",
1226 "client_secret": "test-client-secret",
1227 "refresh_token": "test-refresh-token",
1228 "type": "authorized_user",
1229 "token_uri": "test-endpoint"
1230 });
1231
1232 let uc = Builder::new(authorized_user)
1233 .with_token_uri(server.url("/token").to_string())
1234 .with_scopes(vec!["scope1", "scope2"])
1235 .build()?;
1236 let headers = uc.headers(Extensions::new()).await?;
1237 assert_eq!(
1238 get_token_from_headers(headers.clone()).unwrap(),
1239 "test-access-token"
1240 );
1241 assert_eq!(
1242 get_token_type_from_headers(headers).unwrap(),
1243 "test-token-type"
1244 );
1245
1246 Ok(())
1247 }
1248
1249 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1250 async fn credential_provider_retryable_error() -> TestResult {
1251 let server = Server::run();
1252 server
1253 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1254
1255 let authorized_user = serde_json::json!({
1256 "client_id": "test-client-id",
1257 "client_secret": "test-client-secret",
1258 "refresh_token": "test-refresh-token",
1259 "type": "authorized_user",
1260 "token_uri": server.url("/token").to_string()
1261 });
1262
1263 let uc = Builder::new(authorized_user).build()?;
1264 let err = uc.headers(Extensions::new()).await.unwrap_err();
1265 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1266 assert!(original_err.is_transient());
1267
1268 let source = find_source_error::<reqwest::Error>(&err);
1269 assert!(
1270 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1271 "{err:?}"
1272 );
1273
1274 Ok(())
1275 }
1276
1277 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1278 async fn token_provider_nonretryable_error() -> TestResult {
1279 let server = Server::run();
1280 server
1281 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1282
1283 let authorized_user = serde_json::json!({
1284 "client_id": "test-client-id",
1285 "client_secret": "test-client-secret",
1286 "refresh_token": "test-refresh-token",
1287 "type": "authorized_user",
1288 "token_uri": server.url("/token").to_string()
1289 });
1290
1291 let uc = Builder::new(authorized_user).build()?;
1292 let err = uc.headers(Extensions::new()).await.unwrap_err();
1293 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1294 assert!(!original_err.is_transient());
1295
1296 let source = find_source_error::<reqwest::Error>(&err);
1297 assert!(
1298 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1299 "{err:?}"
1300 );
1301
1302 Ok(())
1303 }
1304
1305 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1306 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1307 let server = Server::run();
1308 server.expect(
1309 Expectation::matching(request::path("/token"))
1310 .respond_with(json_encoded("bad json".to_string())),
1311 );
1312
1313 let authorized_user = serde_json::json!({
1314 "client_id": "test-client-id",
1315 "client_secret": "test-client-secret",
1316 "refresh_token": "test-refresh-token",
1317 "type": "authorized_user",
1318 "token_uri": server.url("/token").to_string()
1319 });
1320
1321 let uc = Builder::new(authorized_user).build()?;
1322 let e = uc.headers(Extensions::new()).await.err().unwrap();
1323 assert!(!e.is_transient(), "{e}");
1324
1325 Ok(())
1326 }
1327
1328 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1329 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1330 let authorized_user = serde_json::json!({
1331 "client_secret": "test-client-secret",
1332 "refresh_token": "test-refresh-token",
1333 "type": "authorized_user",
1334 });
1335
1336 let e = Builder::new(authorized_user).build().unwrap_err();
1337 assert!(e.is_parsing(), "{e}");
1338
1339 Ok(())
1340 }
1341}