1use crate::credentials::dynamic::CredentialsProvider;
61use crate::credentials::{Credentials, DEFAULT_UNIVERSE_DOMAIN, Result};
62use crate::errors::{self, CredentialsError, is_retryable};
63use crate::headers_util::build_bearer_headers;
64use crate::token::{CachedTokenProvider, Token, TokenProvider};
65use crate::token_cache::TokenCache;
66use async_trait::async_trait;
67use bon::Builder;
68use http::{Extensions, HeaderMap, HeaderValue};
69use reqwest::Client;
70use std::default::Default;
71use std::sync::Arc;
72use std::time::Duration;
73use tokio::time::Instant;
74
75const METADATA_FLAVOR_VALUE: &str = "Google";
76const METADATA_FLAVOR: &str = "metadata-flavor";
77const METADATA_ROOT: &str = "http://metadata.google.internal";
78const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
79const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
80
81#[derive(Debug)]
82struct MDSCredentials<T>
83where
84 T: CachedTokenProvider,
85{
86 quota_project_id: Option<String>,
87 universe_domain: Option<String>,
88 token_provider: T,
89}
90
91#[derive(Debug, Default)]
103pub struct Builder {
104 endpoint: Option<String>,
105 quota_project_id: Option<String>,
106 scopes: Option<Vec<String>>,
107 universe_domain: Option<String>,
108}
109
110impl Builder {
111 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
126 self.endpoint = Some(endpoint.into());
127 self
128 }
129
130 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
139 self.quota_project_id = Some(quota_project_id.into());
140 self
141 }
142
143 pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
150 self.universe_domain = Some(universe_domain.into());
151 self
152 }
153
154 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
163 where
164 I: IntoIterator<Item = S>,
165 S: Into<String>,
166 {
167 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
168 self
169 }
170
171 fn build_token_provider(self) -> MDSAccessTokenProvider {
172 let endpoint = match std::env::var(GCE_METADATA_HOST_ENV_VAR) {
173 Ok(endpoint) => format!("http://{}", endpoint),
174 _ => self.endpoint.clone().unwrap_or(METADATA_ROOT.to_string()),
175 };
176 MDSAccessTokenProvider::builder()
177 .endpoint(endpoint)
178 .maybe_scopes(self.scopes)
179 .build()
180 }
181
182 pub fn build(self) -> Result<Credentials> {
184 let mdsc = MDSCredentials {
185 quota_project_id: self.quota_project_id.clone(),
186 universe_domain: self.universe_domain.clone(),
187 token_provider: TokenCache::new(self.build_token_provider()),
188 };
189 Ok(Credentials {
190 inner: Arc::new(mdsc),
191 })
192 }
193}
194
195#[async_trait::async_trait]
196impl<T> CredentialsProvider for MDSCredentials<T>
197where
198 T: CachedTokenProvider,
199{
200 async fn headers(&self, extensions: Extensions) -> Result<HeaderMap> {
201 let token = self.token_provider.token(extensions).await?;
202 build_bearer_headers(&token, &self.quota_project_id)
203 }
204
205 async fn universe_domain(&self) -> Option<String> {
206 if self.universe_domain.is_some() {
207 return self.universe_domain.clone();
208 }
209 return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
210 }
211}
212
213#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
214struct ServiceAccountInfo {
215 email: String,
216 scopes: Option<Vec<String>>,
217 aliases: Option<Vec<String>>,
218}
219
220#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
221struct MDSTokenResponse {
222 access_token: String,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 expires_in: Option<u64>,
225 token_type: String,
226}
227
228#[derive(Debug, Clone, Default, Builder)]
229struct MDSAccessTokenProvider {
230 #[builder(into)]
231 scopes: Option<Vec<String>>,
232 #[builder(into)]
233 endpoint: String,
234}
235
236#[async_trait]
237impl TokenProvider for MDSAccessTokenProvider {
238 async fn token(&self) -> Result<Token> {
239 let client = Client::new();
240 let request = client
241 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
242 .header(
243 METADATA_FLAVOR,
244 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
245 );
246 let scopes = self.scopes.as_ref().map(|v| v.join(","));
249 let request = scopes
250 .into_iter()
251 .fold(request, |r, s| r.query(&[("scopes", s)]));
252
253 let response = request.send().await.map_err(errors::retryable)?;
254 if !response.status().is_success() {
256 let status = response.status();
257 let body = response
258 .text()
259 .await
260 .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
261 return Err(CredentialsError::from_str(
262 is_retryable(status),
263 format!("Failed to fetch token. {body}"),
264 ));
265 }
266 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
267 let retryable = !e.is_decode();
268 CredentialsError::new(retryable, e)
269 })?;
270 let token = Token {
271 token: response.access_token,
272 token_type: response.token_type,
273 expires_at: response
274 .expires_in
275 .map(|d| Instant::now() + Duration::from_secs(d)),
276 metadata: None,
277 };
278 Ok(token)
279 }
280}
281
282#[cfg(test)]
283mod test {
284 use super::*;
285 use crate::credentials::QUOTA_PROJECT_KEY;
286 use crate::credentials::test::{get_token_from_headers, get_token_type_from_headers};
287 use crate::token::test::MockTokenProvider;
288 use axum::extract::Query;
289 use axum::response::IntoResponse;
290 use http::header::AUTHORIZATION;
291 use reqwest::StatusCode;
292 use reqwest::header::HeaderMap;
293 use scoped_env::ScopedEnv;
294 use serde::Deserialize;
295 use serde_json::Value;
296 use serial_test::{parallel, serial};
297 use std::collections::HashMap;
298 use std::error::Error;
299 use std::sync::Mutex;
300 use tokio::task::JoinHandle;
301 use url::Url;
302
303 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
304
305 #[derive(Debug, Clone, Deserialize, PartialEq)]
307 struct TokenQueryParams {
308 scopes: Option<String>,
309 recursive: Option<String>,
310 }
311
312 #[test]
313 fn validate_default_endpoint_urls() {
314 let default_endpoint_address = Url::parse(&format!("{}{}", METADATA_ROOT, MDS_DEFAULT_URI));
315 assert!(default_endpoint_address.is_ok());
316
317 let token_endpoint_address =
318 Url::parse(&format!("{}{}/token", METADATA_ROOT, MDS_DEFAULT_URI));
319 assert!(token_endpoint_address.is_ok());
320 }
321
322 #[tokio::test]
323 async fn headers_success() {
324 let token = Token {
325 token: "test-token".to_string(),
326 token_type: "Bearer".to_string(),
327 expires_at: None,
328 metadata: None,
329 };
330
331 let mut mock = MockTokenProvider::new();
332 mock.expect_token().times(1).return_once(|| Ok(token));
333
334 let mdsc = MDSCredentials {
335 quota_project_id: None,
336 universe_domain: None,
337 token_provider: TokenCache::new(mock),
338 };
339 let headers = mdsc.headers(Extensions::new()).await.unwrap();
340 let token = headers.get(AUTHORIZATION).unwrap();
341
342 assert_eq!(headers.len(), 1, "{headers:?}");
343 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
344 assert!(token.is_sensitive());
345 }
346
347 #[tokio::test]
348 async fn headers_failure() {
349 let mut mock = MockTokenProvider::new();
350 mock.expect_token()
351 .times(1)
352 .return_once(|| Err(errors::non_retryable_from_str("fail")));
353
354 let mdsc = MDSCredentials {
355 quota_project_id: None,
356 universe_domain: None,
357 token_provider: TokenCache::new(mock),
358 };
359 assert!(mdsc.headers(Extensions::new()).await.is_err());
360 }
361
362 fn handle_token_factory(
363 response_code: StatusCode,
364 response_headers: HeaderMap,
365 response_body: Value,
366 ) -> impl IntoResponse {
367 (response_code, response_headers, response_body.to_string()).into_response()
368 }
369
370 type Handlers = HashMap<String, (StatusCode, Value, TokenQueryParams, Arc<Mutex<i32>>)>;
371
372 async fn start(path_handlers: Handlers) -> (String, JoinHandle<()>) {
375 let mut app = axum::Router::new();
376
377 for (path, (code, body, expected_query, call_count)) in path_handlers {
378 let header_map = HeaderMap::new();
379 let handler = move |Query(query): Query<TokenQueryParams>| {
380 let body = body.clone();
381 let header_map = header_map.clone();
382 async move {
383 assert_eq!(expected_query, query);
384 let mut count = call_count.lock().unwrap();
385 *count += 1;
386 handle_token_factory(code, header_map, body)
387 }
388 };
389 app = app.route(&path, axum::routing::get(handler));
390 }
391
392 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
393 let addr = listener.local_addr().unwrap();
394 let server = tokio::spawn(async move {
395 axum::serve(listener, app).await.unwrap();
396 });
397 (format!("http://{}:{}", addr.ip(), addr.port()), server)
398 }
399
400 #[tokio::test]
401 #[serial]
402 async fn test_gce_metadata_host_env_var() {
403 let scopes = ["scope1".to_string(), "scope2".to_string()];
404 let response = MDSTokenResponse {
405 access_token: "test-access-token".to_string(),
406 expires_in: Some(3600),
407 token_type: "test-token-type".to_string(),
408 };
409 let response_body = serde_json::to_value(&response).unwrap();
410
411 let (endpoint, _server) = start(Handlers::from([(
412 format!("{}/token", MDS_DEFAULT_URI),
413 (
414 StatusCode::OK,
415 response_body,
416 TokenQueryParams {
417 scopes: Some(scopes.join(",")),
418 recursive: None,
419 },
420 Arc::new(Mutex::new(0)),
421 ),
422 )]))
423 .await;
424
425 let _e = ScopedEnv::set(
427 super::GCE_METADATA_HOST_ENV_VAR,
428 endpoint.strip_prefix("http://").unwrap_or(&endpoint),
429 );
430 let mdsc = Builder::default()
431 .with_scopes(["scope1", "scope2"])
432 .build()
433 .unwrap();
434 let headers = mdsc.headers(Extensions::new()).await.unwrap();
435 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
436
437 assert_eq!(
438 get_token_from_headers(&headers).unwrap(),
439 "test-access-token"
440 );
441 }
442
443 #[tokio::test]
444 #[parallel]
445 async fn headers_success_with_quota_project() -> TestResult {
446 let scopes = ["scope1".to_string(), "scope2".to_string()];
447 let response = MDSTokenResponse {
448 access_token: "test-access-token".to_string(),
449 expires_in: Some(3600),
450 token_type: "test-token-type".to_string(),
451 };
452 let response_body = serde_json::to_value(&response).unwrap();
453
454 let (endpoint, _server) = start(Handlers::from([(
455 format!("{}/token", MDS_DEFAULT_URI),
456 (
457 StatusCode::OK,
458 response_body,
459 TokenQueryParams {
460 scopes: Some(scopes.join(",")),
461 recursive: None,
462 },
463 Arc::new(Mutex::new(0)),
464 ),
465 )]))
466 .await;
467
468 let mdsc = Builder::default()
469 .with_scopes(["scope1", "scope2"])
470 .with_endpoint(endpoint)
471 .with_quota_project_id("test-project")
472 .build()?;
473
474 let headers = mdsc.headers(Extensions::new()).await.unwrap();
475 let token = headers.get(AUTHORIZATION).unwrap();
476 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
477
478 assert_eq!(headers.len(), 2, "{headers:?}");
479 assert_eq!(
480 token,
481 HeaderValue::from_static("test-token-type test-access-token")
482 );
483 assert!(token.is_sensitive());
484 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
485 assert!(!quota_project.is_sensitive());
486 Ok(())
487 }
488
489 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
490 #[parallel]
491 async fn token_caching() -> TestResult {
492 let scopes = vec!["scope1".to_string()];
493 let response = MDSTokenResponse {
494 access_token: "test-access-token".to_string(),
495 expires_in: Some(3600),
496 token_type: "test-token-type".to_string(),
497 };
498 let response_body = serde_json::to_value(&response).unwrap();
499
500 let call_count = Arc::new(Mutex::new(0));
501 let (endpoint, _server) = start(Handlers::from([(
502 format!("{}/token", MDS_DEFAULT_URI),
503 (
504 StatusCode::OK,
505 response_body,
506 TokenQueryParams {
507 scopes: Some(scopes.join(",")),
508 recursive: None,
509 },
510 call_count.clone(),
511 ),
512 )]))
513 .await;
514
515 let mdsc = Builder::default()
516 .with_scopes(scopes)
517 .with_endpoint(endpoint)
518 .build()?;
519 let headers = mdsc.headers(Extensions::new()).await?;
520 assert_eq!(
521 get_token_from_headers(&headers).unwrap(),
522 "test-access-token"
523 );
524 let headers = mdsc.headers(Extensions::new()).await?;
525 assert_eq!(
526 get_token_from_headers(&headers).unwrap(),
527 "test-access-token"
528 );
529
530 assert_eq!(*call_count.lock().unwrap(), 1);
532
533 Ok(())
534 }
535
536 #[tokio::test(start_paused = true)]
537 #[parallel]
538 async fn token_provider_full() -> TestResult {
539 let scopes = vec!["scope1".to_string()];
540 let response = MDSTokenResponse {
541 access_token: "test-access-token".to_string(),
542 expires_in: Some(3600),
543 token_type: "test-token-type".to_string(),
544 };
545 let response_body = serde_json::to_value(&response).unwrap();
546
547 let (endpoint, _server) = start(Handlers::from([(
548 format!("{}/token", MDS_DEFAULT_URI),
549 (
550 StatusCode::OK,
551 response_body,
552 TokenQueryParams {
553 scopes: Some(scopes.join(",")),
554 recursive: None,
555 },
556 Arc::new(Mutex::new(0)),
557 ),
558 )]))
559 .await;
560 println!("endpoint = {endpoint}");
561
562 let token = Builder::default()
563 .with_endpoint(endpoint)
564 .with_scopes(scopes)
565 .build_token_provider()
566 .token()
567 .await?;
568
569 let now = tokio::time::Instant::now();
570 assert_eq!(token.token, "test-access-token");
571 assert_eq!(token.token_type, "test-token-type");
572 assert!(
573 token
574 .expires_at
575 .is_some_and(|d| d >= now + Duration::from_secs(3600))
576 );
577
578 Ok(())
579 }
580
581 #[tokio::test(start_paused = true)]
582 #[parallel]
583 async fn token_provider_full_no_scopes() -> TestResult {
584 let response = MDSTokenResponse {
585 access_token: "test-access-token".to_string(),
586 expires_in: Some(3600),
587 token_type: "test-token-type".to_string(),
588 };
589 let response_body = serde_json::to_value(&response).unwrap();
590
591 let (endpoint, _server) = start(Handlers::from([(
592 format!("{}/token", MDS_DEFAULT_URI),
593 (
594 StatusCode::OK,
595 response_body,
596 TokenQueryParams {
597 scopes: None,
598 recursive: None,
599 },
600 Arc::new(Mutex::new(0)),
601 ),
602 )]))
603 .await;
604 println!("endpoint = {endpoint}");
605 let token = Builder::default()
606 .with_endpoint(endpoint)
607 .build_token_provider()
608 .token()
609 .await?;
610
611 let now = Instant::now();
612 assert_eq!(token.token, "test-access-token");
613 assert_eq!(token.token_type, "test-token-type");
614 assert!(
615 token
616 .expires_at
617 .is_some_and(|d| d == now + Duration::from_secs(3600))
618 );
619
620 Ok(())
621 }
622
623 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
624 #[parallel]
625 async fn credential_provider_full() -> TestResult {
626 let scopes = vec!["scope1".to_string()];
627 let response = MDSTokenResponse {
628 access_token: "test-access-token".to_string(),
629 expires_in: None,
630 token_type: "test-token-type".to_string(),
631 };
632 let response_body = serde_json::to_value(&response).unwrap();
633 let (endpoint, _server) = start(Handlers::from([(
634 format!("{}/token", MDS_DEFAULT_URI),
635 (
636 StatusCode::OK,
637 response_body,
638 TokenQueryParams {
639 scopes: Some(scopes.join(",")),
640 recursive: None,
641 },
642 Arc::new(Mutex::new(0)),
643 ),
644 )]))
645 .await;
646 println!("endpoint = {endpoint}");
647
648 let mdsc = Builder::default()
649 .with_endpoint(endpoint)
650 .with_scopes(scopes)
651 .build()?;
652 let headers = mdsc.headers(Extensions::new()).await?;
653 assert_eq!(
654 get_token_from_headers(&headers).unwrap(),
655 "test-access-token"
656 );
657 assert_eq!(
658 get_token_type_from_headers(&headers).unwrap(),
659 "test-token-type"
660 );
661
662 Ok(())
663 }
664
665 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
666 #[parallel]
667 async fn credentials_headers_retryable_error() -> TestResult {
668 let scopes = vec!["scope1".to_string()];
669 let (endpoint, _server) = start(Handlers::from([(
670 format!("{}/token", MDS_DEFAULT_URI),
671 (
672 StatusCode::SERVICE_UNAVAILABLE,
673 serde_json::to_value("try again")?,
674 TokenQueryParams {
675 scopes: Some(scopes.join(",")),
676 recursive: None,
677 },
678 Arc::new(Mutex::new(0)),
679 ),
680 )]))
681 .await;
682
683 let mdsc = Builder::default()
684 .with_endpoint(endpoint)
685 .with_scopes(scopes)
686 .build()?;
687 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
688 assert!(e.is_retryable());
689 assert!(e.source().unwrap().to_string().contains("try again"));
690
691 Ok(())
692 }
693
694 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
695 #[parallel]
696 async fn credentials_headers_nonretryable_error() -> TestResult {
697 let scopes = vec!["scope1".to_string()];
698 let (endpoint, _server) = start(Handlers::from([(
699 format!("{}/token", MDS_DEFAULT_URI),
700 (
701 StatusCode::UNAUTHORIZED,
702 serde_json::to_value("epic fail".to_string())?,
703 TokenQueryParams {
704 scopes: Some(scopes.join(",")),
705 recursive: None,
706 },
707 Arc::new(Mutex::new(0)),
708 ),
709 )]))
710 .await;
711
712 let mdsc = Builder::default()
713 .with_endpoint(endpoint)
714 .with_scopes(scopes)
715 .build()?;
716
717 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
718 assert!(!e.is_retryable());
719 assert!(e.source().unwrap().to_string().contains("epic fail"));
720
721 Ok(())
722 }
723
724 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
725 #[parallel]
726 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
727 let scopes = vec!["scope1".to_string()];
728 let (endpoint, _server) = start(Handlers::from([(
729 format!("{}/token", MDS_DEFAULT_URI),
730 (
731 StatusCode::OK,
732 serde_json::to_value("bad json".to_string())?,
733 TokenQueryParams {
734 scopes: Some(scopes.join(",")),
735 recursive: None,
736 },
737 Arc::new(Mutex::new(0)),
738 ),
739 )]))
740 .await;
741
742 let mdsc = Builder::default()
743 .with_endpoint(endpoint)
744 .with_scopes(scopes)
745 .build()?;
746
747 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
748 assert!(!e.is_retryable());
749
750 Ok(())
751 }
752
753 #[tokio::test]
754 async fn get_default_universe_domain_success() -> TestResult {
755 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
756 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
757 Ok(())
758 }
759
760 #[tokio::test]
761 async fn get_custom_universe_domain_success() -> TestResult {
762 let universe_domain = "test-universe";
763 let universe_domain_response = Builder::default()
764 .with_universe_domain(universe_domain)
765 .build()?
766 .universe_domain()
767 .await
768 .unwrap();
769 assert_eq!(universe_domain_response, universe_domain);
770
771 Ok(())
772 }
773}