google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials, DEFAULT_UNIVERSE_DOMAIN};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102// TODO(#2235) - Improve this message by talking about retries when really running with MDS
103const MDS_NOT_FOUND_ERROR: &str = concat!(
104    "Could not fetch an auth token to authenticate with Google Cloud. ",
105    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106    "and you have not configured local credentials for development and testing. ",
107    "To setup local credentials, run `gcloud auth application-default login`. ",
108    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114    T: CachedTokenProvider,
115{
116    quota_project_id: Option<String>,
117    universe_domain: Option<String>,
118    token_provider: T,
119}
120
121/// Creates [Credentials] instances backed by the [Metadata Service].
122///
123/// While the Google Cloud client libraries for Rust default to credentials
124/// backed by the metadata service, some applications may need to:
125/// * Customize the metadata service credentials in some way
126/// * Bypass the [Application Default Credentials] lookup and only
127///   use the metadata server credentials
128/// * Use the credentials directly outside the client libraries
129///
130/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
131/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
132#[derive(Debug, Default)]
133pub struct Builder {
134    endpoint: Option<String>,
135    quota_project_id: Option<String>,
136    scopes: Option<Vec<String>>,
137    universe_domain: Option<String>,
138    created_by_adc: bool,
139    retry_builder: RetryTokenProviderBuilder,
140}
141
142impl Builder {
143    /// Sets the endpoint for this credentials.
144    ///
145    /// A trailing slash is significant, so specify the base URL without a trailing  
146    /// slash. If not set, the credentials use `http://metadata.google.internal`.
147    ///
148    /// # Example
149    /// ```
150    /// # use google_cloud_auth::credentials::mds::Builder;
151    /// # tokio_test::block_on(async {
152    /// let credentials = Builder::default()
153    ///     .with_endpoint("https://metadata.google.foobar")
154    ///     .build();
155    /// # });
156    /// ```
157    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
158        self.endpoint = Some(endpoint.into());
159        self
160    }
161
162    /// Set the [quota project] for this credentials.
163    ///
164    /// In some services, you can use a service account in
165    /// one project for authentication and authorization, and charge
166    /// the usage to a different project. This may require that the
167    /// service account has `serviceusage.services.use` permissions on the quota project.
168    ///
169    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
170    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
171        self.quota_project_id = Some(quota_project_id.into());
172        self
173    }
174
175    /// Sets the universe domain for this credentials.
176    ///
177    /// Client libraries use `universe_domain` to determine
178    /// the API endpoints to use for making requests.
179    /// If not set, then credentials use `${service}.googleapis.com`,
180    /// otherwise they use `${service}.${universe_domain}.
181    pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
182        self.universe_domain = Some(universe_domain.into());
183        self
184    }
185
186    /// Sets the [scopes] for this credentials.
187    ///
188    /// Metadata server issues tokens based on the requested scopes.
189    /// If no scopes are specified, the credentials defaults to all
190    /// scopes configured for the [default service account] on the instance.
191    ///
192    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
193    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
194    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
195    where
196        I: IntoIterator<Item = S>,
197        S: Into<String>,
198    {
199        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
200        self
201    }
202
203    /// Configure the retry policy for fetching tokens.
204    ///
205    /// The retry policy controls how to handle retries, and sets limits on
206    /// the number of attempts or the total time spent retrying.
207    ///
208    /// ```
209    /// # use google_cloud_auth::credentials::mds::Builder;
210    /// # tokio_test::block_on(async {
211    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
212    /// let credentials = Builder::default()
213    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
214    ///     .build();
215    /// # });
216    /// ```
217    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
218        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
219        self
220    }
221
222    /// Configure the retry backoff policy.
223    ///
224    /// The backoff policy controls how long to wait in between retry attempts.
225    ///
226    /// ```
227    /// # use google_cloud_auth::credentials::mds::Builder;
228    /// # use std::time::Duration;
229    /// # tokio_test::block_on(async {
230    /// use gax::exponential_backoff::ExponentialBackoff;
231    /// let policy = ExponentialBackoff::default();
232    /// let credentials = Builder::default()
233    ///     .with_backoff_policy(policy)
234    ///     .build();
235    /// # });
236    /// ```
237    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
238        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
239        self
240    }
241
242    /// Configure the retry throttler.
243    ///
244    /// Advanced applications may want to configure a retry throttler to
245    /// [Address Cascading Failures] and when [Handling Overload] conditions.
246    /// The authentication library throttles its retry loop, using a policy to
247    /// control the throttling algorithm. Use this method to fine tune or
248    /// customize the default retry throttler.
249    ///
250    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
251    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
252    ///
253    /// ```
254    /// # use google_cloud_auth::credentials::mds::Builder;
255    /// # tokio_test::block_on(async {
256    /// use gax::retry_throttler::AdaptiveThrottler;
257    /// let credentials = Builder::default()
258    ///     .with_retry_throttler(AdaptiveThrottler::default())
259    ///     .build();
260    /// # });
261    /// ```
262    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
263        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
264        self
265    }
266
267    // This method is used to build mds credentials from ADC
268    pub(crate) fn from_adc() -> Self {
269        Self {
270            created_by_adc: true,
271            ..Default::default()
272        }
273    }
274
275    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
276        let final_endpoint: String;
277        let endpoint_overridden: bool;
278
279        // Determine the endpoint and whether it was overridden
280        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
281            // Check GCE_METADATA_HOST environment variable first
282            final_endpoint = format!("http://{host_from_env}");
283            endpoint_overridden = true;
284        } else if let Some(builder_endpoint) = self.endpoint {
285            // Else, check if an endpoint was provided to the mds::Builder
286            final_endpoint = builder_endpoint;
287            endpoint_overridden = true;
288        } else {
289            // Else, use the default metadata root
290            final_endpoint = METADATA_ROOT.to_string();
291            endpoint_overridden = false;
292        };
293
294        let tp = MDSAccessTokenProvider::builder()
295            .endpoint(final_endpoint)
296            .maybe_scopes(self.scopes)
297            .endpoint_overridden(endpoint_overridden)
298            .created_by_adc(self.created_by_adc)
299            .build();
300        self.retry_builder.build(tp)
301    }
302
303    /// Returns a [Credentials] instance with the configured settings.
304    pub fn build(self) -> BuildResult<Credentials> {
305        let mdsc = MDSCredentials {
306            quota_project_id: self.quota_project_id.clone(),
307            universe_domain: self.universe_domain.clone(),
308            token_provider: TokenCache::new(self.build_token_provider()),
309        };
310        Ok(Credentials {
311            inner: Arc::new(mdsc),
312        })
313    }
314}
315
316#[async_trait::async_trait]
317impl<T> CredentialsProvider for MDSCredentials<T>
318where
319    T: CachedTokenProvider,
320{
321    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
322        let cached_token = self.token_provider.token(extensions).await?;
323        build_cacheable_headers(&cached_token, &self.quota_project_id)
324    }
325
326    async fn universe_domain(&self) -> Option<String> {
327        if self.universe_domain.is_some() {
328            return self.universe_domain.clone();
329        }
330        return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
331    }
332}
333
334#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
335struct MDSTokenResponse {
336    access_token: String,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    expires_in: Option<u64>,
339    token_type: String,
340}
341
342#[derive(Debug, Clone, Default, Builder)]
343struct MDSAccessTokenProvider {
344    #[builder(into)]
345    scopes: Option<Vec<String>>,
346    #[builder(into)]
347    endpoint: String,
348    endpoint_overridden: bool,
349    created_by_adc: bool,
350}
351
352impl MDSAccessTokenProvider {
353    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
354    // environment variable is not set, we default to MDS credentials without checking if the code is really
355    // running in an environment with MDS. To help users who got to this state because of lack of credentials
356    // setup on their machines, we provide a detailed error message to them talking about local setup and other
357    // auth mechanisms available to them.
358    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
359    // error message because they deliberately wanted to use an MDS.
360    fn error_message(&self) -> &str {
361        if self.use_adc_message() {
362            MDS_NOT_FOUND_ERROR
363        } else {
364            "failed to fetch token"
365        }
366    }
367
368    fn use_adc_message(&self) -> bool {
369        self.created_by_adc && !self.endpoint_overridden
370    }
371}
372
373#[async_trait]
374impl TokenProvider for MDSAccessTokenProvider {
375    async fn token(&self) -> Result<Token> {
376        let client = Client::new();
377        let request = client
378            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
379            .header(
380                METADATA_FLAVOR,
381                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
382            );
383        // Use the `scopes` option if set, otherwise let the MDS use the default
384        // scopes.
385        let scopes = self.scopes.as_ref().map(|v| v.join(","));
386        let request = scopes
387            .into_iter()
388            .fold(request, |r, s| r.query(&[("scopes", s)]));
389
390        // If the connection to MDS was not successful, it is useful to retry when really
391        // running on MDS environments and not useful if there is no MDS. We will mark the error
392        // as retryable and let the retry policy determine whether to retry or not. Whenever we
393        // define a default retry policy, we can skip retrying this case.
394        let response = request
395            .send()
396            .await
397            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
398        // Process the response
399        if !response.status().is_success() {
400            let err = crate::errors::from_http_response(response, self.error_message()).await;
401            return Err(err);
402        }
403        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
404            // Decoding errors are not transient. Typically they indicate a badly
405            // configured MDS endpoint, or DNS redirecting the request to a random
406            // server, e.g., ISPs that redirect unknown services to HTTP.
407            CredentialsError::from_source(!e.is_decode(), e)
408        })?;
409        let token = Token {
410            token: response.access_token,
411            token_type: response.token_type,
412            expires_at: response
413                .expires_in
414                .map(|d| Instant::now() + Duration::from_secs(d)),
415            metadata: None,
416        };
417        Ok(token)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::credentials::QUOTA_PROJECT_KEY;
425    use crate::credentials::tests::{
426        get_headers_from_cache, get_mock_auth_retry_policy, get_mock_backoff_policy,
427        get_mock_retry_throttler, get_token_from_headers, get_token_type_from_headers,
428    };
429    use crate::errors;
430    use crate::token::tests::MockTokenProvider;
431    use http::HeaderValue;
432    use http::header::AUTHORIZATION;
433    use httptest::cycle;
434    use httptest::matchers::{all_of, contains, request, url_decoded};
435    use httptest::responders::{json_encoded, status_code};
436    use httptest::{Expectation, Server};
437    use reqwest::StatusCode;
438    use scoped_env::ScopedEnv;
439    use serial_test::{parallel, serial};
440    use std::error::Error;
441    use test_case::test_case;
442    use url::Url;
443
444    type TestResult = anyhow::Result<()>;
445
446    #[tokio::test]
447    #[parallel]
448    async fn test_mds_retries_on_transient_failures() -> TestResult {
449        let mut server = Server::run();
450        server.expect(
451            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
452                .times(3)
453                .respond_with(status_code(503)),
454        );
455
456        let provider = Builder::default()
457            .with_endpoint(format!("http://{}", server.addr()))
458            .with_retry_policy(get_mock_auth_retry_policy(3))
459            .with_backoff_policy(get_mock_backoff_policy())
460            .with_retry_throttler(get_mock_retry_throttler())
461            .build_token_provider();
462
463        let err = provider.token().await.unwrap_err();
464        assert!(err.is_transient());
465        server.verify_and_clear();
466        Ok(())
467    }
468
469    #[tokio::test]
470    #[parallel]
471    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
472        let mut server = Server::run();
473        server.expect(
474            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
475                .times(1)
476                .respond_with(status_code(401)),
477        );
478
479        let provider = Builder::default()
480            .with_endpoint(format!("http://{}", server.addr()))
481            .with_retry_policy(get_mock_auth_retry_policy(1))
482            .with_backoff_policy(get_mock_backoff_policy())
483            .with_retry_throttler(get_mock_retry_throttler())
484            .build_token_provider();
485
486        let err = provider.token().await.unwrap_err();
487        assert!(!err.is_transient());
488        server.verify_and_clear();
489        Ok(())
490    }
491
492    #[tokio::test]
493    #[parallel]
494    async fn test_mds_retries_for_success() -> TestResult {
495        let mut server = Server::run();
496        let response = MDSTokenResponse {
497            access_token: "test-access-token".to_string(),
498            expires_in: Some(3600),
499            token_type: "test-token-type".to_string(),
500        };
501
502        server.expect(
503            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
504                .times(3)
505                .respond_with(cycle![
506                    status_code(503).body("try-again"),
507                    status_code(503).body("try-again"),
508                    status_code(200)
509                        .append_header("Content-Type", "application/json")
510                        .body(serde_json::to_string(&response).unwrap()),
511                ]),
512        );
513
514        let provider = Builder::default()
515            .with_endpoint(format!("http://{}", server.addr()))
516            .with_retry_policy(get_mock_auth_retry_policy(3))
517            .with_backoff_policy(get_mock_backoff_policy())
518            .with_retry_throttler(get_mock_retry_throttler())
519            .build_token_provider();
520
521        let token = provider.token().await?;
522        assert_eq!(token.token, "test-access-token");
523
524        server.verify_and_clear();
525        Ok(())
526    }
527
528    #[test]
529    fn validate_default_endpoint_urls() {
530        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
531        assert!(default_endpoint_address.is_ok());
532
533        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
534        assert!(token_endpoint_address.is_ok());
535    }
536
537    #[tokio::test]
538    async fn headers_success() -> TestResult {
539        let token = Token {
540            token: "test-token".to_string(),
541            token_type: "Bearer".to_string(),
542            expires_at: None,
543            metadata: None,
544        };
545
546        let mut mock = MockTokenProvider::new();
547        mock.expect_token().times(1).return_once(|| Ok(token));
548
549        let mdsc = MDSCredentials {
550            quota_project_id: None,
551            universe_domain: None,
552            token_provider: TokenCache::new(mock),
553        };
554
555        let mut extensions = Extensions::new();
556        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
557        let (headers, entity_tag) = match cached_headers {
558            CacheableResource::New { entity_tag, data } => (data, entity_tag),
559            CacheableResource::NotModified => unreachable!("expecting new headers"),
560        };
561        let token = headers.get(AUTHORIZATION).unwrap();
562        assert_eq!(headers.len(), 1, "{headers:?}");
563        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
564        assert!(token.is_sensitive());
565
566        extensions.insert(entity_tag);
567
568        let cached_headers = mdsc.headers(extensions).await?;
569
570        match cached_headers {
571            CacheableResource::New { .. } => unreachable!("expecting new headers"),
572            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
573        };
574        Ok(())
575    }
576
577    #[tokio::test]
578    async fn headers_failure() {
579        let mut mock = MockTokenProvider::new();
580        mock.expect_token()
581            .times(1)
582            .return_once(|| Err(errors::non_retryable_from_str("fail")));
583
584        let mdsc = MDSCredentials {
585            quota_project_id: None,
586            universe_domain: None,
587            token_provider: TokenCache::new(mock),
588        };
589        assert!(mdsc.headers(Extensions::new()).await.is_err());
590    }
591
592    #[test]
593    fn error_message_with_adc() {
594        let provider = MDSAccessTokenProvider::builder()
595            .endpoint("http://127.0.0.1")
596            .created_by_adc(true)
597            .endpoint_overridden(false)
598            .build();
599
600        let want = MDS_NOT_FOUND_ERROR;
601        let got = provider.error_message();
602        assert!(got.contains(want), "{got}, {provider:?}");
603    }
604
605    #[test_case(false, false)]
606    #[test_case(false, true)]
607    #[test_case(true, true)]
608    fn error_message_without_adc(adc: bool, overridden: bool) {
609        let provider = MDSAccessTokenProvider::builder()
610            .endpoint("http://127.0.0.1")
611            .created_by_adc(adc)
612            .endpoint_overridden(overridden)
613            .build();
614
615        let not_want = MDS_NOT_FOUND_ERROR;
616        let got = provider.error_message();
617        assert!(!got.contains(not_want), "{got}, {provider:?}");
618    }
619
620    #[tokio::test]
621    #[serial]
622    async fn adc_no_mds() -> TestResult {
623        let err = Builder::from_adc()
624            .build_token_provider()
625            .token()
626            .await
627            .unwrap_err();
628
629        assert!(err.is_transient(), "{err:?}");
630        assert!(
631            err.to_string().contains("application-default"),
632            "display={err}, debug={err:?}"
633        );
634        let source = err
635            .source()
636            .and_then(|e| e.downcast_ref::<reqwest::Error>());
637        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
638
639        Ok(())
640    }
641
642    #[tokio::test]
643    #[serial]
644    async fn adc_overridden_mds() -> TestResult {
645        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
646
647        let err = Builder::from_adc()
648            .build_token_provider()
649            .token()
650            .await
651            .unwrap_err();
652
653        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
654
655        assert!(err.is_transient(), "{err:?}");
656        assert!(
657            !err.to_string().contains("application-default"),
658            "display={err}, debug={err:?}"
659        );
660        let source = err
661            .source()
662            .and_then(|e| e.downcast_ref::<reqwest::Error>());
663        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
664
665        Ok(())
666    }
667
668    #[tokio::test]
669    #[serial]
670    async fn builder_no_mds() -> TestResult {
671        let e = Builder::default()
672            .build_token_provider()
673            .token()
674            .await
675            .err()
676            .unwrap();
677
678        assert!(e.is_transient(), "{e:?}");
679        assert!(
680            !format!("{:?}", e.source()).contains("application-default"),
681            "{e:?}"
682        );
683
684        Ok(())
685    }
686
687    #[tokio::test]
688    #[serial]
689    async fn test_gce_metadata_host_env_var() -> TestResult {
690        let server = Server::run();
691        let scopes = ["scope1", "scope2"];
692        let response = MDSTokenResponse {
693            access_token: "test-access-token".to_string(),
694            expires_in: Some(3600),
695            token_type: "test-token-type".to_string(),
696        };
697        server.expect(
698            Expectation::matching(all_of![
699                request::path(format!("{MDS_DEFAULT_URI}/token")),
700                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
701            ])
702            .respond_with(json_encoded(response)),
703        );
704
705        let addr = server.addr().to_string();
706        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
707        let mdsc = Builder::default()
708            .with_scopes(["scope1", "scope2"])
709            .build()
710            .unwrap();
711        let headers = mdsc.headers(Extensions::new()).await.unwrap();
712        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
713
714        assert_eq!(
715            get_token_from_headers(headers).unwrap(),
716            "test-access-token"
717        );
718        Ok(())
719    }
720
721    #[tokio::test]
722    #[parallel]
723    async fn headers_success_with_quota_project() -> TestResult {
724        let server = Server::run();
725        let scopes = ["scope1", "scope2"];
726        let response = MDSTokenResponse {
727            access_token: "test-access-token".to_string(),
728            expires_in: Some(3600),
729            token_type: "test-token-type".to_string(),
730        };
731        server.expect(
732            Expectation::matching(all_of![
733                request::path(format!("{MDS_DEFAULT_URI}/token")),
734                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
735            ])
736            .respond_with(json_encoded(response)),
737        );
738
739        let mdsc = Builder::default()
740            .with_scopes(["scope1", "scope2"])
741            .with_endpoint(format!("http://{}", server.addr()))
742            .with_quota_project_id("test-project")
743            .build()?;
744
745        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
746        let token = headers.get(AUTHORIZATION).unwrap();
747        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
748
749        assert_eq!(headers.len(), 2, "{headers:?}");
750        assert_eq!(
751            token,
752            HeaderValue::from_static("test-token-type test-access-token")
753        );
754        assert!(token.is_sensitive());
755        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
756        assert!(!quota_project.is_sensitive());
757
758        Ok(())
759    }
760
761    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
762    #[parallel]
763    async fn token_caching() -> TestResult {
764        let mut server = Server::run();
765        let scopes = vec!["scope1".to_string()];
766        let response = MDSTokenResponse {
767            access_token: "test-access-token".to_string(),
768            expires_in: Some(3600),
769            token_type: "test-token-type".to_string(),
770        };
771        server.expect(
772            Expectation::matching(all_of![
773                request::path(format!("{MDS_DEFAULT_URI}/token")),
774                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
775            ])
776            .times(1)
777            .respond_with(json_encoded(response)),
778        );
779
780        let mdsc = Builder::default()
781            .with_scopes(scopes)
782            .with_endpoint(format!("http://{}", server.addr()))
783            .build()?;
784        let headers = mdsc.headers(Extensions::new()).await?;
785        assert_eq!(
786            get_token_from_headers(headers).unwrap(),
787            "test-access-token"
788        );
789        let headers = mdsc.headers(Extensions::new()).await?;
790        assert_eq!(
791            get_token_from_headers(headers).unwrap(),
792            "test-access-token"
793        );
794
795        // validate that the inner token provider is called only once
796        server.verify_and_clear();
797
798        Ok(())
799    }
800
801    #[tokio::test(start_paused = true)]
802    #[parallel]
803    async fn token_provider_full() -> TestResult {
804        let server = Server::run();
805        let scopes = vec!["scope1".to_string()];
806        let response = MDSTokenResponse {
807            access_token: "test-access-token".to_string(),
808            expires_in: Some(3600),
809            token_type: "test-token-type".to_string(),
810        };
811        server.expect(
812            Expectation::matching(all_of![
813                request::path(format!("{MDS_DEFAULT_URI}/token")),
814                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
815            ])
816            .respond_with(json_encoded(response)),
817        );
818
819        let token = Builder::default()
820            .with_endpoint(format!("http://{}", server.addr()))
821            .with_scopes(scopes)
822            .build_token_provider()
823            .token()
824            .await?;
825
826        let now = tokio::time::Instant::now();
827        assert_eq!(token.token, "test-access-token");
828        assert_eq!(token.token_type, "test-token-type");
829        assert!(
830            token
831                .expires_at
832                .is_some_and(|d| d >= now + Duration::from_secs(3600))
833        );
834
835        Ok(())
836    }
837
838    #[tokio::test(start_paused = true)]
839    #[parallel]
840    async fn token_provider_full_no_scopes() -> TestResult {
841        let server = Server::run();
842        let response = MDSTokenResponse {
843            access_token: "test-access-token".to_string(),
844            expires_in: Some(3600),
845            token_type: "test-token-type".to_string(),
846        };
847        server.expect(
848            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
849                .respond_with(json_encoded(response)),
850        );
851
852        let token = Builder::default()
853            .with_endpoint(format!("http://{}", server.addr()))
854            .build_token_provider()
855            .token()
856            .await?;
857
858        let now = Instant::now();
859        assert_eq!(token.token, "test-access-token");
860        assert_eq!(token.token_type, "test-token-type");
861        assert!(
862            token
863                .expires_at
864                .is_some_and(|d| d == now + Duration::from_secs(3600))
865        );
866
867        Ok(())
868    }
869
870    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
871    #[parallel]
872    async fn credential_provider_full() -> TestResult {
873        let server = Server::run();
874        let scopes = vec!["scope1".to_string()];
875        let response = MDSTokenResponse {
876            access_token: "test-access-token".to_string(),
877            expires_in: None,
878            token_type: "test-token-type".to_string(),
879        };
880        server.expect(
881            Expectation::matching(all_of![
882                request::path(format!("{MDS_DEFAULT_URI}/token")),
883                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
884            ])
885            .respond_with(json_encoded(response)),
886        );
887
888        let mdsc = Builder::default()
889            .with_endpoint(format!("http://{}", server.addr()))
890            .with_scopes(scopes)
891            .build()?;
892        let headers = mdsc.headers(Extensions::new()).await?;
893        assert_eq!(
894            get_token_from_headers(headers.clone()).unwrap(),
895            "test-access-token"
896        );
897        assert_eq!(
898            get_token_type_from_headers(headers).unwrap(),
899            "test-token-type"
900        );
901
902        Ok(())
903    }
904
905    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
906    #[parallel]
907    async fn credentials_headers_retryable_error() -> TestResult {
908        let server = Server::run();
909        let scopes = vec!["scope1".to_string()];
910        server.expect(
911            Expectation::matching(all_of![
912                request::path(format!("{MDS_DEFAULT_URI}/token")),
913                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
914            ])
915            .respond_with(status_code(503)),
916        );
917
918        let mdsc = Builder::default()
919            .with_endpoint(format!("http://{}", server.addr()))
920            .with_scopes(scopes)
921            .build()?;
922        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
923        assert!(err.is_transient());
924        let source = err
925            .source()
926            .and_then(|e| e.downcast_ref::<reqwest::Error>());
927        assert!(
928            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
929            "{err:?}"
930        );
931
932        Ok(())
933    }
934
935    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
936    #[parallel]
937    async fn credentials_headers_nonretryable_error() -> TestResult {
938        let server = Server::run();
939        let scopes = vec!["scope1".to_string()];
940        server.expect(
941            Expectation::matching(all_of![
942                request::path(format!("{MDS_DEFAULT_URI}/token")),
943                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
944            ])
945            .respond_with(status_code(401)),
946        );
947
948        let mdsc = Builder::default()
949            .with_endpoint(format!("http://{}", server.addr()))
950            .with_scopes(scopes)
951            .build()?;
952
953        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
954        assert!(!err.is_transient());
955        let source = err
956            .source()
957            .and_then(|e| e.downcast_ref::<reqwest::Error>());
958        assert!(
959            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
960            "{err:?}"
961        );
962
963        Ok(())
964    }
965
966    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
967    #[parallel]
968    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
969        let server = Server::run();
970        let scopes = vec!["scope1".to_string()];
971        server.expect(
972            Expectation::matching(all_of![
973                request::path(format!("{MDS_DEFAULT_URI}/token")),
974                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
975            ])
976            .respond_with(json_encoded("bad json")),
977        );
978
979        let mdsc = Builder::default()
980            .with_endpoint(format!("http://{}", server.addr()))
981            .with_scopes(scopes)
982            .build()?;
983
984        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
985        assert!(!e.is_transient());
986
987        Ok(())
988    }
989
990    #[tokio::test]
991    async fn get_default_universe_domain_success() -> TestResult {
992        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
993        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
994        Ok(())
995    }
996
997    #[tokio::test]
998    async fn get_custom_universe_domain_success() -> TestResult {
999        let universe_domain = "test-universe";
1000        let universe_domain_response = Builder::default()
1001            .with_universe_domain(universe_domain)
1002            .build()?
1003            .universe_domain()
1004            .await
1005            .unwrap();
1006        assert_eq!(universe_domain_response, universe_domain);
1007
1008        Ok(())
1009    }
1010}