google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use gax::backoff_policy::BackoffPolicyArg;
87use gax::retry_policy::RetryPolicyArg;
88use gax::retry_throttler::RetryThrottlerArg;
89use http::{Extensions, HeaderMap, HeaderValue};
90use reqwest::Client;
91use std::default::Default;
92use std::sync::Arc;
93use std::time::Duration;
94use tokio::time::Instant;
95
96pub(crate) const METADATA_FLAVOR_VALUE: &str = "Google";
97pub(crate) const METADATA_FLAVOR: &str = "metadata-flavor";
98pub(crate) const METADATA_ROOT: &str = "http://metadata.google.internal";
99pub(crate) const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
100pub(crate) const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
101// TODO(#2235) - Improve this message by talking about retries when really running with MDS
102const MDS_NOT_FOUND_ERROR: &str = concat!(
103    "Could not fetch an auth token to authenticate with Google Cloud. ",
104    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
105    "and you have not configured local credentials for development and testing. ",
106    "To setup local credentials, run `gcloud auth application-default login`. ",
107    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
108);
109
110#[derive(Debug)]
111struct MDSCredentials<T>
112where
113    T: CachedTokenProvider,
114{
115    quota_project_id: Option<String>,
116    token_provider: T,
117}
118
119/// Creates [Credentials] instances backed by the [Metadata Service].
120///
121/// While the Google Cloud client libraries for Rust default to credentials
122/// backed by the metadata service, some applications may need to:
123/// * Customize the metadata service credentials in some way
124/// * Bypass the [Application Default Credentials] lookup and only
125///   use the metadata server credentials
126/// * Use the credentials directly outside the client libraries
127///
128/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
129/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
130#[derive(Debug, Default)]
131pub struct Builder {
132    endpoint: Option<String>,
133    quota_project_id: Option<String>,
134    scopes: Option<Vec<String>>,
135    created_by_adc: bool,
136    retry_builder: RetryTokenProviderBuilder,
137}
138
139impl Builder {
140    /// Sets the endpoint for this credentials.
141    ///
142    /// A trailing slash is significant, so specify the base URL without a trailing
143    /// slash. If not set, the credentials use `http://metadata.google.internal`.
144    ///
145    /// # Example
146    /// ```
147    /// # use google_cloud_auth::credentials::mds::Builder;
148    /// # tokio_test::block_on(async {
149    /// let credentials = Builder::default()
150    ///     .with_endpoint("https://metadata.google.foobar")
151    ///     .build();
152    /// # });
153    /// ```
154    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
155        self.endpoint = Some(endpoint.into());
156        self
157    }
158
159    /// Set the [quota project] for this credentials.
160    ///
161    /// In some services, you can use a service account in
162    /// one project for authentication and authorization, and charge
163    /// the usage to a different project. This may require that the
164    /// service account has `serviceusage.services.use` permissions on the quota project.
165    ///
166    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
167    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
168        self.quota_project_id = Some(quota_project_id.into());
169        self
170    }
171
172    /// Sets the [scopes] for this credentials.
173    ///
174    /// Metadata server issues tokens based on the requested scopes.
175    /// If no scopes are specified, the credentials defaults to all
176    /// scopes configured for the [default service account] on the instance.
177    ///
178    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
179    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
180    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
181    where
182        I: IntoIterator<Item = S>,
183        S: Into<String>,
184    {
185        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
186        self
187    }
188
189    /// Configure the retry policy for fetching tokens.
190    ///
191    /// The retry policy controls how to handle retries, and sets limits on
192    /// the number of attempts or the total time spent retrying.
193    ///
194    /// ```
195    /// # use google_cloud_auth::credentials::mds::Builder;
196    /// # tokio_test::block_on(async {
197    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
198    /// let credentials = Builder::default()
199    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
200    ///     .build();
201    /// # });
202    /// ```
203    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
204        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
205        self
206    }
207
208    /// Configure the retry backoff policy.
209    ///
210    /// The backoff policy controls how long to wait in between retry attempts.
211    ///
212    /// ```
213    /// # use google_cloud_auth::credentials::mds::Builder;
214    /// # use std::time::Duration;
215    /// # tokio_test::block_on(async {
216    /// use gax::exponential_backoff::ExponentialBackoff;
217    /// let policy = ExponentialBackoff::default();
218    /// let credentials = Builder::default()
219    ///     .with_backoff_policy(policy)
220    ///     .build();
221    /// # });
222    /// ```
223    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
224        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
225        self
226    }
227
228    /// Configure the retry throttler.
229    ///
230    /// Advanced applications may want to configure a retry throttler to
231    /// [Address Cascading Failures] and when [Handling Overload] conditions.
232    /// The authentication library throttles its retry loop, using a policy to
233    /// control the throttling algorithm. Use this method to fine tune or
234    /// customize the default retry throttler.
235    ///
236    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
237    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
238    ///
239    /// ```
240    /// # use google_cloud_auth::credentials::mds::Builder;
241    /// # tokio_test::block_on(async {
242    /// use gax::retry_throttler::AdaptiveThrottler;
243    /// let credentials = Builder::default()
244    ///     .with_retry_throttler(AdaptiveThrottler::default())
245    ///     .build();
246    /// # });
247    /// ```
248    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
249        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
250        self
251    }
252
253    // This method is used to build mds credentials from ADC
254    pub(crate) fn from_adc() -> Self {
255        Self {
256            created_by_adc: true,
257            ..Default::default()
258        }
259    }
260
261    fn resolve_endpoint(&self) -> (String, bool) {
262        let final_endpoint: String;
263        let endpoint_overridden: bool;
264        // Determine the endpoint and whether it was overridden
265        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
266            // Check GCE_METADATA_HOST environment variable first
267            final_endpoint = format!("http://{host_from_env}");
268            endpoint_overridden = true;
269        } else if let Some(builder_endpoint) = self.endpoint.clone() {
270            // Else, check if an endpoint was provided to the mds::Builder
271            final_endpoint = builder_endpoint;
272            endpoint_overridden = true;
273        } else {
274            // Else, use the default metadata root
275            final_endpoint = METADATA_ROOT.to_string();
276            endpoint_overridden = false;
277        };
278        (final_endpoint, endpoint_overridden)
279    }
280
281    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
282        let (final_endpoint, endpoint_overridden) = self.resolve_endpoint();
283
284        let tp = MDSAccessTokenProvider::builder()
285            .endpoint(final_endpoint)
286            .maybe_scopes(self.scopes)
287            .endpoint_overridden(endpoint_overridden)
288            .created_by_adc(self.created_by_adc)
289            .build();
290        self.retry_builder.build(tp)
291    }
292
293    /// Returns a [Credentials] instance with the configured settings.
294    pub fn build(self) -> BuildResult<Credentials> {
295        Ok(self.build_access_token_credentials()?.into())
296    }
297
298    /// Returns an [AccessTokenCredentials] instance with the configured settings.
299    ///
300    /// # Example
301    ///
302    /// ```
303    /// # use google_cloud_auth::credentials::mds::Builder;
304    /// # use google_cloud_auth::credentials::{AccessTokenCredentials, AccessTokenCredentialsProvider};
305    /// # tokio_test::block_on(async {
306    /// let credentials: AccessTokenCredentials = Builder::default()
307    ///     .with_quota_project_id("my-quota-project")
308    ///     .build_access_token_credentials()?;
309    /// let access_token = credentials.access_token().await?;
310    /// println!("Token: {}", access_token.token);
311    /// # Ok::<(), anyhow::Error>(())
312    /// # });
313    /// ```
314    pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
315        let mdsc = MDSCredentials {
316            quota_project_id: self.quota_project_id.clone(),
317            token_provider: TokenCache::new(self.build_token_provider()),
318        };
319        Ok(AccessTokenCredentials {
320            inner: Arc::new(mdsc),
321        })
322    }
323
324    /// Returns a [crate::signer::Signer] instance with the configured settings.
325    ///
326    /// The returned [crate::signer::Signer] uses the [IAM signBlob API] to sign content. This API
327    /// requires a network request for each signing operation.
328    ///
329    /// # Example
330    ///
331    /// ```
332    /// # use google_cloud_auth::credentials::mds::Builder;
333    /// # use google_cloud_auth::signer::Signer;
334    /// # tokio_test::block_on(async {
335    /// let signer: Signer = Builder::default().build_signer()?;
336    /// # Ok::<(), anyhow::Error>(())
337    /// # });
338    /// ```
339    ///
340    /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob
341    pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
342        self.build_signer_with_iam_endpoint_override(None)
343    }
344
345    // only used for testing
346    fn build_signer_with_iam_endpoint_override(
347        self,
348        iam_endpoint: Option<String>,
349    ) -> BuildResult<crate::signer::Signer> {
350        let (endpoint, _) = self.resolve_endpoint();
351        let credentials = self.build()?;
352        let signing_provider = crate::signer::mds::MDSSigner::new(endpoint, credentials);
353        let signing_provider = iam_endpoint
354            .iter()
355            .fold(signing_provider, |signing_provider, endpoint| {
356                signing_provider.with_iam_endpoint_override(endpoint)
357            });
358        Ok(crate::signer::Signer {
359            inner: Arc::new(signing_provider),
360        })
361    }
362}
363
364#[async_trait::async_trait]
365impl<T> CredentialsProvider for MDSCredentials<T>
366where
367    T: CachedTokenProvider,
368{
369    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
370        let cached_token = self.token_provider.token(extensions).await?;
371        build_cacheable_headers(&cached_token, &self.quota_project_id)
372    }
373}
374
375#[async_trait::async_trait]
376impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
377where
378    T: CachedTokenProvider,
379{
380    async fn access_token(&self) -> Result<AccessToken> {
381        let token = self.token_provider.token(Extensions::new()).await?;
382        token.into()
383    }
384}
385
386#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
387struct MDSTokenResponse {
388    access_token: String,
389    #[serde(skip_serializing_if = "Option::is_none")]
390    expires_in: Option<u64>,
391    token_type: String,
392}
393
394#[derive(Debug, Default)]
395struct MDSAccessTokenProviderBuilder {
396    target: MDSAccessTokenProvider,
397}
398
399impl MDSAccessTokenProviderBuilder {
400    fn build(self) -> MDSAccessTokenProvider {
401        self.target
402    }
403
404    fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
405        self.target.scopes = v;
406        self
407    }
408
409    fn endpoint<T>(mut self, v: T) -> Self
410    where
411        T: Into<String>,
412    {
413        self.target.endpoint = v.into();
414        self
415    }
416
417    fn endpoint_overridden(mut self, v: bool) -> Self {
418        self.target.endpoint_overridden = v;
419        self
420    }
421
422    fn created_by_adc(mut self, v: bool) -> Self {
423        self.target.created_by_adc = v;
424        self
425    }
426}
427
428#[derive(Debug, Clone, Default)]
429struct MDSAccessTokenProvider {
430    scopes: Option<Vec<String>>,
431    endpoint: String,
432    endpoint_overridden: bool,
433    created_by_adc: bool,
434}
435
436impl MDSAccessTokenProvider {
437    fn builder() -> MDSAccessTokenProviderBuilder {
438        MDSAccessTokenProviderBuilder::default()
439    }
440
441    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
442    // environment variable is not set, we default to MDS credentials without checking if the code is really
443    // running in an environment with MDS. To help users who got to this state because of lack of credentials
444    // setup on their machines, we provide a detailed error message to them talking about local setup and other
445    // auth mechanisms available to them.
446    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
447    // error message because they deliberately wanted to use an MDS.
448    fn error_message(&self) -> &str {
449        if self.use_adc_message() {
450            MDS_NOT_FOUND_ERROR
451        } else {
452            "failed to fetch token"
453        }
454    }
455
456    fn use_adc_message(&self) -> bool {
457        self.created_by_adc && !self.endpoint_overridden
458    }
459}
460
461#[async_trait]
462impl TokenProvider for MDSAccessTokenProvider {
463    async fn token(&self) -> Result<Token> {
464        let client = Client::new();
465        let request = client
466            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
467            .header(
468                METADATA_FLAVOR,
469                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
470            );
471        // Use the `scopes` option if set, otherwise let the MDS use the default
472        // scopes.
473        let scopes = self.scopes.as_ref().map(|v| v.join(","));
474        let request = scopes
475            .into_iter()
476            .fold(request, |r, s| r.query(&[("scopes", s)]));
477
478        // If the connection to MDS was not successful, it is useful to retry when really
479        // running on MDS environments and not useful if there is no MDS. We will mark the error
480        // as retryable and let the retry policy determine whether to retry or not. Whenever we
481        // define a default retry policy, we can skip retrying this case.
482        let response = request
483            .send()
484            .await
485            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
486        // Process the response
487        if !response.status().is_success() {
488            let err = crate::errors::from_http_response(response, self.error_message()).await;
489            return Err(err);
490        }
491        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
492            // Decoding errors are not transient. Typically they indicate a badly
493            // configured MDS endpoint, or DNS redirecting the request to a random
494            // server, e.g., ISPs that redirect unknown services to HTTP.
495            CredentialsError::from_source(!e.is_decode(), e)
496        })?;
497        let token = Token {
498            token: response.access_token,
499            token_type: response.token_type,
500            expires_at: response
501                .expires_in
502                .map(|d| Instant::now() + Duration::from_secs(d)),
503            metadata: None,
504        };
505        Ok(token)
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
513    use crate::credentials::QUOTA_PROJECT_KEY;
514    use crate::credentials::tests::{
515        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
516        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
517        get_token_type_from_headers,
518    };
519    use crate::errors;
520    use crate::errors::CredentialsError;
521    use crate::token::tests::MockTokenProvider;
522    use http::HeaderValue;
523    use http::header::AUTHORIZATION;
524    use httptest::cycle;
525    use httptest::matchers::{all_of, contains, request, url_decoded};
526    use httptest::responders::{json_encoded, status_code};
527    use httptest::{Expectation, Server};
528    use reqwest::StatusCode;
529    use scoped_env::ScopedEnv;
530    use serial_test::{parallel, serial};
531    use std::error::Error;
532    use test_case::test_case;
533    use url::Url;
534
535    type TestResult = anyhow::Result<()>;
536
537    #[tokio::test]
538    #[parallel]
539    async fn test_mds_retries_on_transient_failures() -> TestResult {
540        let mut server = Server::run();
541        server.expect(
542            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
543                .times(3)
544                .respond_with(status_code(503)),
545        );
546
547        let provider = Builder::default()
548            .with_endpoint(format!("http://{}", server.addr()))
549            .with_retry_policy(get_mock_auth_retry_policy(3))
550            .with_backoff_policy(get_mock_backoff_policy())
551            .with_retry_throttler(get_mock_retry_throttler())
552            .build_token_provider();
553
554        let err = provider.token().await.unwrap_err();
555        assert!(!err.is_transient());
556        server.verify_and_clear();
557        Ok(())
558    }
559
560    #[tokio::test]
561    #[parallel]
562    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
563        let mut server = Server::run();
564        server.expect(
565            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
566                .times(1)
567                .respond_with(status_code(401)),
568        );
569
570        let provider = Builder::default()
571            .with_endpoint(format!("http://{}", server.addr()))
572            .with_retry_policy(get_mock_auth_retry_policy(1))
573            .with_backoff_policy(get_mock_backoff_policy())
574            .with_retry_throttler(get_mock_retry_throttler())
575            .build_token_provider();
576
577        let err = provider.token().await.unwrap_err();
578        assert!(!err.is_transient());
579        server.verify_and_clear();
580        Ok(())
581    }
582
583    #[tokio::test]
584    #[parallel]
585    async fn test_mds_retries_for_success() -> TestResult {
586        let mut server = Server::run();
587        let response = MDSTokenResponse {
588            access_token: "test-access-token".to_string(),
589            expires_in: Some(3600),
590            token_type: "test-token-type".to_string(),
591        };
592
593        server.expect(
594            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
595                .times(3)
596                .respond_with(cycle![
597                    status_code(503).body("try-again"),
598                    status_code(503).body("try-again"),
599                    status_code(200)
600                        .append_header("Content-Type", "application/json")
601                        .body(serde_json::to_string(&response).unwrap()),
602                ]),
603        );
604
605        let provider = Builder::default()
606            .with_endpoint(format!("http://{}", server.addr()))
607            .with_retry_policy(get_mock_auth_retry_policy(3))
608            .with_backoff_policy(get_mock_backoff_policy())
609            .with_retry_throttler(get_mock_retry_throttler())
610            .build_token_provider();
611
612        let token = provider.token().await?;
613        assert_eq!(token.token, "test-access-token");
614
615        server.verify_and_clear();
616        Ok(())
617    }
618
619    #[test]
620    fn validate_default_endpoint_urls() {
621        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
622        assert!(default_endpoint_address.is_ok());
623
624        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
625        assert!(token_endpoint_address.is_ok());
626    }
627
628    #[tokio::test]
629    async fn headers_success() -> TestResult {
630        let token = Token {
631            token: "test-token".to_string(),
632            token_type: "Bearer".to_string(),
633            expires_at: None,
634            metadata: None,
635        };
636
637        let mut mock = MockTokenProvider::new();
638        mock.expect_token().times(1).return_once(|| Ok(token));
639
640        let mdsc = MDSCredentials {
641            quota_project_id: None,
642            token_provider: TokenCache::new(mock),
643        };
644
645        let mut extensions = Extensions::new();
646        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
647        let (headers, entity_tag) = match cached_headers {
648            CacheableResource::New { entity_tag, data } => (data, entity_tag),
649            CacheableResource::NotModified => unreachable!("expecting new headers"),
650        };
651        let token = headers.get(AUTHORIZATION).unwrap();
652        assert_eq!(headers.len(), 1, "{headers:?}");
653        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
654        assert!(token.is_sensitive());
655
656        extensions.insert(entity_tag);
657
658        let cached_headers = mdsc.headers(extensions).await?;
659
660        match cached_headers {
661            CacheableResource::New { .. } => unreachable!("expecting new headers"),
662            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
663        };
664        Ok(())
665    }
666
667    #[tokio::test]
668    async fn access_token_success() -> TestResult {
669        let server = Server::run();
670        let response = MDSTokenResponse {
671            access_token: "test-access-token".to_string(),
672            expires_in: Some(3600),
673            token_type: "Bearer".to_string(),
674        };
675        server.expect(
676            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
677                .respond_with(json_encoded(response)),
678        );
679
680        let creds = Builder::default()
681            .with_endpoint(format!("http://{}", server.addr()))
682            .build_access_token_credentials()
683            .unwrap();
684
685        let access_token = creds.access_token().await.unwrap();
686        assert_eq!(access_token.token, "test-access-token");
687
688        Ok(())
689    }
690
691    #[tokio::test]
692    async fn headers_failure() {
693        let mut mock = MockTokenProvider::new();
694        mock.expect_token()
695            .times(1)
696            .return_once(|| Err(errors::non_retryable_from_str("fail")));
697
698        let mdsc = MDSCredentials {
699            quota_project_id: None,
700            token_provider: TokenCache::new(mock),
701        };
702        assert!(mdsc.headers(Extensions::new()).await.is_err());
703    }
704
705    #[test]
706    fn error_message_with_adc() {
707        let provider = MDSAccessTokenProvider::builder()
708            .endpoint("http://127.0.0.1")
709            .created_by_adc(true)
710            .endpoint_overridden(false)
711            .build();
712
713        let want = MDS_NOT_FOUND_ERROR;
714        let got = provider.error_message();
715        assert!(got.contains(want), "{got}, {provider:?}");
716    }
717
718    #[test_case(false, false)]
719    #[test_case(false, true)]
720    #[test_case(true, true)]
721    fn error_message_without_adc(adc: bool, overridden: bool) {
722        let provider = MDSAccessTokenProvider::builder()
723            .endpoint("http://127.0.0.1")
724            .created_by_adc(adc)
725            .endpoint_overridden(overridden)
726            .build();
727
728        let not_want = MDS_NOT_FOUND_ERROR;
729        let got = provider.error_message();
730        assert!(!got.contains(not_want), "{got}, {provider:?}");
731    }
732
733    #[tokio::test]
734    #[serial]
735    async fn adc_no_mds() -> TestResult {
736        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
737            // The environment has an MDS, skip the test.
738            return Ok(());
739        };
740
741        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
742        assert!(
743            original_err.to_string().contains("application-default"),
744            "display={err}, debug={err:?}"
745        );
746
747        Ok(())
748    }
749
750    #[tokio::test]
751    #[serial]
752    async fn adc_overridden_mds() -> TestResult {
753        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
754
755        let err = Builder::from_adc()
756            .build_token_provider()
757            .token()
758            .await
759            .unwrap_err();
760
761        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
762
763        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
764        assert!(original_err.is_transient());
765        assert!(
766            !original_err.to_string().contains("application-default"),
767            "display={err}, debug={err:?}"
768        );
769        let source = find_source_error::<reqwest::Error>(&err);
770        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
771
772        Ok(())
773    }
774
775    #[tokio::test]
776    #[serial]
777    async fn builder_no_mds() -> TestResult {
778        let Err(e) = Builder::default().build_token_provider().token().await else {
779            // The environment has an MDS, skip the test.
780            return Ok(());
781        };
782
783        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
784        assert!(
785            !format!("{:?}", original_err.source()).contains("application-default"),
786            "{e:?}"
787        );
788
789        Ok(())
790    }
791
792    #[tokio::test]
793    #[serial]
794    async fn test_gce_metadata_host_env_var() -> TestResult {
795        let server = Server::run();
796        let scopes = ["scope1", "scope2"];
797        let response = MDSTokenResponse {
798            access_token: "test-access-token".to_string(),
799            expires_in: Some(3600),
800            token_type: "test-token-type".to_string(),
801        };
802        server.expect(
803            Expectation::matching(all_of![
804                request::path(format!("{MDS_DEFAULT_URI}/token")),
805                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
806            ])
807            .respond_with(json_encoded(response)),
808        );
809
810        let addr = server.addr().to_string();
811        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
812        let mdsc = Builder::default()
813            .with_scopes(["scope1", "scope2"])
814            .build()
815            .unwrap();
816        let headers = mdsc.headers(Extensions::new()).await.unwrap();
817        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
818
819        assert_eq!(
820            get_token_from_headers(headers).unwrap(),
821            "test-access-token"
822        );
823        Ok(())
824    }
825
826    #[tokio::test]
827    #[parallel]
828    async fn headers_success_with_quota_project() -> TestResult {
829        let server = Server::run();
830        let scopes = ["scope1", "scope2"];
831        let response = MDSTokenResponse {
832            access_token: "test-access-token".to_string(),
833            expires_in: Some(3600),
834            token_type: "test-token-type".to_string(),
835        };
836        server.expect(
837            Expectation::matching(all_of![
838                request::path(format!("{MDS_DEFAULT_URI}/token")),
839                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
840            ])
841            .respond_with(json_encoded(response)),
842        );
843
844        let mdsc = Builder::default()
845            .with_scopes(["scope1", "scope2"])
846            .with_endpoint(format!("http://{}", server.addr()))
847            .with_quota_project_id("test-project")
848            .build()?;
849
850        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
851        let token = headers.get(AUTHORIZATION).unwrap();
852        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
853
854        assert_eq!(headers.len(), 2, "{headers:?}");
855        assert_eq!(
856            token,
857            HeaderValue::from_static("test-token-type test-access-token")
858        );
859        assert!(token.is_sensitive());
860        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
861        assert!(!quota_project.is_sensitive());
862
863        Ok(())
864    }
865
866    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
867    #[parallel]
868    async fn token_caching() -> TestResult {
869        let mut server = Server::run();
870        let scopes = vec!["scope1".to_string()];
871        let response = MDSTokenResponse {
872            access_token: "test-access-token".to_string(),
873            expires_in: Some(3600),
874            token_type: "test-token-type".to_string(),
875        };
876        server.expect(
877            Expectation::matching(all_of![
878                request::path(format!("{MDS_DEFAULT_URI}/token")),
879                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
880            ])
881            .times(1)
882            .respond_with(json_encoded(response)),
883        );
884
885        let mdsc = Builder::default()
886            .with_scopes(scopes)
887            .with_endpoint(format!("http://{}", server.addr()))
888            .build()?;
889        let headers = mdsc.headers(Extensions::new()).await?;
890        assert_eq!(
891            get_token_from_headers(headers).unwrap(),
892            "test-access-token"
893        );
894        let headers = mdsc.headers(Extensions::new()).await?;
895        assert_eq!(
896            get_token_from_headers(headers).unwrap(),
897            "test-access-token"
898        );
899
900        // validate that the inner token provider is called only once
901        server.verify_and_clear();
902
903        Ok(())
904    }
905
906    #[tokio::test(start_paused = true)]
907    #[parallel]
908    async fn token_provider_full() -> TestResult {
909        let server = Server::run();
910        let scopes = vec!["scope1".to_string()];
911        let response = MDSTokenResponse {
912            access_token: "test-access-token".to_string(),
913            expires_in: Some(3600),
914            token_type: "test-token-type".to_string(),
915        };
916        server.expect(
917            Expectation::matching(all_of![
918                request::path(format!("{MDS_DEFAULT_URI}/token")),
919                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
920            ])
921            .respond_with(json_encoded(response)),
922        );
923
924        let token = Builder::default()
925            .with_endpoint(format!("http://{}", server.addr()))
926            .with_scopes(scopes)
927            .build_token_provider()
928            .token()
929            .await?;
930
931        let now = tokio::time::Instant::now();
932        assert_eq!(token.token, "test-access-token");
933        assert_eq!(token.token_type, "test-token-type");
934        assert!(
935            token
936                .expires_at
937                .is_some_and(|d| d >= now + Duration::from_secs(3600))
938        );
939
940        Ok(())
941    }
942
943    #[tokio::test(start_paused = true)]
944    #[parallel]
945    async fn token_provider_full_no_scopes() -> TestResult {
946        let server = Server::run();
947        let response = MDSTokenResponse {
948            access_token: "test-access-token".to_string(),
949            expires_in: Some(3600),
950            token_type: "test-token-type".to_string(),
951        };
952        server.expect(
953            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
954                .respond_with(json_encoded(response)),
955        );
956
957        let token = Builder::default()
958            .with_endpoint(format!("http://{}", server.addr()))
959            .build_token_provider()
960            .token()
961            .await?;
962
963        let now = Instant::now();
964        assert_eq!(token.token, "test-access-token");
965        assert_eq!(token.token_type, "test-token-type");
966        assert!(
967            token
968                .expires_at
969                .is_some_and(|d| d == now + Duration::from_secs(3600))
970        );
971
972        Ok(())
973    }
974
975    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
976    #[parallel]
977    async fn credential_provider_full() -> TestResult {
978        let server = Server::run();
979        let scopes = vec!["scope1".to_string()];
980        let response = MDSTokenResponse {
981            access_token: "test-access-token".to_string(),
982            expires_in: None,
983            token_type: "test-token-type".to_string(),
984        };
985        server.expect(
986            Expectation::matching(all_of![
987                request::path(format!("{MDS_DEFAULT_URI}/token")),
988                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
989            ])
990            .respond_with(json_encoded(response)),
991        );
992
993        let mdsc = Builder::default()
994            .with_endpoint(format!("http://{}", server.addr()))
995            .with_scopes(scopes)
996            .build()?;
997        let headers = mdsc.headers(Extensions::new()).await?;
998        assert_eq!(
999            get_token_from_headers(headers.clone()).unwrap(),
1000            "test-access-token"
1001        );
1002        assert_eq!(
1003            get_token_type_from_headers(headers).unwrap(),
1004            "test-token-type"
1005        );
1006
1007        Ok(())
1008    }
1009
1010    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1011    #[parallel]
1012    async fn credentials_headers_retryable_error() -> TestResult {
1013        let server = Server::run();
1014        let scopes = vec!["scope1".to_string()];
1015        server.expect(
1016            Expectation::matching(all_of![
1017                request::path(format!("{MDS_DEFAULT_URI}/token")),
1018                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1019            ])
1020            .respond_with(status_code(503)),
1021        );
1022
1023        let mdsc = Builder::default()
1024            .with_endpoint(format!("http://{}", server.addr()))
1025            .with_scopes(scopes)
1026            .build()?;
1027        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1028        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1029        assert!(original_err.is_transient());
1030        let source = find_source_error::<reqwest::Error>(&err);
1031        assert!(
1032            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1033            "{err:?}"
1034        );
1035
1036        Ok(())
1037    }
1038
1039    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1040    #[parallel]
1041    async fn credentials_headers_nonretryable_error() -> TestResult {
1042        let server = Server::run();
1043        let scopes = vec!["scope1".to_string()];
1044        server.expect(
1045            Expectation::matching(all_of![
1046                request::path(format!("{MDS_DEFAULT_URI}/token")),
1047                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1048            ])
1049            .respond_with(status_code(401)),
1050        );
1051
1052        let mdsc = Builder::default()
1053            .with_endpoint(format!("http://{}", server.addr()))
1054            .with_scopes(scopes)
1055            .build()?;
1056
1057        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1058        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1059        assert!(!original_err.is_transient());
1060        let source = find_source_error::<reqwest::Error>(&err);
1061        assert!(
1062            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1063            "{err:?}"
1064        );
1065
1066        Ok(())
1067    }
1068
1069    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1070    #[parallel]
1071    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1072        let server = Server::run();
1073        let scopes = vec!["scope1".to_string()];
1074        server.expect(
1075            Expectation::matching(all_of![
1076                request::path(format!("{MDS_DEFAULT_URI}/token")),
1077                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1078            ])
1079            .respond_with(json_encoded("bad json")),
1080        );
1081
1082        let mdsc = Builder::default()
1083            .with_endpoint(format!("http://{}", server.addr()))
1084            .with_scopes(scopes)
1085            .build()?;
1086
1087        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1088        assert!(!e.is_transient());
1089
1090        Ok(())
1091    }
1092
1093    #[tokio::test]
1094    async fn get_default_universe_domain_success() -> TestResult {
1095        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1096        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1097        Ok(())
1098    }
1099
1100    #[tokio::test]
1101    async fn get_mds_signer() -> TestResult {
1102        use base64::{Engine, prelude::BASE64_STANDARD};
1103        use serde_json::json;
1104
1105        let server = Server::run();
1106        server.expect(
1107            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1108                .respond_with(json_encoded(MDSTokenResponse {
1109                    access_token: "test-access-token".to_string(),
1110                    expires_in: None,
1111                    token_type: "Bearer".to_string(),
1112                })),
1113        );
1114        server.expect(
1115            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1116                .respond_with(status_code(200).body("test-client-email")),
1117        );
1118        server.expect(
1119            Expectation::matching(all_of![
1120                request::method_path(
1121                    "POST",
1122                    "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1123                ),
1124                request::headers(contains(("authorization", "Bearer test-access-token"))),
1125            ])
1126            .respond_with(json_encoded(json!({
1127                "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1128            }))),
1129        );
1130
1131        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1132
1133        let signer = Builder::default()
1134            .with_endpoint(&endpoint)
1135            .build_signer_with_iam_endpoint_override(Some(endpoint))?;
1136
1137        let client_email = signer.client_email().await?;
1138        assert_eq!(client_email, "test-client-email");
1139
1140        let signature = signer.sign(b"test").await?;
1141        assert_eq!(signature.as_ref(), b"signed_blob");
1142
1143        Ok(())
1144    }
1145}