1use crate::build_errors::Error as BuilderError;
98use crate::credentials::dynamic::CredentialsProvider;
99use crate::credentials::{CacheableResource, Credentials};
100use crate::errors::{self, CredentialsError};
101use crate::headers_util::build_cacheable_headers;
102use crate::retry::Builder as RetryTokenProviderBuilder;
103use crate::token::{CachedTokenProvider, Token, TokenProvider};
104use crate::token_cache::TokenCache;
105use crate::{BuildResult, Result};
106use gax::backoff_policy::BackoffPolicyArg;
107use gax::retry_policy::RetryPolicyArg;
108use gax::retry_throttler::RetryThrottlerArg;
109use http::header::CONTENT_TYPE;
110use http::{Extensions, HeaderMap, HeaderValue};
111use reqwest::{Client, Method};
112use serde_json::Value;
113use std::sync::Arc;
114use tokio::time::{Duration, Instant};
115
116const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
117
118pub struct Builder {
129 authorized_user: Value,
130 scopes: Option<Vec<String>>,
131 quota_project_id: Option<String>,
132 token_uri: Option<String>,
133 retry_builder: RetryTokenProviderBuilder,
134}
135
136impl Builder {
137 pub fn new(authorized_user: Value) -> Self {
144 Self {
145 authorized_user,
146 scopes: None,
147 quota_project_id: None,
148 token_uri: None,
149 retry_builder: RetryTokenProviderBuilder::default(),
150 }
151 }
152
153 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
167 self.token_uri = Some(token_uri.into());
168 self
169 }
170
171 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
192 where
193 I: IntoIterator<Item = S>,
194 S: Into<String>,
195 {
196 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
197 self
198 }
199
200 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
221 self.quota_project_id = Some(quota_project_id.into());
222 self
223 }
224
225 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
246 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
247 self
248 }
249
250 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
271 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
272 self
273 }
274
275 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
302 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
303 self
304 }
305
306 pub fn build(self) -> BuildResult<Credentials> {
319 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
320 .map_err(BuilderError::parsing)?;
321 let endpoint = self
322 .token_uri
323 .or(authorized_user.token_uri)
324 .unwrap_or(OAUTH2_ENDPOINT.to_string());
325 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
326
327 let token_provider = UserTokenProvider {
328 client_id: authorized_user.client_id,
329 client_secret: authorized_user.client_secret,
330 refresh_token: authorized_user.refresh_token,
331 endpoint,
332 scopes: self.scopes.map(|scopes| scopes.join(" ")),
333 };
334
335 let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337 Ok(Credentials {
338 inner: Arc::new(UserCredentials {
339 token_provider,
340 quota_project_id,
341 }),
342 })
343 }
344}
345
346#[derive(PartialEq)]
347struct UserTokenProvider {
348 client_id: String,
349 client_secret: String,
350 refresh_token: String,
351 endpoint: String,
352 scopes: Option<String>,
353}
354
355impl std::fmt::Debug for UserTokenProvider {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 f.debug_struct("UserCredentials")
358 .field("client_id", &self.client_id)
359 .field("client_secret", &"[censored]")
360 .field("refresh_token", &"[censored]")
361 .field("endpoint", &self.endpoint)
362 .field("scopes", &self.scopes)
363 .finish()
364 }
365}
366
367#[async_trait::async_trait]
368impl TokenProvider for UserTokenProvider {
369 async fn token(&self) -> Result<Token> {
370 let client = Client::new();
371
372 let req = Oauth2RefreshRequest {
374 grant_type: RefreshGrantType::RefreshToken,
375 client_id: self.client_id.clone(),
376 client_secret: self.client_secret.clone(),
377 refresh_token: self.refresh_token.clone(),
378 scopes: self.scopes.clone(),
379 };
380 let header = HeaderValue::from_static("application/json");
381 let builder = client
382 .request(Method::POST, self.endpoint.as_str())
383 .header(CONTENT_TYPE, header)
384 .json(&req);
385 let resp = builder
386 .send()
387 .await
388 .map_err(|e| errors::from_http_error(e, MSG))?;
389
390 if !resp.status().is_success() {
392 let err = errors::from_http_response(resp, MSG).await;
393 return Err(err);
394 }
395 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
396 let retryable = !e.is_decode();
397 CredentialsError::from_source(retryable, e)
398 })?;
399 let token = Token {
400 token: response.access_token,
401 token_type: response.token_type,
402 expires_at: response
403 .expires_in
404 .map(|d| Instant::now() + Duration::from_secs(d)),
405 metadata: None,
406 };
407 Ok(token)
408 }
409}
410
411const MSG: &str = "failed to refresh user access token";
412
413#[derive(Debug)]
417pub(crate) struct UserCredentials<T>
418where
419 T: CachedTokenProvider,
420{
421 token_provider: T,
422 quota_project_id: Option<String>,
423}
424
425#[async_trait::async_trait]
426impl<T> CredentialsProvider for UserCredentials<T>
427where
428 T: CachedTokenProvider,
429{
430 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
431 let token = self.token_provider.token(extensions).await?;
432 build_cacheable_headers(&token, &self.quota_project_id)
433 }
434}
435
436#[derive(Debug, PartialEq, serde::Deserialize)]
437pub(crate) struct AuthorizedUser {
438 #[serde(rename = "type")]
439 cred_type: String,
440 client_id: String,
441 client_secret: String,
442 refresh_token: String,
443 #[serde(skip_serializing_if = "Option::is_none")]
444 token_uri: Option<String>,
445 #[serde(skip_serializing_if = "Option::is_none")]
446 quota_project_id: Option<String>,
447}
448
449#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
450enum RefreshGrantType {
451 #[serde(rename = "refresh_token")]
452 RefreshToken,
453}
454
455#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
456struct Oauth2RefreshRequest {
457 grant_type: RefreshGrantType,
458 client_id: String,
459 client_secret: String,
460 refresh_token: String,
461 scopes: Option<String>,
462}
463
464#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
465struct Oauth2RefreshResponse {
466 access_token: String,
467 #[serde(skip_serializing_if = "Option::is_none")]
468 scope: Option<String>,
469 #[serde(skip_serializing_if = "Option::is_none")]
470 expires_in: Option<u64>,
471 token_type: String,
472 #[serde(skip_serializing_if = "Option::is_none")]
473 refresh_token: Option<String>,
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use crate::credentials::tests::{
480 get_headers_from_cache, get_mock_auth_retry_policy, get_mock_backoff_policy,
481 get_mock_retry_throttler, get_token_from_headers, get_token_type_from_headers,
482 };
483 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
484 use crate::token::tests::MockTokenProvider;
485 use http::StatusCode;
486 use http::header::AUTHORIZATION;
487 use httptest::matchers::{all_of, json_decoded, request};
488 use httptest::responders::{json_encoded, status_code};
489 use httptest::{Expectation, Server, cycle};
490 use std::error::Error;
491
492 type TestResult = anyhow::Result<()>;
493
494 fn authorized_user_json(token_uri: String) -> Value {
495 serde_json::json!({
496 "client_id": "test-client-id",
497 "client_secret": "test-client-secret",
498 "refresh_token": "test-refresh-token",
499 "type": "authorized_user",
500 "token_uri": token_uri,
501 })
502 }
503
504 #[tokio::test]
505 async fn test_user_account_retries_on_transient_failures() -> TestResult {
506 let mut server = Server::run();
507 server.expect(
508 Expectation::matching(request::path("/token"))
509 .times(3)
510 .respond_with(status_code(503)),
511 );
512
513 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
514 .with_retry_policy(get_mock_auth_retry_policy(3))
515 .with_backoff_policy(get_mock_backoff_policy())
516 .with_retry_throttler(get_mock_retry_throttler())
517 .build()?;
518
519 let err = credentials.headers(Extensions::new()).await.unwrap_err();
520 assert!(err.is_transient());
521 server.verify_and_clear();
522 Ok(())
523 }
524
525 #[tokio::test]
526 async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
527 let mut server = Server::run();
528 server.expect(
529 Expectation::matching(request::path("/token"))
530 .times(1)
531 .respond_with(status_code(401)),
532 );
533
534 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
535 .with_retry_policy(get_mock_auth_retry_policy(1))
536 .with_backoff_policy(get_mock_backoff_policy())
537 .with_retry_throttler(get_mock_retry_throttler())
538 .build()?;
539
540 let err = credentials.headers(Extensions::new()).await.unwrap_err();
541 assert!(!err.is_transient());
542 server.verify_and_clear();
543 Ok(())
544 }
545
546 #[tokio::test]
547 async fn test_user_account_retries_for_success() -> TestResult {
548 let mut server = Server::run();
549 let response = Oauth2RefreshResponse {
550 access_token: "test-access-token".to_string(),
551 expires_in: Some(3600),
552 refresh_token: Some("test-refresh-token".to_string()),
553 scope: Some("scope1 scope2".to_string()),
554 token_type: "test-token-type".to_string(),
555 };
556
557 server.expect(
558 Expectation::matching(request::path("/token"))
559 .times(3)
560 .respond_with(cycle![
561 status_code(503).body("try-again"),
562 status_code(503).body("try-again"),
563 status_code(200)
564 .append_header("Content-Type", "application/json")
565 .body(serde_json::to_string(&response).unwrap()),
566 ]),
567 );
568
569 let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
570 .with_retry_policy(get_mock_auth_retry_policy(3))
571 .with_backoff_policy(get_mock_backoff_policy())
572 .with_retry_throttler(get_mock_retry_throttler())
573 .build()?;
574
575 let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
576 assert_eq!(token.unwrap(), "test-access-token");
577
578 server.verify_and_clear();
579 Ok(())
580 }
581
582 #[test]
583 fn debug_token_provider() {
584 let expected = UserTokenProvider {
585 client_id: "test-client-id".to_string(),
586 client_secret: "test-client-secret".to_string(),
587 refresh_token: "test-refresh-token".to_string(),
588 endpoint: OAUTH2_ENDPOINT.to_string(),
589 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
590 };
591 let fmt = format!("{expected:?}");
592 assert!(fmt.contains("test-client-id"), "{fmt}");
593 assert!(!fmt.contains("test-client-secret"), "{fmt}");
594 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
595 assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
596 assert!(
597 fmt.contains("https://www.googleapis.com/auth/pubsub"),
598 "{fmt}"
599 );
600 }
601
602 #[test]
603 fn authorized_user_full_from_json_success() {
604 let json = serde_json::json!({
605 "account": "",
606 "client_id": "test-client-id",
607 "client_secret": "test-client-secret",
608 "refresh_token": "test-refresh-token",
609 "type": "authorized_user",
610 "universe_domain": "googleapis.com",
611 "quota_project_id": "test-project",
612 "token_uri" : "test-token-uri",
613 });
614
615 let expected = AuthorizedUser {
616 cred_type: "authorized_user".to_string(),
617 client_id: "test-client-id".to_string(),
618 client_secret: "test-client-secret".to_string(),
619 refresh_token: "test-refresh-token".to_string(),
620 quota_project_id: Some("test-project".to_string()),
621 token_uri: Some("test-token-uri".to_string()),
622 };
623 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
624 assert_eq!(actual, expected);
625 }
626
627 #[test]
628 fn authorized_user_partial_from_json_success() {
629 let json = serde_json::json!({
630 "client_id": "test-client-id",
631 "client_secret": "test-client-secret",
632 "refresh_token": "test-refresh-token",
633 "type": "authorized_user",
634 });
635
636 let expected = AuthorizedUser {
637 cred_type: "authorized_user".to_string(),
638 client_id: "test-client-id".to_string(),
639 client_secret: "test-client-secret".to_string(),
640 refresh_token: "test-refresh-token".to_string(),
641 quota_project_id: None,
642 token_uri: None,
643 };
644 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
645 assert_eq!(actual, expected);
646 }
647
648 #[test]
649 fn authorized_user_from_json_parse_fail() {
650 let json_full = serde_json::json!({
651 "client_id": "test-client-id",
652 "client_secret": "test-client-secret",
653 "refresh_token": "test-refresh-token",
654 "type": "authorized_user",
655 "quota_project_id": "test-project"
656 });
657
658 for required_field in ["client_id", "client_secret", "refresh_token"] {
659 let mut json = json_full.clone();
660 json[required_field].take();
662 serde_json::from_value::<AuthorizedUser>(json)
663 .err()
664 .unwrap();
665 }
666 }
667
668 #[tokio::test]
669 async fn default_universe_domain_success() {
670 let mock = TokenCache::new(MockTokenProvider::new());
671
672 let uc = UserCredentials {
673 token_provider: mock,
674 quota_project_id: None,
675 };
676 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
677 }
678
679 #[tokio::test]
680 async fn headers_success() -> TestResult {
681 let token = Token {
682 token: "test-token".to_string(),
683 token_type: "Bearer".to_string(),
684 expires_at: None,
685 metadata: None,
686 };
687
688 let mut mock = MockTokenProvider::new();
689 mock.expect_token().times(1).return_once(|| Ok(token));
690
691 let uc = UserCredentials {
692 token_provider: TokenCache::new(mock),
693 quota_project_id: None,
694 };
695
696 let mut extensions = Extensions::new();
697 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
698 let (headers, entity_tag) = match cached_headers {
699 CacheableResource::New { entity_tag, data } => (data, entity_tag),
700 CacheableResource::NotModified => unreachable!("expecting new headers"),
701 };
702 let token = headers.get(AUTHORIZATION).unwrap();
703
704 assert_eq!(headers.len(), 1, "{headers:?}");
705 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
706 assert!(token.is_sensitive());
707
708 extensions.insert(entity_tag);
709
710 let cached_headers = uc.headers(extensions).await?;
711
712 match cached_headers {
713 CacheableResource::New { .. } => unreachable!("expecting new headers"),
714 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
715 };
716 Ok(())
717 }
718
719 #[tokio::test]
720 async fn headers_failure() {
721 let mut mock = MockTokenProvider::new();
722 mock.expect_token()
723 .times(1)
724 .return_once(|| Err(errors::non_retryable_from_str("fail")));
725
726 let uc = UserCredentials {
727 token_provider: TokenCache::new(mock),
728 quota_project_id: None,
729 };
730 assert!(uc.headers(Extensions::new()).await.is_err());
731 }
732
733 #[tokio::test]
734 async fn headers_with_quota_project_success() -> TestResult {
735 let token = Token {
736 token: "test-token".to_string(),
737 token_type: "Bearer".to_string(),
738 expires_at: None,
739 metadata: None,
740 };
741
742 let mut mock = MockTokenProvider::new();
743 mock.expect_token().times(1).return_once(|| Ok(token));
744
745 let uc = UserCredentials {
746 token_provider: TokenCache::new(mock),
747 quota_project_id: Some("test-project".to_string()),
748 };
749
750 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
751 let token = headers.get(AUTHORIZATION).unwrap();
752 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
753
754 assert_eq!(headers.len(), 2, "{headers:?}");
755 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
756 assert!(token.is_sensitive());
757 assert_eq!(
758 quota_project_header,
759 HeaderValue::from_static("test-project")
760 );
761 assert!(!quota_project_header.is_sensitive());
762 Ok(())
763 }
764
765 #[test]
766 fn oauth2_request_serde() {
767 let request = Oauth2RefreshRequest {
768 grant_type: RefreshGrantType::RefreshToken,
769 client_id: "test-client-id".to_string(),
770 client_secret: "test-client-secret".to_string(),
771 refresh_token: "test-refresh-token".to_string(),
772 scopes: Some("scope1 scope2".to_string()),
773 };
774
775 let json = serde_json::to_value(&request).unwrap();
776 let expected = serde_json::json!({
777 "grant_type": "refresh_token",
778 "client_id": "test-client-id",
779 "client_secret": "test-client-secret",
780 "refresh_token": "test-refresh-token",
781 "scopes": "scope1 scope2",
782 });
783 assert_eq!(json, expected);
784 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
785 assert_eq!(request, roundtrip);
786 }
787
788 #[test]
789 fn oauth2_response_serde_full() {
790 let response = Oauth2RefreshResponse {
791 access_token: "test-access-token".to_string(),
792 scope: Some("scope1 scope2".to_string()),
793 expires_in: Some(3600),
794 token_type: "test-token-type".to_string(),
795 refresh_token: Some("test-refresh-token".to_string()),
796 };
797
798 let json = serde_json::to_value(&response).unwrap();
799 let expected = serde_json::json!({
800 "access_token": "test-access-token",
801 "scope": "scope1 scope2",
802 "expires_in": 3600,
803 "token_type": "test-token-type",
804 "refresh_token": "test-refresh-token"
805 });
806 assert_eq!(json, expected);
807 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
808 assert_eq!(response, roundtrip);
809 }
810
811 #[test]
812 fn oauth2_response_serde_partial() {
813 let response = Oauth2RefreshResponse {
814 access_token: "test-access-token".to_string(),
815 scope: None,
816 expires_in: None,
817 token_type: "test-token-type".to_string(),
818 refresh_token: None,
819 };
820
821 let json = serde_json::to_value(&response).unwrap();
822 let expected = serde_json::json!({
823 "access_token": "test-access-token",
824 "token_type": "test-token-type",
825 });
826 assert_eq!(json, expected);
827 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
828 assert_eq!(response, roundtrip);
829 }
830
831 fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
832 request.client_id == "test-client-id"
833 && request.client_secret == "test-client-secret"
834 && request.refresh_token == "test-refresh-token"
835 && request.grant_type == RefreshGrantType::RefreshToken
836 && request.scopes == expected_scopes
837 }
838
839 #[tokio::test(start_paused = true)]
840 async fn token_provider_full() -> TestResult {
841 let server = Server::run();
842 let response = Oauth2RefreshResponse {
843 access_token: "test-access-token".to_string(),
844 expires_in: Some(3600),
845 refresh_token: Some("test-refresh-token".to_string()),
846 scope: Some("scope1 scope2".to_string()),
847 token_type: "test-token-type".to_string(),
848 };
849 server.expect(
850 Expectation::matching(all_of![
851 request::path("/token"),
852 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
853 check_request(req, Some("scope1 scope2".to_string()))
854 }))
855 ])
856 .respond_with(json_encoded(response)),
857 );
858
859 let tp = UserTokenProvider {
860 client_id: "test-client-id".to_string(),
861 client_secret: "test-client-secret".to_string(),
862 refresh_token: "test-refresh-token".to_string(),
863 endpoint: server.url("/token").to_string(),
864 scopes: Some("scope1 scope2".to_string()),
865 };
866 let now = Instant::now();
867 let token = tp.token().await?;
868 assert_eq!(token.token, "test-access-token");
869 assert_eq!(token.token_type, "test-token-type");
870 assert!(
871 token
872 .expires_at
873 .is_some_and(|d| d == now + Duration::from_secs(3600)),
874 "now: {:?}, expires_at: {:?}",
875 now,
876 token.expires_at
877 );
878
879 Ok(())
880 }
881
882 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
883 async fn credential_full_with_quota_project() -> TestResult {
884 let server = Server::run();
885 let response = Oauth2RefreshResponse {
886 access_token: "test-access-token".to_string(),
887 expires_in: Some(3600),
888 refresh_token: Some("test-refresh-token".to_string()),
889 scope: None,
890 token_type: "test-token-type".to_string(),
891 };
892 server.expect(
893 Expectation::matching(all_of![
894 request::path("/token"),
895 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
896 check_request(req, None)
897 }))
898 ])
899 .respond_with(json_encoded(response)),
900 );
901
902 let authorized_user = serde_json::json!({
903 "client_id": "test-client-id",
904 "client_secret": "test-client-secret",
905 "refresh_token": "test-refresh-token",
906 "type": "authorized_user",
907 "token_uri": server.url("/token").to_string(),
908 });
909 let cred = Builder::new(authorized_user)
910 .with_quota_project_id("test-project")
911 .build()?;
912
913 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
914 let token = headers.get(AUTHORIZATION).unwrap();
915 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
916
917 assert_eq!(headers.len(), 2, "{headers:?}");
918 assert_eq!(
919 token,
920 HeaderValue::from_static("test-token-type test-access-token")
921 );
922 assert!(token.is_sensitive());
923 assert_eq!(
924 quota_project_header,
925 HeaderValue::from_static("test-project")
926 );
927 assert!(!quota_project_header.is_sensitive());
928
929 Ok(())
930 }
931
932 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
933 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
934 let mut server = Server::run();
935 let response = Oauth2RefreshResponse {
936 access_token: "test-access-token".to_string(),
937 expires_in: Some(3600),
938 refresh_token: Some("test-refresh-token".to_string()),
939 scope: Some("scope1 scope2".to_string()),
940 token_type: "test-token-type".to_string(),
941 };
942 server.expect(
943 Expectation::matching(all_of![
944 request::path("/token"),
945 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
946 check_request(req, Some("scope1 scope2".to_string()))
947 }))
948 ])
949 .times(1)
950 .respond_with(json_encoded(response)),
951 );
952
953 let json = serde_json::json!({
954 "client_id": "test-client-id",
955 "client_secret": "test-client-secret",
956 "refresh_token": "test-refresh-token",
957 "type": "authorized_user",
958 "universe_domain": "googleapis.com",
959 "quota_project_id": "test-project",
960 "token_uri": server.url("/token").to_string(),
961 });
962
963 let cred = Builder::new(json)
964 .with_scopes(vec!["scope1", "scope2"])
965 .build()?;
966
967 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
968 assert_eq!(token.unwrap(), "test-access-token");
969
970 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
971 assert_eq!(token.unwrap(), "test-access-token");
972
973 server.verify_and_clear();
974
975 Ok(())
976 }
977
978 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
979 async fn credential_provider_partial() -> TestResult {
980 let server = Server::run();
981 let response = Oauth2RefreshResponse {
982 access_token: "test-access-token".to_string(),
983 expires_in: None,
984 refresh_token: None,
985 scope: None,
986 token_type: "test-token-type".to_string(),
987 };
988 server.expect(
989 Expectation::matching(all_of![
990 request::path("/token"),
991 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
992 check_request(req, None)
993 }))
994 ])
995 .respond_with(json_encoded(response)),
996 );
997
998 let authorized_user = serde_json::json!({
999 "client_id": "test-client-id",
1000 "client_secret": "test-client-secret",
1001 "refresh_token": "test-refresh-token",
1002 "type": "authorized_user",
1003 "token_uri": server.url("/token").to_string()
1004 });
1005
1006 let uc = Builder::new(authorized_user).build()?;
1007 let headers = uc.headers(Extensions::new()).await?;
1008 assert_eq!(
1009 get_token_from_headers(headers.clone()).unwrap(),
1010 "test-access-token"
1011 );
1012 assert_eq!(
1013 get_token_type_from_headers(headers).unwrap(),
1014 "test-token-type"
1015 );
1016
1017 Ok(())
1018 }
1019
1020 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1021 async fn credential_provider_with_token_uri() -> TestResult {
1022 let server = Server::run();
1023 let response = Oauth2RefreshResponse {
1024 access_token: "test-access-token".to_string(),
1025 expires_in: None,
1026 refresh_token: None,
1027 scope: None,
1028 token_type: "test-token-type".to_string(),
1029 };
1030 server.expect(
1031 Expectation::matching(all_of![
1032 request::path("/token"),
1033 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1034 check_request(req, None)
1035 }))
1036 ])
1037 .respond_with(json_encoded(response)),
1038 );
1039
1040 let authorized_user = serde_json::json!({
1041 "client_id": "test-client-id",
1042 "client_secret": "test-client-secret",
1043 "refresh_token": "test-refresh-token",
1044 "type": "authorized_user",
1045 "token_uri": "test-endpoint"
1046 });
1047
1048 let uc = Builder::new(authorized_user)
1049 .with_token_uri(server.url("/token").to_string())
1050 .build()?;
1051 let headers = uc.headers(Extensions::new()).await?;
1052 assert_eq!(
1053 get_token_from_headers(headers.clone()).unwrap(),
1054 "test-access-token"
1055 );
1056 assert_eq!(
1057 get_token_type_from_headers(headers).unwrap(),
1058 "test-token-type"
1059 );
1060
1061 Ok(())
1062 }
1063
1064 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1065 async fn credential_provider_with_scopes() -> TestResult {
1066 let server = Server::run();
1067 let response = Oauth2RefreshResponse {
1068 access_token: "test-access-token".to_string(),
1069 expires_in: None,
1070 refresh_token: None,
1071 scope: Some("scope1 scope2".to_string()),
1072 token_type: "test-token-type".to_string(),
1073 };
1074 server.expect(
1075 Expectation::matching(all_of![
1076 request::path("/token"),
1077 request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1078 check_request(req, Some("scope1 scope2".to_string()))
1079 }))
1080 ])
1081 .respond_with(json_encoded(response)),
1082 );
1083
1084 let authorized_user = serde_json::json!({
1085 "client_id": "test-client-id",
1086 "client_secret": "test-client-secret",
1087 "refresh_token": "test-refresh-token",
1088 "type": "authorized_user",
1089 "token_uri": "test-endpoint"
1090 });
1091
1092 let uc = Builder::new(authorized_user)
1093 .with_token_uri(server.url("/token").to_string())
1094 .with_scopes(vec!["scope1", "scope2"])
1095 .build()?;
1096 let headers = uc.headers(Extensions::new()).await?;
1097 assert_eq!(
1098 get_token_from_headers(headers.clone()).unwrap(),
1099 "test-access-token"
1100 );
1101 assert_eq!(
1102 get_token_type_from_headers(headers).unwrap(),
1103 "test-token-type"
1104 );
1105
1106 Ok(())
1107 }
1108
1109 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1110 async fn credential_provider_retryable_error() -> TestResult {
1111 let server = Server::run();
1112 server
1113 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1114
1115 let authorized_user = serde_json::json!({
1116 "client_id": "test-client-id",
1117 "client_secret": "test-client-secret",
1118 "refresh_token": "test-refresh-token",
1119 "type": "authorized_user",
1120 "token_uri": server.url("/token").to_string()
1121 });
1122
1123 let uc = Builder::new(authorized_user).build()?;
1124 let err = uc.headers(Extensions::new()).await.unwrap_err();
1125 assert!(err.is_transient(), "{err}");
1126 let source = err
1127 .source()
1128 .and_then(|e| e.downcast_ref::<reqwest::Error>());
1129 assert!(
1130 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1131 "{err:?}"
1132 );
1133
1134 Ok(())
1135 }
1136
1137 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1138 async fn token_provider_nonretryable_error() -> TestResult {
1139 let server = Server::run();
1140 server
1141 .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1142
1143 let authorized_user = serde_json::json!({
1144 "client_id": "test-client-id",
1145 "client_secret": "test-client-secret",
1146 "refresh_token": "test-refresh-token",
1147 "type": "authorized_user",
1148 "token_uri": server.url("/token").to_string()
1149 });
1150
1151 let uc = Builder::new(authorized_user).build()?;
1152 let err = uc.headers(Extensions::new()).await.unwrap_err();
1153 assert!(!err.is_transient(), "{err:?}");
1154 let source = err
1155 .source()
1156 .and_then(|e| e.downcast_ref::<reqwest::Error>());
1157 assert!(
1158 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1159 "{err:?}"
1160 );
1161
1162 Ok(())
1163 }
1164
1165 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1166 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1167 let server = Server::run();
1168 server.expect(
1169 Expectation::matching(request::path("/token"))
1170 .respond_with(json_encoded("bad json".to_string())),
1171 );
1172
1173 let authorized_user = serde_json::json!({
1174 "client_id": "test-client-id",
1175 "client_secret": "test-client-secret",
1176 "refresh_token": "test-refresh-token",
1177 "type": "authorized_user",
1178 "token_uri": server.url("/token").to_string()
1179 });
1180
1181 let uc = Builder::new(authorized_user).build()?;
1182 let e = uc.headers(Extensions::new()).await.err().unwrap();
1183 assert!(!e.is_transient(), "{e}");
1184
1185 Ok(())
1186 }
1187
1188 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1189 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1190 let authorized_user = serde_json::json!({
1191 "client_secret": "test-client-secret",
1192 "refresh_token": "test-refresh-token",
1193 "type": "authorized_user",
1194 });
1195
1196 let e = Builder::new(authorized_user).build().unwrap_err();
1197 assert!(e.is_parsing(), "{e}");
1198
1199 Ok(())
1200 }
1201}