Skip to main content

google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! ## Example: Creating credentials with a custom quota project
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use http::Extensions;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build()?;
43//! let headers = credentials.headers(Extensions::new()).await?;
44//! println!("Headers: {headers:?}");
45//! # Ok::<(), anyhow::Error>(())
46//! # });
47//! ```
48//!
49//! ## Example: Creating credentials with custom retry behavior
50//!
51//! ```
52//! # use google_cloud_auth::credentials::mds::Builder;
53//! # use google_cloud_auth::credentials::Credentials;
54//! # use http::Extensions;
55//! # use std::time::Duration;
56//! # tokio_test::block_on(async {
57//! use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
58//! use google_cloud_gax::exponential_backoff::ExponentialBackoff;
59//! let backoff = ExponentialBackoff::default();
60//! let credentials: Credentials = Builder::default()
61//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
62//!     .with_backoff_policy(backoff)
63//!     .build()?;
64//! let headers = credentials.headers(Extensions::new()).await?;
65//! println!("Headers: {headers:?}");
66//! # Ok::<(), anyhow::Error>(())
67//! # });
68//! ```
69//!
70//! [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
71//! [Cloud Run]: https://cloud.google.com/run
72//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
73//! [gce-link]: https://cloud.google.com/products/compute
74//! [gke-link]: https://cloud.google.com/kubernetes-engine
75//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
76
77use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::headers_util::AuthHeadersBuilder;
80use crate::mds::client::Client as MDSClient;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use google_cloud_gax::backoff_policy::BackoffPolicyArg;
87use google_cloud_gax::error::CredentialsError;
88use google_cloud_gax::retry_policy::RetryPolicyArg;
89use google_cloud_gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap};
91use std::default::Default;
92use std::sync::Arc;
93
94// TODO(#2235) - Improve this message by talking about retries when really running with MDS
95const MDS_NOT_FOUND_ERROR: &str = concat!(
96    "Could not fetch an auth token to authenticate with Google Cloud. ",
97    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
98    "and you have not configured local credentials for development and testing. ",
99    "To setup local credentials, run `gcloud auth application-default login`. ",
100    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
101);
102
103#[derive(Debug)]
104struct MDSCredentials<T>
105where
106    T: CachedTokenProvider,
107{
108    quota_project_id: Option<String>,
109    token_provider: T,
110}
111
112/// Creates [Credentials] instances backed by the [Metadata Service].
113///
114/// While the Google Cloud client libraries for Rust default to credentials
115/// backed by the metadata service, some applications may need to:
116/// * Customize the metadata service credentials in some way
117/// * Bypass the [Application Default Credentials] lookup and only
118///   use the metadata server credentials
119/// * Use the credentials directly outside the client libraries
120///
121/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
122/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
123#[derive(Debug, Default)]
124pub struct Builder {
125    endpoint: Option<String>,
126    quota_project_id: Option<String>,
127    scopes: Option<Vec<String>>,
128    created_by_adc: bool,
129    retry_builder: RetryTokenProviderBuilder,
130}
131
132impl Builder {
133    /// Sets the endpoint for this credentials.
134    ///
135    /// A trailing slash is significant, so specify the base URL without a trailing
136    /// slash. If not set, the credentials use `http://metadata.google.internal`.
137    ///
138    /// # Example
139    /// ```
140    /// # use google_cloud_auth::credentials::mds::Builder;
141    /// # tokio_test::block_on(async {
142    /// let credentials = Builder::default()
143    ///     .with_endpoint("https://metadata.google.foobar")
144    ///     .build();
145    /// # });
146    /// ```
147    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
148        self.endpoint = Some(endpoint.into());
149        self
150    }
151
152    /// Set the [quota project] for this credentials.
153    ///
154    /// In some services, you can use a service account in
155    /// one project for authentication and authorization, and charge
156    /// the usage to a different project. This may require that the
157    /// service account has `serviceusage.services.use` permissions on the quota project.
158    ///
159    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
160    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
161        self.quota_project_id = Some(quota_project_id.into());
162        self
163    }
164
165    /// Sets the [scopes] for this credentials.
166    ///
167    /// Metadata server issues tokens based on the requested scopes.
168    /// If no scopes are specified, the credentials defaults to all
169    /// scopes configured for the [default service account] on the instance.
170    ///
171    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
172    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
173    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
174    where
175        I: IntoIterator<Item = S>,
176        S: Into<String>,
177    {
178        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
179        self
180    }
181
182    /// Configure the retry policy for fetching tokens.
183    ///
184    /// The retry policy controls how to handle retries, and sets limits on
185    /// the number of attempts or the total time spent retrying.
186    ///
187    /// ```
188    /// # use google_cloud_auth::credentials::mds::Builder;
189    /// # tokio_test::block_on(async {
190    /// use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
191    /// let credentials = Builder::default()
192    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
193    ///     .build();
194    /// # });
195    /// ```
196    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
197        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
198        self
199    }
200
201    /// Configure the retry backoff policy.
202    ///
203    /// The backoff policy controls how long to wait in between retry attempts.
204    ///
205    /// ```
206    /// # use google_cloud_auth::credentials::mds::Builder;
207    /// # use std::time::Duration;
208    /// # tokio_test::block_on(async {
209    /// use google_cloud_gax::exponential_backoff::ExponentialBackoff;
210    /// let policy = ExponentialBackoff::default();
211    /// let credentials = Builder::default()
212    ///     .with_backoff_policy(policy)
213    ///     .build();
214    /// # });
215    /// ```
216    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
217        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
218        self
219    }
220
221    /// Configure the retry throttler.
222    ///
223    /// Advanced applications may want to configure a retry throttler to
224    /// [Address Cascading Failures] and when [Handling Overload] conditions.
225    /// The authentication library throttles its retry loop, using a policy to
226    /// control the throttling algorithm. Use this method to fine tune or
227    /// customize the default retry throttler.
228    ///
229    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
230    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
231    ///
232    /// ```
233    /// # use google_cloud_auth::credentials::mds::Builder;
234    /// # tokio_test::block_on(async {
235    /// use google_cloud_gax::retry_throttler::AdaptiveThrottler;
236    /// let credentials = Builder::default()
237    ///     .with_retry_throttler(AdaptiveThrottler::default())
238    ///     .build();
239    /// # });
240    /// ```
241    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
242        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
243        self
244    }
245
246    // This method is used to build mds credentials from ADC
247    pub(crate) fn from_adc() -> Self {
248        Self {
249            created_by_adc: true,
250            ..Default::default()
251        }
252    }
253
254    fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
255        let tp = MDSAccessTokenProvider::builder()
256            .endpoint(self.endpoint)
257            .maybe_scopes(self.scopes)
258            .created_by_adc(self.created_by_adc)
259            .build();
260        self.retry_builder.build(tp)
261    }
262
263    /// Returns a [Credentials] instance with the configured settings.
264    pub fn build(self) -> BuildResult<Credentials> {
265        Ok(self.build_access_token_credentials()?.into())
266    }
267
268    /// Returns an [AccessTokenCredentials] instance with the configured settings.
269    ///
270    /// # Example
271    ///
272    /// ```
273    /// # use google_cloud_auth::credentials::mds::Builder;
274    /// # use google_cloud_auth::credentials::{AccessTokenCredentials, AccessTokenCredentialsProvider};
275    /// # tokio_test::block_on(async {
276    /// let credentials: AccessTokenCredentials = Builder::default()
277    ///     .with_quota_project_id("my-quota-project")
278    ///     .build_access_token_credentials()?;
279    /// let access_token = credentials.access_token().await?;
280    /// println!("Token: {}", access_token.token);
281    /// # Ok::<(), anyhow::Error>(())
282    /// # });
283    /// ```
284    pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
285        let mdsc = MDSCredentials {
286            quota_project_id: self.quota_project_id.clone(),
287            token_provider: TokenCache::new(self.build_token_provider()),
288        };
289        Ok(AccessTokenCredentials {
290            inner: Arc::new(mdsc),
291        })
292    }
293
294    /// Returns a [crate::signer::Signer] instance with the configured settings.
295    ///
296    /// The returned [crate::signer::Signer] uses the [IAM signBlob API] to sign content. This API
297    /// requires a network request for each signing operation.
298    ///
299    /// # Example
300    ///
301    /// ```
302    /// # use google_cloud_auth::credentials::mds::Builder;
303    /// # use google_cloud_auth::signer::Signer;
304    /// # tokio_test::block_on(async {
305    /// let signer: Signer = Builder::default().build_signer()?;
306    /// # Ok::<(), anyhow::Error>(())
307    /// # });
308    /// ```
309    ///
310    /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob
311    pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
312        self.build_signer_with_iam_endpoint_override(None)
313    }
314
315    // only used for testing
316    fn build_signer_with_iam_endpoint_override(
317        self,
318        iam_endpoint: Option<String>,
319    ) -> BuildResult<crate::signer::Signer> {
320        let client = MDSClient::new(self.endpoint.clone());
321        let credentials = self.build()?;
322        let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
323        let signing_provider = iam_endpoint
324            .iter()
325            .fold(signing_provider, |signing_provider, endpoint| {
326                signing_provider.with_iam_endpoint_override(endpoint)
327            });
328        Ok(crate::signer::Signer {
329            inner: Arc::new(signing_provider),
330        })
331    }
332}
333
334#[async_trait::async_trait]
335impl<T> CredentialsProvider for MDSCredentials<T>
336where
337    T: CachedTokenProvider,
338{
339    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
340        let token = self.token_provider.token(extensions).await?;
341
342        AuthHeadersBuilder::new(&token)
343            .maybe_quota_project_id(self.quota_project_id.as_deref())
344            .build()
345    }
346}
347
348#[async_trait::async_trait]
349impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
350where
351    T: CachedTokenProvider,
352{
353    async fn access_token(&self) -> Result<AccessToken> {
354        let token = self.token_provider.token(Extensions::new()).await?;
355        token.into()
356    }
357}
358
359#[derive(Debug, Default)]
360struct MDSAccessTokenProviderBuilder {
361    scopes: Option<Vec<String>>,
362    endpoint: Option<String>,
363    created_by_adc: bool,
364}
365
366impl MDSAccessTokenProviderBuilder {
367    fn build(self) -> MDSAccessTokenProvider {
368        MDSAccessTokenProvider {
369            client: MDSClient::new(self.endpoint),
370            scopes: self.scopes,
371            created_by_adc: self.created_by_adc,
372        }
373    }
374
375    fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
376        self.scopes = v;
377        self
378    }
379
380    fn endpoint<T>(mut self, v: Option<T>) -> Self
381    where
382        T: Into<String>,
383    {
384        self.endpoint = v.map(Into::into);
385        self
386    }
387
388    fn created_by_adc(mut self, v: bool) -> Self {
389        self.created_by_adc = v;
390        self
391    }
392}
393
394#[derive(Debug, Clone)]
395struct MDSAccessTokenProvider {
396    scopes: Option<Vec<String>>,
397    client: MDSClient,
398    created_by_adc: bool,
399}
400
401impl MDSAccessTokenProvider {
402    fn builder() -> MDSAccessTokenProviderBuilder {
403        MDSAccessTokenProviderBuilder::default()
404    }
405
406    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
407    // environment variable is not set, we default to MDS credentials without checking if the code is really
408    // running in an environment with MDS. To help users who got to this state because of lack of credentials
409    // setup on their machines, we provide a detailed error message to them talking about local setup and other
410    // auth mechanisms available to them.
411    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
412    // error message because they deliberately wanted to use an MDS.
413    fn error_message(&self) -> &str {
414        if self.use_adc_message() {
415            MDS_NOT_FOUND_ERROR
416        } else {
417            "failed to fetch token"
418        }
419    }
420
421    fn use_adc_message(&self) -> bool {
422        self.created_by_adc && self.client.is_default_endpoint
423    }
424}
425
426#[async_trait]
427impl TokenProvider for MDSAccessTokenProvider {
428    async fn token(&self) -> Result<Token> {
429        self.client
430            .access_token(self.scopes.clone())
431            .await
432            .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
440    use crate::credentials::QUOTA_PROJECT_KEY;
441    use crate::credentials::tests::{
442        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
443        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
444        get_token_type_from_headers,
445    };
446    use crate::errors;
447    use crate::errors::CredentialsError;
448    use crate::mds::client::MDSTokenResponse;
449    use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
450    use crate::token::tests::MockTokenProvider;
451    use http::HeaderValue;
452    use http::header::AUTHORIZATION;
453    use httptest::cycle;
454    use httptest::matchers::{all_of, contains, request, url_decoded};
455    use httptest::responders::{json_encoded, status_code};
456    use httptest::{Expectation, Server};
457    use reqwest::StatusCode;
458    use scoped_env::ScopedEnv;
459    use serial_test::{parallel, serial};
460    use std::error::Error;
461    use std::time::Duration;
462    use test_case::test_case;
463    use tokio::time::Instant;
464    use url::Url;
465
466    type TestResult = anyhow::Result<()>;
467
468    #[tokio::test]
469    #[parallel]
470    async fn test_mds_retries_on_transient_failures() -> TestResult {
471        let mut server = Server::run();
472        server.expect(
473            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
474                .times(3)
475                .respond_with(status_code(503)),
476        );
477
478        let provider = Builder::default()
479            .with_endpoint(format!("http://{}", server.addr()))
480            .with_retry_policy(get_mock_auth_retry_policy(3))
481            .with_backoff_policy(get_mock_backoff_policy())
482            .with_retry_throttler(get_mock_retry_throttler())
483            .build_token_provider();
484
485        let err = provider.token().await.unwrap_err();
486        assert!(err.is_transient(), "{err:?}");
487        server.verify_and_clear();
488        Ok(())
489    }
490
491    #[tokio::test]
492    #[parallel]
493    async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
494        let mut server = Server::run();
495        server.expect(
496            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
497                .times(1)
498                .respond_with(status_code(401)),
499        );
500
501        let provider = Builder::default()
502            .with_endpoint(format!("http://{}", server.addr()))
503            .with_retry_policy(get_mock_auth_retry_policy(1))
504            .with_backoff_policy(get_mock_backoff_policy())
505            .with_retry_throttler(get_mock_retry_throttler())
506            .build_token_provider();
507
508        let err = provider.token().await.unwrap_err();
509        assert!(!err.is_transient());
510        server.verify_and_clear();
511        Ok(())
512    }
513
514    #[tokio::test]
515    #[parallel]
516    async fn test_mds_retries_for_success() -> TestResult {
517        let mut server = Server::run();
518        let response = MDSTokenResponse {
519            access_token: "test-access-token".to_string(),
520            expires_in: Some(3600),
521            token_type: "test-token-type".to_string(),
522        };
523
524        server.expect(
525            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
526                .times(3)
527                .respond_with(cycle![
528                    status_code(503).body("try-again"),
529                    status_code(503).body("try-again"),
530                    status_code(200)
531                        .append_header("Content-Type", "application/json")
532                        .body(serde_json::to_string(&response).unwrap()),
533                ]),
534        );
535
536        let provider = Builder::default()
537            .with_endpoint(format!("http://{}", server.addr()))
538            .with_retry_policy(get_mock_auth_retry_policy(3))
539            .with_backoff_policy(get_mock_backoff_policy())
540            .with_retry_throttler(get_mock_retry_throttler())
541            .build_token_provider();
542
543        let token = provider.token().await?;
544        assert_eq!(token.token, "test-access-token");
545
546        server.verify_and_clear();
547        Ok(())
548    }
549
550    #[test]
551    #[parallel]
552    fn validate_default_endpoint_urls() {
553        let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
554        assert!(
555            default_endpoint_address.is_ok(),
556            "{default_endpoint_address:?}"
557        );
558
559        let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
560        assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
561    }
562
563    #[tokio::test]
564    #[parallel]
565    async fn headers_success() -> TestResult {
566        let token = Token {
567            token: "test-token".to_string(),
568            token_type: "Bearer".to_string(),
569            expires_at: None,
570            metadata: None,
571        };
572
573        let mut mock = MockTokenProvider::new();
574        mock.expect_token().times(1).return_once(|| Ok(token));
575
576        let mdsc = MDSCredentials {
577            quota_project_id: None,
578            token_provider: TokenCache::new(mock),
579        };
580
581        let mut extensions = Extensions::new();
582        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
583        let (headers, entity_tag) = match cached_headers {
584            CacheableResource::New { entity_tag, data } => (data, entity_tag),
585            CacheableResource::NotModified => unreachable!("expecting new headers"),
586        };
587        let token = headers.get(AUTHORIZATION).unwrap();
588        assert_eq!(headers.len(), 1, "{headers:?}");
589        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
590        assert!(token.is_sensitive());
591
592        extensions.insert(entity_tag);
593
594        let cached_headers = mdsc.headers(extensions).await?;
595
596        match cached_headers {
597            CacheableResource::New { .. } => unreachable!("expecting new headers"),
598            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
599        };
600        Ok(())
601    }
602
603    #[tokio::test]
604    #[parallel]
605    async fn access_token_success() -> TestResult {
606        let server = Server::run();
607        let response = MDSTokenResponse {
608            access_token: "test-access-token".to_string(),
609            expires_in: Some(3600),
610            token_type: "Bearer".to_string(),
611        };
612        server.expect(
613            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
614                .respond_with(json_encoded(response)),
615        );
616
617        let creds = Builder::default()
618            .with_endpoint(format!("http://{}", server.addr()))
619            .build_access_token_credentials()
620            .unwrap();
621
622        let access_token = creds.access_token().await.unwrap();
623        assert_eq!(access_token.token, "test-access-token");
624
625        Ok(())
626    }
627
628    #[tokio::test]
629    #[parallel]
630    async fn headers_failure() {
631        let mut mock = MockTokenProvider::new();
632        mock.expect_token()
633            .times(1)
634            .return_once(|| Err(errors::non_retryable_from_str("fail")));
635
636        let mdsc = MDSCredentials {
637            quota_project_id: None,
638            token_provider: TokenCache::new(mock),
639        };
640        let result = mdsc.headers(Extensions::new()).await;
641        assert!(result.is_err(), "{result:?}");
642    }
643
644    #[test]
645    #[parallel]
646    fn error_message_with_adc() {
647        let provider = MDSAccessTokenProvider::builder()
648            .created_by_adc(true)
649            .build();
650
651        let want = MDS_NOT_FOUND_ERROR;
652        let got = provider.error_message();
653        assert!(got.contains(want), "{got}, {provider:?}");
654    }
655
656    #[test_case(false, false)]
657    #[test_case(false, true)]
658    #[test_case(true, true)]
659    fn error_message_without_adc(adc: bool, overridden: bool) {
660        let endpoint = if overridden {
661            Some("http://127.0.0.1")
662        } else {
663            None
664        };
665        let provider = MDSAccessTokenProvider::builder()
666            .endpoint(endpoint)
667            .created_by_adc(adc)
668            .build();
669
670        let not_want = MDS_NOT_FOUND_ERROR;
671        let got = provider.error_message();
672        assert!(!got.contains(not_want), "{got}, {provider:?}");
673    }
674
675    #[tokio::test]
676    #[serial]
677    async fn adc_no_mds() -> TestResult {
678        let Err(err) = Builder::from_adc().build_token_provider().token().await else {
679            // The environment has an MDS, skip the test.
680            return Ok(());
681        };
682
683        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
684        assert!(
685            original_err.to_string().contains("application-default"),
686            "display={err}, debug={err:?}"
687        );
688
689        Ok(())
690    }
691
692    #[tokio::test]
693    #[serial]
694    async fn adc_overridden_mds() -> TestResult {
695        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
696
697        let err = Builder::from_adc()
698            .build_token_provider()
699            .token()
700            .await
701            .unwrap_err();
702
703        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
704
705        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
706        assert!(original_err.is_transient());
707        assert!(
708            !original_err.to_string().contains("application-default"),
709            "display={err}, debug={err:?}"
710        );
711        let source = find_source_error::<reqwest::Error>(&err);
712        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
713
714        Ok(())
715    }
716
717    #[tokio::test]
718    #[serial]
719    async fn builder_no_mds() -> TestResult {
720        let Err(e) = Builder::default().build_token_provider().token().await else {
721            // The environment has an MDS, skip the test.
722            return Ok(());
723        };
724
725        let original_err = find_source_error::<CredentialsError>(&e).unwrap();
726        assert!(
727            !format!("{:?}", original_err.source()).contains("application-default"),
728            "{e:?}"
729        );
730
731        Ok(())
732    }
733
734    #[tokio::test]
735    #[serial]
736    async fn test_gce_metadata_host_env_var() -> TestResult {
737        let server = Server::run();
738        let scopes = ["scope1", "scope2"];
739        let response = MDSTokenResponse {
740            access_token: "test-access-token".to_string(),
741            expires_in: Some(3600),
742            token_type: "test-token-type".to_string(),
743        };
744        server.expect(
745            Expectation::matching(all_of![
746                request::path(format!("{MDS_DEFAULT_URI}/token")),
747                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
748            ])
749            .respond_with(json_encoded(response)),
750        );
751
752        let addr = server.addr().to_string();
753        let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
754        let mdsc = Builder::default()
755            .with_scopes(["scope1", "scope2"])
756            .build()
757            .unwrap();
758        let headers = mdsc.headers(Extensions::new()).await.unwrap();
759        let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
760
761        assert_eq!(
762            get_token_from_headers(headers).unwrap(),
763            "test-access-token"
764        );
765        Ok(())
766    }
767
768    #[tokio::test]
769    #[parallel]
770    async fn headers_success_with_quota_project() -> TestResult {
771        let server = Server::run();
772        let scopes = ["scope1", "scope2"];
773        let response = MDSTokenResponse {
774            access_token: "test-access-token".to_string(),
775            expires_in: Some(3600),
776            token_type: "test-token-type".to_string(),
777        };
778        server.expect(
779            Expectation::matching(all_of![
780                request::path(format!("{MDS_DEFAULT_URI}/token")),
781                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
782            ])
783            .respond_with(json_encoded(response)),
784        );
785
786        let mdsc = Builder::default()
787            .with_scopes(["scope1", "scope2"])
788            .with_endpoint(format!("http://{}", server.addr()))
789            .with_quota_project_id("test-project")
790            .build()?;
791
792        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
793        let token = headers.get(AUTHORIZATION).unwrap();
794        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
795
796        assert_eq!(headers.len(), 2, "{headers:?}");
797        assert_eq!(
798            token,
799            HeaderValue::from_static("test-token-type test-access-token")
800        );
801        assert!(token.is_sensitive());
802        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
803        assert!(!quota_project.is_sensitive());
804
805        Ok(())
806    }
807
808    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
809    #[parallel]
810    async fn token_caching() -> TestResult {
811        let mut server = Server::run();
812        let scopes = vec!["scope1".to_string()];
813        let response = MDSTokenResponse {
814            access_token: "test-access-token".to_string(),
815            expires_in: Some(3600),
816            token_type: "test-token-type".to_string(),
817        };
818        server.expect(
819            Expectation::matching(all_of![
820                request::path(format!("{MDS_DEFAULT_URI}/token")),
821                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
822            ])
823            .times(1)
824            .respond_with(json_encoded(response)),
825        );
826
827        let mdsc = Builder::default()
828            .with_scopes(scopes)
829            .with_endpoint(format!("http://{}", server.addr()))
830            .build()?;
831        let headers = mdsc.headers(Extensions::new()).await?;
832        assert_eq!(
833            get_token_from_headers(headers).unwrap(),
834            "test-access-token"
835        );
836        let headers = mdsc.headers(Extensions::new()).await?;
837        assert_eq!(
838            get_token_from_headers(headers).unwrap(),
839            "test-access-token"
840        );
841
842        // validate that the inner token provider is called only once
843        server.verify_and_clear();
844
845        Ok(())
846    }
847
848    #[tokio::test(start_paused = true)]
849    #[parallel]
850    async fn token_provider_full() -> TestResult {
851        let server = Server::run();
852        let scopes = vec!["scope1".to_string()];
853        let response = MDSTokenResponse {
854            access_token: "test-access-token".to_string(),
855            expires_in: Some(3600),
856            token_type: "test-token-type".to_string(),
857        };
858        server.expect(
859            Expectation::matching(all_of![
860                request::path(format!("{MDS_DEFAULT_URI}/token")),
861                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
862            ])
863            .respond_with(json_encoded(response)),
864        );
865
866        let token = Builder::default()
867            .with_endpoint(format!("http://{}", server.addr()))
868            .with_scopes(scopes)
869            .build_token_provider()
870            .token()
871            .await?;
872
873        let now = tokio::time::Instant::now();
874        assert_eq!(token.token, "test-access-token");
875        assert_eq!(token.token_type, "test-token-type");
876        assert!(
877            token
878                .expires_at
879                .is_some_and(|d| d >= now + Duration::from_secs(3600))
880        );
881
882        Ok(())
883    }
884
885    #[tokio::test(start_paused = true)]
886    #[parallel]
887    async fn token_provider_full_no_scopes() -> TestResult {
888        let server = Server::run();
889        let response = MDSTokenResponse {
890            access_token: "test-access-token".to_string(),
891            expires_in: Some(3600),
892            token_type: "test-token-type".to_string(),
893        };
894        server.expect(
895            Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
896                .respond_with(json_encoded(response)),
897        );
898
899        let token = Builder::default()
900            .with_endpoint(format!("http://{}", server.addr()))
901            .build_token_provider()
902            .token()
903            .await?;
904
905        let now = Instant::now();
906        assert_eq!(token.token, "test-access-token");
907        assert_eq!(token.token_type, "test-token-type");
908        assert!(
909            token
910                .expires_at
911                .is_some_and(|d| d == now + Duration::from_secs(3600))
912        );
913
914        Ok(())
915    }
916
917    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
918    #[parallel]
919    async fn credential_provider_full() -> TestResult {
920        let server = Server::run();
921        let scopes = vec!["scope1".to_string()];
922        let response = MDSTokenResponse {
923            access_token: "test-access-token".to_string(),
924            expires_in: None,
925            token_type: "test-token-type".to_string(),
926        };
927        server.expect(
928            Expectation::matching(all_of![
929                request::path(format!("{MDS_DEFAULT_URI}/token")),
930                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
931            ])
932            .respond_with(json_encoded(response)),
933        );
934
935        let mdsc = Builder::default()
936            .with_endpoint(format!("http://{}", server.addr()))
937            .with_scopes(scopes)
938            .build()?;
939        let headers = mdsc.headers(Extensions::new()).await?;
940        assert_eq!(
941            get_token_from_headers(headers.clone()).unwrap(),
942            "test-access-token"
943        );
944        assert_eq!(
945            get_token_type_from_headers(headers).unwrap(),
946            "test-token-type"
947        );
948
949        Ok(())
950    }
951
952    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
953    #[parallel]
954    async fn credentials_headers_retryable_error() -> TestResult {
955        let server = Server::run();
956        let scopes = vec!["scope1".to_string()];
957        server.expect(
958            Expectation::matching(all_of![
959                request::path(format!("{MDS_DEFAULT_URI}/token")),
960                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
961            ])
962            .respond_with(status_code(503)),
963        );
964
965        let mdsc = Builder::default()
966            .with_endpoint(format!("http://{}", server.addr()))
967            .with_scopes(scopes)
968            .build()?;
969        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
970        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
971        assert!(original_err.is_transient());
972        let source = find_source_error::<reqwest::Error>(&err);
973        assert!(
974            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
975            "{err:?}"
976        );
977
978        Ok(())
979    }
980
981    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
982    #[parallel]
983    async fn credentials_headers_nonretryable_error() -> TestResult {
984        let server = Server::run();
985        let scopes = vec!["scope1".to_string()];
986        server.expect(
987            Expectation::matching(all_of![
988                request::path(format!("{MDS_DEFAULT_URI}/token")),
989                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
990            ])
991            .respond_with(status_code(401)),
992        );
993
994        let mdsc = Builder::default()
995            .with_endpoint(format!("http://{}", server.addr()))
996            .with_scopes(scopes)
997            .build()?;
998
999        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1000        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1001        assert!(!original_err.is_transient());
1002        let source = find_source_error::<reqwest::Error>(&err);
1003        assert!(
1004            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1005            "{err:?}"
1006        );
1007
1008        Ok(())
1009    }
1010
1011    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1012    #[parallel]
1013    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1014        let server = Server::run();
1015        let scopes = vec!["scope1".to_string()];
1016        server.expect(
1017            Expectation::matching(all_of![
1018                request::path(format!("{MDS_DEFAULT_URI}/token")),
1019                request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1020            ])
1021            .respond_with(json_encoded("bad json")),
1022        );
1023
1024        let mdsc = Builder::default()
1025            .with_endpoint(format!("http://{}", server.addr()))
1026            .with_scopes(scopes)
1027            .build()?;
1028
1029        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1030        assert!(!e.is_transient());
1031
1032        Ok(())
1033    }
1034
1035    #[tokio::test]
1036    #[parallel]
1037    async fn get_default_universe_domain_success() -> TestResult {
1038        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1039        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1040        Ok(())
1041    }
1042
1043    #[tokio::test]
1044    #[parallel]
1045    async fn get_mds_signer() -> TestResult {
1046        use base64::{Engine, prelude::BASE64_STANDARD};
1047        use serde_json::json;
1048
1049        let server = Server::run();
1050        server.expect(
1051            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1052                .respond_with(json_encoded(MDSTokenResponse {
1053                    access_token: "test-access-token".to_string(),
1054                    expires_in: None,
1055                    token_type: "Bearer".to_string(),
1056                })),
1057        );
1058        server.expect(
1059            Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1060                .respond_with(status_code(200).body("test-client-email")),
1061        );
1062        server.expect(
1063            Expectation::matching(all_of![
1064                request::method_path(
1065                    "POST",
1066                    "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1067                ),
1068                request::headers(contains(("authorization", "Bearer test-access-token"))),
1069            ])
1070            .respond_with(json_encoded(json!({
1071                "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1072            }))),
1073        );
1074
1075        let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1076
1077        let signer = Builder::default()
1078            .with_endpoint(&endpoint)
1079            .build_signer_with_iam_endpoint_override(Some(endpoint))?;
1080
1081        let client_email = signer.client_email().await?;
1082        assert_eq!(client_email, "test-client-email");
1083
1084        let signature = signer.sign(b"test").await?;
1085        assert_eq!(signature.as_ref(), b"signed_blob");
1086
1087        Ok(())
1088    }
1089}