1use crate::credentials::dynamic::CredentialsTrait;
56use crate::credentials::{Credentials, DEFAULT_UNIVERSE_DOMAIN, Result};
57use crate::errors::{self, CredentialsError, is_retryable};
58use crate::headers_util::build_bearer_headers;
59use crate::token::{Token, TokenProvider};
60use async_trait::async_trait;
61use bon::Builder;
62use http::header::{HeaderName, HeaderValue};
63use reqwest::Client;
64use std::default::Default;
65use std::sync::Arc;
66use std::time::Duration;
67
68const METADATA_FLAVOR_VALUE: &str = "Google";
69const METADATA_FLAVOR: &str = "metadata-flavor";
70const METADATA_ROOT: &str = "http://metadata.google.internal/";
71
72pub(crate) fn new() -> Credentials {
73 Builder::default().build()
74}
75
76#[derive(Debug)]
77struct MDSCredentials<T>
78where
79 T: TokenProvider,
80{
81 quota_project_id: Option<String>,
82 universe_domain: Option<String>,
83 token_provider: T,
84}
85
86#[derive(Debug, Default)]
98pub struct Builder {
99 endpoint: Option<String>,
100 quota_project_id: Option<String>,
101 scopes: Option<Vec<String>>,
102 universe_domain: Option<String>,
103}
104
105impl Builder {
106 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
110 self.endpoint = Some(endpoint.into());
111 self
112 }
113
114 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
123 self.quota_project_id = Some(quota_project_id.into());
124 self
125 }
126
127 pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
134 self.universe_domain = Some(universe_domain.into());
135 self
136 }
137
138 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
147 where
148 I: IntoIterator<Item = S>,
149 S: Into<String>,
150 {
151 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
152 self
153 }
154
155 pub fn build(self) -> Credentials {
157 let endpoint = self.endpoint.clone().unwrap_or(METADATA_ROOT.to_string());
158
159 let token_provider = MDSAccessTokenProvider::builder()
160 .endpoint(endpoint)
161 .maybe_scopes(self.scopes)
162 .build();
163 let cached_token_provider = crate::token_cache::TokenCache::new(token_provider);
164
165 let mdsc = MDSCredentials {
166 quota_project_id: self.quota_project_id,
167 token_provider: cached_token_provider,
168 universe_domain: self.universe_domain,
169 };
170 Credentials {
171 inner: Arc::new(mdsc),
172 }
173 }
174}
175
176#[async_trait::async_trait]
177impl<T> CredentialsTrait for MDSCredentials<T>
178where
179 T: TokenProvider,
180{
181 async fn token(&self) -> Result<Token> {
182 self.token_provider.token().await
183 }
184
185 async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
186 let token = self.token().await?;
187 build_bearer_headers(&token, &self.quota_project_id)
188 }
189
190 async fn universe_domain(&self) -> Option<String> {
191 if self.universe_domain.is_some() {
192 return self.universe_domain.clone();
193 }
194 return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
195 }
196}
197
198#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
199struct ServiceAccountInfo {
200 email: String,
201 scopes: Option<Vec<String>>,
202 aliases: Option<Vec<String>>,
203}
204
205#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
206struct MDSTokenResponse {
207 access_token: String,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 expires_in: Option<u64>,
210 token_type: String,
211}
212
213#[derive(Debug, Clone, Default, Builder)]
214struct MDSAccessTokenProvider {
215 #[builder(into)]
216 scopes: Option<Vec<String>>,
217 #[builder(into)]
218 endpoint: String,
219}
220
221impl MDSAccessTokenProvider {
222 async fn get_service_account_info(&self, client: &Client) -> Result<ServiceAccountInfo> {
223 let request = client
224 .get(format!(
225 "{}/computeMetadata/v1/instance/service-accounts/default/",
226 self.endpoint
227 ))
228 .query(&[("recursive", "true")])
229 .header(
230 METADATA_FLAVOR,
231 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
232 );
233
234 let response = request.send().await.map_err(errors::retryable)?;
235
236 response
237 .json::<ServiceAccountInfo>()
238 .await
239 .map_err(errors::non_retryable)
240 }
241}
242
243#[async_trait]
244impl TokenProvider for MDSAccessTokenProvider {
245 async fn token(&self) -> Result<Token> {
246 let client = Client::new();
247 let scopes = match &self.scopes {
249 Some(s) => s.clone().join(","),
250 None => {
251 let service_account_info = self.get_service_account_info(&client).await?;
252 service_account_info.scopes.unwrap_or_default().join(",")
253 }
254 };
255
256 let request = client
257 .get(format!(
258 "{}/computeMetadata/v1/instance/service-accounts/default/token",
259 self.endpoint
260 ))
261 .query(&[("scopes", scopes)])
262 .header(
263 METADATA_FLAVOR,
264 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
265 );
266
267 let response = request.send().await.map_err(errors::retryable)?;
268 if !response.status().is_success() {
270 let status = response.status();
271 let body = response
272 .text()
273 .await
274 .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
275 return Err(CredentialsError::from_str(
276 is_retryable(status),
277 format!("Failed to fetch token. {body}"),
278 ));
279 }
280 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
281 let retryable = !e.is_decode();
282 CredentialsError::new(retryable, e)
283 })?;
284 let token = Token {
285 token: response.access_token,
286 token_type: response.token_type,
287 expires_at: response
288 .expires_in
289 .map(|d| std::time::Instant::now() + Duration::from_secs(d)),
290 metadata: None,
291 };
292 Ok(token)
293 }
294}
295
296#[cfg(test)]
297mod test {
298 use super::*;
299 use crate::credentials::QUOTA_PROJECT_KEY;
300 use crate::credentials::test::HV;
301 use crate::token::test::MockTokenProvider;
302 use axum::extract::Query;
303 use axum::response::IntoResponse;
304 use http::header::AUTHORIZATION;
305 use reqwest::StatusCode;
306 use reqwest::header::HeaderMap;
307 use serde::Deserialize;
308 use serde_json::Value;
309 use std::collections::HashMap;
310 use std::error::Error;
311 use std::sync::Mutex;
312 use tokio::task::JoinHandle;
313
314 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
315 const MDS_TOKEN_URI: &str = "/computeMetadata/v1/instance/service-accounts/default/token";
316
317 #[derive(Debug, Clone, Deserialize, PartialEq)]
319 struct TokenQueryParams {
320 scopes: Option<String>,
321 recursive: Option<String>,
322 }
323
324 #[tokio::test]
325 async fn token_success() {
326 let expected = Token {
327 token: "test-token".to_string(),
328 token_type: "Bearer".to_string(),
329 expires_at: None,
330 metadata: None,
331 };
332 let expected_clone = expected.clone();
333
334 let mut mock = MockTokenProvider::new();
335 mock.expect_token()
336 .times(1)
337 .return_once(|| Ok(expected_clone));
338
339 let mdsc = MDSCredentials {
340 quota_project_id: None,
341 universe_domain: None,
342 token_provider: mock,
343 };
344 let actual = mdsc.token().await.unwrap();
345 assert_eq!(actual, expected);
346 }
347
348 #[tokio::test]
349 async fn token_failure() {
350 let mut mock = MockTokenProvider::new();
351 mock.expect_token()
352 .times(1)
353 .return_once(|| Err(errors::non_retryable_from_str("fail")));
354
355 let mdsc = MDSCredentials {
356 quota_project_id: None,
357 universe_domain: None,
358 token_provider: mock,
359 };
360 assert!(mdsc.token().await.is_err());
361 }
362
363 #[tokio::test]
364 async fn headers_success() {
365 let token = Token {
366 token: "test-token".to_string(),
367 token_type: "Bearer".to_string(),
368 expires_at: None,
369 metadata: None,
370 };
371
372 let mut mock = MockTokenProvider::new();
373 mock.expect_token().times(1).return_once(|| Ok(token));
374
375 let mdsc = MDSCredentials {
376 quota_project_id: None,
377 universe_domain: None,
378 token_provider: mock,
379 };
380 let headers: Vec<HV> = HV::from(mdsc.headers().await.unwrap());
381
382 assert_eq!(
383 headers,
384 vec![HV {
385 header: AUTHORIZATION.to_string(),
386 value: "Bearer test-token".to_string(),
387 is_sensitive: true,
388 }]
389 );
390 }
391
392 #[tokio::test]
393 async fn headers_failure() {
394 let mut mock = MockTokenProvider::new();
395 mock.expect_token()
396 .times(1)
397 .return_once(|| Err(errors::non_retryable_from_str("fail")));
398
399 let mdsc = MDSCredentials {
400 quota_project_id: None,
401 universe_domain: None,
402 token_provider: mock,
403 };
404 assert!(mdsc.headers().await.is_err());
405 }
406
407 fn handle_token_factory(
408 response_code: StatusCode,
409 response_headers: HeaderMap,
410 response_body: Value,
411 ) -> impl IntoResponse {
412 (response_code, response_headers, response_body.to_string()).into_response()
413 }
414
415 type Handlers = HashMap<String, (StatusCode, Value, TokenQueryParams, Arc<Mutex<i32>>)>;
416
417 async fn start(path_handlers: Handlers) -> (String, JoinHandle<()>) {
420 let mut app = axum::Router::new();
421
422 for (path, (code, body, expected_query, call_count)) in path_handlers {
423 let header_map = HeaderMap::new();
424 let handler = move |Query(query): Query<TokenQueryParams>| {
425 let body = body.clone();
426 let header_map = header_map.clone();
427 async move {
428 assert_eq!(expected_query, query);
429 let mut count = call_count.lock().unwrap();
430 *count += 1;
431 handle_token_factory(code, header_map, body)
432 }
433 };
434 app = app.route(&path, axum::routing::get(handler));
435 }
436
437 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
438 let addr = listener.local_addr().unwrap();
439 let server = tokio::spawn(async move {
440 axum::serve(listener, app).await.unwrap();
441 });
442 (format!("http://{}:{}", addr.ip(), addr.port()), server)
443 }
444
445 #[tokio::test]
446 async fn get_default_service_account_info_success() {
447 let service_account = "default";
448 let path = format!(
449 "/computeMetadata/v1/instance/service-accounts/{}/",
450 service_account
451 );
452 let service_account_info = ServiceAccountInfo {
453 email: "test@test.com".to_string(),
454 scopes: Some(vec!["scope 1".to_string(), "scope 2".to_string()]),
455 aliases: None,
456 };
457 let service_account_info_json = serde_json::to_value(service_account_info.clone()).unwrap();
458 let (endpoint, _server) = start(Handlers::from([(
459 path,
460 (
461 StatusCode::OK,
462 service_account_info_json,
463 TokenQueryParams {
464 scopes: None,
465 recursive: Some("true".to_string()),
466 },
467 Arc::new(Mutex::new(0)),
468 ),
469 )]))
470 .await;
471
472 let request = Client::new();
473 let token_provider = MDSAccessTokenProvider::builder().endpoint(endpoint).build();
474
475 let result = token_provider.get_service_account_info(&request).await;
476
477 assert!(result.is_ok());
478 assert_eq!(result.unwrap(), service_account_info);
479 }
480
481 #[tokio::test]
482 async fn get_service_account_info_server_error() {
483 let path = "/computeMetadata/v1/instance/service-accounts/default/".to_string();
484 let (endpoint, _server) = start(Handlers::from([(
485 path,
486 (
487 StatusCode::SERVICE_UNAVAILABLE,
488 serde_json::to_value("try again").unwrap(),
489 TokenQueryParams {
490 scopes: None,
491 recursive: Some("true".to_string()),
492 },
493 Arc::new(Mutex::new(0)),
494 ),
495 )]))
496 .await;
497
498 let request = Client::new();
499 let token_provider = MDSAccessTokenProvider::builder().endpoint(endpoint).build();
500
501 let result = token_provider.get_service_account_info(&request).await;
502 assert!(result.is_err());
503 }
504
505 #[tokio::test]
506 async fn headers_success_with_quota_project() {
507 let scopes = ["scope1".to_string(), "scope2".to_string()];
508 let response = MDSTokenResponse {
509 access_token: "test-access-token".to_string(),
510 expires_in: Some(3600),
511 token_type: "test-token-type".to_string(),
512 };
513 let response_body = serde_json::to_value(&response).unwrap();
514
515 let (endpoint, _server) = start(Handlers::from([(
516 MDS_TOKEN_URI.to_string(),
517 (
518 StatusCode::OK,
519 response_body,
520 TokenQueryParams {
521 scopes: Some(scopes.join(",")),
522 recursive: None,
523 },
524 Arc::new(Mutex::new(0)),
525 ),
526 )]))
527 .await;
528
529 let mdsc = Builder::default()
530 .with_scopes(["scope1", "scope2"])
531 .with_endpoint(endpoint)
532 .with_quota_project_id("test-project")
533 .build();
534
535 let headers: Vec<HV> = HV::from(mdsc.headers().await.unwrap());
536 assert_eq!(
537 headers,
538 vec![
539 HV {
540 header: AUTHORIZATION.to_string(),
541 value: "test-token-type test-access-token".to_string(),
542 is_sensitive: true,
543 },
544 HV {
545 header: QUOTA_PROJECT_KEY.to_string(),
546 value: "test-project".to_string(),
547 is_sensitive: false,
548 }
549 ]
550 );
551 }
552
553 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
554 async fn token_caching() -> TestResult {
555 let scopes = vec!["scope1".to_string()];
556 let response = MDSTokenResponse {
557 access_token: "test-access-token".to_string(),
558 expires_in: Some(3600),
559 token_type: "test-token-type".to_string(),
560 };
561 let response_body = serde_json::to_value(&response).unwrap();
562
563 let call_count = Arc::new(Mutex::new(0));
564 let (endpoint, _server) = start(Handlers::from([(
565 MDS_TOKEN_URI.to_string(),
566 (
567 StatusCode::OK,
568 response_body,
569 TokenQueryParams {
570 scopes: Some(scopes.join(",")),
571 recursive: None,
572 },
573 call_count.clone(),
574 ),
575 )]))
576 .await;
577
578 let mdsc = Builder::default()
579 .with_scopes(scopes)
580 .with_endpoint(endpoint)
581 .build();
582 let token = mdsc.token().await?;
583 assert_eq!(token.token, "test-access-token");
584 let token = mdsc.token().await?;
585 assert_eq!(token.token, "test-access-token");
586
587 assert_eq!(*call_count.lock().unwrap(), 1);
589
590 Ok(())
591 }
592
593 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
594 async fn token_provider_full() -> TestResult {
595 let scopes = vec!["scope1".to_string()];
596 let response = MDSTokenResponse {
597 access_token: "test-access-token".to_string(),
598 expires_in: Some(3600),
599 token_type: "test-token-type".to_string(),
600 };
601 let response_body = serde_json::to_value(&response).unwrap();
602
603 let (endpoint, _server) = start(Handlers::from([(
604 MDS_TOKEN_URI.to_string(),
605 (
606 StatusCode::OK,
607 response_body,
608 TokenQueryParams {
609 scopes: Some(scopes.join(",")),
610 recursive: None,
611 },
612 Arc::new(Mutex::new(0)),
613 ),
614 )]))
615 .await;
616 println!("endpoint = {endpoint}");
617
618 let mdsc = Builder::default()
619 .with_scopes(scopes)
620 .with_endpoint(endpoint)
621 .build();
622 let now = std::time::Instant::now();
623 let token = mdsc.token().await?;
624 assert_eq!(token.token, "test-access-token");
625 assert_eq!(token.token_type, "test-token-type");
626 assert!(
627 token
628 .expires_at
629 .is_some_and(|d| d >= now + Duration::from_secs(3600))
630 );
631
632 Ok(())
633 }
634
635 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
636 async fn token_provider_full_no_scopes() -> TestResult {
637 let service_account_info_path =
638 "/computeMetadata/v1/instance/service-accounts/default/".to_string();
639 let scopes = vec!["scope 1".to_string(), "scope 2".to_string()];
640 let service_account_info = ServiceAccountInfo {
641 email: "test@test.com".to_string(),
642 scopes: Some(scopes.clone()),
643 aliases: None,
644 };
645 let service_account_info_json = serde_json::to_value(service_account_info.clone()).unwrap();
646
647 let response = MDSTokenResponse {
648 access_token: "test-access-token".to_string(),
649 expires_in: Some(3600),
650 token_type: "test-token-type".to_string(),
651 };
652 let response_body = serde_json::to_value(&response).unwrap();
653
654 let (endpoint, _server) = start(Handlers::from([
655 (
656 service_account_info_path,
657 (
658 StatusCode::OK,
659 service_account_info_json,
660 TokenQueryParams {
661 scopes: None,
662 recursive: Some("true".to_string()),
663 },
664 Arc::new(Mutex::new(0)),
665 ),
666 ),
667 (
668 MDS_TOKEN_URI.to_string(),
669 (
670 StatusCode::OK,
671 response_body,
672 TokenQueryParams {
673 scopes: Some(scopes.join(",")),
674 recursive: None,
675 },
676 Arc::new(Mutex::new(0)),
677 ),
678 ),
679 ]))
680 .await;
681 println!("endpoint = {endpoint}");
682
683 let mdsc = Builder::default().with_endpoint(endpoint).build();
684 let now = std::time::Instant::now();
685 let token = mdsc.token().await?;
686 assert_eq!(token.token, "test-access-token");
687 assert_eq!(token.token_type, "test-token-type");
688 assert!(
689 token
690 .expires_at
691 .is_some_and(|d| d >= now + Duration::from_secs(3600))
692 );
693
694 Ok(())
695 }
696
697 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
698 async fn token_provider_partial() -> TestResult {
699 let scopes = vec!["scope1".to_string()];
700 let response = MDSTokenResponse {
701 access_token: "test-access-token".to_string(),
702 expires_in: None,
703 token_type: "test-token-type".to_string(),
704 };
705 let response_body = serde_json::to_value(&response).unwrap();
706 let (endpoint, _server) = start(Handlers::from([(
707 MDS_TOKEN_URI.to_string(),
708 (
709 StatusCode::OK,
710 response_body,
711 TokenQueryParams {
712 scopes: Some(scopes.join(",")),
713 recursive: None,
714 },
715 Arc::new(Mutex::new(0)),
716 ),
717 )]))
718 .await;
719 println!("endpoint = {endpoint}");
720
721 let mdsc = Builder::default()
722 .with_endpoint(endpoint)
723 .with_scopes(scopes)
724 .build();
725 let token = mdsc.token().await?;
726 assert_eq!(token.token, "test-access-token");
727 assert_eq!(token.token_type, "test-token-type");
728 assert_eq!(token.expires_at, None);
729
730 Ok(())
731 }
732
733 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
734 async fn token_provider_retryable_error() -> TestResult {
735 let scopes = vec!["scope1".to_string()];
736 let (endpoint, _server) = start(Handlers::from([(
737 MDS_TOKEN_URI.to_string(),
738 (
739 StatusCode::SERVICE_UNAVAILABLE,
740 serde_json::to_value("try again")?,
741 TokenQueryParams {
742 scopes: Some(scopes.join(",")),
743 recursive: None,
744 },
745 Arc::new(Mutex::new(0)),
746 ),
747 )]))
748 .await;
749
750 let mdsc = Builder::default()
751 .with_endpoint(endpoint)
752 .with_scopes(scopes)
753 .build();
754 let e = mdsc.token().await.err().unwrap();
755 assert!(e.is_retryable());
756 assert!(e.source().unwrap().to_string().contains("try again"));
757
758 Ok(())
759 }
760
761 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
762 async fn token_provider_nonretryable_error() -> TestResult {
763 let scopes = vec!["scope1".to_string()];
764 let (endpoint, _server) = start(Handlers::from([(
765 MDS_TOKEN_URI.to_string(),
766 (
767 StatusCode::UNAUTHORIZED,
768 serde_json::to_value("epic fail".to_string())?,
769 TokenQueryParams {
770 scopes: Some(scopes.join(",")),
771 recursive: None,
772 },
773 Arc::new(Mutex::new(0)),
774 ),
775 )]))
776 .await;
777
778 let mdsc = Builder::default()
779 .with_endpoint(endpoint)
780 .with_scopes(scopes)
781 .build();
782
783 let e = mdsc.token().await.err().unwrap();
784 assert!(!e.is_retryable());
785 assert!(e.source().unwrap().to_string().contains("epic fail"));
786
787 Ok(())
788 }
789
790 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
791 async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
792 let scopes = vec!["scope1".to_string()];
793 let (endpoint, _server) = start(Handlers::from([(
794 MDS_TOKEN_URI.to_string(),
795 (
796 StatusCode::OK,
797 serde_json::to_value("bad json".to_string())?,
798 TokenQueryParams {
799 scopes: Some(scopes.join(",")),
800 recursive: None,
801 },
802 Arc::new(Mutex::new(0)),
803 ),
804 )]))
805 .await;
806
807 let mdsc = Builder::default()
808 .with_endpoint(endpoint)
809 .with_scopes(scopes)
810 .build();
811
812 let e = mdsc.token().await.err().unwrap();
813 assert!(!e.is_retryable());
814
815 Ok(())
816 }
817
818 #[tokio::test]
819 async fn get_default_universe_domain_success() {
820 let universe_domain_response = Builder::default().build().universe_domain().await.unwrap();
821 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
822 }
823
824 #[tokio::test]
825 async fn get_custom_universe_domain_success() {
826 let universe_domain = "test-universe";
827 let universe_domain_response = Builder::default()
828 .with_universe_domain(universe_domain)
829 .build()
830 .universe_domain()
831 .await
832 .unwrap();
833 assert_eq!(universe_domain_response, universe_domain);
834 }
835}