1use crate::credentials::dynamic::CredentialsProvider;
71use crate::credentials::{Credentials, Result};
72use crate::errors::{self, CredentialsError, is_retryable};
73use crate::headers_util::build_bearer_headers;
74use crate::token::{CachedTokenProvider, Token, TokenProvider};
75use crate::token_cache::TokenCache;
76use http::header::CONTENT_TYPE;
77use http::{Extensions, HeaderMap, HeaderValue};
78use reqwest::{Client, Method};
79use serde_json::Value;
80use std::sync::Arc;
81use tokio::time::{Duration, Instant};
82
83const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
84
85pub struct Builder {
96 authorized_user: Value,
97 scopes: Option<Vec<String>>,
98 quota_project_id: Option<String>,
99 token_uri: Option<String>,
100}
101
102impl Builder {
103 pub fn new(authorized_user: Value) -> Self {
110 Self {
111 authorized_user,
112 scopes: None,
113 quota_project_id: None,
114 token_uri: None,
115 }
116 }
117
118 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
132 self.token_uri = Some(token_uri.into());
133 self
134 }
135
136 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
157 where
158 I: IntoIterator<Item = S>,
159 S: Into<String>,
160 {
161 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
162 self
163 }
164
165 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
186 self.quota_project_id = Some(quota_project_id.into());
187 self
188 }
189
190 pub fn build(self) -> Result<Credentials> {
203 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
204 .map_err(errors::non_retryable)?;
205 let endpoint = self
206 .token_uri
207 .or(authorized_user.token_uri)
208 .unwrap_or(OAUTH2_ENDPOINT.to_string());
209 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
210
211 let token_provider = UserTokenProvider {
212 client_id: authorized_user.client_id,
213 client_secret: authorized_user.client_secret,
214 refresh_token: authorized_user.refresh_token,
215 endpoint,
216 scopes: self.scopes.map(|scopes| scopes.join(" ")),
217 };
218 let token_provider = TokenCache::new(token_provider);
219
220 Ok(Credentials {
221 inner: Arc::new(UserCredentials {
222 token_provider,
223 quota_project_id,
224 }),
225 })
226 }
227}
228
229#[derive(PartialEq)]
230struct UserTokenProvider {
231 client_id: String,
232 client_secret: String,
233 refresh_token: String,
234 endpoint: String,
235 scopes: Option<String>,
236}
237
238impl std::fmt::Debug for UserTokenProvider {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 f.debug_struct("UserCredentials")
241 .field("client_id", &self.client_id)
242 .field("client_secret", &"[censored]")
243 .field("refresh_token", &"[censored]")
244 .field("endpoint", &self.endpoint)
245 .field("scopes", &self.scopes)
246 .finish()
247 }
248}
249
250#[async_trait::async_trait]
251impl TokenProvider for UserTokenProvider {
252 async fn token(&self) -> Result<Token> {
253 let client = Client::new();
254
255 let req = Oauth2RefreshRequest {
257 grant_type: RefreshGrantType::RefreshToken,
258 client_id: self.client_id.clone(),
259 client_secret: self.client_secret.clone(),
260 refresh_token: self.refresh_token.clone(),
261 scopes: self.scopes.clone(),
262 };
263 let header = HeaderValue::from_static("application/json");
264 let builder = client
265 .request(Method::POST, self.endpoint.as_str())
266 .header(CONTENT_TYPE, header)
267 .json(&req);
268 let resp = builder.send().await.map_err(errors::retryable)?;
269
270 if !resp.status().is_success() {
272 let status = resp.status();
273 let body = resp
274 .text()
275 .await
276 .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
277 return Err(CredentialsError::from_str(
278 is_retryable(status),
279 format!("Failed to fetch token. {body}"),
280 ));
281 }
282 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
283 let retryable = !e.is_decode();
284 CredentialsError::new(retryable, e)
285 })?;
286 let token = Token {
287 token: response.access_token,
288 token_type: response.token_type,
289 expires_at: response
290 .expires_in
291 .map(|d| Instant::now() + Duration::from_secs(d)),
292 metadata: None,
293 };
294 Ok(token)
295 }
296}
297
298#[derive(Debug)]
302pub(crate) struct UserCredentials<T>
303where
304 T: CachedTokenProvider,
305{
306 token_provider: T,
307 quota_project_id: Option<String>,
308}
309
310#[async_trait::async_trait]
311impl<T> CredentialsProvider for UserCredentials<T>
312where
313 T: CachedTokenProvider,
314{
315 async fn headers(&self, extensions: Extensions) -> Result<HeaderMap> {
316 let token = self.token_provider.token(extensions).await?;
317 build_bearer_headers(&token, &self.quota_project_id)
318 }
319}
320
321#[derive(Debug, PartialEq, serde::Deserialize)]
322pub(crate) struct AuthorizedUser {
323 #[serde(rename = "type")]
324 cred_type: String,
325 client_id: String,
326 client_secret: String,
327 refresh_token: String,
328 #[serde(skip_serializing_if = "Option::is_none")]
329 token_uri: Option<String>,
330 #[serde(skip_serializing_if = "Option::is_none")]
331 quota_project_id: Option<String>,
332}
333
334#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
335enum RefreshGrantType {
336 #[serde(rename = "refresh_token")]
337 RefreshToken,
338}
339
340#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
341struct Oauth2RefreshRequest {
342 grant_type: RefreshGrantType,
343 client_id: String,
344 client_secret: String,
345 refresh_token: String,
346 scopes: Option<String>,
347}
348
349#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
350struct Oauth2RefreshResponse {
351 access_token: String,
352 #[serde(skip_serializing_if = "Option::is_none")]
353 scope: Option<String>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 expires_in: Option<u64>,
356 token_type: String,
357 #[serde(skip_serializing_if = "Option::is_none")]
358 refresh_token: Option<String>,
359}
360
361#[cfg(test)]
362mod test {
363 use super::*;
364 use crate::credentials::test::{get_token_from_headers, get_token_type_from_headers};
365 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
366 use crate::token::test::MockTokenProvider;
367 use axum::extract::Json;
368 use http::StatusCode;
369 use http::header::AUTHORIZATION;
370 use std::error::Error;
371 use std::sync::Mutex;
372 use tokio::task::JoinHandle;
373
374 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
375
376 #[test]
377 fn debug_token_provider() {
378 let expected = UserTokenProvider {
379 client_id: "test-client-id".to_string(),
380 client_secret: "test-client-secret".to_string(),
381 refresh_token: "test-refresh-token".to_string(),
382 endpoint: OAUTH2_ENDPOINT.to_string(),
383 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
384 };
385 let fmt = format!("{expected:?}");
386 assert!(fmt.contains("test-client-id"), "{fmt}");
387 assert!(!fmt.contains("test-client-secret"), "{fmt}");
388 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
389 assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
390 assert!(
391 fmt.contains("https://www.googleapis.com/auth/pubsub"),
392 "{fmt}"
393 );
394 }
395
396 #[test]
397 fn authorized_user_full_from_json_success() {
398 let json = serde_json::json!({
399 "account": "",
400 "client_id": "test-client-id",
401 "client_secret": "test-client-secret",
402 "refresh_token": "test-refresh-token",
403 "type": "authorized_user",
404 "universe_domain": "googleapis.com",
405 "quota_project_id": "test-project",
406 "token_uri" : "test-token-uri",
407 });
408
409 let expected = AuthorizedUser {
410 cred_type: "authorized_user".to_string(),
411 client_id: "test-client-id".to_string(),
412 client_secret: "test-client-secret".to_string(),
413 refresh_token: "test-refresh-token".to_string(),
414 quota_project_id: Some("test-project".to_string()),
415 token_uri: Some("test-token-uri".to_string()),
416 };
417 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
418 assert_eq!(actual, expected);
419 }
420
421 #[test]
422 fn authorized_user_partial_from_json_success() {
423 let json = serde_json::json!({
424 "client_id": "test-client-id",
425 "client_secret": "test-client-secret",
426 "refresh_token": "test-refresh-token",
427 "type": "authorized_user",
428 });
429
430 let expected = AuthorizedUser {
431 cred_type: "authorized_user".to_string(),
432 client_id: "test-client-id".to_string(),
433 client_secret: "test-client-secret".to_string(),
434 refresh_token: "test-refresh-token".to_string(),
435 quota_project_id: None,
436 token_uri: None,
437 };
438 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
439 assert_eq!(actual, expected);
440 }
441
442 #[test]
443 fn authorized_user_from_json_parse_fail() {
444 let json_full = serde_json::json!({
445 "client_id": "test-client-id",
446 "client_secret": "test-client-secret",
447 "refresh_token": "test-refresh-token",
448 "type": "authorized_user",
449 "quota_project_id": "test-project"
450 });
451
452 for required_field in ["client_id", "client_secret", "refresh_token"] {
453 let mut json = json_full.clone();
454 json[required_field].take();
456 serde_json::from_value::<AuthorizedUser>(json)
457 .err()
458 .unwrap();
459 }
460 }
461
462 #[tokio::test]
463 async fn default_universe_domain_success() {
464 let mock = TokenCache::new(MockTokenProvider::new());
465
466 let uc = UserCredentials {
467 token_provider: mock,
468 quota_project_id: None,
469 };
470 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
471 }
472
473 #[tokio::test]
474 async fn headers_success() {
475 let token = Token {
476 token: "test-token".to_string(),
477 token_type: "Bearer".to_string(),
478 expires_at: None,
479 metadata: None,
480 };
481
482 let mut mock = MockTokenProvider::new();
483 mock.expect_token().times(1).return_once(|| Ok(token));
484
485 let uc = UserCredentials {
486 token_provider: TokenCache::new(mock),
487 quota_project_id: None,
488 };
489
490 let headers = uc.headers(Extensions::new()).await.unwrap();
491 let token = headers.get(AUTHORIZATION).unwrap();
492
493 assert_eq!(headers.len(), 1, "{headers:?}");
494 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
495 assert!(token.is_sensitive());
496 }
497
498 #[tokio::test]
499 async fn headers_failure() {
500 let mut mock = MockTokenProvider::new();
501 mock.expect_token()
502 .times(1)
503 .return_once(|| Err(errors::non_retryable_from_str("fail")));
504
505 let uc = UserCredentials {
506 token_provider: TokenCache::new(mock),
507 quota_project_id: None,
508 };
509 assert!(uc.headers(Extensions::new()).await.is_err());
510 }
511
512 #[tokio::test]
513 async fn headers_with_quota_project_success() {
514 let token = Token {
515 token: "test-token".to_string(),
516 token_type: "Bearer".to_string(),
517 expires_at: None,
518 metadata: None,
519 };
520
521 let mut mock = MockTokenProvider::new();
522 mock.expect_token().times(1).return_once(|| Ok(token));
523
524 let uc = UserCredentials {
525 token_provider: TokenCache::new(mock),
526 quota_project_id: Some("test-project".to_string()),
527 };
528
529 let headers = uc.headers(Extensions::new()).await.unwrap();
530 let token = headers.get(AUTHORIZATION).unwrap();
531 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
532
533 assert_eq!(headers.len(), 2, "{headers:?}");
534 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
535 assert!(token.is_sensitive());
536 assert_eq!(
537 quota_project_header,
538 HeaderValue::from_static("test-project")
539 );
540 assert!(!quota_project_header.is_sensitive());
541 }
542
543 #[test]
544 fn oauth2_request_serde() {
545 let request = Oauth2RefreshRequest {
546 grant_type: RefreshGrantType::RefreshToken,
547 client_id: "test-client-id".to_string(),
548 client_secret: "test-client-secret".to_string(),
549 refresh_token: "test-refresh-token".to_string(),
550 scopes: Some("scope1 scope2".to_string()),
551 };
552
553 let json = serde_json::to_value(&request).unwrap();
554 let expected = serde_json::json!({
555 "grant_type": "refresh_token",
556 "client_id": "test-client-id",
557 "client_secret": "test-client-secret",
558 "refresh_token": "test-refresh-token",
559 "scopes": "scope1 scope2",
560 });
561 assert_eq!(json, expected);
562 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
563 assert_eq!(request, roundtrip);
564 }
565
566 #[test]
567 fn oauth2_response_serde_full() {
568 let response = Oauth2RefreshResponse {
569 access_token: "test-access-token".to_string(),
570 scope: Some("scope1 scope2".to_string()),
571 expires_in: Some(3600),
572 token_type: "test-token-type".to_string(),
573 refresh_token: Some("test-refresh-token".to_string()),
574 };
575
576 let json = serde_json::to_value(&response).unwrap();
577 let expected = serde_json::json!({
578 "access_token": "test-access-token",
579 "scope": "scope1 scope2",
580 "expires_in": 3600,
581 "token_type": "test-token-type",
582 "refresh_token": "test-refresh-token"
583 });
584 assert_eq!(json, expected);
585 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
586 assert_eq!(response, roundtrip);
587 }
588
589 #[test]
590 fn oauth2_response_serde_partial() {
591 let response = Oauth2RefreshResponse {
592 access_token: "test-access-token".to_string(),
593 scope: None,
594 expires_in: None,
595 token_type: "test-token-type".to_string(),
596 refresh_token: None,
597 };
598
599 let json = serde_json::to_value(&response).unwrap();
600 let expected = serde_json::json!({
601 "access_token": "test-access-token",
602 "token_type": "test-token-type",
603 });
604 assert_eq!(json, expected);
605 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
606 assert_eq!(response, roundtrip);
607 }
608
609 async fn start(
611 response_code: StatusCode,
612 response_body: Value,
613 call_count: Arc<Mutex<i32>>,
614 ) -> (String, JoinHandle<()>) {
615 let code = response_code;
616 let body = response_body.clone();
617 let handler = move |req| async move { handle_token_factory(code, body, call_count)(req) };
618 let app = axum::Router::new().route("/token", axum::routing::post(handler));
619 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
620 let addr = listener.local_addr().unwrap();
621 let server = tokio::spawn(async {
622 axum::serve(listener, app).await.unwrap();
623 });
624
625 (
626 format!("http://{}:{}/token", addr.ip(), addr.port()),
627 server,
628 )
629 }
630
631 fn handle_token_factory(
635 response_code: StatusCode,
636 response_body: Value,
637 call_count: Arc<std::sync::Mutex<i32>>,
638 ) -> impl Fn(Json<Oauth2RefreshRequest>) -> (StatusCode, String) {
639 move |request: Json<Oauth2RefreshRequest>| -> (StatusCode, String) {
640 let mut count = call_count.lock().unwrap();
641 *count += 1;
642 assert_eq!(request.client_id, "test-client-id");
643 assert_eq!(request.client_secret, "test-client-secret");
644 assert_eq!(request.refresh_token, "test-refresh-token");
645 assert_eq!(request.grant_type, RefreshGrantType::RefreshToken);
646 assert_eq!(
647 request.scopes,
648 response_body["scope"].as_str().map(|s| s.to_owned())
649 );
650
651 (response_code, response_body.to_string())
652 }
653 }
654
655 #[tokio::test(start_paused = true)]
656 async fn token_provider_full() -> TestResult {
657 let response = Oauth2RefreshResponse {
658 access_token: "test-access-token".to_string(),
659 expires_in: Some(3600),
660 refresh_token: Some("test-refresh-token".to_string()),
661 scope: Some("scope1 scope2".to_string()),
662 token_type: "test-token-type".to_string(),
663 };
664 let response_body = serde_json::to_value(&response).unwrap();
665 let (endpoint, _server) =
666 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
667 println!("endpoint = {endpoint}");
668
669 let tp = UserTokenProvider {
670 client_id: "test-client-id".to_string(),
671 client_secret: "test-client-secret".to_string(),
672 refresh_token: "test-refresh-token".to_string(),
673 endpoint,
674 scopes: Some("scope1 scope2".to_string()),
675 };
676 let now = Instant::now();
677 let token = tp.token().await?;
678 assert_eq!(token.token, "test-access-token");
679 assert_eq!(token.token_type, "test-token-type");
680 assert!(
681 token
682 .expires_at
683 .is_some_and(|d| d == now + Duration::from_secs(3600)),
684 "now: {:?}, expires_at: {:?}",
685 now,
686 token.expires_at
687 );
688
689 Ok(())
690 }
691
692 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
693 async fn credential_full_with_quota_project() -> TestResult {
694 let response = Oauth2RefreshResponse {
695 access_token: "test-access-token".to_string(),
696 expires_in: Some(3600),
697 refresh_token: Some("test-refresh-token".to_string()),
698 scope: None,
699 token_type: "test-token-type".to_string(),
700 };
701 let response_body = serde_json::to_value(&response).unwrap();
702 let (endpoint, _server) =
703 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
704 println!("endpoint = {endpoint}");
705
706 let authorized_user = serde_json::json!({
707 "client_id": "test-client-id",
708 "client_secret": "test-client-secret",
709 "refresh_token": "test-refresh-token",
710 "type": "authorized_user",
711 "token_uri": endpoint,
712 });
713 let cred = Builder::new(authorized_user)
714 .with_quota_project_id("test-project")
715 .build()?;
716
717 let headers = cred.headers(Extensions::new()).await.unwrap();
718 let token = headers.get(AUTHORIZATION).unwrap();
719 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
720
721 assert_eq!(headers.len(), 2, "{headers:?}");
722 assert_eq!(
723 token,
724 HeaderValue::from_static("test-token-type test-access-token")
725 );
726 assert!(token.is_sensitive());
727 assert_eq!(
728 quota_project_header,
729 HeaderValue::from_static("test-project")
730 );
731 assert!(!quota_project_header.is_sensitive());
732
733 Ok(())
734 }
735
736 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
737 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
738 let response = Oauth2RefreshResponse {
739 access_token: "test-access-token".to_string(),
740 expires_in: Some(3600),
741 refresh_token: Some("test-refresh-token".to_string()),
742 scope: Some("scope1 scope2".to_string()),
743 token_type: "test-token-type".to_string(),
744 };
745 let response_body = serde_json::to_value(&response).unwrap();
746 let call_count = Arc::new(Mutex::new(0));
747 let (endpoint, _server) = start(StatusCode::OK, response_body, call_count.clone()).await;
748 println!("endpoint = {endpoint}");
749
750 let json = serde_json::json!({
751 "client_id": "test-client-id",
752 "client_secret": "test-client-secret",
753 "refresh_token": "test-refresh-token",
754 "type": "authorized_user",
755 "universe_domain": "googleapis.com",
756 "quota_project_id": "test-project",
757 "token_uri": endpoint,
758 });
759
760 let cred = Builder::new(json)
761 .with_scopes(vec!["scope1", "scope2"])
762 .build()?;
763
764 let token = get_token_from_headers(&cred.headers(Extensions::new()).await?);
765 assert_eq!(token.unwrap(), "test-access-token");
766
767 let token = get_token_from_headers(&cred.headers(Extensions::new()).await?);
768 assert_eq!(token.unwrap(), "test-access-token");
769
770 assert_eq!(*call_count.lock().unwrap(), 1);
773
774 Ok(())
775 }
776
777 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
778 async fn credential_provider_partial() -> TestResult {
779 let response = Oauth2RefreshResponse {
780 access_token: "test-access-token".to_string(),
781 expires_in: None,
782 refresh_token: None,
783 scope: None,
784 token_type: "test-token-type".to_string(),
785 };
786 let response_body = serde_json::to_value(&response).unwrap();
787 let (endpoint, _server) =
788 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
789 println!("endpoint = {endpoint}");
790
791 let authorized_user = serde_json::json!({
792 "client_id": "test-client-id",
793 "client_secret": "test-client-secret",
794 "refresh_token": "test-refresh-token",
795 "type": "authorized_user",
796 "token_uri": endpoint});
797
798 let uc = Builder::new(authorized_user).build()?;
799 let headers = uc.headers(Extensions::new()).await?;
800 assert_eq!(
801 get_token_from_headers(&headers).unwrap(),
802 "test-access-token"
803 );
804 assert_eq!(
805 get_token_type_from_headers(&headers).unwrap(),
806 "test-token-type"
807 );
808
809 Ok(())
810 }
811
812 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
813 async fn credential_provider_with_token_uri() -> TestResult {
814 let response = Oauth2RefreshResponse {
815 access_token: "test-access-token".to_string(),
816 expires_in: None,
817 refresh_token: None,
818 scope: None,
819 token_type: "test-token-type".to_string(),
820 };
821 let response_body = serde_json::to_value(&response).unwrap();
822 let (endpoint, _server) =
823 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
824 println!("endpoint = {endpoint}");
825
826 let authorized_user = serde_json::json!({
827 "client_id": "test-client-id",
828 "client_secret": "test-client-secret",
829 "refresh_token": "test-refresh-token",
830 "type": "authorized_user",
831 "token_uri": "test-endpoint"});
832
833 let uc = Builder::new(authorized_user)
834 .with_token_uri(endpoint)
835 .build()?;
836 let headers = uc.headers(Extensions::new()).await?;
837 assert_eq!(
838 get_token_from_headers(&headers).unwrap(),
839 "test-access-token"
840 );
841 assert_eq!(
842 get_token_type_from_headers(&headers).unwrap(),
843 "test-token-type"
844 );
845
846 Ok(())
847 }
848
849 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
850 async fn credential_provider_with_scopes() -> TestResult {
851 let response = Oauth2RefreshResponse {
852 access_token: "test-access-token".to_string(),
853 expires_in: None,
854 refresh_token: None,
855 scope: Some("scope1 scope2".to_string()),
856 token_type: "test-token-type".to_string(),
857 };
858 let response_body = serde_json::to_value(&response).unwrap();
859 let (endpoint, _server) =
860 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
861 println!("endpoint = {endpoint}");
862
863 let authorized_user = serde_json::json!({
864 "client_id": "test-client-id",
865 "client_secret": "test-client-secret",
866 "refresh_token": "test-refresh-token",
867 "type": "authorized_user",
868 "token_uri": "test-endpoint"});
869
870 let uc = Builder::new(authorized_user)
871 .with_token_uri(endpoint)
872 .with_scopes(vec!["scope1", "scope2"])
873 .build()?;
874 let headers = uc.headers(Extensions::new()).await?;
875 assert_eq!(
876 get_token_from_headers(&headers).unwrap(),
877 "test-access-token"
878 );
879 assert_eq!(
880 get_token_type_from_headers(&headers).unwrap(),
881 "test-token-type"
882 );
883
884 Ok(())
885 }
886
887 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
888 async fn credential_provider_retryable_error() -> TestResult {
889 let (endpoint, _server) = start(
890 StatusCode::SERVICE_UNAVAILABLE,
891 serde_json::to_value("try again".to_string())?,
892 Arc::new(Mutex::new(0)),
893 )
894 .await;
895 println!("endpoint = {endpoint}");
896
897 let authorized_user = serde_json::json!({
898 "client_id": "test-client-id",
899 "client_secret": "test-client-secret",
900 "refresh_token": "test-refresh-token",
901 "type": "authorized_user",
902 "token_uri": endpoint});
903
904 let uc = Builder::new(authorized_user).build()?;
905 let e = uc.headers(Extensions::new()).await.err().unwrap();
906 assert!(e.is_retryable(), "{e}");
907 assert!(e.source().unwrap().to_string().contains("try again"), "{e}");
908
909 Ok(())
910 }
911
912 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
913 async fn token_provider_nonretryable_error() -> TestResult {
914 let (endpoint, _server) = start(
915 StatusCode::UNAUTHORIZED,
916 serde_json::to_value("epic fail".to_string())?,
917 Arc::new(Mutex::new(0)),
918 )
919 .await;
920 println!("endpoint = {endpoint}");
921
922 let authorized_user = serde_json::json!({
923 "client_id": "test-client-id",
924 "client_secret": "test-client-secret",
925 "refresh_token": "test-refresh-token",
926 "type": "authorized_user",
927 "token_uri": endpoint});
928
929 let uc = Builder::new(authorized_user).build()?;
930 let e = uc.headers(Extensions::new()).await.err().unwrap();
931 assert!(!e.is_retryable(), "{e}");
932 assert!(e.source().unwrap().to_string().contains("epic fail"), "{e}");
933
934 Ok(())
935 }
936
937 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
938 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
939 let (endpoint, _server) = start(
940 StatusCode::OK,
941 serde_json::to_value("bad json".to_string())?,
942 Arc::new(Mutex::new(0)),
943 )
944 .await;
945 println!("endpoint = {endpoint}");
946
947 let authorized_user = serde_json::json!({
948 "client_id": "test-client-id",
949 "client_secret": "test-client-secret",
950 "refresh_token": "test-refresh-token",
951 "type": "authorized_user",
952 "token_uri": endpoint});
953
954 let uc = Builder::new(authorized_user).build()?;
955 let e = uc.headers(Extensions::new()).await.err().unwrap();
956 assert!(!e.is_retryable(), "{e}");
957
958 Ok(())
959 }
960
961 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
962 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
963 let authorized_user = serde_json::json!({
964 "client_secret": "test-client-secret",
965 "refresh_token": "test-refresh-token",
966 "type": "authorized_user",
967 });
968
969 let e = Builder::new(authorized_user).build().unwrap_err();
970 assert!(!e.is_retryable(), "{e}");
971
972 Ok(())
973 }
974}