1use crate::credentials::dynamic::CredentialsTrait;
70use crate::credentials::{Credentials, Result};
71use crate::errors::{self, CredentialsError, is_retryable};
72use crate::headers_util::build_bearer_headers;
73use crate::token::{Token, TokenProvider};
74use crate::token_cache::TokenCache;
75use http::header::{CONTENT_TYPE, HeaderName, HeaderValue};
76use reqwest::{Client, Method};
77use serde_json::Value;
78use std::sync::Arc;
79use std::time::Duration;
80
81const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
82
83pub(crate) fn creds_from(js: Value) -> Result<Credentials> {
84 Builder::new(js).build()
85}
86
87pub struct Builder {
98 authorized_user: Value,
99 scopes: Option<Vec<String>>,
100 quota_project_id: Option<String>,
101 token_uri: Option<String>,
102}
103
104impl Builder {
105 pub fn new(authorized_user: Value) -> Self {
112 Self {
113 authorized_user,
114 scopes: None,
115 quota_project_id: None,
116 token_uri: None,
117 }
118 }
119
120 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
134 self.token_uri = Some(token_uri.into());
135 self
136 }
137
138 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
159 where
160 I: IntoIterator<Item = S>,
161 S: Into<String>,
162 {
163 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
164 self
165 }
166
167 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
188 self.quota_project_id = Some(quota_project_id.into());
189 self
190 }
191
192 pub fn build(self) -> Result<Credentials> {
205 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
206 .map_err(errors::non_retryable)?;
207 let endpoint = self
208 .token_uri
209 .or(authorized_user.token_uri)
210 .unwrap_or(OAUTH2_ENDPOINT.to_string());
211 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
212
213 let token_provider = UserTokenProvider {
214 client_id: authorized_user.client_id,
215 client_secret: authorized_user.client_secret,
216 refresh_token: authorized_user.refresh_token,
217 endpoint,
218 scopes: self.scopes.map(|scopes| scopes.join(" ")),
219 };
220 let token_provider = TokenCache::new(token_provider);
221
222 Ok(Credentials {
223 inner: Arc::new(UserCredentials {
224 token_provider,
225 quota_project_id,
226 }),
227 })
228 }
229}
230
231#[derive(PartialEq)]
232struct UserTokenProvider {
233 client_id: String,
234 client_secret: String,
235 refresh_token: String,
236 endpoint: String,
237 scopes: Option<String>,
238}
239
240impl std::fmt::Debug for UserTokenProvider {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 f.debug_struct("UserCredentials")
243 .field("client_id", &self.client_id)
244 .field("client_secret", &"[censored]")
245 .field("refresh_token", &"[censored]")
246 .field("endpoint", &self.endpoint)
247 .field("scopes", &self.scopes)
248 .finish()
249 }
250}
251
252#[async_trait::async_trait]
253impl TokenProvider for UserTokenProvider {
254 async fn token(&self) -> Result<Token> {
255 let client = Client::new();
256
257 let req = Oauth2RefreshRequest {
259 grant_type: RefreshGrantType::RefreshToken,
260 client_id: self.client_id.clone(),
261 client_secret: self.client_secret.clone(),
262 refresh_token: self.refresh_token.clone(),
263 scopes: self.scopes.clone(),
264 };
265 let header = HeaderValue::from_static("application/json");
266 let builder = client
267 .request(Method::POST, self.endpoint.as_str())
268 .header(CONTENT_TYPE, header)
269 .json(&req);
270 let resp = builder.send().await.map_err(errors::retryable)?;
271
272 if !resp.status().is_success() {
274 let status = resp.status();
275 let body = resp
276 .text()
277 .await
278 .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
279 return Err(CredentialsError::from_str(
280 is_retryable(status),
281 format!("Failed to fetch token. {body}"),
282 ));
283 }
284 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
285 let retryable = !e.is_decode();
286 CredentialsError::new(retryable, e)
287 })?;
288 let token = Token {
289 token: response.access_token,
290 token_type: response.token_type,
291 expires_at: response
292 .expires_in
293 .map(|d| std::time::Instant::now() + Duration::from_secs(d)),
294 metadata: None,
295 };
296 Ok(token)
297 }
298}
299
300#[derive(Debug)]
304pub(crate) struct UserCredentials<T>
305where
306 T: TokenProvider,
307{
308 token_provider: T,
309 quota_project_id: Option<String>,
310}
311
312#[async_trait::async_trait]
313impl<T> CredentialsTrait for UserCredentials<T>
314where
315 T: TokenProvider,
316{
317 async fn token(&self) -> Result<Token> {
318 self.token_provider.token().await
319 }
320
321 async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
322 let token = self.token().await?;
323 build_bearer_headers(&token, &self.quota_project_id)
324 }
325}
326
327#[derive(Debug, PartialEq, serde::Deserialize)]
328pub(crate) struct AuthorizedUser {
329 #[serde(rename = "type")]
330 cred_type: String,
331 client_id: String,
332 client_secret: String,
333 refresh_token: String,
334 #[serde(skip_serializing_if = "Option::is_none")]
335 token_uri: Option<String>,
336 #[serde(skip_serializing_if = "Option::is_none")]
337 quota_project_id: Option<String>,
338}
339
340#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
341enum RefreshGrantType {
342 #[serde(rename = "refresh_token")]
343 RefreshToken,
344}
345
346#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
347struct Oauth2RefreshRequest {
348 grant_type: RefreshGrantType,
349 client_id: String,
350 client_secret: String,
351 refresh_token: String,
352 scopes: Option<String>,
353}
354
355#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
356struct Oauth2RefreshResponse {
357 access_token: String,
358 #[serde(skip_serializing_if = "Option::is_none")]
359 scope: Option<String>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 expires_in: Option<u64>,
362 token_type: String,
363 #[serde(skip_serializing_if = "Option::is_none")]
364 refresh_token: Option<String>,
365}
366
367#[cfg(test)]
368mod test {
369 use super::*;
370 use crate::credentials::QUOTA_PROJECT_KEY;
371 use crate::credentials::test::HV;
372 use crate::token::test::MockTokenProvider;
373 use axum::extract::Json;
374 use http::StatusCode;
375 use http::header::AUTHORIZATION;
376 use std::error::Error;
377 use std::sync::Mutex;
378 use tokio::task::JoinHandle;
379
380 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
381
382 #[test]
383 fn debug_token_provider() {
384 let expected = UserTokenProvider {
385 client_id: "test-client-id".to_string(),
386 client_secret: "test-client-secret".to_string(),
387 refresh_token: "test-refresh-token".to_string(),
388 endpoint: OAUTH2_ENDPOINT.to_string(),
389 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
390 };
391 let fmt = format!("{expected:?}");
392 assert!(fmt.contains("test-client-id"), "{fmt}");
393 assert!(!fmt.contains("test-client-secret"), "{fmt}");
394 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
395 assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
396 assert!(
397 fmt.contains("https://www.googleapis.com/auth/pubsub"),
398 "{fmt}"
399 );
400 }
401
402 #[test]
403 fn authorized_user_full_from_json_success() {
404 let json = serde_json::json!({
405 "account": "",
406 "client_id": "test-client-id",
407 "client_secret": "test-client-secret",
408 "refresh_token": "test-refresh-token",
409 "type": "authorized_user",
410 "universe_domain": "googleapis.com",
411 "quota_project_id": "test-project",
412 "token_uri" : "test-token-uri",
413 });
414
415 let expected = AuthorizedUser {
416 cred_type: "authorized_user".to_string(),
417 client_id: "test-client-id".to_string(),
418 client_secret: "test-client-secret".to_string(),
419 refresh_token: "test-refresh-token".to_string(),
420 quota_project_id: Some("test-project".to_string()),
421 token_uri: Some("test-token-uri".to_string()),
422 };
423 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
424 assert_eq!(actual, expected);
425 }
426
427 #[test]
428 fn authorized_user_partial_from_json_success() {
429 let json = serde_json::json!({
430 "client_id": "test-client-id",
431 "client_secret": "test-client-secret",
432 "refresh_token": "test-refresh-token",
433 "type": "authorized_user",
434 });
435
436 let expected = AuthorizedUser {
437 cred_type: "authorized_user".to_string(),
438 client_id: "test-client-id".to_string(),
439 client_secret: "test-client-secret".to_string(),
440 refresh_token: "test-refresh-token".to_string(),
441 quota_project_id: None,
442 token_uri: None,
443 };
444 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
445 assert_eq!(actual, expected);
446 }
447
448 #[test]
449 fn authorized_user_from_json_parse_fail() {
450 let json_full = serde_json::json!({
451 "client_id": "test-client-id",
452 "client_secret": "test-client-secret",
453 "refresh_token": "test-refresh-token",
454 "type": "authorized_user",
455 "quota_project_id": "test-project"
456 });
457
458 for required_field in ["client_id", "client_secret", "refresh_token"] {
459 let mut json = json_full.clone();
460 json[required_field].take();
462 serde_json::from_value::<AuthorizedUser>(json)
463 .err()
464 .unwrap();
465 }
466 }
467
468 #[tokio::test]
469 async fn token_success() {
470 let expected = Token {
471 token: "test-token".to_string(),
472 token_type: "Bearer".to_string(),
473 expires_at: None,
474 metadata: None,
475 };
476 let expected_clone = expected.clone();
477
478 let mut mock = MockTokenProvider::new();
479 mock.expect_token()
480 .times(1)
481 .return_once(|| Ok(expected_clone));
482
483 let uc = UserCredentials {
484 token_provider: mock,
485 quota_project_id: None,
486 };
487 let actual = uc.token().await.unwrap();
488 assert_eq!(actual, expected);
489 }
490
491 #[tokio::test]
492 async fn token_failure() {
493 let mut mock = MockTokenProvider::new();
494 mock.expect_token()
495 .times(1)
496 .return_once(|| Err(errors::non_retryable_from_str("fail")));
497
498 let uc = UserCredentials {
499 token_provider: mock,
500 quota_project_id: None,
501 };
502 assert!(uc.token().await.is_err());
503 }
504
505 #[tokio::test]
506 async fn headers_success() {
507 let token = Token {
508 token: "test-token".to_string(),
509 token_type: "Bearer".to_string(),
510 expires_at: None,
511 metadata: None,
512 };
513
514 let mut mock = MockTokenProvider::new();
515 mock.expect_token().times(1).return_once(|| Ok(token));
516
517 let uc = UserCredentials {
518 token_provider: mock,
519 quota_project_id: None,
520 };
521 let headers: Vec<HV> = HV::from(uc.headers().await.unwrap());
522
523 assert_eq!(
524 headers,
525 vec![HV {
526 header: AUTHORIZATION.to_string(),
527 value: "Bearer test-token".to_string(),
528 is_sensitive: true,
529 }]
530 );
531 }
532
533 #[tokio::test]
534 async fn headers_failure() {
535 let mut mock = MockTokenProvider::new();
536 mock.expect_token()
537 .times(1)
538 .return_once(|| Err(errors::non_retryable_from_str("fail")));
539
540 let uc = UserCredentials {
541 token_provider: mock,
542 quota_project_id: None,
543 };
544 assert!(uc.headers().await.is_err());
545 }
546
547 #[tokio::test]
548 async fn headers_with_quota_project_success() {
549 let token = Token {
550 token: "test-token".to_string(),
551 token_type: "Bearer".to_string(),
552 expires_at: None,
553 metadata: None,
554 };
555
556 let mut mock = MockTokenProvider::new();
557 mock.expect_token().times(1).return_once(|| Ok(token));
558
559 let uc = UserCredentials {
560 token_provider: mock,
561 quota_project_id: Some("test-project".to_string()),
562 };
563 let headers: Vec<HV> = HV::from(uc.headers().await.unwrap());
564 assert_eq!(
565 headers,
566 vec![
567 HV {
568 header: AUTHORIZATION.to_string(),
569 value: "Bearer test-token".to_string(),
570 is_sensitive: true,
571 },
572 HV {
573 header: QUOTA_PROJECT_KEY.to_string(),
574 value: "test-project".to_string(),
575 is_sensitive: false,
576 }
577 ]
578 );
579 }
580
581 #[test]
582 fn oauth2_request_serde() {
583 let request = Oauth2RefreshRequest {
584 grant_type: RefreshGrantType::RefreshToken,
585 client_id: "test-client-id".to_string(),
586 client_secret: "test-client-secret".to_string(),
587 refresh_token: "test-refresh-token".to_string(),
588 scopes: Some("scope1 scope2".to_string()),
589 };
590
591 let json = serde_json::to_value(&request).unwrap();
592 let expected = serde_json::json!({
593 "grant_type": "refresh_token",
594 "client_id": "test-client-id",
595 "client_secret": "test-client-secret",
596 "refresh_token": "test-refresh-token",
597 "scopes": "scope1 scope2",
598 });
599 assert_eq!(json, expected);
600 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
601 assert_eq!(request, roundtrip);
602 }
603
604 #[test]
605 fn oauth2_response_serde_full() {
606 let response = Oauth2RefreshResponse {
607 access_token: "test-access-token".to_string(),
608 scope: Some("scope1 scope2".to_string()),
609 expires_in: Some(3600),
610 token_type: "test-token-type".to_string(),
611 refresh_token: Some("test-refresh-token".to_string()),
612 };
613
614 let json = serde_json::to_value(&response).unwrap();
615 let expected = serde_json::json!({
616 "access_token": "test-access-token",
617 "scope": "scope1 scope2",
618 "expires_in": 3600,
619 "token_type": "test-token-type",
620 "refresh_token": "test-refresh-token"
621 });
622 assert_eq!(json, expected);
623 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
624 assert_eq!(response, roundtrip);
625 }
626
627 #[test]
628 fn oauth2_response_serde_partial() {
629 let response = Oauth2RefreshResponse {
630 access_token: "test-access-token".to_string(),
631 scope: None,
632 expires_in: None,
633 token_type: "test-token-type".to_string(),
634 refresh_token: None,
635 };
636
637 let json = serde_json::to_value(&response).unwrap();
638 let expected = serde_json::json!({
639 "access_token": "test-access-token",
640 "token_type": "test-token-type",
641 });
642 assert_eq!(json, expected);
643 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
644 assert_eq!(response, roundtrip);
645 }
646
647 async fn start(
649 response_code: StatusCode,
650 response_body: Value,
651 call_count: Arc<Mutex<i32>>,
652 ) -> (String, JoinHandle<()>) {
653 let code = response_code;
654 let body = response_body.clone();
655 let handler = move |req| async move { handle_token_factory(code, body, call_count)(req) };
656 let app = axum::Router::new().route("/token", axum::routing::post(handler));
657 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
658 let addr = listener.local_addr().unwrap();
659 let server = tokio::spawn(async {
660 axum::serve(listener, app).await.unwrap();
661 });
662
663 (
664 format!("http://{}:{}/token", addr.ip(), addr.port()),
665 server,
666 )
667 }
668
669 fn handle_token_factory(
673 response_code: StatusCode,
674 response_body: Value,
675 call_count: Arc<std::sync::Mutex<i32>>,
676 ) -> impl Fn(Json<Oauth2RefreshRequest>) -> (StatusCode, String) {
677 move |request: Json<Oauth2RefreshRequest>| -> (StatusCode, String) {
678 let mut count = call_count.lock().unwrap();
679 *count += 1;
680 assert_eq!(request.client_id, "test-client-id");
681 assert_eq!(request.client_secret, "test-client-secret");
682 assert_eq!(request.refresh_token, "test-refresh-token");
683 assert_eq!(request.grant_type, RefreshGrantType::RefreshToken);
684 assert_eq!(
685 request.scopes,
686 response_body["scope"].as_str().map(|s| s.to_owned())
687 );
688
689 (response_code, response_body.to_string())
690 }
691 }
692
693 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
694 async fn token_provider_full() -> TestResult {
695 let response = Oauth2RefreshResponse {
696 access_token: "test-access-token".to_string(),
697 expires_in: Some(3600),
698 refresh_token: Some("test-refresh-token".to_string()),
699 scope: Some("scope1 scope2".to_string()),
700 token_type: "test-token-type".to_string(),
701 };
702 let response_body = serde_json::to_value(&response).unwrap();
703 let (endpoint, _server) =
704 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
705 println!("endpoint = {endpoint}");
706
707 let authorized_user = serde_json::json!({
708 "client_id": "test-client-id",
709 "client_secret": "test-client-secret",
710 "refresh_token": "test-refresh-token",
711 "type": "authorized_user",
712 "token_uri": endpoint,
713 });
714 let cred = Builder::new(authorized_user)
715 .with_scopes(vec!["scope1", "scope2"])
716 .build()?;
717
718 let now = std::time::Instant::now();
719 let token = cred.token().await?;
720 assert_eq!(token.token, "test-access-token");
721 assert_eq!(token.token_type, "test-token-type");
722 assert!(
723 token
724 .expires_at
725 .is_some_and(|d| d >= now + Duration::from_secs(3600)),
726 "now: {:?}, expires_at: {:?}",
727 now,
728 token.expires_at
729 );
730
731 Ok(())
732 }
733
734 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
735 async fn token_provider_full_with_quota_project() -> TestResult {
736 let response = Oauth2RefreshResponse {
737 access_token: "test-access-token".to_string(),
738 expires_in: Some(3600),
739 refresh_token: Some("test-refresh-token".to_string()),
740 scope: None,
741 token_type: "test-token-type".to_string(),
742 };
743 let response_body = serde_json::to_value(&response).unwrap();
744 let (endpoint, _server) =
745 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
746 println!("endpoint = {endpoint}");
747
748 let authorized_user = serde_json::json!({
749 "client_id": "test-client-id",
750 "client_secret": "test-client-secret",
751 "refresh_token": "test-refresh-token",
752 "type": "authorized_user",
753 "token_uri": endpoint,
754 });
755 let cred = Builder::new(authorized_user)
756 .with_quota_project_id("test-project")
757 .build()?;
758
759 let headers: Vec<HV> = HV::from(cred.headers().await.unwrap());
760 assert_eq!(
761 headers,
762 vec![
763 HV {
764 header: AUTHORIZATION.to_string(),
765 value: "test-token-type test-access-token".to_string(),
766 is_sensitive: true,
767 },
768 HV {
769 header: QUOTA_PROJECT_KEY.to_string(),
770 value: "test-project".to_string(),
771 is_sensitive: false,
772 }
773 ]
774 );
775 Ok(())
776 }
777
778 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
779 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
780 let response = Oauth2RefreshResponse {
781 access_token: "test-access-token".to_string(),
782 expires_in: Some(3600),
783 refresh_token: Some("test-refresh-token".to_string()),
784 scope: Some("scope1 scope2".to_string()),
785 token_type: "test-token-type".to_string(),
786 };
787 let response_body = serde_json::to_value(&response).unwrap();
788 let call_count = Arc::new(Mutex::new(0));
789 let (endpoint, _server) = start(StatusCode::OK, response_body, call_count.clone()).await;
790 println!("endpoint = {endpoint}");
791
792 let json = serde_json::json!({
793 "client_id": "test-client-id",
794 "client_secret": "test-client-secret",
795 "refresh_token": "test-refresh-token",
796 "type": "authorized_user",
797 "universe_domain": "googleapis.com",
798 "quota_project_id": "test-project",
799 "token_uri": endpoint,
800 });
801
802 let cred = Builder::new(json)
803 .with_scopes(vec!["scope1", "scope2"])
804 .build()?;
805
806 let token = cred.token().await?;
807 assert_eq!(token.token, "test-access-token");
808
809 let token = cred.token().await?;
810 assert_eq!(token.token, "test-access-token");
811
812 assert_eq!(*call_count.lock().unwrap(), 1);
815
816 Ok(())
817 }
818
819 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
820 async fn token_provider_partial() -> TestResult {
821 let response = Oauth2RefreshResponse {
822 access_token: "test-access-token".to_string(),
823 expires_in: None,
824 refresh_token: None,
825 scope: None,
826 token_type: "test-token-type".to_string(),
827 };
828 let response_body = serde_json::to_value(&response).unwrap();
829 let (endpoint, _server) =
830 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
831 println!("endpoint = {endpoint}");
832
833 let authorized_user = serde_json::json!({
834 "client_id": "test-client-id",
835 "client_secret": "test-client-secret",
836 "refresh_token": "test-refresh-token",
837 "type": "authorized_user",
838 "token_uri": endpoint});
839
840 let uc = Builder::new(authorized_user).build()?;
841 let token = uc.token().await?;
842 assert_eq!(token.token, "test-access-token");
843 assert_eq!(token.token_type, "test-token-type");
844 assert_eq!(token.expires_at, None);
845
846 Ok(())
847 }
848
849 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
850 async fn token_provider_with_token_uri() -> TestResult {
851 let response = Oauth2RefreshResponse {
852 access_token: "test-access-token".to_string(),
853 expires_in: None,
854 refresh_token: None,
855 scope: None,
856 token_type: "test-token-type".to_string(),
857 };
858 let response_body = serde_json::to_value(&response).unwrap();
859 let (endpoint, _server) =
860 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
861 println!("endpoint = {endpoint}");
862
863 let authorized_user = serde_json::json!({
864 "client_id": "test-client-id",
865 "client_secret": "test-client-secret",
866 "refresh_token": "test-refresh-token",
867 "type": "authorized_user",
868 "token_uri": "test-endpoint"});
869
870 let uc = Builder::new(authorized_user)
871 .with_token_uri(endpoint)
872 .build()?;
873 let token = uc.token().await?;
874 assert_eq!(token.token, "test-access-token");
875 assert_eq!(token.token_type, "test-token-type");
876 assert_eq!(token.expires_at, None);
877
878 Ok(())
879 }
880
881 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
882 async fn token_provider_with_scopes() -> TestResult {
883 let response = Oauth2RefreshResponse {
884 access_token: "test-access-token".to_string(),
885 expires_in: None,
886 refresh_token: None,
887 scope: Some("scope1 scope2".to_string()),
888 token_type: "test-token-type".to_string(),
889 };
890 let response_body = serde_json::to_value(&response).unwrap();
891 let (endpoint, _server) =
892 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
893 println!("endpoint = {endpoint}");
894
895 let authorized_user = serde_json::json!({
896 "client_id": "test-client-id",
897 "client_secret": "test-client-secret",
898 "refresh_token": "test-refresh-token",
899 "type": "authorized_user",
900 "token_uri": "test-endpoint"});
901
902 let uc = Builder::new(authorized_user)
903 .with_token_uri(endpoint)
904 .with_scopes(vec!["scope1", "scope2"])
905 .build()?;
906 let token = uc.token().await?;
907 assert_eq!(token.token, "test-access-token");
908 assert_eq!(token.token_type, "test-token-type");
909 assert_eq!(token.expires_at, None);
910
911 Ok(())
912 }
913
914 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
915 async fn token_provider_retryable_error() -> TestResult {
916 let (endpoint, _server) = start(
917 StatusCode::SERVICE_UNAVAILABLE,
918 serde_json::to_value("try again".to_string())?,
919 Arc::new(Mutex::new(0)),
920 )
921 .await;
922 println!("endpoint = {endpoint}");
923
924 let authorized_user = serde_json::json!({
925 "client_id": "test-client-id",
926 "client_secret": "test-client-secret",
927 "refresh_token": "test-refresh-token",
928 "type": "authorized_user",
929 "token_uri": endpoint});
930
931 let uc = Builder::new(authorized_user).build()?;
932 let e = uc.token().await.err().unwrap();
933 assert!(e.is_retryable(), "{e}");
934 assert!(e.source().unwrap().to_string().contains("try again"), "{e}");
935
936 Ok(())
937 }
938
939 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
940 async fn token_provider_nonretryable_error() -> TestResult {
941 let (endpoint, _server) = start(
942 StatusCode::UNAUTHORIZED,
943 serde_json::to_value("epic fail".to_string())?,
944 Arc::new(Mutex::new(0)),
945 )
946 .await;
947 println!("endpoint = {endpoint}");
948
949 let authorized_user = serde_json::json!({
950 "client_id": "test-client-id",
951 "client_secret": "test-client-secret",
952 "refresh_token": "test-refresh-token",
953 "type": "authorized_user",
954 "token_uri": endpoint});
955
956 let uc = Builder::new(authorized_user).build()?;
957 let e = uc.token().await.err().unwrap();
958 assert!(!e.is_retryable(), "{e}");
959 assert!(e.source().unwrap().to_string().contains("epic fail"), "{e}");
960
961 Ok(())
962 }
963
964 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
965 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
966 let (endpoint, _server) = start(
967 StatusCode::OK,
968 serde_json::to_value("bad json".to_string())?,
969 Arc::new(Mutex::new(0)),
970 )
971 .await;
972 println!("endpoint = {endpoint}");
973
974 let authorized_user = serde_json::json!({
975 "client_id": "test-client-id",
976 "client_secret": "test-client-secret",
977 "refresh_token": "test-refresh-token",
978 "type": "authorized_user",
979 "token_uri": endpoint});
980
981 let uc = Builder::new(authorized_user).build()?;
982 let e = uc.token().await.err().unwrap();
983 assert!(!e.is_retryable(), "{e}");
984
985 Ok(())
986 }
987
988 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
989 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
990 let authorized_user = serde_json::json!({
991 "client_secret": "test-client-secret",
992 "refresh_token": "test-refresh-token",
993 "type": "authorized_user",
994 });
995
996 let e = Builder::new(authorized_user).build().unwrap_err();
997 assert!(!e.is_retryable(), "{e}");
998
999 Ok(())
1000 }
1001}