1use crate::build_errors::Error as BuilderError;
98use crate::credentials::dynamic::CredentialsProvider;
99use crate::credentials::{CacheableResource, Credentials};
100use crate::errors::{self, CredentialsError};
101use crate::headers_util::build_cacheable_headers;
102use crate::retry::Builder as RetryTokenProviderBuilder;
103use crate::token::{CachedTokenProvider, Token, TokenProvider};
104use crate::token_cache::TokenCache;
105use crate::{BuildResult, Result};
106use gax::backoff_policy::BackoffPolicyArg;
107use gax::retry_policy::RetryPolicyArg;
108use gax::retry_throttler::RetryThrottlerArg;
109use http::header::CONTENT_TYPE;
110use http::{Extensions, HeaderMap, HeaderValue};
111use reqwest::{Client, Method};
112use serde_json::Value;
113use std::sync::Arc;
114use tokio::time::{Duration, Instant};
115
116const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
117
118pub struct Builder {
129 authorized_user: Value,
130 scopes: Option<Vec<String>>,
131 quota_project_id: Option<String>,
132 token_uri: Option<String>,
133 retry_builder: RetryTokenProviderBuilder,
134}
135
136impl Builder {
137 pub fn new(authorized_user: Value) -> Self {
144 Self {
145 authorized_user,
146 scopes: None,
147 quota_project_id: None,
148 token_uri: None,
149 retry_builder: RetryTokenProviderBuilder::default(),
150 }
151 }
152
153 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
167 self.token_uri = Some(token_uri.into());
168 self
169 }
170
171 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
192 where
193 I: IntoIterator<Item = S>,
194 S: Into<String>,
195 {
196 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
197 self
198 }
199
200 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
221 self.quota_project_id = Some(quota_project_id.into());
222 self
223 }
224
225 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
246 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
247 self
248 }
249
250 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
271 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
272 self
273 }
274
275 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
302 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
303 self
304 }
305
306 pub fn build(self) -> BuildResult<Credentials> {
319 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
320 .map_err(BuilderError::parsing)?;
321 let endpoint = self
322 .token_uri
323 .or(authorized_user.token_uri)
324 .unwrap_or(OAUTH2_ENDPOINT.to_string());
325 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
326
327 let token_provider = UserTokenProvider {
328 client_id: authorized_user.client_id,
329 client_secret: authorized_user.client_secret,
330 refresh_token: authorized_user.refresh_token,
331 endpoint,
332 scopes: self.scopes.map(|scopes| scopes.join(" ")),
333 };
334
335 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337 Ok(Credentials {
338 inner: Arc::new(UserCredentials {
339 token_provider,
340 quota_project_id,
341 }),
342 })
343 }
344}
345
346#[derive(PartialEq)]
347struct UserTokenProvider {
348 client_id: String,
349 client_secret: String,
350 refresh_token: String,
351 endpoint: String,
352 scopes: Option<String>,
353}
354
355impl std::fmt::Debug for UserTokenProvider {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 f.debug_struct("UserCredentials")
358 .field("client_id", &self.client_id)
359 .field("client_secret", &"[censored]")
360 .field("refresh_token", &"[censored]")
361 .field("endpoint", &self.endpoint)
362 .field("scopes", &self.scopes)
363 .finish()
364 }
365}
366
367#[async_trait::async_trait]
368impl TokenProvider for UserTokenProvider {
369 async fn token(&self) -> Result<Token> {
370 let client = Client::new();
371
372 let req = Oauth2RefreshRequest {
374 grant_type: RefreshGrantType::RefreshToken,
375 client_id: self.client_id.clone(),
376 client_secret: self.client_secret.clone(),
377 refresh_token: self.refresh_token.clone(),
378 scopes: self.scopes.clone(),
379 };
380 let header = HeaderValue::from_static("application/json");
381 let builder = client
382 .request(Method::POST, self.endpoint.as_str())
383 .header(CONTENT_TYPE, header)
384 .json(&req);
385 let resp = builder
386 .send()
387 .await
388 .map_err(|e| errors::from_http_error(e, MSG))?;
389
390 if !resp.status().is_success() {
392 let err = errors::from_http_response(resp, MSG).await;
393 return Err(err);
394 }
395 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
396 let retryable = !e.is_decode();
397 CredentialsError::from_source(retryable, e)
398 })?;
399 let token = Token {
400 token: response.access_token,
401 token_type: response.token_type,
402 expires_at: response
403 .expires_in
404 .map(|d| Instant::now() + Duration::from_secs(d)),
405 metadata: None,
406 };
407 Ok(token)
408 }
409}
410
411const MSG: &str = "failed to refresh user access token";
412
413#[derive(Debug)]
417pub(crate) struct UserCredentials<T>
418where
419 T: CachedTokenProvider,
420{
421 token_provider: T,
422 quota_project_id: Option<String>,
423}
424
425#[async_trait::async_trait]
426impl<T> CredentialsProvider for UserCredentials<T>
427where
428 T: CachedTokenProvider,
429{
430 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
431 let token = self.token_provider.token(extensions).await?;
432 build_cacheable_headers(&token, &self.quota_project_id)
433 }
434}
435
436#[derive(Debug, PartialEq, serde::Deserialize)]
437pub(crate) struct AuthorizedUser {
438 #[serde(rename = "type")]
439 cred_type: String,
440 client_id: String,
441 client_secret: String,
442 refresh_token: String,
443 #[serde(skip_serializing_if = "Option::is_none")]
444 token_uri: Option<String>,
445 #[serde(skip_serializing_if = "Option::is_none")]
446 quota_project_id: Option<String>,
447}
448
449#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
450enum RefreshGrantType {
451 #[serde(rename = "refresh_token")]
452 RefreshToken,
453}
454
455#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
456struct Oauth2RefreshRequest {
457 grant_type: RefreshGrantType,
458 client_id: String,
459 client_secret: String,
460 refresh_token: String,
461 scopes: Option<String>,
462}
463
464#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
465struct Oauth2RefreshResponse {
466 access_token: String,
467 #[serde(skip_serializing_if = "Option::is_none")]
468 scope: Option<String>,
469 #[serde(skip_serializing_if = "Option::is_none")]
470 expires_in: Option<u64>,
471 token_type: String,
472 #[serde(skip_serializing_if = "Option::is_none")]
473 refresh_token: Option<String>,
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use crate::credentials::tests::{
480 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
481 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
482 get_token_type_from_headers,
483 };
484 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
485 use crate::errors::CredentialsError;
486 use crate::token::tests::MockTokenProvider;
487 use http::StatusCode;
488 use http::header::AUTHORIZATION;
489 use httptest::cycle;
490 use httptest::matchers::{all_of, json_decoded, request};
491 use httptest::responders::{json_encoded, status_code};
492 use httptest::{Expectation, Server};
493
494 type TestResult = anyhow::Result<()>;
495
496 fn authorized_user_json(token_uri: String) -> Value {
497 serde_json::json!({
498 "client_id": "test-client-id",
499 "client_secret": "test-client-secret",
500 "refresh_token": "test-refresh-token",
501 "type": "authorized_user",
502 "token_uri": token_uri,
503 })
504 }
505
506 #[tokio::test]
507 async fn test_user_account_retries_on_transient_failures() -> TestResult {
508 let mut server = Server::run();
509 server.expect(
510 Expectation::matching(request::path("/token"))
511 .times(3)
512 .respond_with(status_code(503)),
513 );
514
515 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
516 .with_retry_policy(get_mock_auth_retry_policy(3))
517 .with_backoff_policy(get_mock_backoff_policy())
518 .with_retry_throttler(get_mock_retry_throttler())
519 .build()?;
520
521 let err = credentials.headers(Extensions::new()).await.unwrap_err();
522 assert!(!err.is_transient());
523 server.verify_and_clear();
524 Ok(())
525 }
526
527 #[tokio::test]
528 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
529 let mut server = Server::run();
530 server.expect(
531 Expectation::matching(request::path("/token"))
532 .times(1)
533 .respond_with(status_code(401)),
534 );
535
536 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
537 .with_retry_policy(get_mock_auth_retry_policy(1))
538 .with_backoff_policy(get_mock_backoff_policy())
539 .with_retry_throttler(get_mock_retry_throttler())
540 .build()?;
541
542 let err = credentials.headers(Extensions::new()).await.unwrap_err();
543 assert!(!err.is_transient());
544 server.verify_and_clear();
545 Ok(())
546 }
547
548 #[tokio::test]
549 async fn test_user_account_retries_for_success() -> TestResult {
550 let mut server = Server::run();
551 let response = Oauth2RefreshResponse {
552 access_token: "test-access-token".to_string(),
553 expires_in: Some(3600),
554 refresh_token: Some("test-refresh-token".to_string()),
555 scope: Some("scope1 scope2".to_string()),
556 token_type: "test-token-type".to_string(),
557 };
558
559 server.expect(
560 Expectation::matching(request::path("/token"))
561 .times(3)
562 .respond_with(cycle![
563 status_code(503).body("try-again"),
564 status_code(503).body("try-again"),
565 status_code(200)
566 .append_header("Content-Type", "application/json")
567 .body(serde_json::to_string(&response).unwrap()),
568 ]),
569 );
570
571 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
572 .with_retry_policy(get_mock_auth_retry_policy(3))
573 .with_backoff_policy(get_mock_backoff_policy())
574 .with_retry_throttler(get_mock_retry_throttler())
575 .build()?;
576
577 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
578 assert_eq!(token.unwrap(), "test-access-token");
579
580 server.verify_and_clear();
581 Ok(())
582 }
583
584 #[test]
585 fn debug_token_provider() {
586 let expected = UserTokenProvider {
587 client_id: "test-client-id".to_string(),
588 client_secret: "test-client-secret".to_string(),
589 refresh_token: "test-refresh-token".to_string(),
590 endpoint: OAUTH2_ENDPOINT.to_string(),
591 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
592 };
593 let fmt = format!("{expected:?}");
594 assert!(fmt.contains("test-client-id"), "{fmt}");
595 assert!(!fmt.contains("test-client-secret"), "{fmt}");
596 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
597 assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
598 assert!(
599 fmt.contains("https://www.googleapis.com/auth/pubsub"),
600 "{fmt}"
601 );
602 }
603
604 #[test]
605 fn authorized_user_full_from_json_success() {
606 let json = serde_json::json!({
607 "account": "",
608 "client_id": "test-client-id",
609 "client_secret": "test-client-secret",
610 "refresh_token": "test-refresh-token",
611 "type": "authorized_user",
612 "universe_domain": "googleapis.com",
613 "quota_project_id": "test-project",
614 "token_uri" : "test-token-uri",
615 });
616
617 let expected = AuthorizedUser {
618 cred_type: "authorized_user".to_string(),
619 client_id: "test-client-id".to_string(),
620 client_secret: "test-client-secret".to_string(),
621 refresh_token: "test-refresh-token".to_string(),
622 quota_project_id: Some("test-project".to_string()),
623 token_uri: Some("test-token-uri".to_string()),
624 };
625 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
626 assert_eq!(actual, expected);
627 }
628
629 #[test]
630 fn authorized_user_partial_from_json_success() {
631 let json = serde_json::json!({
632 "client_id": "test-client-id",
633 "client_secret": "test-client-secret",
634 "refresh_token": "test-refresh-token",
635 "type": "authorized_user",
636 });
637
638 let expected = AuthorizedUser {
639 cred_type: "authorized_user".to_string(),
640 client_id: "test-client-id".to_string(),
641 client_secret: "test-client-secret".to_string(),
642 refresh_token: "test-refresh-token".to_string(),
643 quota_project_id: None,
644 token_uri: None,
645 };
646 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
647 assert_eq!(actual, expected);
648 }
649
650 #[test]
651 fn authorized_user_from_json_parse_fail() {
652 let json_full = serde_json::json!({
653 "client_id": "test-client-id",
654 "client_secret": "test-client-secret",
655 "refresh_token": "test-refresh-token",
656 "type": "authorized_user",
657 "quota_project_id": "test-project"
658 });
659
660 for required_field in ["client_id", "client_secret", "refresh_token"] {
661 let mut json = json_full.clone();
662 json[required_field].take();
664 serde_json::from_value::<AuthorizedUser>(json)
665 .err()
666 .unwrap();
667 }
668 }
669
670 #[tokio::test]
671 async fn default_universe_domain_success() {
672 let mock = TokenCache::new(MockTokenProvider::new());
673
674 let uc = UserCredentials {
675 token_provider: mock,
676 quota_project_id: None,
677 };
678 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
679 }
680
681 #[tokio::test]
682 async fn headers_success() -> TestResult {
683 let token = Token {
684 token: "test-token".to_string(),
685 token_type: "Bearer".to_string(),
686 expires_at: None,
687 metadata: None,
688 };
689
690 let mut mock = MockTokenProvider::new();
691 mock.expect_token().times(1).return_once(|| Ok(token));
692
693 let uc = UserCredentials {
694 token_provider: TokenCache::new(mock),
695 quota_project_id: None,
696 };
697
698 let mut extensions = Extensions::new();
699 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
700 let (headers, entity_tag) = match cached_headers {
701 CacheableResource::New { entity_tag, data } => (data, entity_tag),
702 CacheableResource::NotModified => unreachable!("expecting new headers"),
703 };
704 let token = headers.get(AUTHORIZATION).unwrap();
705
706 assert_eq!(headers.len(), 1, "{headers:?}");
707 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
708 assert!(token.is_sensitive());
709
710 extensions.insert(entity_tag);
711
712 let cached_headers = uc.headers(extensions).await?;
713
714 match cached_headers {
715 CacheableResource::New { .. } => unreachable!("expecting new headers"),
716 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
717 };
718 Ok(())
719 }
720
721 #[tokio::test]
722 async fn headers_failure() {
723 let mut mock = MockTokenProvider::new();
724 mock.expect_token()
725 .times(1)
726 .return_once(|| Err(errors::non_retryable_from_str("fail")));
727
728 let uc = UserCredentials {
729 token_provider: TokenCache::new(mock),
730 quota_project_id: None,
731 };
732 assert!(uc.headers(Extensions::new()).await.is_err());
733 }
734
735 #[tokio::test]
736 async fn headers_with_quota_project_success() -> TestResult {
737 let token = Token {
738 token: "test-token".to_string(),
739 token_type: "Bearer".to_string(),
740 expires_at: None,
741 metadata: None,
742 };
743
744 let mut mock = MockTokenProvider::new();
745 mock.expect_token().times(1).return_once(|| Ok(token));
746
747 let uc = UserCredentials {
748 token_provider: TokenCache::new(mock),
749 quota_project_id: Some("test-project".to_string()),
750 };
751
752 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
753 let token = headers.get(AUTHORIZATION).unwrap();
754 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
755
756 assert_eq!(headers.len(), 2, "{headers:?}");
757 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
758 assert!(token.is_sensitive());
759 assert_eq!(
760 quota_project_header,
761 HeaderValue::from_static("test-project")
762 );
763 assert!(!quota_project_header.is_sensitive());
764 Ok(())
765 }
766
767 #[test]
768 fn oauth2_request_serde() {
769 let request = Oauth2RefreshRequest {
770 grant_type: RefreshGrantType::RefreshToken,
771 client_id: "test-client-id".to_string(),
772 client_secret: "test-client-secret".to_string(),
773 refresh_token: "test-refresh-token".to_string(),
774 scopes: Some("scope1 scope2".to_string()),
775 };
776
777 let json = serde_json::to_value(&request).unwrap();
778 let expected = serde_json::json!({
779 "grant_type": "refresh_token",
780 "client_id": "test-client-id",
781 "client_secret": "test-client-secret",
782 "refresh_token": "test-refresh-token",
783 "scopes": "scope1 scope2",
784 });
785 assert_eq!(json, expected);
786 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
787 assert_eq!(request, roundtrip);
788 }
789
790 #[test]
791 fn oauth2_response_serde_full() {
792 let response = Oauth2RefreshResponse {
793 access_token: "test-access-token".to_string(),
794 scope: Some("scope1 scope2".to_string()),
795 expires_in: Some(3600),
796 token_type: "test-token-type".to_string(),
797 refresh_token: Some("test-refresh-token".to_string()),
798 };
799
800 let json = serde_json::to_value(&response).unwrap();
801 let expected = serde_json::json!({
802 "access_token": "test-access-token",
803 "scope": "scope1 scope2",
804 "expires_in": 3600,
805 "token_type": "test-token-type",
806 "refresh_token": "test-refresh-token"
807 });
808 assert_eq!(json, expected);
809 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
810 assert_eq!(response, roundtrip);
811 }
812
813 #[test]
814 fn oauth2_response_serde_partial() {
815 let response = Oauth2RefreshResponse {
816 access_token: "test-access-token".to_string(),
817 scope: None,
818 expires_in: None,
819 token_type: "test-token-type".to_string(),
820 refresh_token: None,
821 };
822
823 let json = serde_json::to_value(&response).unwrap();
824 let expected = serde_json::json!({
825 "access_token": "test-access-token",
826 "token_type": "test-token-type",
827 });
828 assert_eq!(json, expected);
829 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
830 assert_eq!(response, roundtrip);
831 }
832
833 fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
834 request.client_id == "test-client-id"
835 && request.client_secret == "test-client-secret"
836 && request.refresh_token == "test-refresh-token"
837 && request.grant_type == RefreshGrantType::RefreshToken
838 && request.scopes == expected_scopes
839 }
840
841 #[tokio::test(start_paused = true)]
842 async fn token_provider_full() -> TestResult {
843 let server = Server::run();
844 let response = Oauth2RefreshResponse {
845 access_token: "test-access-token".to_string(),
846 expires_in: Some(3600),
847 refresh_token: Some("test-refresh-token".to_string()),
848 scope: Some("scope1 scope2".to_string()),
849 token_type: "test-token-type".to_string(),
850 };
851 server.expect(
852 Expectation::matching(all_of![
853 request::path("/token"),
854 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
855 check_request(req, Some("scope1 scope2".to_string()))
856 }))
857 ])
858 .respond_with(json_encoded(response)),
859 );
860
861 let tp = UserTokenProvider {
862 client_id: "test-client-id".to_string(),
863 client_secret: "test-client-secret".to_string(),
864 refresh_token: "test-refresh-token".to_string(),
865 endpoint: server.url("/token").to_string(),
866 scopes: Some("scope1 scope2".to_string()),
867 };
868 let now = Instant::now();
869 let token = tp.token().await?;
870 assert_eq!(token.token, "test-access-token");
871 assert_eq!(token.token_type, "test-token-type");
872 assert!(
873 token
874 .expires_at
875 .is_some_and(|d| d == now + Duration::from_secs(3600)),
876 "now: {:?}, expires_at: {:?}",
877 now,
878 token.expires_at
879 );
880
881 Ok(())
882 }
883
884 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
885 async fn credential_full_with_quota_project() -> TestResult {
886 let server = Server::run();
887 let response = Oauth2RefreshResponse {
888 access_token: "test-access-token".to_string(),
889 expires_in: Some(3600),
890 refresh_token: Some("test-refresh-token".to_string()),
891 scope: None,
892 token_type: "test-token-type".to_string(),
893 };
894 server.expect(
895 Expectation::matching(all_of![
896 request::path("/token"),
897 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
898 check_request(req, None)
899 }))
900 ])
901 .respond_with(json_encoded(response)),
902 );
903
904 let authorized_user = serde_json::json!({
905 "client_id": "test-client-id",
906 "client_secret": "test-client-secret",
907 "refresh_token": "test-refresh-token",
908 "type": "authorized_user",
909 "token_uri": server.url("/token").to_string(),
910 });
911 let cred = Builder::new(authorized_user)
912 .with_quota_project_id("test-project")
913 .build()?;
914
915 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
916 let token = headers.get(AUTHORIZATION).unwrap();
917 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
918
919 assert_eq!(headers.len(), 2, "{headers:?}");
920 assert_eq!(
921 token,
922 HeaderValue::from_static("test-token-type test-access-token")
923 );
924 assert!(token.is_sensitive());
925 assert_eq!(
926 quota_project_header,
927 HeaderValue::from_static("test-project")
928 );
929 assert!(!quota_project_header.is_sensitive());
930
931 Ok(())
932 }
933
934 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
935 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
936 let mut server = Server::run();
937 let response = Oauth2RefreshResponse {
938 access_token: "test-access-token".to_string(),
939 expires_in: Some(3600),
940 refresh_token: Some("test-refresh-token".to_string()),
941 scope: Some("scope1 scope2".to_string()),
942 token_type: "test-token-type".to_string(),
943 };
944 server.expect(
945 Expectation::matching(all_of![
946 request::path("/token"),
947 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
948 check_request(req, Some("scope1 scope2".to_string()))
949 }))
950 ])
951 .times(1)
952 .respond_with(json_encoded(response)),
953 );
954
955 let json = serde_json::json!({
956 "client_id": "test-client-id",
957 "client_secret": "test-client-secret",
958 "refresh_token": "test-refresh-token",
959 "type": "authorized_user",
960 "universe_domain": "googleapis.com",
961 "quota_project_id": "test-project",
962 "token_uri": server.url("/token").to_string(),
963 });
964
965 let cred = Builder::new(json)
966 .with_scopes(vec!["scope1", "scope2"])
967 .build()?;
968
969 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
970 assert_eq!(token.unwrap(), "test-access-token");
971
972 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
973 assert_eq!(token.unwrap(), "test-access-token");
974
975 server.verify_and_clear();
976
977 Ok(())
978 }
979
980 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
981 async fn credential_provider_partial() -> TestResult {
982 let server = Server::run();
983 let response = Oauth2RefreshResponse {
984 access_token: "test-access-token".to_string(),
985 expires_in: None,
986 refresh_token: None,
987 scope: None,
988 token_type: "test-token-type".to_string(),
989 };
990 server.expect(
991 Expectation::matching(all_of![
992 request::path("/token"),
993 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
994 check_request(req, None)
995 }))
996 ])
997 .respond_with(json_encoded(response)),
998 );
999
1000 let authorized_user = serde_json::json!({
1001 "client_id": "test-client-id",
1002 "client_secret": "test-client-secret",
1003 "refresh_token": "test-refresh-token",
1004 "type": "authorized_user",
1005 "token_uri": server.url("/token").to_string()
1006 });
1007
1008 let uc = Builder::new(authorized_user).build()?;
1009 let headers = uc.headers(Extensions::new()).await?;
1010 assert_eq!(
1011 get_token_from_headers(headers.clone()).unwrap(),
1012 "test-access-token"
1013 );
1014 assert_eq!(
1015 get_token_type_from_headers(headers).unwrap(),
1016 "test-token-type"
1017 );
1018
1019 Ok(())
1020 }
1021
1022 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1023 async fn credential_provider_with_token_uri() -> TestResult {
1024 let server = Server::run();
1025 let response = Oauth2RefreshResponse {
1026 access_token: "test-access-token".to_string(),
1027 expires_in: None,
1028 refresh_token: None,
1029 scope: None,
1030 token_type: "test-token-type".to_string(),
1031 };
1032 server.expect(
1033 Expectation::matching(all_of![
1034 request::path("/token"),
1035 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1036 check_request(req, None)
1037 }))
1038 ])
1039 .respond_with(json_encoded(response)),
1040 );
1041
1042 let authorized_user = serde_json::json!({
1043 "client_id": "test-client-id",
1044 "client_secret": "test-client-secret",
1045 "refresh_token": "test-refresh-token",
1046 "type": "authorized_user",
1047 "token_uri": "test-endpoint"
1048 });
1049
1050 let uc = Builder::new(authorized_user)
1051 .with_token_uri(server.url("/token").to_string())
1052 .build()?;
1053 let headers = uc.headers(Extensions::new()).await?;
1054 assert_eq!(
1055 get_token_from_headers(headers.clone()).unwrap(),
1056 "test-access-token"
1057 );
1058 assert_eq!(
1059 get_token_type_from_headers(headers).unwrap(),
1060 "test-token-type"
1061 );
1062
1063 Ok(())
1064 }
1065
1066 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1067 async fn credential_provider_with_scopes() -> TestResult {
1068 let server = Server::run();
1069 let response = Oauth2RefreshResponse {
1070 access_token: "test-access-token".to_string(),
1071 expires_in: None,
1072 refresh_token: None,
1073 scope: Some("scope1 scope2".to_string()),
1074 token_type: "test-token-type".to_string(),
1075 };
1076 server.expect(
1077 Expectation::matching(all_of![
1078 request::path("/token"),
1079 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1080 check_request(req, Some("scope1 scope2".to_string()))
1081 }))
1082 ])
1083 .respond_with(json_encoded(response)),
1084 );
1085
1086 let authorized_user = serde_json::json!({
1087 "client_id": "test-client-id",
1088 "client_secret": "test-client-secret",
1089 "refresh_token": "test-refresh-token",
1090 "type": "authorized_user",
1091 "token_uri": "test-endpoint"
1092 });
1093
1094 let uc = Builder::new(authorized_user)
1095 .with_token_uri(server.url("/token").to_string())
1096 .with_scopes(vec!["scope1", "scope2"])
1097 .build()?;
1098 let headers = uc.headers(Extensions::new()).await?;
1099 assert_eq!(
1100 get_token_from_headers(headers.clone()).unwrap(),
1101 "test-access-token"
1102 );
1103 assert_eq!(
1104 get_token_type_from_headers(headers).unwrap(),
1105 "test-token-type"
1106 );
1107
1108 Ok(())
1109 }
1110
1111 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1112 async fn credential_provider_retryable_error() -> TestResult {
1113 let server = Server::run();
1114 server
1115 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1116
1117 let authorized_user = serde_json::json!({
1118 "client_id": "test-client-id",
1119 "client_secret": "test-client-secret",
1120 "refresh_token": "test-refresh-token",
1121 "type": "authorized_user",
1122 "token_uri": server.url("/token").to_string()
1123 });
1124
1125 let uc = Builder::new(authorized_user).build()?;
1126 let err = uc.headers(Extensions::new()).await.unwrap_err();
1127 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1128 assert!(original_err.is_transient());
1129
1130 let source = find_source_error::<reqwest::Error>(&err);
1131 assert!(
1132 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1133 "{err:?}"
1134 );
1135
1136 Ok(())
1137 }
1138
1139 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1140 async fn token_provider_nonretryable_error() -> TestResult {
1141 let server = Server::run();
1142 server
1143 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1144
1145 let authorized_user = serde_json::json!({
1146 "client_id": "test-client-id",
1147 "client_secret": "test-client-secret",
1148 "refresh_token": "test-refresh-token",
1149 "type": "authorized_user",
1150 "token_uri": server.url("/token").to_string()
1151 });
1152
1153 let uc = Builder::new(authorized_user).build()?;
1154 let err = uc.headers(Extensions::new()).await.unwrap_err();
1155 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1156 assert!(!original_err.is_transient());
1157
1158 let source = find_source_error::<reqwest::Error>(&err);
1159 assert!(
1160 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1161 "{err:?}"
1162 );
1163
1164 Ok(())
1165 }
1166
1167 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1168 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1169 let server = Server::run();
1170 server.expect(
1171 Expectation::matching(request::path("/token"))
1172 .respond_with(json_encoded("bad json".to_string())),
1173 );
1174
1175 let authorized_user = serde_json::json!({
1176 "client_id": "test-client-id",
1177 "client_secret": "test-client-secret",
1178 "refresh_token": "test-refresh-token",
1179 "type": "authorized_user",
1180 "token_uri": server.url("/token").to_string()
1181 });
1182
1183 let uc = Builder::new(authorized_user).build()?;
1184 let e = uc.headers(Extensions::new()).await.err().unwrap();
1185 assert!(!e.is_transient(), "{e}");
1186
1187 Ok(())
1188 }
1189
1190 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1191 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1192 let authorized_user = serde_json::json!({
1193 "client_secret": "test-client-secret",
1194 "refresh_token": "test-refresh-token",
1195 "type": "authorized_user",
1196 });
1197
1198 let e = Builder::new(authorized_user).build().unwrap_err();
1199 assert!(e.is_parsing(), "{e}");
1200
1201 Ok(())
1202 }
1203}