1use crate::build_errors::Error as BuilderError;
98use crate::constants::OAUTH2_TOKEN_SERVER_URL;
99use crate::credentials::dynamic::CredentialsProvider;
100use crate::credentials::{CacheableResource, Credentials};
101use crate::errors::{self, CredentialsError};
102use crate::headers_util::build_cacheable_headers;
103use crate::retry::Builder as RetryTokenProviderBuilder;
104use crate::token::{CachedTokenProvider, Token, TokenProvider};
105use crate::token_cache::TokenCache;
106use crate::{BuildResult, Result};
107use gax::backoff_policy::BackoffPolicyArg;
108use gax::retry_policy::RetryPolicyArg;
109use gax::retry_throttler::RetryThrottlerArg;
110use http::header::CONTENT_TYPE;
111use http::{Extensions, HeaderMap, HeaderValue};
112use reqwest::{Client, Method};
113use serde_json::Value;
114use std::sync::Arc;
115use tokio::time::{Duration, Instant};
116
117pub struct Builder {
128 authorized_user: Value,
129 scopes: Option<Vec<String>>,
130 quota_project_id: Option<String>,
131 token_uri: Option<String>,
132 retry_builder: RetryTokenProviderBuilder,
133}
134
135impl Builder {
136 pub fn new(authorized_user: Value) -> Self {
143 Self {
144 authorized_user,
145 scopes: None,
146 quota_project_id: None,
147 token_uri: None,
148 retry_builder: RetryTokenProviderBuilder::default(),
149 }
150 }
151
152 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
166 self.token_uri = Some(token_uri.into());
167 self
168 }
169
170 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191 where
192 I: IntoIterator<Item = S>,
193 S: Into<String>,
194 {
195 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196 self
197 }
198
199 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
220 self.quota_project_id = Some(quota_project_id.into());
221 self
222 }
223
224 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
245 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
246 self
247 }
248
249 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
270 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
271 self
272 }
273
274 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
301 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
302 self
303 }
304
305 pub fn build(self) -> BuildResult<Credentials> {
318 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
319 .map_err(BuilderError::parsing)?;
320 let endpoint = self
321 .token_uri
322 .or(authorized_user.token_uri)
323 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
324 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
325
326 let token_provider = UserTokenProvider {
327 client_id: authorized_user.client_id,
328 client_secret: authorized_user.client_secret,
329 refresh_token: authorized_user.refresh_token,
330 endpoint,
331 scopes: self.scopes.map(|scopes| scopes.join(" ")),
332 source: UserTokenSource::AccessToken,
333 };
334
335 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337 Ok(Credentials {
338 inner: Arc::new(UserCredentials {
339 token_provider,
340 quota_project_id,
341 }),
342 })
343 }
344}
345
346#[derive(PartialEq)]
347pub(crate) struct UserTokenProvider {
348 client_id: String,
349 client_secret: String,
350 refresh_token: String,
351 endpoint: String,
352 scopes: Option<String>,
353 source: UserTokenSource,
354}
355
356#[allow(dead_code)]
357#[derive(PartialEq)]
358enum UserTokenSource {
359 IdToken,
360 AccessToken,
361}
362
363impl std::fmt::Debug for UserTokenProvider {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.debug_struct("UserCredentials")
366 .field("client_id", &self.client_id)
367 .field("client_secret", &"[censored]")
368 .field("refresh_token", &"[censored]")
369 .field("endpoint", &self.endpoint)
370 .field("scopes", &self.scopes)
371 .finish()
372 }
373}
374
375impl UserTokenProvider {
376 #[cfg(google_cloud_unstable_id_token)]
377 pub(crate) fn new_id_token_provider(
378 authorized_user: AuthorizedUser,
379 token_uri: Option<String>,
380 ) -> UserTokenProvider {
381 let endpoint = token_uri
382 .or(authorized_user.token_uri)
383 .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
384 UserTokenProvider {
385 client_id: authorized_user.client_id,
386 client_secret: authorized_user.client_secret,
387 refresh_token: authorized_user.refresh_token,
388 endpoint,
389 source: UserTokenSource::IdToken,
390 scopes: None,
391 }
392 }
393}
394
395#[async_trait::async_trait]
396impl TokenProvider for UserTokenProvider {
397 async fn token(&self) -> Result<Token> {
398 let client = Client::new();
399
400 let req = Oauth2RefreshRequest {
402 grant_type: RefreshGrantType::RefreshToken,
403 client_id: self.client_id.clone(),
404 client_secret: self.client_secret.clone(),
405 refresh_token: self.refresh_token.clone(),
406 scopes: self.scopes.clone(),
407 };
408 let header = HeaderValue::from_static("application/json");
409 let builder = client
410 .request(Method::POST, self.endpoint.as_str())
411 .header(CONTENT_TYPE, header)
412 .json(&req);
413 let resp = builder
414 .send()
415 .await
416 .map_err(|e| errors::from_http_error(e, MSG))?;
417
418 if !resp.status().is_success() {
420 let err = errors::from_http_response(resp, MSG).await;
421 return Err(err);
422 }
423 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
424 let retryable = !e.is_decode();
425 CredentialsError::from_source(retryable, e)
426 })?;
427
428 let token = match self.source {
429 UserTokenSource::AccessToken => Ok(response.access_token),
430 UserTokenSource::IdToken => response
431 .id_token
432 .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
433 }?;
434 let token = Token {
435 token,
436 token_type: response.token_type,
437 expires_at: response
438 .expires_in
439 .map(|d| Instant::now() + Duration::from_secs(d)),
440 metadata: None,
441 };
442 Ok(token)
443 }
444}
445
446const MSG: &str = "failed to refresh user access token";
447const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
448gcloud running 'gcloud auth application-default login`";
449
450#[derive(Debug)]
454pub(crate) struct UserCredentials<T>
455where
456 T: CachedTokenProvider,
457{
458 token_provider: T,
459 quota_project_id: Option<String>,
460}
461
462#[async_trait::async_trait]
463impl<T> CredentialsProvider for UserCredentials<T>
464where
465 T: CachedTokenProvider,
466{
467 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
468 let token = self.token_provider.token(extensions).await?;
469 build_cacheable_headers(&token, &self.quota_project_id)
470 }
471}
472
473#[derive(Debug, PartialEq, serde::Deserialize)]
474pub(crate) struct AuthorizedUser {
475 #[serde(rename = "type")]
476 cred_type: String,
477 client_id: String,
478 client_secret: String,
479 refresh_token: String,
480 #[serde(skip_serializing_if = "Option::is_none")]
481 token_uri: Option<String>,
482 #[serde(skip_serializing_if = "Option::is_none")]
483 quota_project_id: Option<String>,
484}
485
486#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
487pub(crate) enum RefreshGrantType {
488 #[serde(rename = "refresh_token")]
489 RefreshToken,
490}
491
492#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
493pub(crate) struct Oauth2RefreshRequest {
494 pub(crate) grant_type: RefreshGrantType,
495 pub(crate) client_id: String,
496 pub(crate) client_secret: String,
497 pub(crate) refresh_token: String,
498 scopes: Option<String>,
499}
500
501#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
502pub(crate) struct Oauth2RefreshResponse {
503 pub(crate) access_token: String,
504 #[serde(skip_serializing_if = "Option::is_none")]
505 pub(crate) id_token: Option<String>,
506 #[serde(skip_serializing_if = "Option::is_none")]
507 pub(crate) scope: Option<String>,
508 #[serde(skip_serializing_if = "Option::is_none")]
509 pub(crate) expires_in: Option<u64>,
510 pub(crate) token_type: String,
511 #[serde(skip_serializing_if = "Option::is_none")]
512 pub(crate) refresh_token: Option<String>,
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::credentials::tests::{
519 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
520 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
521 get_token_type_from_headers,
522 };
523 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
524 use crate::errors::CredentialsError;
525 use crate::token::tests::MockTokenProvider;
526 use http::StatusCode;
527 use http::header::AUTHORIZATION;
528 use httptest::cycle;
529 use httptest::matchers::{all_of, json_decoded, request};
530 use httptest::responders::{json_encoded, status_code};
531 use httptest::{Expectation, Server};
532
533 type TestResult = anyhow::Result<()>;
534
535 fn authorized_user_json(token_uri: String) -> Value {
536 serde_json::json!({
537 "client_id": "test-client-id",
538 "client_secret": "test-client-secret",
539 "refresh_token": "test-refresh-token",
540 "type": "authorized_user",
541 "token_uri": token_uri,
542 })
543 }
544
545 #[tokio::test]
546 async fn test_user_account_retries_on_transient_failures() -> TestResult {
547 let mut server = Server::run();
548 server.expect(
549 Expectation::matching(request::path("/token"))
550 .times(3)
551 .respond_with(status_code(503)),
552 );
553
554 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
555 .with_retry_policy(get_mock_auth_retry_policy(3))
556 .with_backoff_policy(get_mock_backoff_policy())
557 .with_retry_throttler(get_mock_retry_throttler())
558 .build()?;
559
560 let err = credentials.headers(Extensions::new()).await.unwrap_err();
561 assert!(!err.is_transient());
562 server.verify_and_clear();
563 Ok(())
564 }
565
566 #[tokio::test]
567 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
568 let mut server = Server::run();
569 server.expect(
570 Expectation::matching(request::path("/token"))
571 .times(1)
572 .respond_with(status_code(401)),
573 );
574
575 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
576 .with_retry_policy(get_mock_auth_retry_policy(1))
577 .with_backoff_policy(get_mock_backoff_policy())
578 .with_retry_throttler(get_mock_retry_throttler())
579 .build()?;
580
581 let err = credentials.headers(Extensions::new()).await.unwrap_err();
582 assert!(!err.is_transient());
583 server.verify_and_clear();
584 Ok(())
585 }
586
587 #[tokio::test]
588 async fn test_user_account_retries_for_success() -> TestResult {
589 let mut server = Server::run();
590 let response = Oauth2RefreshResponse {
591 access_token: "test-access-token".to_string(),
592 id_token: None,
593 expires_in: Some(3600),
594 refresh_token: Some("test-refresh-token".to_string()),
595 scope: Some("scope1 scope2".to_string()),
596 token_type: "test-token-type".to_string(),
597 };
598
599 server.expect(
600 Expectation::matching(request::path("/token"))
601 .times(3)
602 .respond_with(cycle![
603 status_code(503).body("try-again"),
604 status_code(503).body("try-again"),
605 status_code(200)
606 .append_header("Content-Type", "application/json")
607 .body(serde_json::to_string(&response).unwrap()),
608 ]),
609 );
610
611 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
612 .with_retry_policy(get_mock_auth_retry_policy(3))
613 .with_backoff_policy(get_mock_backoff_policy())
614 .with_retry_throttler(get_mock_retry_throttler())
615 .build()?;
616
617 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
618 assert_eq!(token.unwrap(), "test-access-token");
619
620 server.verify_and_clear();
621 Ok(())
622 }
623
624 #[test]
625 fn debug_token_provider() {
626 let expected = UserTokenProvider {
627 client_id: "test-client-id".to_string(),
628 client_secret: "test-client-secret".to_string(),
629 refresh_token: "test-refresh-token".to_string(),
630 endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
631 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
632 source: UserTokenSource::AccessToken,
633 };
634 let fmt = format!("{expected:?}");
635 assert!(fmt.contains("test-client-id"), "{fmt}");
636 assert!(!fmt.contains("test-client-secret"), "{fmt}");
637 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
638 assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
639 assert!(
640 fmt.contains("https://www.googleapis.com/auth/pubsub"),
641 "{fmt}"
642 );
643 }
644
645 #[test]
646 fn authorized_user_full_from_json_success() {
647 let json = serde_json::json!({
648 "account": "",
649 "client_id": "test-client-id",
650 "client_secret": "test-client-secret",
651 "refresh_token": "test-refresh-token",
652 "type": "authorized_user",
653 "universe_domain": "googleapis.com",
654 "quota_project_id": "test-project",
655 "token_uri" : "test-token-uri",
656 });
657
658 let expected = AuthorizedUser {
659 cred_type: "authorized_user".to_string(),
660 client_id: "test-client-id".to_string(),
661 client_secret: "test-client-secret".to_string(),
662 refresh_token: "test-refresh-token".to_string(),
663 quota_project_id: Some("test-project".to_string()),
664 token_uri: Some("test-token-uri".to_string()),
665 };
666 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
667 assert_eq!(actual, expected);
668 }
669
670 #[test]
671 fn authorized_user_partial_from_json_success() {
672 let json = serde_json::json!({
673 "client_id": "test-client-id",
674 "client_secret": "test-client-secret",
675 "refresh_token": "test-refresh-token",
676 "type": "authorized_user",
677 });
678
679 let expected = AuthorizedUser {
680 cred_type: "authorized_user".to_string(),
681 client_id: "test-client-id".to_string(),
682 client_secret: "test-client-secret".to_string(),
683 refresh_token: "test-refresh-token".to_string(),
684 quota_project_id: None,
685 token_uri: None,
686 };
687 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
688 assert_eq!(actual, expected);
689 }
690
691 #[test]
692 fn authorized_user_from_json_parse_fail() {
693 let json_full = serde_json::json!({
694 "client_id": "test-client-id",
695 "client_secret": "test-client-secret",
696 "refresh_token": "test-refresh-token",
697 "type": "authorized_user",
698 "quota_project_id": "test-project"
699 });
700
701 for required_field in ["client_id", "client_secret", "refresh_token"] {
702 let mut json = json_full.clone();
703 json[required_field].take();
705 serde_json::from_value::<AuthorizedUser>(json)
706 .err()
707 .unwrap();
708 }
709 }
710
711 #[tokio::test]
712 async fn default_universe_domain_success() {
713 let mock = TokenCache::new(MockTokenProvider::new());
714
715 let uc = UserCredentials {
716 token_provider: mock,
717 quota_project_id: None,
718 };
719 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
720 }
721
722 #[tokio::test]
723 async fn headers_success() -> TestResult {
724 let token = Token {
725 token: "test-token".to_string(),
726 token_type: "Bearer".to_string(),
727 expires_at: None,
728 metadata: None,
729 };
730
731 let mut mock = MockTokenProvider::new();
732 mock.expect_token().times(1).return_once(|| Ok(token));
733
734 let uc = UserCredentials {
735 token_provider: TokenCache::new(mock),
736 quota_project_id: None,
737 };
738
739 let mut extensions = Extensions::new();
740 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
741 let (headers, entity_tag) = match cached_headers {
742 CacheableResource::New { entity_tag, data } => (data, entity_tag),
743 CacheableResource::NotModified => unreachable!("expecting new headers"),
744 };
745 let token = headers.get(AUTHORIZATION).unwrap();
746
747 assert_eq!(headers.len(), 1, "{headers:?}");
748 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
749 assert!(token.is_sensitive());
750
751 extensions.insert(entity_tag);
752
753 let cached_headers = uc.headers(extensions).await?;
754
755 match cached_headers {
756 CacheableResource::New { .. } => unreachable!("expecting new headers"),
757 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
758 };
759 Ok(())
760 }
761
762 #[tokio::test]
763 async fn headers_failure() {
764 let mut mock = MockTokenProvider::new();
765 mock.expect_token()
766 .times(1)
767 .return_once(|| Err(errors::non_retryable_from_str("fail")));
768
769 let uc = UserCredentials {
770 token_provider: TokenCache::new(mock),
771 quota_project_id: None,
772 };
773 assert!(uc.headers(Extensions::new()).await.is_err());
774 }
775
776 #[tokio::test]
777 async fn headers_with_quota_project_success() -> TestResult {
778 let token = Token {
779 token: "test-token".to_string(),
780 token_type: "Bearer".to_string(),
781 expires_at: None,
782 metadata: None,
783 };
784
785 let mut mock = MockTokenProvider::new();
786 mock.expect_token().times(1).return_once(|| Ok(token));
787
788 let uc = UserCredentials {
789 token_provider: TokenCache::new(mock),
790 quota_project_id: Some("test-project".to_string()),
791 };
792
793 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
794 let token = headers.get(AUTHORIZATION).unwrap();
795 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
796
797 assert_eq!(headers.len(), 2, "{headers:?}");
798 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
799 assert!(token.is_sensitive());
800 assert_eq!(
801 quota_project_header,
802 HeaderValue::from_static("test-project")
803 );
804 assert!(!quota_project_header.is_sensitive());
805 Ok(())
806 }
807
808 #[test]
809 fn oauth2_request_serde() {
810 let request = Oauth2RefreshRequest {
811 grant_type: RefreshGrantType::RefreshToken,
812 client_id: "test-client-id".to_string(),
813 client_secret: "test-client-secret".to_string(),
814 refresh_token: "test-refresh-token".to_string(),
815 scopes: Some("scope1 scope2".to_string()),
816 };
817
818 let json = serde_json::to_value(&request).unwrap();
819 let expected = serde_json::json!({
820 "grant_type": "refresh_token",
821 "client_id": "test-client-id",
822 "client_secret": "test-client-secret",
823 "refresh_token": "test-refresh-token",
824 "scopes": "scope1 scope2",
825 });
826 assert_eq!(json, expected);
827 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
828 assert_eq!(request, roundtrip);
829 }
830
831 #[test]
832 fn oauth2_response_serde_full() {
833 let response = Oauth2RefreshResponse {
834 access_token: "test-access-token".to_string(),
835 id_token: None,
836 scope: Some("scope1 scope2".to_string()),
837 expires_in: Some(3600),
838 token_type: "test-token-type".to_string(),
839 refresh_token: Some("test-refresh-token".to_string()),
840 };
841
842 let json = serde_json::to_value(&response).unwrap();
843 let expected = serde_json::json!({
844 "access_token": "test-access-token",
845 "scope": "scope1 scope2",
846 "expires_in": 3600,
847 "token_type": "test-token-type",
848 "refresh_token": "test-refresh-token"
849 });
850 assert_eq!(json, expected);
851 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
852 assert_eq!(response, roundtrip);
853 }
854
855 #[test]
856 fn oauth2_response_serde_partial() {
857 let response = Oauth2RefreshResponse {
858 access_token: "test-access-token".to_string(),
859 id_token: None,
860 scope: None,
861 expires_in: None,
862 token_type: "test-token-type".to_string(),
863 refresh_token: None,
864 };
865
866 let json = serde_json::to_value(&response).unwrap();
867 let expected = serde_json::json!({
868 "access_token": "test-access-token",
869 "token_type": "test-token-type",
870 });
871 assert_eq!(json, expected);
872 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
873 assert_eq!(response, roundtrip);
874 }
875
876 fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
877 request.client_id == "test-client-id"
878 && request.client_secret == "test-client-secret"
879 && request.refresh_token == "test-refresh-token"
880 && request.grant_type == RefreshGrantType::RefreshToken
881 && request.scopes == expected_scopes
882 }
883
884 #[tokio::test(start_paused = true)]
885 async fn token_provider_full() -> TestResult {
886 let server = Server::run();
887 let response = Oauth2RefreshResponse {
888 access_token: "test-access-token".to_string(),
889 id_token: None,
890 expires_in: Some(3600),
891 refresh_token: Some("test-refresh-token".to_string()),
892 scope: Some("scope1 scope2".to_string()),
893 token_type: "test-token-type".to_string(),
894 };
895 server.expect(
896 Expectation::matching(all_of![
897 request::path("/token"),
898 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
899 check_request(req, Some("scope1 scope2".to_string()))
900 }))
901 ])
902 .respond_with(json_encoded(response)),
903 );
904
905 let tp = UserTokenProvider {
906 client_id: "test-client-id".to_string(),
907 client_secret: "test-client-secret".to_string(),
908 refresh_token: "test-refresh-token".to_string(),
909 endpoint: server.url("/token").to_string(),
910 scopes: Some("scope1 scope2".to_string()),
911 source: UserTokenSource::AccessToken,
912 };
913 let now = Instant::now();
914 let token = tp.token().await?;
915 assert_eq!(token.token, "test-access-token");
916 assert_eq!(token.token_type, "test-token-type");
917 assert!(
918 token
919 .expires_at
920 .is_some_and(|d| d == now + Duration::from_secs(3600)),
921 "now: {:?}, expires_at: {:?}",
922 now,
923 token.expires_at
924 );
925
926 Ok(())
927 }
928
929 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
930 async fn credential_full_with_quota_project() -> TestResult {
931 let server = Server::run();
932 let response = Oauth2RefreshResponse {
933 access_token: "test-access-token".to_string(),
934 id_token: None,
935 expires_in: Some(3600),
936 refresh_token: Some("test-refresh-token".to_string()),
937 scope: None,
938 token_type: "test-token-type".to_string(),
939 };
940 server.expect(
941 Expectation::matching(all_of![
942 request::path("/token"),
943 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
944 check_request(req, None)
945 }))
946 ])
947 .respond_with(json_encoded(response)),
948 );
949
950 let authorized_user = serde_json::json!({
951 "client_id": "test-client-id",
952 "client_secret": "test-client-secret",
953 "refresh_token": "test-refresh-token",
954 "type": "authorized_user",
955 "token_uri": server.url("/token").to_string(),
956 });
957 let cred = Builder::new(authorized_user)
958 .with_quota_project_id("test-project")
959 .build()?;
960
961 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
962 let token = headers.get(AUTHORIZATION).unwrap();
963 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
964
965 assert_eq!(headers.len(), 2, "{headers:?}");
966 assert_eq!(
967 token,
968 HeaderValue::from_static("test-token-type test-access-token")
969 );
970 assert!(token.is_sensitive());
971 assert_eq!(
972 quota_project_header,
973 HeaderValue::from_static("test-project")
974 );
975 assert!(!quota_project_header.is_sensitive());
976
977 Ok(())
978 }
979
980 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
981 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
982 let mut server = Server::run();
983 let response = Oauth2RefreshResponse {
984 access_token: "test-access-token".to_string(),
985 id_token: None,
986 expires_in: Some(3600),
987 refresh_token: Some("test-refresh-token".to_string()),
988 scope: Some("scope1 scope2".to_string()),
989 token_type: "test-token-type".to_string(),
990 };
991 server.expect(
992 Expectation::matching(all_of![
993 request::path("/token"),
994 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
995 check_request(req, Some("scope1 scope2".to_string()))
996 }))
997 ])
998 .times(1)
999 .respond_with(json_encoded(response)),
1000 );
1001
1002 let json = serde_json::json!({
1003 "client_id": "test-client-id",
1004 "client_secret": "test-client-secret",
1005 "refresh_token": "test-refresh-token",
1006 "type": "authorized_user",
1007 "universe_domain": "googleapis.com",
1008 "quota_project_id": "test-project",
1009 "token_uri": server.url("/token").to_string(),
1010 });
1011
1012 let cred = Builder::new(json)
1013 .with_scopes(vec!["scope1", "scope2"])
1014 .build()?;
1015
1016 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1017 assert_eq!(token.unwrap(), "test-access-token");
1018
1019 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1020 assert_eq!(token.unwrap(), "test-access-token");
1021
1022 server.verify_and_clear();
1023
1024 Ok(())
1025 }
1026
1027 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1028 async fn credential_provider_partial() -> TestResult {
1029 let server = Server::run();
1030 let response = Oauth2RefreshResponse {
1031 access_token: "test-access-token".to_string(),
1032 id_token: None,
1033 expires_in: None,
1034 refresh_token: None,
1035 scope: None,
1036 token_type: "test-token-type".to_string(),
1037 };
1038 server.expect(
1039 Expectation::matching(all_of![
1040 request::path("/token"),
1041 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1042 check_request(req, None)
1043 }))
1044 ])
1045 .respond_with(json_encoded(response)),
1046 );
1047
1048 let authorized_user = serde_json::json!({
1049 "client_id": "test-client-id",
1050 "client_secret": "test-client-secret",
1051 "refresh_token": "test-refresh-token",
1052 "type": "authorized_user",
1053 "token_uri": server.url("/token").to_string()
1054 });
1055
1056 let uc = Builder::new(authorized_user).build()?;
1057 let headers = uc.headers(Extensions::new()).await?;
1058 assert_eq!(
1059 get_token_from_headers(headers.clone()).unwrap(),
1060 "test-access-token"
1061 );
1062 assert_eq!(
1063 get_token_type_from_headers(headers).unwrap(),
1064 "test-token-type"
1065 );
1066
1067 Ok(())
1068 }
1069
1070 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1071 async fn credential_provider_with_token_uri() -> TestResult {
1072 let server = Server::run();
1073 let response = Oauth2RefreshResponse {
1074 access_token: "test-access-token".to_string(),
1075 id_token: None,
1076 expires_in: None,
1077 refresh_token: None,
1078 scope: None,
1079 token_type: "test-token-type".to_string(),
1080 };
1081 server.expect(
1082 Expectation::matching(all_of![
1083 request::path("/token"),
1084 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1085 check_request(req, None)
1086 }))
1087 ])
1088 .respond_with(json_encoded(response)),
1089 );
1090
1091 let authorized_user = serde_json::json!({
1092 "client_id": "test-client-id",
1093 "client_secret": "test-client-secret",
1094 "refresh_token": "test-refresh-token",
1095 "type": "authorized_user",
1096 "token_uri": "test-endpoint"
1097 });
1098
1099 let uc = Builder::new(authorized_user)
1100 .with_token_uri(server.url("/token").to_string())
1101 .build()?;
1102 let headers = uc.headers(Extensions::new()).await?;
1103 assert_eq!(
1104 get_token_from_headers(headers.clone()).unwrap(),
1105 "test-access-token"
1106 );
1107 assert_eq!(
1108 get_token_type_from_headers(headers).unwrap(),
1109 "test-token-type"
1110 );
1111
1112 Ok(())
1113 }
1114
1115 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1116 async fn credential_provider_with_scopes() -> TestResult {
1117 let server = Server::run();
1118 let response = Oauth2RefreshResponse {
1119 access_token: "test-access-token".to_string(),
1120 id_token: None,
1121 expires_in: None,
1122 refresh_token: None,
1123 scope: Some("scope1 scope2".to_string()),
1124 token_type: "test-token-type".to_string(),
1125 };
1126 server.expect(
1127 Expectation::matching(all_of![
1128 request::path("/token"),
1129 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1130 check_request(req, Some("scope1 scope2".to_string()))
1131 }))
1132 ])
1133 .respond_with(json_encoded(response)),
1134 );
1135
1136 let authorized_user = serde_json::json!({
1137 "client_id": "test-client-id",
1138 "client_secret": "test-client-secret",
1139 "refresh_token": "test-refresh-token",
1140 "type": "authorized_user",
1141 "token_uri": "test-endpoint"
1142 });
1143
1144 let uc = Builder::new(authorized_user)
1145 .with_token_uri(server.url("/token").to_string())
1146 .with_scopes(vec!["scope1", "scope2"])
1147 .build()?;
1148 let headers = uc.headers(Extensions::new()).await?;
1149 assert_eq!(
1150 get_token_from_headers(headers.clone()).unwrap(),
1151 "test-access-token"
1152 );
1153 assert_eq!(
1154 get_token_type_from_headers(headers).unwrap(),
1155 "test-token-type"
1156 );
1157
1158 Ok(())
1159 }
1160
1161 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1162 async fn credential_provider_retryable_error() -> TestResult {
1163 let server = Server::run();
1164 server
1165 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1166
1167 let authorized_user = serde_json::json!({
1168 "client_id": "test-client-id",
1169 "client_secret": "test-client-secret",
1170 "refresh_token": "test-refresh-token",
1171 "type": "authorized_user",
1172 "token_uri": server.url("/token").to_string()
1173 });
1174
1175 let uc = Builder::new(authorized_user).build()?;
1176 let err = uc.headers(Extensions::new()).await.unwrap_err();
1177 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1178 assert!(original_err.is_transient());
1179
1180 let source = find_source_error::<reqwest::Error>(&err);
1181 assert!(
1182 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1183 "{err:?}"
1184 );
1185
1186 Ok(())
1187 }
1188
1189 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1190 async fn token_provider_nonretryable_error() -> TestResult {
1191 let server = Server::run();
1192 server
1193 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1194
1195 let authorized_user = serde_json::json!({
1196 "client_id": "test-client-id",
1197 "client_secret": "test-client-secret",
1198 "refresh_token": "test-refresh-token",
1199 "type": "authorized_user",
1200 "token_uri": server.url("/token").to_string()
1201 });
1202
1203 let uc = Builder::new(authorized_user).build()?;
1204 let err = uc.headers(Extensions::new()).await.unwrap_err();
1205 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1206 assert!(!original_err.is_transient());
1207
1208 let source = find_source_error::<reqwest::Error>(&err);
1209 assert!(
1210 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1211 "{err:?}"
1212 );
1213
1214 Ok(())
1215 }
1216
1217 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1218 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1219 let server = Server::run();
1220 server.expect(
1221 Expectation::matching(request::path("/token"))
1222 .respond_with(json_encoded("bad json".to_string())),
1223 );
1224
1225 let authorized_user = serde_json::json!({
1226 "client_id": "test-client-id",
1227 "client_secret": "test-client-secret",
1228 "refresh_token": "test-refresh-token",
1229 "type": "authorized_user",
1230 "token_uri": server.url("/token").to_string()
1231 });
1232
1233 let uc = Builder::new(authorized_user).build()?;
1234 let e = uc.headers(Extensions::new()).await.err().unwrap();
1235 assert!(!e.is_transient(), "{e}");
1236
1237 Ok(())
1238 }
1239
1240 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1241 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1242 let authorized_user = serde_json::json!({
1243 "client_secret": "test-client-secret",
1244 "refresh_token": "test-refresh-token",
1245 "type": "authorized_user",
1246 });
1247
1248 let e = Builder::new(authorized_user).build().unwrap_err();
1249 assert!(e.is_parsing(), "{e}");
1250
1251 Ok(())
1252 }
1253}