1use crate::build_errors::Error as BuilderError;
70use crate::credentials::dynamic::CredentialsProvider;
71use crate::credentials::{CacheableResource, Credentials};
72use crate::errors::{self, CredentialsError};
73use crate::headers_util::build_cacheable_headers;
74use crate::token::{CachedTokenProvider, Token, TokenProvider};
75use crate::token_cache::TokenCache;
76use crate::{BuildResult, Result};
77use http::header::CONTENT_TYPE;
78use http::{Extensions, HeaderMap, HeaderValue};
79use reqwest::{Client, Method};
80use serde_json::Value;
81use std::sync::Arc;
82use tokio::time::{Duration, Instant};
83
84const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
85
86pub struct Builder {
97 authorized_user: Value,
98 scopes: Option<Vec<String>>,
99 quota_project_id: Option<String>,
100 token_uri: Option<String>,
101}
102
103impl Builder {
104 pub fn new(authorized_user: Value) -> Self {
111 Self {
112 authorized_user,
113 scopes: None,
114 quota_project_id: None,
115 token_uri: None,
116 }
117 }
118
119 pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
133 self.token_uri = Some(token_uri.into());
134 self
135 }
136
137 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
158 where
159 I: IntoIterator<Item = S>,
160 S: Into<String>,
161 {
162 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
163 self
164 }
165
166 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
187 self.quota_project_id = Some(quota_project_id.into());
188 self
189 }
190
191 pub fn build(self) -> BuildResult<Credentials> {
204 let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
205 .map_err(BuilderError::parsing)?;
206 let endpoint = self
207 .token_uri
208 .or(authorized_user.token_uri)
209 .unwrap_or(OAUTH2_ENDPOINT.to_string());
210 let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
211
212 let token_provider = UserTokenProvider {
213 client_id: authorized_user.client_id,
214 client_secret: authorized_user.client_secret,
215 refresh_token: authorized_user.refresh_token,
216 endpoint,
217 scopes: self.scopes.map(|scopes| scopes.join(" ")),
218 };
219 let token_provider = TokenCache::new(token_provider);
220
221 Ok(Credentials {
222 inner: Arc::new(UserCredentials {
223 token_provider,
224 quota_project_id,
225 }),
226 })
227 }
228}
229
230#[derive(PartialEq)]
231struct UserTokenProvider {
232 client_id: String,
233 client_secret: String,
234 refresh_token: String,
235 endpoint: String,
236 scopes: Option<String>,
237}
238
239impl std::fmt::Debug for UserTokenProvider {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 f.debug_struct("UserCredentials")
242 .field("client_id", &self.client_id)
243 .field("client_secret", &"[censored]")
244 .field("refresh_token", &"[censored]")
245 .field("endpoint", &self.endpoint)
246 .field("scopes", &self.scopes)
247 .finish()
248 }
249}
250
251#[async_trait::async_trait]
252impl TokenProvider for UserTokenProvider {
253 async fn token(&self) -> Result<Token> {
254 let client = Client::new();
255
256 let req = Oauth2RefreshRequest {
258 grant_type: RefreshGrantType::RefreshToken,
259 client_id: self.client_id.clone(),
260 client_secret: self.client_secret.clone(),
261 refresh_token: self.refresh_token.clone(),
262 scopes: self.scopes.clone(),
263 };
264 let header = HeaderValue::from_static("application/json");
265 let builder = client
266 .request(Method::POST, self.endpoint.as_str())
267 .header(CONTENT_TYPE, header)
268 .json(&req);
269 let resp = builder
270 .send()
271 .await
272 .map_err(|e| errors::from_http_error(e, MSG))?;
273
274 if !resp.status().is_success() {
276 let err = errors::from_http_response(resp, MSG).await;
277 return Err(err);
278 }
279 let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
280 let retryable = !e.is_decode();
281 CredentialsError::from_source(retryable, e)
282 })?;
283 let token = Token {
284 token: response.access_token,
285 token_type: response.token_type,
286 expires_at: response
287 .expires_in
288 .map(|d| Instant::now() + Duration::from_secs(d)),
289 metadata: None,
290 };
291 Ok(token)
292 }
293}
294
295const MSG: &str = "failed to refresh user access token";
296
297#[derive(Debug)]
301pub(crate) struct UserCredentials<T>
302where
303 T: CachedTokenProvider,
304{
305 token_provider: T,
306 quota_project_id: Option<String>,
307}
308
309#[async_trait::async_trait]
310impl<T> CredentialsProvider for UserCredentials<T>
311where
312 T: CachedTokenProvider,
313{
314 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
315 let token = self.token_provider.token(extensions).await?;
316 build_cacheable_headers(&token, &self.quota_project_id)
317 }
318}
319
320#[derive(Debug, PartialEq, serde::Deserialize)]
321pub(crate) struct AuthorizedUser {
322 #[serde(rename = "type")]
323 cred_type: String,
324 client_id: String,
325 client_secret: String,
326 refresh_token: String,
327 #[serde(skip_serializing_if = "Option::is_none")]
328 token_uri: Option<String>,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 quota_project_id: Option<String>,
331}
332
333#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
334enum RefreshGrantType {
335 #[serde(rename = "refresh_token")]
336 RefreshToken,
337}
338
339#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
340struct Oauth2RefreshRequest {
341 grant_type: RefreshGrantType,
342 client_id: String,
343 client_secret: String,
344 refresh_token: String,
345 scopes: Option<String>,
346}
347
348#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
349struct Oauth2RefreshResponse {
350 access_token: String,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 scope: Option<String>,
353 #[serde(skip_serializing_if = "Option::is_none")]
354 expires_in: Option<u64>,
355 token_type: String,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 refresh_token: Option<String>,
358}
359
360#[cfg(test)]
361mod test {
362 use super::*;
363 use crate::credentials::test::{
364 get_headers_from_cache, get_token_from_headers, get_token_type_from_headers,
365 };
366 use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
367 use crate::token::test::MockTokenProvider;
368 use axum::extract::Json;
369 use http::StatusCode;
370 use http::header::AUTHORIZATION;
371 use std::error::Error;
372 use std::sync::Mutex;
373 use tokio::task::JoinHandle;
374
375 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
376
377 #[test]
378 fn debug_token_provider() {
379 let expected = UserTokenProvider {
380 client_id: "test-client-id".to_string(),
381 client_secret: "test-client-secret".to_string(),
382 refresh_token: "test-refresh-token".to_string(),
383 endpoint: OAUTH2_ENDPOINT.to_string(),
384 scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
385 };
386 let fmt = format!("{expected:?}");
387 assert!(fmt.contains("test-client-id"), "{fmt}");
388 assert!(!fmt.contains("test-client-secret"), "{fmt}");
389 assert!(!fmt.contains("test-refresh-token"), "{fmt}");
390 assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
391 assert!(
392 fmt.contains("https://www.googleapis.com/auth/pubsub"),
393 "{fmt}"
394 );
395 }
396
397 #[test]
398 fn authorized_user_full_from_json_success() {
399 let json = serde_json::json!({
400 "account": "",
401 "client_id": "test-client-id",
402 "client_secret": "test-client-secret",
403 "refresh_token": "test-refresh-token",
404 "type": "authorized_user",
405 "universe_domain": "googleapis.com",
406 "quota_project_id": "test-project",
407 "token_uri" : "test-token-uri",
408 });
409
410 let expected = AuthorizedUser {
411 cred_type: "authorized_user".to_string(),
412 client_id: "test-client-id".to_string(),
413 client_secret: "test-client-secret".to_string(),
414 refresh_token: "test-refresh-token".to_string(),
415 quota_project_id: Some("test-project".to_string()),
416 token_uri: Some("test-token-uri".to_string()),
417 };
418 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
419 assert_eq!(actual, expected);
420 }
421
422 #[test]
423 fn authorized_user_partial_from_json_success() {
424 let json = serde_json::json!({
425 "client_id": "test-client-id",
426 "client_secret": "test-client-secret",
427 "refresh_token": "test-refresh-token",
428 "type": "authorized_user",
429 });
430
431 let expected = AuthorizedUser {
432 cred_type: "authorized_user".to_string(),
433 client_id: "test-client-id".to_string(),
434 client_secret: "test-client-secret".to_string(),
435 refresh_token: "test-refresh-token".to_string(),
436 quota_project_id: None,
437 token_uri: None,
438 };
439 let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
440 assert_eq!(actual, expected);
441 }
442
443 #[test]
444 fn authorized_user_from_json_parse_fail() {
445 let json_full = serde_json::json!({
446 "client_id": "test-client-id",
447 "client_secret": "test-client-secret",
448 "refresh_token": "test-refresh-token",
449 "type": "authorized_user",
450 "quota_project_id": "test-project"
451 });
452
453 for required_field in ["client_id", "client_secret", "refresh_token"] {
454 let mut json = json_full.clone();
455 json[required_field].take();
457 serde_json::from_value::<AuthorizedUser>(json)
458 .err()
459 .unwrap();
460 }
461 }
462
463 #[tokio::test]
464 async fn default_universe_domain_success() {
465 let mock = TokenCache::new(MockTokenProvider::new());
466
467 let uc = UserCredentials {
468 token_provider: mock,
469 quota_project_id: None,
470 };
471 assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
472 }
473
474 #[tokio::test]
475 async fn headers_success() -> TestResult {
476 let token = Token {
477 token: "test-token".to_string(),
478 token_type: "Bearer".to_string(),
479 expires_at: None,
480 metadata: None,
481 };
482
483 let mut mock = MockTokenProvider::new();
484 mock.expect_token().times(1).return_once(|| Ok(token));
485
486 let uc = UserCredentials {
487 token_provider: TokenCache::new(mock),
488 quota_project_id: None,
489 };
490
491 let mut extensions = Extensions::new();
492 let cached_headers = uc.headers(extensions.clone()).await.unwrap();
493 let (headers, entity_tag) = match cached_headers {
494 CacheableResource::New { entity_tag, data } => (data, entity_tag),
495 CacheableResource::NotModified => unreachable!("expecting new headers"),
496 };
497 let token = headers.get(AUTHORIZATION).unwrap();
498
499 assert_eq!(headers.len(), 1, "{headers:?}");
500 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
501 assert!(token.is_sensitive());
502
503 extensions.insert(entity_tag);
504
505 let cached_headers = uc.headers(extensions).await?;
506
507 match cached_headers {
508 CacheableResource::New { .. } => unreachable!("expecting new headers"),
509 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
510 };
511 Ok(())
512 }
513
514 #[tokio::test]
515 async fn headers_failure() {
516 let mut mock = MockTokenProvider::new();
517 mock.expect_token()
518 .times(1)
519 .return_once(|| Err(errors::non_retryable_from_str("fail")));
520
521 let uc = UserCredentials {
522 token_provider: TokenCache::new(mock),
523 quota_project_id: None,
524 };
525 assert!(uc.headers(Extensions::new()).await.is_err());
526 }
527
528 #[tokio::test]
529 async fn headers_with_quota_project_success() -> TestResult {
530 let token = Token {
531 token: "test-token".to_string(),
532 token_type: "Bearer".to_string(),
533 expires_at: None,
534 metadata: None,
535 };
536
537 let mut mock = MockTokenProvider::new();
538 mock.expect_token().times(1).return_once(|| Ok(token));
539
540 let uc = UserCredentials {
541 token_provider: TokenCache::new(mock),
542 quota_project_id: Some("test-project".to_string()),
543 };
544
545 let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
546 let token = headers.get(AUTHORIZATION).unwrap();
547 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
548
549 assert_eq!(headers.len(), 2, "{headers:?}");
550 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
551 assert!(token.is_sensitive());
552 assert_eq!(
553 quota_project_header,
554 HeaderValue::from_static("test-project")
555 );
556 assert!(!quota_project_header.is_sensitive());
557 Ok(())
558 }
559
560 #[test]
561 fn oauth2_request_serde() {
562 let request = Oauth2RefreshRequest {
563 grant_type: RefreshGrantType::RefreshToken,
564 client_id: "test-client-id".to_string(),
565 client_secret: "test-client-secret".to_string(),
566 refresh_token: "test-refresh-token".to_string(),
567 scopes: Some("scope1 scope2".to_string()),
568 };
569
570 let json = serde_json::to_value(&request).unwrap();
571 let expected = serde_json::json!({
572 "grant_type": "refresh_token",
573 "client_id": "test-client-id",
574 "client_secret": "test-client-secret",
575 "refresh_token": "test-refresh-token",
576 "scopes": "scope1 scope2",
577 });
578 assert_eq!(json, expected);
579 let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
580 assert_eq!(request, roundtrip);
581 }
582
583 #[test]
584 fn oauth2_response_serde_full() {
585 let response = Oauth2RefreshResponse {
586 access_token: "test-access-token".to_string(),
587 scope: Some("scope1 scope2".to_string()),
588 expires_in: Some(3600),
589 token_type: "test-token-type".to_string(),
590 refresh_token: Some("test-refresh-token".to_string()),
591 };
592
593 let json = serde_json::to_value(&response).unwrap();
594 let expected = serde_json::json!({
595 "access_token": "test-access-token",
596 "scope": "scope1 scope2",
597 "expires_in": 3600,
598 "token_type": "test-token-type",
599 "refresh_token": "test-refresh-token"
600 });
601 assert_eq!(json, expected);
602 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
603 assert_eq!(response, roundtrip);
604 }
605
606 #[test]
607 fn oauth2_response_serde_partial() {
608 let response = Oauth2RefreshResponse {
609 access_token: "test-access-token".to_string(),
610 scope: None,
611 expires_in: None,
612 token_type: "test-token-type".to_string(),
613 refresh_token: None,
614 };
615
616 let json = serde_json::to_value(&response).unwrap();
617 let expected = serde_json::json!({
618 "access_token": "test-access-token",
619 "token_type": "test-token-type",
620 });
621 assert_eq!(json, expected);
622 let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
623 assert_eq!(response, roundtrip);
624 }
625
626 async fn start(
628 response_code: StatusCode,
629 response_body: Value,
630 call_count: Arc<Mutex<i32>>,
631 ) -> (String, JoinHandle<()>) {
632 let code = response_code;
633 let body = response_body.clone();
634 let handler = move |req| async move { handle_token_factory(code, body, call_count)(req) };
635 let app = axum::Router::new().route("/token", axum::routing::post(handler));
636 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
637 let addr = listener.local_addr().unwrap();
638 let server = tokio::spawn(async {
639 axum::serve(listener, app).await.unwrap();
640 });
641
642 (
643 format!("http://{}:{}/token", addr.ip(), addr.port()),
644 server,
645 )
646 }
647
648 fn handle_token_factory(
652 response_code: StatusCode,
653 response_body: Value,
654 call_count: Arc<std::sync::Mutex<i32>>,
655 ) -> impl Fn(Json<Oauth2RefreshRequest>) -> (StatusCode, String) {
656 move |request: Json<Oauth2RefreshRequest>| -> (StatusCode, String) {
657 let mut count = call_count.lock().unwrap();
658 *count += 1;
659 assert_eq!(request.client_id, "test-client-id");
660 assert_eq!(request.client_secret, "test-client-secret");
661 assert_eq!(request.refresh_token, "test-refresh-token");
662 assert_eq!(request.grant_type, RefreshGrantType::RefreshToken);
663 assert_eq!(
664 request.scopes,
665 response_body["scope"].as_str().map(|s| s.to_owned())
666 );
667
668 (response_code, response_body.to_string())
669 }
670 }
671
672 #[tokio::test(start_paused = true)]
673 async fn token_provider_full() -> TestResult {
674 let response = Oauth2RefreshResponse {
675 access_token: "test-access-token".to_string(),
676 expires_in: Some(3600),
677 refresh_token: Some("test-refresh-token".to_string()),
678 scope: Some("scope1 scope2".to_string()),
679 token_type: "test-token-type".to_string(),
680 };
681 let response_body = serde_json::to_value(&response).unwrap();
682 let (endpoint, _server) =
683 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
684 println!("endpoint = {endpoint}");
685
686 let tp = UserTokenProvider {
687 client_id: "test-client-id".to_string(),
688 client_secret: "test-client-secret".to_string(),
689 refresh_token: "test-refresh-token".to_string(),
690 endpoint,
691 scopes: Some("scope1 scope2".to_string()),
692 };
693 let now = Instant::now();
694 let token = tp.token().await?;
695 assert_eq!(token.token, "test-access-token");
696 assert_eq!(token.token_type, "test-token-type");
697 assert!(
698 token
699 .expires_at
700 .is_some_and(|d| d == now + Duration::from_secs(3600)),
701 "now: {:?}, expires_at: {:?}",
702 now,
703 token.expires_at
704 );
705
706 Ok(())
707 }
708
709 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
710 async fn credential_full_with_quota_project() -> TestResult {
711 let response = Oauth2RefreshResponse {
712 access_token: "test-access-token".to_string(),
713 expires_in: Some(3600),
714 refresh_token: Some("test-refresh-token".to_string()),
715 scope: None,
716 token_type: "test-token-type".to_string(),
717 };
718 let response_body = serde_json::to_value(&response).unwrap();
719 let (endpoint, _server) =
720 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
721 println!("endpoint = {endpoint}");
722
723 let authorized_user = serde_json::json!({
724 "client_id": "test-client-id",
725 "client_secret": "test-client-secret",
726 "refresh_token": "test-refresh-token",
727 "type": "authorized_user",
728 "token_uri": endpoint,
729 });
730 let cred = Builder::new(authorized_user)
731 .with_quota_project_id("test-project")
732 .build()?;
733
734 let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
735 let token = headers.get(AUTHORIZATION).unwrap();
736 let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
737
738 assert_eq!(headers.len(), 2, "{headers:?}");
739 assert_eq!(
740 token,
741 HeaderValue::from_static("test-token-type test-access-token")
742 );
743 assert!(token.is_sensitive());
744 assert_eq!(
745 quota_project_header,
746 HeaderValue::from_static("test-project")
747 );
748 assert!(!quota_project_header.is_sensitive());
749
750 Ok(())
751 }
752
753 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
754 async fn creds_from_json_custom_uri_with_caching() -> TestResult {
755 let response = Oauth2RefreshResponse {
756 access_token: "test-access-token".to_string(),
757 expires_in: Some(3600),
758 refresh_token: Some("test-refresh-token".to_string()),
759 scope: Some("scope1 scope2".to_string()),
760 token_type: "test-token-type".to_string(),
761 };
762 let response_body = serde_json::to_value(&response).unwrap();
763 let call_count = Arc::new(Mutex::new(0));
764 let (endpoint, _server) = start(StatusCode::OK, response_body, call_count.clone()).await;
765 println!("endpoint = {endpoint}");
766
767 let json = serde_json::json!({
768 "client_id": "test-client-id",
769 "client_secret": "test-client-secret",
770 "refresh_token": "test-refresh-token",
771 "type": "authorized_user",
772 "universe_domain": "googleapis.com",
773 "quota_project_id": "test-project",
774 "token_uri": endpoint,
775 });
776
777 let cred = Builder::new(json)
778 .with_scopes(vec!["scope1", "scope2"])
779 .build()?;
780
781 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
782 assert_eq!(token.unwrap(), "test-access-token");
783
784 let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
785 assert_eq!(token.unwrap(), "test-access-token");
786
787 assert_eq!(*call_count.lock().unwrap(), 1);
790
791 Ok(())
792 }
793
794 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
795 async fn credential_provider_partial() -> TestResult {
796 let response = Oauth2RefreshResponse {
797 access_token: "test-access-token".to_string(),
798 expires_in: None,
799 refresh_token: None,
800 scope: None,
801 token_type: "test-token-type".to_string(),
802 };
803 let response_body = serde_json::to_value(&response).unwrap();
804 let (endpoint, _server) =
805 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
806 println!("endpoint = {endpoint}");
807
808 let authorized_user = serde_json::json!({
809 "client_id": "test-client-id",
810 "client_secret": "test-client-secret",
811 "refresh_token": "test-refresh-token",
812 "type": "authorized_user",
813 "token_uri": endpoint});
814
815 let uc = Builder::new(authorized_user).build()?;
816 let headers = uc.headers(Extensions::new()).await?;
817 assert_eq!(
818 get_token_from_headers(headers.clone()).unwrap(),
819 "test-access-token"
820 );
821 assert_eq!(
822 get_token_type_from_headers(headers).unwrap(),
823 "test-token-type"
824 );
825
826 Ok(())
827 }
828
829 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
830 async fn credential_provider_with_token_uri() -> TestResult {
831 let response = Oauth2RefreshResponse {
832 access_token: "test-access-token".to_string(),
833 expires_in: None,
834 refresh_token: None,
835 scope: None,
836 token_type: "test-token-type".to_string(),
837 };
838 let response_body = serde_json::to_value(&response).unwrap();
839 let (endpoint, _server) =
840 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
841 println!("endpoint = {endpoint}");
842
843 let authorized_user = serde_json::json!({
844 "client_id": "test-client-id",
845 "client_secret": "test-client-secret",
846 "refresh_token": "test-refresh-token",
847 "type": "authorized_user",
848 "token_uri": "test-endpoint"});
849
850 let uc = Builder::new(authorized_user)
851 .with_token_uri(endpoint)
852 .build()?;
853 let headers = uc.headers(Extensions::new()).await?;
854 assert_eq!(
855 get_token_from_headers(headers.clone()).unwrap(),
856 "test-access-token"
857 );
858 assert_eq!(
859 get_token_type_from_headers(headers).unwrap(),
860 "test-token-type"
861 );
862
863 Ok(())
864 }
865
866 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
867 async fn credential_provider_with_scopes() -> TestResult {
868 let response = Oauth2RefreshResponse {
869 access_token: "test-access-token".to_string(),
870 expires_in: None,
871 refresh_token: None,
872 scope: Some("scope1 scope2".to_string()),
873 token_type: "test-token-type".to_string(),
874 };
875 let response_body = serde_json::to_value(&response).unwrap();
876 let (endpoint, _server) =
877 start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
878 println!("endpoint = {endpoint}");
879
880 let authorized_user = serde_json::json!({
881 "client_id": "test-client-id",
882 "client_secret": "test-client-secret",
883 "refresh_token": "test-refresh-token",
884 "type": "authorized_user",
885 "token_uri": "test-endpoint"});
886
887 let uc = Builder::new(authorized_user)
888 .with_token_uri(endpoint)
889 .with_scopes(vec!["scope1", "scope2"])
890 .build()?;
891 let headers = uc.headers(Extensions::new()).await?;
892 assert_eq!(
893 get_token_from_headers(headers.clone()).unwrap(),
894 "test-access-token"
895 );
896 assert_eq!(
897 get_token_type_from_headers(headers).unwrap(),
898 "test-token-type"
899 );
900
901 Ok(())
902 }
903
904 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
905 async fn credential_provider_retryable_error() -> TestResult {
906 let (endpoint, _server) = start(
907 StatusCode::SERVICE_UNAVAILABLE,
908 serde_json::to_value("try again".to_string())?,
909 Arc::new(Mutex::new(0)),
910 )
911 .await;
912 println!("endpoint = {endpoint}");
913
914 let authorized_user = serde_json::json!({
915 "client_id": "test-client-id",
916 "client_secret": "test-client-secret",
917 "refresh_token": "test-refresh-token",
918 "type": "authorized_user",
919 "token_uri": endpoint});
920
921 let uc = Builder::new(authorized_user).build()?;
922 let err = uc.headers(Extensions::new()).await.unwrap_err();
923 assert!(err.is_transient(), "{err}");
924 assert!(err.to_string().contains("try again"), "{err:?}");
925 let source = err
926 .source()
927 .and_then(|e| e.downcast_ref::<reqwest::Error>());
928 assert!(
929 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
930 "{err:?}"
931 );
932
933 Ok(())
934 }
935
936 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
937 async fn token_provider_nonretryable_error() -> TestResult {
938 let (endpoint, _server) = start(
939 StatusCode::UNAUTHORIZED,
940 serde_json::to_value("epic fail".to_string())?,
941 Arc::new(Mutex::new(0)),
942 )
943 .await;
944 println!("endpoint = {endpoint}");
945
946 let authorized_user = serde_json::json!({
947 "client_id": "test-client-id",
948 "client_secret": "test-client-secret",
949 "refresh_token": "test-refresh-token",
950 "type": "authorized_user",
951 "token_uri": endpoint});
952
953 let uc = Builder::new(authorized_user).build()?;
954 let err = uc.headers(Extensions::new()).await.unwrap_err();
955 assert!(!err.is_transient(), "{err:?}");
956 assert!(err.to_string().contains("epic fail"), "{err:?}");
957 let source = err
958 .source()
959 .and_then(|e| e.downcast_ref::<reqwest::Error>());
960 assert!(
961 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
962 "{err:?}"
963 );
964
965 Ok(())
966 }
967
968 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
969 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
970 let (endpoint, _server) = start(
971 StatusCode::OK,
972 serde_json::to_value("bad json".to_string())?,
973 Arc::new(Mutex::new(0)),
974 )
975 .await;
976 println!("endpoint = {endpoint}");
977
978 let authorized_user = serde_json::json!({
979 "client_id": "test-client-id",
980 "client_secret": "test-client-secret",
981 "refresh_token": "test-refresh-token",
982 "type": "authorized_user",
983 "token_uri": endpoint
984 });
985
986 let uc = Builder::new(authorized_user).build()?;
987 let e = uc.headers(Extensions::new()).await.err().unwrap();
988 assert!(!e.is_transient(), "{e}");
989
990 Ok(())
991 }
992
993 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
994 async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
995 let authorized_user = serde_json::json!({
996 "client_secret": "test-client-secret",
997 "refresh_token": "test-refresh-token",
998 "type": "authorized_user",
999 });
1000
1001 let e = Builder::new(authorized_user).build().unwrap_err();
1002 assert!(e.is_parsing(), "{e}");
1003
1004 Ok(())
1005 }
1006}