google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//! * Customize the **retry behavior** when fetching access tokens.
39//!
40//! ## Example: Creating credentials from a JSON object
41//!
42//! ```
43//! # use google_cloud_auth::credentials::user_account::Builder;
44//! # use google_cloud_auth::credentials::Credentials;
45//! # use http::Extensions;
46//! # tokio_test::block_on(async {
47//! let authorized_user = serde_json::json!({
48//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
49//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
50//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
51//!     "type": "authorized_user",
52//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
53//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
54//! });
55//! let credentials: Credentials = Builder::new(authorized_user).build()?;
56//! let headers = credentials.headers(Extensions::new()).await?;
57//! println!("Headers: {headers:?}");
58//! # Ok::<(), anyhow::Error>(())
59//! # });
60//! ```
61//!
62//! ## Example: Creating credentials with custom retry behavior
63//!
64//! ```
65//! # use google_cloud_auth::credentials::user_account::Builder;
66//! # use google_cloud_auth::credentials::Credentials;
67//! # use http::Extensions;
68//! # use std::time::Duration;
69//! # tokio_test::block_on(async {
70//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
71//! use gax::exponential_backoff::ExponentialBackoff;
72//! let authorized_user = serde_json::json!({
73//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
74//!     "client_secret": "YOUR_CLIENT_SECRET",
75//!     "refresh_token": "YOUR_REFRESH_TOKEN",
76//!     "type": "authorized_user",
77//! });
78//! let backoff = ExponentialBackoff::default();
79//! let credentials: Credentials = Builder::new(authorized_user)
80//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
81//!     .with_backoff_policy(backoff)
82//!     .build()?;
83//! let headers = credentials.headers(Extensions::new()).await?;
84//! println!("Headers: {headers:?}");
85//! # Ok::<(), anyhow::Error>(())
86//! # });
87//! ```
88//!
89//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
90//! [Cloud Identity]: https://cloud.google.com/identity
91//! [Google Accounts]: https://myaccount.google.com/
92//! [Google Workspace]: https://workspace.google.com/
93//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
94//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
95//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
96
97use crate::build_errors::Error as BuilderError;
98use crate::constants::OAUTH2_TOKEN_SERVER_URL;
99use crate::credentials::dynamic::CredentialsProvider;
100use crate::credentials::{CacheableResource, Credentials};
101use crate::errors::{self, CredentialsError};
102use crate::headers_util::build_cacheable_headers;
103use crate::retry::Builder as RetryTokenProviderBuilder;
104use crate::token::{CachedTokenProvider, Token, TokenProvider};
105use crate::token_cache::TokenCache;
106use crate::{BuildResult, Result};
107use gax::backoff_policy::BackoffPolicyArg;
108use gax::retry_policy::RetryPolicyArg;
109use gax::retry_throttler::RetryThrottlerArg;
110use http::header::CONTENT_TYPE;
111use http::{Extensions, HeaderMap, HeaderValue};
112use reqwest::{Client, Method};
113use serde_json::Value;
114use std::sync::Arc;
115use tokio::time::{Duration, Instant};
116
117/// A builder for constructing `user_account` [Credentials] instance.
118///
119/// # Example
120/// ```
121/// # use google_cloud_auth::credentials::user_account::Builder;
122/// # tokio_test::block_on(async {
123/// let authorized_user = serde_json::json!({ /* add details here */ });
124/// let credentials = Builder::new(authorized_user).build();
125/// })
126/// ```
127pub struct Builder {
128    authorized_user: Value,
129    scopes: Option<Vec<String>>,
130    quota_project_id: Option<String>,
131    token_uri: Option<String>,
132    retry_builder: RetryTokenProviderBuilder,
133}
134
135impl Builder {
136    /// Creates a new builder using `authorized_user` JSON value.
137    ///
138    /// The `authorized_user` JSON is typically generated when a user
139    /// authenticates using the [application-default login] process.
140    ///
141    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
142    pub fn new(authorized_user: Value) -> Self {
143        Self {
144            authorized_user,
145            scopes: None,
146            quota_project_id: None,
147            token_uri: None,
148            retry_builder: RetryTokenProviderBuilder::default(),
149        }
150    }
151
152    /// Sets the URI for the token endpoint used to fetch access tokens.
153    ///
154    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
155    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
156    ///
157    /// # Example
158    /// ```
159    /// # use google_cloud_auth::credentials::user_account::Builder;
160    /// let authorized_user = serde_json::json!({ /* add details here */ });
161    /// let credentials = Builder::new(authorized_user)
162    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
163    ///     .build();
164    /// ```
165    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
166        self.token_uri = Some(token_uri.into());
167        self
168    }
169
170    /// Sets the [scopes] for these credentials.
171    ///
172    /// `scopes` define the *permissions being requested* for this specific access token
173    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
174    /// IAM permissions, on the other hand, define the *underlying capabilities*
175    /// the user account possesses within a system. For example, `storage.buckets.delete`.
176    /// When a token generated with specific scopes is used, the request must be permitted
177    /// by both the user account's underlying IAM permissions and the scopes requested
178    /// for the token. Therefore, scopes act as an additional restriction on what the token
179    /// can be used for.
180    ///
181    /// # Example
182    /// ```
183    /// # use google_cloud_auth::credentials::user_account::Builder;
184    /// let authorized_user = serde_json::json!({ /* add details here */ });
185    /// let credentials = Builder::new(authorized_user)
186    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
187    ///     .build();
188    /// ```
189    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
190    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191    where
192        I: IntoIterator<Item = S>,
193        S: Into<String>,
194    {
195        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196        self
197    }
198
199    /// Sets the [quota project] for these credentials.
200    ///
201    /// In some services, you can use an account in
202    /// one project for authentication and authorization, and charge
203    /// the usage to a different project. This requires that the
204    /// user has `serviceusage.services.use` permissions on the quota project.
205    ///
206    /// Any value set here overrides a `quota_project_id` value from the
207    /// input `authorized_user` JSON.
208    ///
209    /// # Example
210    /// ```
211    /// # use google_cloud_auth::credentials::user_account::Builder;
212    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
213    /// let credentials = Builder::new(authorized_user)
214    ///     .with_quota_project_id("my-project")
215    ///     .build();
216    /// ```
217    ///
218    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
219    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
220        self.quota_project_id = Some(quota_project_id.into());
221        self
222    }
223
224    /// Configure the retry policy for fetching tokens.
225    ///
226    /// The retry policy controls how to handle retries, and sets limits on
227    /// the number of attempts or the total time spent retrying.
228    ///
229    /// ```
230    /// # use google_cloud_auth::credentials::user_account::Builder;
231    /// # tokio_test::block_on(async {
232    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
233    /// let authorized_user = serde_json::json!({
234    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
235    ///     "client_secret": "YOUR_CLIENT_SECRET",
236    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
237    ///     "type": "authorized_user",
238    /// });
239    /// let credentials = Builder::new(authorized_user)
240    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
241    ///     .build();
242    /// # });
243    /// ```
244    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
245        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
246        self
247    }
248
249    /// Configure the retry backoff policy.
250    ///
251    /// The backoff policy controls how long to wait in between retry attempts.
252    ///
253    /// ```
254    /// # use google_cloud_auth::credentials::user_account::Builder;
255    /// # use std::time::Duration;
256    /// # tokio_test::block_on(async {
257    /// use gax::exponential_backoff::ExponentialBackoff;
258    /// let authorized_user = serde_json::json!({
259    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
260    ///     "client_secret": "YOUR_CLIENT_SECRET",
261    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
262    ///     "type": "authorized_user",
263    /// });
264    /// let credentials = Builder::new(authorized_user)
265    ///     .with_backoff_policy(ExponentialBackoff::default())
266    ///     .build();
267    /// # });
268    /// ```
269    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
270        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
271        self
272    }
273
274    /// Configure the retry throttler.
275    ///
276    /// Advanced applications may want to configure a retry throttler to
277    /// [Address Cascading Failures] and when [Handling Overload] conditions.
278    /// The authentication library throttles its retry loop, using a policy to
279    /// control the throttling algorithm. Use this method to fine tune or
280    /// customize the default retry throttler.
281    ///
282    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
283    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
284    ///
285    /// ```
286    /// # use google_cloud_auth::credentials::user_account::Builder;
287    /// # tokio_test::block_on(async {
288    /// use gax::retry_throttler::AdaptiveThrottler;
289    /// let authorized_user = serde_json::json!({
290    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
291    ///     "client_secret": "YOUR_CLIENT_SECRET",
292    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
293    ///     "type": "authorized_user",
294    /// });
295    /// let credentials = Builder::new(authorized_user)
296    ///     .with_retry_throttler(AdaptiveThrottler::default())
297    ///     .build();
298    /// # });
299    /// ```
300    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
301        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
302        self
303    }
304
305    /// Returns a [Credentials] instance with the configured settings.
306    ///
307    /// # Errors
308    ///
309    /// Returns a [CredentialsError] if the `authorized_user`
310    /// provided to [`Builder::new`] cannot be successfully deserialized into the
311    /// expected format. This typically happens if the JSON value is malformed or
312    /// missing required fields. For more information, on how to generate
313    /// `authorized_user` json, consult the relevant section in the
314    /// [application-default credentials] guide.
315    ///
316    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
317    pub fn build(self) -> BuildResult<Credentials> {
318        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
319            .map_err(BuilderError::parsing)?;
320        let endpoint = self
321            .token_uri
322            .or(authorized_user.token_uri)
323            .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
324        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
325
326        let token_provider = UserTokenProvider {
327            client_id: authorized_user.client_id,
328            client_secret: authorized_user.client_secret,
329            refresh_token: authorized_user.refresh_token,
330            endpoint,
331            scopes: self.scopes.map(|scopes| scopes.join(" ")),
332            source: UserTokenSource::AccessToken,
333        };
334
335        let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337        Ok(Credentials {
338            inner: Arc::new(UserCredentials {
339                token_provider,
340                quota_project_id,
341            }),
342        })
343    }
344}
345
346#[derive(PartialEq)]
347pub(crate) struct UserTokenProvider {
348    client_id: String,
349    client_secret: String,
350    refresh_token: String,
351    endpoint: String,
352    scopes: Option<String>,
353    source: UserTokenSource,
354}
355
356#[allow(dead_code)]
357#[derive(PartialEq)]
358enum UserTokenSource {
359    IdToken,
360    AccessToken,
361}
362
363impl std::fmt::Debug for UserTokenProvider {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        f.debug_struct("UserCredentials")
366            .field("client_id", &self.client_id)
367            .field("client_secret", &"[censored]")
368            .field("refresh_token", &"[censored]")
369            .field("endpoint", &self.endpoint)
370            .field("scopes", &self.scopes)
371            .finish()
372    }
373}
374
375impl UserTokenProvider {
376    #[cfg(google_cloud_unstable_id_token)]
377    pub(crate) fn new_id_token_provider(
378        authorized_user: AuthorizedUser,
379        token_uri: Option<String>,
380    ) -> UserTokenProvider {
381        let endpoint = token_uri
382            .or(authorized_user.token_uri)
383            .unwrap_or(OAUTH2_TOKEN_SERVER_URL.to_string());
384        UserTokenProvider {
385            client_id: authorized_user.client_id,
386            client_secret: authorized_user.client_secret,
387            refresh_token: authorized_user.refresh_token,
388            endpoint,
389            source: UserTokenSource::IdToken,
390            scopes: None,
391        }
392    }
393}
394
395#[async_trait::async_trait]
396impl TokenProvider for UserTokenProvider {
397    async fn token(&self) -> Result<Token> {
398        let client = Client::new();
399
400        // Make the request
401        let req = Oauth2RefreshRequest {
402            grant_type: RefreshGrantType::RefreshToken,
403            client_id: self.client_id.clone(),
404            client_secret: self.client_secret.clone(),
405            refresh_token: self.refresh_token.clone(),
406            scopes: self.scopes.clone(),
407        };
408        let header = HeaderValue::from_static("application/json");
409        let builder = client
410            .request(Method::POST, self.endpoint.as_str())
411            .header(CONTENT_TYPE, header)
412            .json(&req);
413        let resp = builder
414            .send()
415            .await
416            .map_err(|e| errors::from_http_error(e, MSG))?;
417
418        // Process the response
419        if !resp.status().is_success() {
420            let err = errors::from_http_response(resp, MSG).await;
421            return Err(err);
422        }
423        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
424            let retryable = !e.is_decode();
425            CredentialsError::from_source(retryable, e)
426        })?;
427
428        let token = match self.source {
429            UserTokenSource::AccessToken => Ok(response.access_token),
430            UserTokenSource::IdToken => response
431                .id_token
432                .ok_or_else(|| CredentialsError::from_msg(false, MISSING_ID_TOKEN_MSG)),
433        }?;
434        let token = Token {
435            token,
436            token_type: response.token_type,
437            expires_at: response
438                .expires_in
439                .map(|d| Instant::now() + Duration::from_secs(d)),
440            metadata: None,
441        };
442        Ok(token)
443    }
444}
445
446const MSG: &str = "failed to refresh user access token";
447const MISSING_ID_TOKEN_MSG: &str = "UserCredentials can obtain an id token only when authenticated through \
448gcloud running 'gcloud auth application-default login`";
449
450/// Data model for a UserCredentials
451///
452/// See: https://cloud.google.com/docs/authentication#user-accounts
453#[derive(Debug)]
454pub(crate) struct UserCredentials<T>
455where
456    T: CachedTokenProvider,
457{
458    token_provider: T,
459    quota_project_id: Option<String>,
460}
461
462#[async_trait::async_trait]
463impl<T> CredentialsProvider for UserCredentials<T>
464where
465    T: CachedTokenProvider,
466{
467    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
468        let token = self.token_provider.token(extensions).await?;
469        build_cacheable_headers(&token, &self.quota_project_id)
470    }
471}
472
473#[derive(Debug, PartialEq, serde::Deserialize)]
474pub(crate) struct AuthorizedUser {
475    #[serde(rename = "type")]
476    cred_type: String,
477    client_id: String,
478    client_secret: String,
479    refresh_token: String,
480    #[serde(skip_serializing_if = "Option::is_none")]
481    token_uri: Option<String>,
482    #[serde(skip_serializing_if = "Option::is_none")]
483    quota_project_id: Option<String>,
484}
485
486#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
487pub(crate) enum RefreshGrantType {
488    #[serde(rename = "refresh_token")]
489    RefreshToken,
490}
491
492#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
493pub(crate) struct Oauth2RefreshRequest {
494    pub(crate) grant_type: RefreshGrantType,
495    pub(crate) client_id: String,
496    pub(crate) client_secret: String,
497    pub(crate) refresh_token: String,
498    scopes: Option<String>,
499}
500
501#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
502pub(crate) struct Oauth2RefreshResponse {
503    pub(crate) access_token: String,
504    #[serde(skip_serializing_if = "Option::is_none")]
505    pub(crate) id_token: Option<String>,
506    #[serde(skip_serializing_if = "Option::is_none")]
507    pub(crate) scope: Option<String>,
508    #[serde(skip_serializing_if = "Option::is_none")]
509    pub(crate) expires_in: Option<u64>,
510    pub(crate) token_type: String,
511    #[serde(skip_serializing_if = "Option::is_none")]
512    pub(crate) refresh_token: Option<String>,
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::credentials::tests::{
519        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
520        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
521        get_token_type_from_headers,
522    };
523    use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
524    use crate::errors::CredentialsError;
525    use crate::token::tests::MockTokenProvider;
526    use http::StatusCode;
527    use http::header::AUTHORIZATION;
528    use httptest::cycle;
529    use httptest::matchers::{all_of, json_decoded, request};
530    use httptest::responders::{json_encoded, status_code};
531    use httptest::{Expectation, Server};
532
533    type TestResult = anyhow::Result<()>;
534
535    fn authorized_user_json(token_uri: String) -> Value {
536        serde_json::json!({
537            "client_id": "test-client-id",
538            "client_secret": "test-client-secret",
539            "refresh_token": "test-refresh-token",
540            "type": "authorized_user",
541            "token_uri": token_uri,
542        })
543    }
544
545    #[tokio::test]
546    async fn test_user_account_retries_on_transient_failures() -> TestResult {
547        let mut server = Server::run();
548        server.expect(
549            Expectation::matching(request::path("/token"))
550                .times(3)
551                .respond_with(status_code(503)),
552        );
553
554        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
555            .with_retry_policy(get_mock_auth_retry_policy(3))
556            .with_backoff_policy(get_mock_backoff_policy())
557            .with_retry_throttler(get_mock_retry_throttler())
558            .build()?;
559
560        let err = credentials.headers(Extensions::new()).await.unwrap_err();
561        assert!(!err.is_transient());
562        server.verify_and_clear();
563        Ok(())
564    }
565
566    #[tokio::test]
567    async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
568        let mut server = Server::run();
569        server.expect(
570            Expectation::matching(request::path("/token"))
571                .times(1)
572                .respond_with(status_code(401)),
573        );
574
575        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
576            .with_retry_policy(get_mock_auth_retry_policy(1))
577            .with_backoff_policy(get_mock_backoff_policy())
578            .with_retry_throttler(get_mock_retry_throttler())
579            .build()?;
580
581        let err = credentials.headers(Extensions::new()).await.unwrap_err();
582        assert!(!err.is_transient());
583        server.verify_and_clear();
584        Ok(())
585    }
586
587    #[tokio::test]
588    async fn test_user_account_retries_for_success() -> TestResult {
589        let mut server = Server::run();
590        let response = Oauth2RefreshResponse {
591            access_token: "test-access-token".to_string(),
592            id_token: None,
593            expires_in: Some(3600),
594            refresh_token: Some("test-refresh-token".to_string()),
595            scope: Some("scope1 scope2".to_string()),
596            token_type: "test-token-type".to_string(),
597        };
598
599        server.expect(
600            Expectation::matching(request::path("/token"))
601                .times(3)
602                .respond_with(cycle![
603                    status_code(503).body("try-again"),
604                    status_code(503).body("try-again"),
605                    status_code(200)
606                        .append_header("Content-Type", "application/json")
607                        .body(serde_json::to_string(&response).unwrap()),
608                ]),
609        );
610
611        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
612            .with_retry_policy(get_mock_auth_retry_policy(3))
613            .with_backoff_policy(get_mock_backoff_policy())
614            .with_retry_throttler(get_mock_retry_throttler())
615            .build()?;
616
617        let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
618        assert_eq!(token.unwrap(), "test-access-token");
619
620        server.verify_and_clear();
621        Ok(())
622    }
623
624    #[test]
625    fn debug_token_provider() {
626        let expected = UserTokenProvider {
627            client_id: "test-client-id".to_string(),
628            client_secret: "test-client-secret".to_string(),
629            refresh_token: "test-refresh-token".to_string(),
630            endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(),
631            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
632            source: UserTokenSource::AccessToken,
633        };
634        let fmt = format!("{expected:?}");
635        assert!(fmt.contains("test-client-id"), "{fmt}");
636        assert!(!fmt.contains("test-client-secret"), "{fmt}");
637        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
638        assert!(fmt.contains(OAUTH2_TOKEN_SERVER_URL), "{fmt}");
639        assert!(
640            fmt.contains("https://www.googleapis.com/auth/pubsub"),
641            "{fmt}"
642        );
643    }
644
645    #[test]
646    fn authorized_user_full_from_json_success() {
647        let json = serde_json::json!({
648            "account": "",
649            "client_id": "test-client-id",
650            "client_secret": "test-client-secret",
651            "refresh_token": "test-refresh-token",
652            "type": "authorized_user",
653            "universe_domain": "googleapis.com",
654            "quota_project_id": "test-project",
655            "token_uri" : "test-token-uri",
656        });
657
658        let expected = AuthorizedUser {
659            cred_type: "authorized_user".to_string(),
660            client_id: "test-client-id".to_string(),
661            client_secret: "test-client-secret".to_string(),
662            refresh_token: "test-refresh-token".to_string(),
663            quota_project_id: Some("test-project".to_string()),
664            token_uri: Some("test-token-uri".to_string()),
665        };
666        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
667        assert_eq!(actual, expected);
668    }
669
670    #[test]
671    fn authorized_user_partial_from_json_success() {
672        let json = serde_json::json!({
673            "client_id": "test-client-id",
674            "client_secret": "test-client-secret",
675            "refresh_token": "test-refresh-token",
676            "type": "authorized_user",
677        });
678
679        let expected = AuthorizedUser {
680            cred_type: "authorized_user".to_string(),
681            client_id: "test-client-id".to_string(),
682            client_secret: "test-client-secret".to_string(),
683            refresh_token: "test-refresh-token".to_string(),
684            quota_project_id: None,
685            token_uri: None,
686        };
687        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
688        assert_eq!(actual, expected);
689    }
690
691    #[test]
692    fn authorized_user_from_json_parse_fail() {
693        let json_full = serde_json::json!({
694            "client_id": "test-client-id",
695            "client_secret": "test-client-secret",
696            "refresh_token": "test-refresh-token",
697            "type": "authorized_user",
698            "quota_project_id": "test-project"
699        });
700
701        for required_field in ["client_id", "client_secret", "refresh_token"] {
702            let mut json = json_full.clone();
703            // Remove a required field from the JSON
704            json[required_field].take();
705            serde_json::from_value::<AuthorizedUser>(json)
706                .err()
707                .unwrap();
708        }
709    }
710
711    #[tokio::test]
712    async fn default_universe_domain_success() {
713        let mock = TokenCache::new(MockTokenProvider::new());
714
715        let uc = UserCredentials {
716            token_provider: mock,
717            quota_project_id: None,
718        };
719        assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
720    }
721
722    #[tokio::test]
723    async fn headers_success() -> TestResult {
724        let token = Token {
725            token: "test-token".to_string(),
726            token_type: "Bearer".to_string(),
727            expires_at: None,
728            metadata: None,
729        };
730
731        let mut mock = MockTokenProvider::new();
732        mock.expect_token().times(1).return_once(|| Ok(token));
733
734        let uc = UserCredentials {
735            token_provider: TokenCache::new(mock),
736            quota_project_id: None,
737        };
738
739        let mut extensions = Extensions::new();
740        let cached_headers = uc.headers(extensions.clone()).await.unwrap();
741        let (headers, entity_tag) = match cached_headers {
742            CacheableResource::New { entity_tag, data } => (data, entity_tag),
743            CacheableResource::NotModified => unreachable!("expecting new headers"),
744        };
745        let token = headers.get(AUTHORIZATION).unwrap();
746
747        assert_eq!(headers.len(), 1, "{headers:?}");
748        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
749        assert!(token.is_sensitive());
750
751        extensions.insert(entity_tag);
752
753        let cached_headers = uc.headers(extensions).await?;
754
755        match cached_headers {
756            CacheableResource::New { .. } => unreachable!("expecting new headers"),
757            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
758        };
759        Ok(())
760    }
761
762    #[tokio::test]
763    async fn headers_failure() {
764        let mut mock = MockTokenProvider::new();
765        mock.expect_token()
766            .times(1)
767            .return_once(|| Err(errors::non_retryable_from_str("fail")));
768
769        let uc = UserCredentials {
770            token_provider: TokenCache::new(mock),
771            quota_project_id: None,
772        };
773        assert!(uc.headers(Extensions::new()).await.is_err());
774    }
775
776    #[tokio::test]
777    async fn headers_with_quota_project_success() -> TestResult {
778        let token = Token {
779            token: "test-token".to_string(),
780            token_type: "Bearer".to_string(),
781            expires_at: None,
782            metadata: None,
783        };
784
785        let mut mock = MockTokenProvider::new();
786        mock.expect_token().times(1).return_once(|| Ok(token));
787
788        let uc = UserCredentials {
789            token_provider: TokenCache::new(mock),
790            quota_project_id: Some("test-project".to_string()),
791        };
792
793        let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
794        let token = headers.get(AUTHORIZATION).unwrap();
795        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
796
797        assert_eq!(headers.len(), 2, "{headers:?}");
798        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
799        assert!(token.is_sensitive());
800        assert_eq!(
801            quota_project_header,
802            HeaderValue::from_static("test-project")
803        );
804        assert!(!quota_project_header.is_sensitive());
805        Ok(())
806    }
807
808    #[test]
809    fn oauth2_request_serde() {
810        let request = Oauth2RefreshRequest {
811            grant_type: RefreshGrantType::RefreshToken,
812            client_id: "test-client-id".to_string(),
813            client_secret: "test-client-secret".to_string(),
814            refresh_token: "test-refresh-token".to_string(),
815            scopes: Some("scope1 scope2".to_string()),
816        };
817
818        let json = serde_json::to_value(&request).unwrap();
819        let expected = serde_json::json!({
820            "grant_type": "refresh_token",
821            "client_id": "test-client-id",
822            "client_secret": "test-client-secret",
823            "refresh_token": "test-refresh-token",
824            "scopes": "scope1 scope2",
825        });
826        assert_eq!(json, expected);
827        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
828        assert_eq!(request, roundtrip);
829    }
830
831    #[test]
832    fn oauth2_response_serde_full() {
833        let response = Oauth2RefreshResponse {
834            access_token: "test-access-token".to_string(),
835            id_token: None,
836            scope: Some("scope1 scope2".to_string()),
837            expires_in: Some(3600),
838            token_type: "test-token-type".to_string(),
839            refresh_token: Some("test-refresh-token".to_string()),
840        };
841
842        let json = serde_json::to_value(&response).unwrap();
843        let expected = serde_json::json!({
844            "access_token": "test-access-token",
845            "scope": "scope1 scope2",
846            "expires_in": 3600,
847            "token_type": "test-token-type",
848            "refresh_token": "test-refresh-token"
849        });
850        assert_eq!(json, expected);
851        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
852        assert_eq!(response, roundtrip);
853    }
854
855    #[test]
856    fn oauth2_response_serde_partial() {
857        let response = Oauth2RefreshResponse {
858            access_token: "test-access-token".to_string(),
859            id_token: None,
860            scope: None,
861            expires_in: None,
862            token_type: "test-token-type".to_string(),
863            refresh_token: None,
864        };
865
866        let json = serde_json::to_value(&response).unwrap();
867        let expected = serde_json::json!({
868            "access_token": "test-access-token",
869            "token_type": "test-token-type",
870        });
871        assert_eq!(json, expected);
872        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
873        assert_eq!(response, roundtrip);
874    }
875
876    fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
877        request.client_id == "test-client-id"
878            && request.client_secret == "test-client-secret"
879            && request.refresh_token == "test-refresh-token"
880            && request.grant_type == RefreshGrantType::RefreshToken
881            && request.scopes == expected_scopes
882    }
883
884    #[tokio::test(start_paused = true)]
885    async fn token_provider_full() -> TestResult {
886        let server = Server::run();
887        let response = Oauth2RefreshResponse {
888            access_token: "test-access-token".to_string(),
889            id_token: None,
890            expires_in: Some(3600),
891            refresh_token: Some("test-refresh-token".to_string()),
892            scope: Some("scope1 scope2".to_string()),
893            token_type: "test-token-type".to_string(),
894        };
895        server.expect(
896            Expectation::matching(all_of![
897                request::path("/token"),
898                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
899                    check_request(req, Some("scope1 scope2".to_string()))
900                }))
901            ])
902            .respond_with(json_encoded(response)),
903        );
904
905        let tp = UserTokenProvider {
906            client_id: "test-client-id".to_string(),
907            client_secret: "test-client-secret".to_string(),
908            refresh_token: "test-refresh-token".to_string(),
909            endpoint: server.url("/token").to_string(),
910            scopes: Some("scope1 scope2".to_string()),
911            source: UserTokenSource::AccessToken,
912        };
913        let now = Instant::now();
914        let token = tp.token().await?;
915        assert_eq!(token.token, "test-access-token");
916        assert_eq!(token.token_type, "test-token-type");
917        assert!(
918            token
919                .expires_at
920                .is_some_and(|d| d == now + Duration::from_secs(3600)),
921            "now: {:?}, expires_at: {:?}",
922            now,
923            token.expires_at
924        );
925
926        Ok(())
927    }
928
929    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
930    async fn credential_full_with_quota_project() -> TestResult {
931        let server = Server::run();
932        let response = Oauth2RefreshResponse {
933            access_token: "test-access-token".to_string(),
934            id_token: None,
935            expires_in: Some(3600),
936            refresh_token: Some("test-refresh-token".to_string()),
937            scope: None,
938            token_type: "test-token-type".to_string(),
939        };
940        server.expect(
941            Expectation::matching(all_of![
942                request::path("/token"),
943                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
944                    check_request(req, None)
945                }))
946            ])
947            .respond_with(json_encoded(response)),
948        );
949
950        let authorized_user = serde_json::json!({
951            "client_id": "test-client-id",
952            "client_secret": "test-client-secret",
953            "refresh_token": "test-refresh-token",
954            "type": "authorized_user",
955            "token_uri": server.url("/token").to_string(),
956        });
957        let cred = Builder::new(authorized_user)
958            .with_quota_project_id("test-project")
959            .build()?;
960
961        let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
962        let token = headers.get(AUTHORIZATION).unwrap();
963        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
964
965        assert_eq!(headers.len(), 2, "{headers:?}");
966        assert_eq!(
967            token,
968            HeaderValue::from_static("test-token-type test-access-token")
969        );
970        assert!(token.is_sensitive());
971        assert_eq!(
972            quota_project_header,
973            HeaderValue::from_static("test-project")
974        );
975        assert!(!quota_project_header.is_sensitive());
976
977        Ok(())
978    }
979
980    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
981    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
982        let mut server = Server::run();
983        let response = Oauth2RefreshResponse {
984            access_token: "test-access-token".to_string(),
985            id_token: None,
986            expires_in: Some(3600),
987            refresh_token: Some("test-refresh-token".to_string()),
988            scope: Some("scope1 scope2".to_string()),
989            token_type: "test-token-type".to_string(),
990        };
991        server.expect(
992            Expectation::matching(all_of![
993                request::path("/token"),
994                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
995                    check_request(req, Some("scope1 scope2".to_string()))
996                }))
997            ])
998            .times(1)
999            .respond_with(json_encoded(response)),
1000        );
1001
1002        let json = serde_json::json!({
1003            "client_id": "test-client-id",
1004            "client_secret": "test-client-secret",
1005            "refresh_token": "test-refresh-token",
1006            "type": "authorized_user",
1007            "universe_domain": "googleapis.com",
1008            "quota_project_id": "test-project",
1009            "token_uri": server.url("/token").to_string(),
1010        });
1011
1012        let cred = Builder::new(json)
1013            .with_scopes(vec!["scope1", "scope2"])
1014            .build()?;
1015
1016        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1017        assert_eq!(token.unwrap(), "test-access-token");
1018
1019        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
1020        assert_eq!(token.unwrap(), "test-access-token");
1021
1022        server.verify_and_clear();
1023
1024        Ok(())
1025    }
1026
1027    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1028    async fn credential_provider_partial() -> TestResult {
1029        let server = Server::run();
1030        let response = Oauth2RefreshResponse {
1031            access_token: "test-access-token".to_string(),
1032            id_token: None,
1033            expires_in: None,
1034            refresh_token: None,
1035            scope: None,
1036            token_type: "test-token-type".to_string(),
1037        };
1038        server.expect(
1039            Expectation::matching(all_of![
1040                request::path("/token"),
1041                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1042                    check_request(req, None)
1043                }))
1044            ])
1045            .respond_with(json_encoded(response)),
1046        );
1047
1048        let authorized_user = serde_json::json!({
1049            "client_id": "test-client-id",
1050            "client_secret": "test-client-secret",
1051            "refresh_token": "test-refresh-token",
1052            "type": "authorized_user",
1053            "token_uri": server.url("/token").to_string()
1054        });
1055
1056        let uc = Builder::new(authorized_user).build()?;
1057        let headers = uc.headers(Extensions::new()).await?;
1058        assert_eq!(
1059            get_token_from_headers(headers.clone()).unwrap(),
1060            "test-access-token"
1061        );
1062        assert_eq!(
1063            get_token_type_from_headers(headers).unwrap(),
1064            "test-token-type"
1065        );
1066
1067        Ok(())
1068    }
1069
1070    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1071    async fn credential_provider_with_token_uri() -> TestResult {
1072        let server = Server::run();
1073        let response = Oauth2RefreshResponse {
1074            access_token: "test-access-token".to_string(),
1075            id_token: None,
1076            expires_in: None,
1077            refresh_token: None,
1078            scope: None,
1079            token_type: "test-token-type".to_string(),
1080        };
1081        server.expect(
1082            Expectation::matching(all_of![
1083                request::path("/token"),
1084                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1085                    check_request(req, None)
1086                }))
1087            ])
1088            .respond_with(json_encoded(response)),
1089        );
1090
1091        let authorized_user = serde_json::json!({
1092            "client_id": "test-client-id",
1093            "client_secret": "test-client-secret",
1094            "refresh_token": "test-refresh-token",
1095            "type": "authorized_user",
1096            "token_uri": "test-endpoint"
1097        });
1098
1099        let uc = Builder::new(authorized_user)
1100            .with_token_uri(server.url("/token").to_string())
1101            .build()?;
1102        let headers = uc.headers(Extensions::new()).await?;
1103        assert_eq!(
1104            get_token_from_headers(headers.clone()).unwrap(),
1105            "test-access-token"
1106        );
1107        assert_eq!(
1108            get_token_type_from_headers(headers).unwrap(),
1109            "test-token-type"
1110        );
1111
1112        Ok(())
1113    }
1114
1115    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1116    async fn credential_provider_with_scopes() -> TestResult {
1117        let server = Server::run();
1118        let response = Oauth2RefreshResponse {
1119            access_token: "test-access-token".to_string(),
1120            id_token: None,
1121            expires_in: None,
1122            refresh_token: None,
1123            scope: Some("scope1 scope2".to_string()),
1124            token_type: "test-token-type".to_string(),
1125        };
1126        server.expect(
1127            Expectation::matching(all_of![
1128                request::path("/token"),
1129                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1130                    check_request(req, Some("scope1 scope2".to_string()))
1131                }))
1132            ])
1133            .respond_with(json_encoded(response)),
1134        );
1135
1136        let authorized_user = serde_json::json!({
1137            "client_id": "test-client-id",
1138            "client_secret": "test-client-secret",
1139            "refresh_token": "test-refresh-token",
1140            "type": "authorized_user",
1141            "token_uri": "test-endpoint"
1142        });
1143
1144        let uc = Builder::new(authorized_user)
1145            .with_token_uri(server.url("/token").to_string())
1146            .with_scopes(vec!["scope1", "scope2"])
1147            .build()?;
1148        let headers = uc.headers(Extensions::new()).await?;
1149        assert_eq!(
1150            get_token_from_headers(headers.clone()).unwrap(),
1151            "test-access-token"
1152        );
1153        assert_eq!(
1154            get_token_type_from_headers(headers).unwrap(),
1155            "test-token-type"
1156        );
1157
1158        Ok(())
1159    }
1160
1161    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1162    async fn credential_provider_retryable_error() -> TestResult {
1163        let server = Server::run();
1164        server
1165            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1166
1167        let authorized_user = serde_json::json!({
1168            "client_id": "test-client-id",
1169            "client_secret": "test-client-secret",
1170            "refresh_token": "test-refresh-token",
1171            "type": "authorized_user",
1172            "token_uri": server.url("/token").to_string()
1173        });
1174
1175        let uc = Builder::new(authorized_user).build()?;
1176        let err = uc.headers(Extensions::new()).await.unwrap_err();
1177        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1178        assert!(original_err.is_transient());
1179
1180        let source = find_source_error::<reqwest::Error>(&err);
1181        assert!(
1182            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1183            "{err:?}"
1184        );
1185
1186        Ok(())
1187    }
1188
1189    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1190    async fn token_provider_nonretryable_error() -> TestResult {
1191        let server = Server::run();
1192        server
1193            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1194
1195        let authorized_user = serde_json::json!({
1196            "client_id": "test-client-id",
1197            "client_secret": "test-client-secret",
1198            "refresh_token": "test-refresh-token",
1199            "type": "authorized_user",
1200            "token_uri": server.url("/token").to_string()
1201        });
1202
1203        let uc = Builder::new(authorized_user).build()?;
1204        let err = uc.headers(Extensions::new()).await.unwrap_err();
1205        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1206        assert!(!original_err.is_transient());
1207
1208        let source = find_source_error::<reqwest::Error>(&err);
1209        assert!(
1210            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1211            "{err:?}"
1212        );
1213
1214        Ok(())
1215    }
1216
1217    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1218    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1219        let server = Server::run();
1220        server.expect(
1221            Expectation::matching(request::path("/token"))
1222                .respond_with(json_encoded("bad json".to_string())),
1223        );
1224
1225        let authorized_user = serde_json::json!({
1226            "client_id": "test-client-id",
1227            "client_secret": "test-client-secret",
1228            "refresh_token": "test-refresh-token",
1229            "type": "authorized_user",
1230            "token_uri": server.url("/token").to_string()
1231        });
1232
1233        let uc = Builder::new(authorized_user).build()?;
1234        let e = uc.headers(Extensions::new()).await.err().unwrap();
1235        assert!(!e.is_transient(), "{e}");
1236
1237        Ok(())
1238    }
1239
1240    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1241    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1242        let authorized_user = serde_json::json!({
1243            "client_secret": "test-client-secret",
1244            "refresh_token": "test-refresh-token",
1245            "type": "authorized_user",
1246        });
1247
1248        let e = Builder::new(authorized_user).build().unwrap_err();
1249        assert!(e.is_parsing(), "{e}");
1250
1251        Ok(())
1252    }
1253}