google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//! * Customize the **retry behavior** when fetching access tokens.
39//!
40//! ## Example: Creating credentials from a JSON object
41//!
42//! ```
43//! # use google_cloud_auth::credentials::user_account::Builder;
44//! # use google_cloud_auth::credentials::Credentials;
45//! # use http::Extensions;
46//! # tokio_test::block_on(async {
47//! let authorized_user = serde_json::json!({
48//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
49//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
50//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
51//!     "type": "authorized_user",
52//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
53//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
54//! });
55//! let credentials: Credentials = Builder::new(authorized_user).build()?;
56//! let headers = credentials.headers(Extensions::new()).await?;
57//! println!("Headers: {headers:?}");
58//! # Ok::<(), anyhow::Error>(())
59//! # });
60//! ```
61//!
62//! ## Example: Creating credentials with custom retry behavior
63//!
64//! ```
65//! # use google_cloud_auth::credentials::user_account::Builder;
66//! # use google_cloud_auth::credentials::Credentials;
67//! # use http::Extensions;
68//! # use std::time::Duration;
69//! # tokio_test::block_on(async {
70//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
71//! use gax::exponential_backoff::ExponentialBackoff;
72//! let authorized_user = serde_json::json!({
73//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
74//!     "client_secret": "YOUR_CLIENT_SECRET",
75//!     "refresh_token": "YOUR_REFRESH_TOKEN",
76//!     "type": "authorized_user",
77//! });
78//! let backoff = ExponentialBackoff::default();
79//! let credentials: Credentials = Builder::new(authorized_user)
80//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
81//!     .with_backoff_policy(backoff)
82//!     .build()?;
83//! let headers = credentials.headers(Extensions::new()).await?;
84//! println!("Headers: {headers:?}");
85//! # Ok::<(), anyhow::Error>(())
86//! # });
87//! ```
88//!
89//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
90//! [Cloud Identity]: https://cloud.google.com/identity
91//! [Google Accounts]: https://myaccount.google.com/
92//! [Google Workspace]: https://workspace.google.com/
93//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
94//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
95//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
96
97use crate::build_errors::Error as BuilderError;
98use crate::credentials::dynamic::CredentialsProvider;
99use crate::credentials::{CacheableResource, Credentials};
100use crate::errors::{self, CredentialsError};
101use crate::headers_util::build_cacheable_headers;
102use crate::retry::Builder as RetryTokenProviderBuilder;
103use crate::token::{CachedTokenProvider, Token, TokenProvider};
104use crate::token_cache::TokenCache;
105use crate::{BuildResult, Result};
106use gax::backoff_policy::BackoffPolicyArg;
107use gax::retry_policy::RetryPolicyArg;
108use gax::retry_throttler::RetryThrottlerArg;
109use http::header::CONTENT_TYPE;
110use http::{Extensions, HeaderMap, HeaderValue};
111use reqwest::{Client, Method};
112use serde_json::Value;
113use std::sync::Arc;
114use tokio::time::{Duration, Instant};
115
116const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
117
118/// A builder for constructing `user_account` [Credentials] instance.
119///
120/// # Example
121/// ```
122/// # use google_cloud_auth::credentials::user_account::Builder;
123/// # tokio_test::block_on(async {
124/// let authorized_user = serde_json::json!({ /* add details here */ });
125/// let credentials = Builder::new(authorized_user).build();
126/// })
127/// ```
128pub struct Builder {
129    authorized_user: Value,
130    scopes: Option<Vec<String>>,
131    quota_project_id: Option<String>,
132    token_uri: Option<String>,
133    retry_builder: RetryTokenProviderBuilder,
134}
135
136impl Builder {
137    /// Creates a new builder using `authorized_user` JSON value.
138    ///
139    /// The `authorized_user` JSON is typically generated when a user
140    /// authenticates using the [application-default login] process.
141    ///
142    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
143    pub fn new(authorized_user: Value) -> Self {
144        Self {
145            authorized_user,
146            scopes: None,
147            quota_project_id: None,
148            token_uri: None,
149            retry_builder: RetryTokenProviderBuilder::default(),
150        }
151    }
152
153    /// Sets the URI for the token endpoint used to fetch access tokens.
154    ///
155    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
156    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
157    ///
158    /// # Example
159    /// ```
160    /// # use google_cloud_auth::credentials::user_account::Builder;
161    /// let authorized_user = serde_json::json!({ /* add details here */ });
162    /// let credentials = Builder::new(authorized_user)
163    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
164    ///     .build();
165    /// ```
166    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
167        self.token_uri = Some(token_uri.into());
168        self
169    }
170
171    /// Sets the [scopes] for these credentials.
172    ///
173    /// `scopes` define the *permissions being requested* for this specific access token
174    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
175    /// IAM permissions, on the other hand, define the *underlying capabilities*
176    /// the user account possesses within a system. For example, `storage.buckets.delete`.
177    /// When a token generated with specific scopes is used, the request must be permitted
178    /// by both the user account's underlying IAM permissions and the scopes requested
179    /// for the token. Therefore, scopes act as an additional restriction on what the token
180    /// can be used for.
181    ///
182    /// # Example
183    /// ```
184    /// # use google_cloud_auth::credentials::user_account::Builder;
185    /// let authorized_user = serde_json::json!({ /* add details here */ });
186    /// let credentials = Builder::new(authorized_user)
187    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
188    ///     .build();
189    /// ```
190    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
191    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
192    where
193        I: IntoIterator<Item = S>,
194        S: Into<String>,
195    {
196        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
197        self
198    }
199
200    /// Sets the [quota project] for these credentials.
201    ///
202    /// In some services, you can use an account in
203    /// one project for authentication and authorization, and charge
204    /// the usage to a different project. This requires that the
205    /// user has `serviceusage.services.use` permissions on the quota project.
206    ///
207    /// Any value set here overrides a `quota_project_id` value from the
208    /// input `authorized_user` JSON.
209    ///
210    /// # Example
211    /// ```
212    /// # use google_cloud_auth::credentials::user_account::Builder;
213    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
214    /// let credentials = Builder::new(authorized_user)
215    ///     .with_quota_project_id("my-project")
216    ///     .build();
217    /// ```
218    ///
219    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
220    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
221        self.quota_project_id = Some(quota_project_id.into());
222        self
223    }
224
225    /// Configure the retry policy for fetching tokens.
226    ///
227    /// The retry policy controls how to handle retries, and sets limits on
228    /// the number of attempts or the total time spent retrying.
229    ///
230    /// ```
231    /// # use google_cloud_auth::credentials::user_account::Builder;
232    /// # tokio_test::block_on(async {
233    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
234    /// let authorized_user = serde_json::json!({
235    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
236    ///     "client_secret": "YOUR_CLIENT_SECRET",
237    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
238    ///     "type": "authorized_user",
239    /// });
240    /// let credentials = Builder::new(authorized_user)
241    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
242    ///     .build();
243    /// # });
244    /// ```
245    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
246        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
247        self
248    }
249
250    /// Configure the retry backoff policy.
251    ///
252    /// The backoff policy controls how long to wait in between retry attempts.
253    ///
254    /// ```
255    /// # use google_cloud_auth::credentials::user_account::Builder;
256    /// # use std::time::Duration;
257    /// # tokio_test::block_on(async {
258    /// use gax::exponential_backoff::ExponentialBackoff;
259    /// let authorized_user = serde_json::json!({
260    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
261    ///     "client_secret": "YOUR_CLIENT_SECRET",
262    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
263    ///     "type": "authorized_user",
264    /// });
265    /// let credentials = Builder::new(authorized_user)
266    ///     .with_backoff_policy(ExponentialBackoff::default())
267    ///     .build();
268    /// # });
269    /// ```
270    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
271        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
272        self
273    }
274
275    /// Configure the retry throttler.
276    ///
277    /// Advanced applications may want to configure a retry throttler to
278    /// [Address Cascading Failures] and when [Handling Overload] conditions.
279    /// The authentication library throttles its retry loop, using a policy to
280    /// control the throttling algorithm. Use this method to fine tune or
281    /// customize the default retry throttler.
282    ///
283    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
284    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
285    ///
286    /// ```
287    /// # use google_cloud_auth::credentials::user_account::Builder;
288    /// # tokio_test::block_on(async {
289    /// use gax::retry_throttler::AdaptiveThrottler;
290    /// let authorized_user = serde_json::json!({
291    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
292    ///     "client_secret": "YOUR_CLIENT_SECRET",
293    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
294    ///     "type": "authorized_user",
295    /// });
296    /// let credentials = Builder::new(authorized_user)
297    ///     .with_retry_throttler(AdaptiveThrottler::default())
298    ///     .build();
299    /// # });
300    /// ```
301    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
302        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
303        self
304    }
305
306    /// Returns a [Credentials] instance with the configured settings.
307    ///
308    /// # Errors
309    ///
310    /// Returns a [CredentialsError] if the `authorized_user`
311    /// provided to [`Builder::new`] cannot be successfully deserialized into the
312    /// expected format. This typically happens if the JSON value is malformed or
313    /// missing required fields. For more information, on how to generate
314    /// `authorized_user` json, consult the relevant section in the
315    /// [application-default credentials] guide.
316    ///
317    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
318    pub fn build(self) -> BuildResult<Credentials> {
319        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
320            .map_err(BuilderError::parsing)?;
321        let endpoint = self
322            .token_uri
323            .or(authorized_user.token_uri)
324            .unwrap_or(OAUTH2_ENDPOINT.to_string());
325        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
326
327        let token_provider = UserTokenProvider {
328            client_id: authorized_user.client_id,
329            client_secret: authorized_user.client_secret,
330            refresh_token: authorized_user.refresh_token,
331            endpoint,
332            scopes: self.scopes.map(|scopes| scopes.join(" ")),
333        };
334
335        let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337        Ok(Credentials {
338            inner: Arc::new(UserCredentials {
339                token_provider,
340                quota_project_id,
341            }),
342        })
343    }
344}
345
346#[derive(PartialEq)]
347struct UserTokenProvider {
348    client_id: String,
349    client_secret: String,
350    refresh_token: String,
351    endpoint: String,
352    scopes: Option<String>,
353}
354
355impl std::fmt::Debug for UserTokenProvider {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        f.debug_struct("UserCredentials")
358            .field("client_id", &self.client_id)
359            .field("client_secret", &"[censored]")
360            .field("refresh_token", &"[censored]")
361            .field("endpoint", &self.endpoint)
362            .field("scopes", &self.scopes)
363            .finish()
364    }
365}
366
367#[async_trait::async_trait]
368impl TokenProvider for UserTokenProvider {
369    async fn token(&self) -> Result<Token> {
370        let client = Client::new();
371
372        // Make the request
373        let req = Oauth2RefreshRequest {
374            grant_type: RefreshGrantType::RefreshToken,
375            client_id: self.client_id.clone(),
376            client_secret: self.client_secret.clone(),
377            refresh_token: self.refresh_token.clone(),
378            scopes: self.scopes.clone(),
379        };
380        let header = HeaderValue::from_static("application/json");
381        let builder = client
382            .request(Method::POST, self.endpoint.as_str())
383            .header(CONTENT_TYPE, header)
384            .json(&req);
385        let resp = builder
386            .send()
387            .await
388            .map_err(|e| errors::from_http_error(e, MSG))?;
389
390        // Process the response
391        if !resp.status().is_success() {
392            let err = errors::from_http_response(resp, MSG).await;
393            return Err(err);
394        }
395        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
396            let retryable = !e.is_decode();
397            CredentialsError::from_source(retryable, e)
398        })?;
399        let token = Token {
400            token: response.access_token,
401            token_type: response.token_type,
402            expires_at: response
403                .expires_in
404                .map(|d| Instant::now() + Duration::from_secs(d)),
405            metadata: None,
406        };
407        Ok(token)
408    }
409}
410
411const MSG: &str = "failed to refresh user access token";
412
413/// Data model for a UserCredentials
414///
415/// See: https://cloud.google.com/docs/authentication#user-accounts
416#[derive(Debug)]
417pub(crate) struct UserCredentials<T>
418where
419    T: CachedTokenProvider,
420{
421    token_provider: T,
422    quota_project_id: Option<String>,
423}
424
425#[async_trait::async_trait]
426impl<T> CredentialsProvider for UserCredentials<T>
427where
428    T: CachedTokenProvider,
429{
430    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
431        let token = self.token_provider.token(extensions).await?;
432        build_cacheable_headers(&token, &self.quota_project_id)
433    }
434}
435
436#[derive(Debug, PartialEq, serde::Deserialize)]
437pub(crate) struct AuthorizedUser {
438    #[serde(rename = "type")]
439    cred_type: String,
440    client_id: String,
441    client_secret: String,
442    refresh_token: String,
443    #[serde(skip_serializing_if = "Option::is_none")]
444    token_uri: Option<String>,
445    #[serde(skip_serializing_if = "Option::is_none")]
446    quota_project_id: Option<String>,
447}
448
449#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
450enum RefreshGrantType {
451    #[serde(rename = "refresh_token")]
452    RefreshToken,
453}
454
455#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
456struct Oauth2RefreshRequest {
457    grant_type: RefreshGrantType,
458    client_id: String,
459    client_secret: String,
460    refresh_token: String,
461    scopes: Option<String>,
462}
463
464#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
465struct Oauth2RefreshResponse {
466    access_token: String,
467    #[serde(skip_serializing_if = "Option::is_none")]
468    scope: Option<String>,
469    #[serde(skip_serializing_if = "Option::is_none")]
470    expires_in: Option<u64>,
471    token_type: String,
472    #[serde(skip_serializing_if = "Option::is_none")]
473    refresh_token: Option<String>,
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use crate::credentials::tests::{
480        find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
481        get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
482        get_token_type_from_headers,
483    };
484    use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
485    use crate::errors::CredentialsError;
486    use crate::token::tests::MockTokenProvider;
487    use http::StatusCode;
488    use http::header::AUTHORIZATION;
489    use httptest::cycle;
490    use httptest::matchers::{all_of, json_decoded, request};
491    use httptest::responders::{json_encoded, status_code};
492    use httptest::{Expectation, Server};
493
494    type TestResult = anyhow::Result<()>;
495
496    fn authorized_user_json(token_uri: String) -> Value {
497        serde_json::json!({
498            "client_id": "test-client-id",
499            "client_secret": "test-client-secret",
500            "refresh_token": "test-refresh-token",
501            "type": "authorized_user",
502            "token_uri": token_uri,
503        })
504    }
505
506    #[tokio::test]
507    async fn test_user_account_retries_on_transient_failures() -> TestResult {
508        let mut server = Server::run();
509        server.expect(
510            Expectation::matching(request::path("/token"))
511                .times(3)
512                .respond_with(status_code(503)),
513        );
514
515        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
516            .with_retry_policy(get_mock_auth_retry_policy(3))
517            .with_backoff_policy(get_mock_backoff_policy())
518            .with_retry_throttler(get_mock_retry_throttler())
519            .build()?;
520
521        let err = credentials.headers(Extensions::new()).await.unwrap_err();
522        assert!(!err.is_transient());
523        server.verify_and_clear();
524        Ok(())
525    }
526
527    #[tokio::test]
528    async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
529        let mut server = Server::run();
530        server.expect(
531            Expectation::matching(request::path("/token"))
532                .times(1)
533                .respond_with(status_code(401)),
534        );
535
536        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
537            .with_retry_policy(get_mock_auth_retry_policy(1))
538            .with_backoff_policy(get_mock_backoff_policy())
539            .with_retry_throttler(get_mock_retry_throttler())
540            .build()?;
541
542        let err = credentials.headers(Extensions::new()).await.unwrap_err();
543        assert!(!err.is_transient());
544        server.verify_and_clear();
545        Ok(())
546    }
547
548    #[tokio::test]
549    async fn test_user_account_retries_for_success() -> TestResult {
550        let mut server = Server::run();
551        let response = Oauth2RefreshResponse {
552            access_token: "test-access-token".to_string(),
553            expires_in: Some(3600),
554            refresh_token: Some("test-refresh-token".to_string()),
555            scope: Some("scope1 scope2".to_string()),
556            token_type: "test-token-type".to_string(),
557        };
558
559        server.expect(
560            Expectation::matching(request::path("/token"))
561                .times(3)
562                .respond_with(cycle![
563                    status_code(503).body("try-again"),
564                    status_code(503).body("try-again"),
565                    status_code(200)
566                        .append_header("Content-Type", "application/json")
567                        .body(serde_json::to_string(&response).unwrap()),
568                ]),
569        );
570
571        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
572            .with_retry_policy(get_mock_auth_retry_policy(3))
573            .with_backoff_policy(get_mock_backoff_policy())
574            .with_retry_throttler(get_mock_retry_throttler())
575            .build()?;
576
577        let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
578        assert_eq!(token.unwrap(), "test-access-token");
579
580        server.verify_and_clear();
581        Ok(())
582    }
583
584    #[test]
585    fn debug_token_provider() {
586        let expected = UserTokenProvider {
587            client_id: "test-client-id".to_string(),
588            client_secret: "test-client-secret".to_string(),
589            refresh_token: "test-refresh-token".to_string(),
590            endpoint: OAUTH2_ENDPOINT.to_string(),
591            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
592        };
593        let fmt = format!("{expected:?}");
594        assert!(fmt.contains("test-client-id"), "{fmt}");
595        assert!(!fmt.contains("test-client-secret"), "{fmt}");
596        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
597        assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
598        assert!(
599            fmt.contains("https://www.googleapis.com/auth/pubsub"),
600            "{fmt}"
601        );
602    }
603
604    #[test]
605    fn authorized_user_full_from_json_success() {
606        let json = serde_json::json!({
607            "account": "",
608            "client_id": "test-client-id",
609            "client_secret": "test-client-secret",
610            "refresh_token": "test-refresh-token",
611            "type": "authorized_user",
612            "universe_domain": "googleapis.com",
613            "quota_project_id": "test-project",
614            "token_uri" : "test-token-uri",
615        });
616
617        let expected = AuthorizedUser {
618            cred_type: "authorized_user".to_string(),
619            client_id: "test-client-id".to_string(),
620            client_secret: "test-client-secret".to_string(),
621            refresh_token: "test-refresh-token".to_string(),
622            quota_project_id: Some("test-project".to_string()),
623            token_uri: Some("test-token-uri".to_string()),
624        };
625        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
626        assert_eq!(actual, expected);
627    }
628
629    #[test]
630    fn authorized_user_partial_from_json_success() {
631        let json = serde_json::json!({
632            "client_id": "test-client-id",
633            "client_secret": "test-client-secret",
634            "refresh_token": "test-refresh-token",
635            "type": "authorized_user",
636        });
637
638        let expected = AuthorizedUser {
639            cred_type: "authorized_user".to_string(),
640            client_id: "test-client-id".to_string(),
641            client_secret: "test-client-secret".to_string(),
642            refresh_token: "test-refresh-token".to_string(),
643            quota_project_id: None,
644            token_uri: None,
645        };
646        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
647        assert_eq!(actual, expected);
648    }
649
650    #[test]
651    fn authorized_user_from_json_parse_fail() {
652        let json_full = serde_json::json!({
653            "client_id": "test-client-id",
654            "client_secret": "test-client-secret",
655            "refresh_token": "test-refresh-token",
656            "type": "authorized_user",
657            "quota_project_id": "test-project"
658        });
659
660        for required_field in ["client_id", "client_secret", "refresh_token"] {
661            let mut json = json_full.clone();
662            // Remove a required field from the JSON
663            json[required_field].take();
664            serde_json::from_value::<AuthorizedUser>(json)
665                .err()
666                .unwrap();
667        }
668    }
669
670    #[tokio::test]
671    async fn default_universe_domain_success() {
672        let mock = TokenCache::new(MockTokenProvider::new());
673
674        let uc = UserCredentials {
675            token_provider: mock,
676            quota_project_id: None,
677        };
678        assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
679    }
680
681    #[tokio::test]
682    async fn headers_success() -> TestResult {
683        let token = Token {
684            token: "test-token".to_string(),
685            token_type: "Bearer".to_string(),
686            expires_at: None,
687            metadata: None,
688        };
689
690        let mut mock = MockTokenProvider::new();
691        mock.expect_token().times(1).return_once(|| Ok(token));
692
693        let uc = UserCredentials {
694            token_provider: TokenCache::new(mock),
695            quota_project_id: None,
696        };
697
698        let mut extensions = Extensions::new();
699        let cached_headers = uc.headers(extensions.clone()).await.unwrap();
700        let (headers, entity_tag) = match cached_headers {
701            CacheableResource::New { entity_tag, data } => (data, entity_tag),
702            CacheableResource::NotModified => unreachable!("expecting new headers"),
703        };
704        let token = headers.get(AUTHORIZATION).unwrap();
705
706        assert_eq!(headers.len(), 1, "{headers:?}");
707        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
708        assert!(token.is_sensitive());
709
710        extensions.insert(entity_tag);
711
712        let cached_headers = uc.headers(extensions).await?;
713
714        match cached_headers {
715            CacheableResource::New { .. } => unreachable!("expecting new headers"),
716            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
717        };
718        Ok(())
719    }
720
721    #[tokio::test]
722    async fn headers_failure() {
723        let mut mock = MockTokenProvider::new();
724        mock.expect_token()
725            .times(1)
726            .return_once(|| Err(errors::non_retryable_from_str("fail")));
727
728        let uc = UserCredentials {
729            token_provider: TokenCache::new(mock),
730            quota_project_id: None,
731        };
732        assert!(uc.headers(Extensions::new()).await.is_err());
733    }
734
735    #[tokio::test]
736    async fn headers_with_quota_project_success() -> TestResult {
737        let token = Token {
738            token: "test-token".to_string(),
739            token_type: "Bearer".to_string(),
740            expires_at: None,
741            metadata: None,
742        };
743
744        let mut mock = MockTokenProvider::new();
745        mock.expect_token().times(1).return_once(|| Ok(token));
746
747        let uc = UserCredentials {
748            token_provider: TokenCache::new(mock),
749            quota_project_id: Some("test-project".to_string()),
750        };
751
752        let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
753        let token = headers.get(AUTHORIZATION).unwrap();
754        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
755
756        assert_eq!(headers.len(), 2, "{headers:?}");
757        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
758        assert!(token.is_sensitive());
759        assert_eq!(
760            quota_project_header,
761            HeaderValue::from_static("test-project")
762        );
763        assert!(!quota_project_header.is_sensitive());
764        Ok(())
765    }
766
767    #[test]
768    fn oauth2_request_serde() {
769        let request = Oauth2RefreshRequest {
770            grant_type: RefreshGrantType::RefreshToken,
771            client_id: "test-client-id".to_string(),
772            client_secret: "test-client-secret".to_string(),
773            refresh_token: "test-refresh-token".to_string(),
774            scopes: Some("scope1 scope2".to_string()),
775        };
776
777        let json = serde_json::to_value(&request).unwrap();
778        let expected = serde_json::json!({
779            "grant_type": "refresh_token",
780            "client_id": "test-client-id",
781            "client_secret": "test-client-secret",
782            "refresh_token": "test-refresh-token",
783            "scopes": "scope1 scope2",
784        });
785        assert_eq!(json, expected);
786        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
787        assert_eq!(request, roundtrip);
788    }
789
790    #[test]
791    fn oauth2_response_serde_full() {
792        let response = Oauth2RefreshResponse {
793            access_token: "test-access-token".to_string(),
794            scope: Some("scope1 scope2".to_string()),
795            expires_in: Some(3600),
796            token_type: "test-token-type".to_string(),
797            refresh_token: Some("test-refresh-token".to_string()),
798        };
799
800        let json = serde_json::to_value(&response).unwrap();
801        let expected = serde_json::json!({
802            "access_token": "test-access-token",
803            "scope": "scope1 scope2",
804            "expires_in": 3600,
805            "token_type": "test-token-type",
806            "refresh_token": "test-refresh-token"
807        });
808        assert_eq!(json, expected);
809        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
810        assert_eq!(response, roundtrip);
811    }
812
813    #[test]
814    fn oauth2_response_serde_partial() {
815        let response = Oauth2RefreshResponse {
816            access_token: "test-access-token".to_string(),
817            scope: None,
818            expires_in: None,
819            token_type: "test-token-type".to_string(),
820            refresh_token: None,
821        };
822
823        let json = serde_json::to_value(&response).unwrap();
824        let expected = serde_json::json!({
825            "access_token": "test-access-token",
826            "token_type": "test-token-type",
827        });
828        assert_eq!(json, expected);
829        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
830        assert_eq!(response, roundtrip);
831    }
832
833    fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
834        request.client_id == "test-client-id"
835            && request.client_secret == "test-client-secret"
836            && request.refresh_token == "test-refresh-token"
837            && request.grant_type == RefreshGrantType::RefreshToken
838            && request.scopes == expected_scopes
839    }
840
841    #[tokio::test(start_paused = true)]
842    async fn token_provider_full() -> TestResult {
843        let server = Server::run();
844        let response = Oauth2RefreshResponse {
845            access_token: "test-access-token".to_string(),
846            expires_in: Some(3600),
847            refresh_token: Some("test-refresh-token".to_string()),
848            scope: Some("scope1 scope2".to_string()),
849            token_type: "test-token-type".to_string(),
850        };
851        server.expect(
852            Expectation::matching(all_of![
853                request::path("/token"),
854                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
855                    check_request(req, Some("scope1 scope2".to_string()))
856                }))
857            ])
858            .respond_with(json_encoded(response)),
859        );
860
861        let tp = UserTokenProvider {
862            client_id: "test-client-id".to_string(),
863            client_secret: "test-client-secret".to_string(),
864            refresh_token: "test-refresh-token".to_string(),
865            endpoint: server.url("/token").to_string(),
866            scopes: Some("scope1 scope2".to_string()),
867        };
868        let now = Instant::now();
869        let token = tp.token().await?;
870        assert_eq!(token.token, "test-access-token");
871        assert_eq!(token.token_type, "test-token-type");
872        assert!(
873            token
874                .expires_at
875                .is_some_and(|d| d == now + Duration::from_secs(3600)),
876            "now: {:?}, expires_at: {:?}",
877            now,
878            token.expires_at
879        );
880
881        Ok(())
882    }
883
884    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
885    async fn credential_full_with_quota_project() -> TestResult {
886        let server = Server::run();
887        let response = Oauth2RefreshResponse {
888            access_token: "test-access-token".to_string(),
889            expires_in: Some(3600),
890            refresh_token: Some("test-refresh-token".to_string()),
891            scope: None,
892            token_type: "test-token-type".to_string(),
893        };
894        server.expect(
895            Expectation::matching(all_of![
896                request::path("/token"),
897                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
898                    check_request(req, None)
899                }))
900            ])
901            .respond_with(json_encoded(response)),
902        );
903
904        let authorized_user = serde_json::json!({
905            "client_id": "test-client-id",
906            "client_secret": "test-client-secret",
907            "refresh_token": "test-refresh-token",
908            "type": "authorized_user",
909            "token_uri": server.url("/token").to_string(),
910        });
911        let cred = Builder::new(authorized_user)
912            .with_quota_project_id("test-project")
913            .build()?;
914
915        let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
916        let token = headers.get(AUTHORIZATION).unwrap();
917        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
918
919        assert_eq!(headers.len(), 2, "{headers:?}");
920        assert_eq!(
921            token,
922            HeaderValue::from_static("test-token-type test-access-token")
923        );
924        assert!(token.is_sensitive());
925        assert_eq!(
926            quota_project_header,
927            HeaderValue::from_static("test-project")
928        );
929        assert!(!quota_project_header.is_sensitive());
930
931        Ok(())
932    }
933
934    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
935    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
936        let mut server = Server::run();
937        let response = Oauth2RefreshResponse {
938            access_token: "test-access-token".to_string(),
939            expires_in: Some(3600),
940            refresh_token: Some("test-refresh-token".to_string()),
941            scope: Some("scope1 scope2".to_string()),
942            token_type: "test-token-type".to_string(),
943        };
944        server.expect(
945            Expectation::matching(all_of![
946                request::path("/token"),
947                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
948                    check_request(req, Some("scope1 scope2".to_string()))
949                }))
950            ])
951            .times(1)
952            .respond_with(json_encoded(response)),
953        );
954
955        let json = serde_json::json!({
956            "client_id": "test-client-id",
957            "client_secret": "test-client-secret",
958            "refresh_token": "test-refresh-token",
959            "type": "authorized_user",
960            "universe_domain": "googleapis.com",
961            "quota_project_id": "test-project",
962            "token_uri": server.url("/token").to_string(),
963        });
964
965        let cred = Builder::new(json)
966            .with_scopes(vec!["scope1", "scope2"])
967            .build()?;
968
969        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
970        assert_eq!(token.unwrap(), "test-access-token");
971
972        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
973        assert_eq!(token.unwrap(), "test-access-token");
974
975        server.verify_and_clear();
976
977        Ok(())
978    }
979
980    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
981    async fn credential_provider_partial() -> TestResult {
982        let server = Server::run();
983        let response = Oauth2RefreshResponse {
984            access_token: "test-access-token".to_string(),
985            expires_in: None,
986            refresh_token: None,
987            scope: None,
988            token_type: "test-token-type".to_string(),
989        };
990        server.expect(
991            Expectation::matching(all_of![
992                request::path("/token"),
993                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
994                    check_request(req, None)
995                }))
996            ])
997            .respond_with(json_encoded(response)),
998        );
999
1000        let authorized_user = serde_json::json!({
1001            "client_id": "test-client-id",
1002            "client_secret": "test-client-secret",
1003            "refresh_token": "test-refresh-token",
1004            "type": "authorized_user",
1005            "token_uri": server.url("/token").to_string()
1006        });
1007
1008        let uc = Builder::new(authorized_user).build()?;
1009        let headers = uc.headers(Extensions::new()).await?;
1010        assert_eq!(
1011            get_token_from_headers(headers.clone()).unwrap(),
1012            "test-access-token"
1013        );
1014        assert_eq!(
1015            get_token_type_from_headers(headers).unwrap(),
1016            "test-token-type"
1017        );
1018
1019        Ok(())
1020    }
1021
1022    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1023    async fn credential_provider_with_token_uri() -> TestResult {
1024        let server = Server::run();
1025        let response = Oauth2RefreshResponse {
1026            access_token: "test-access-token".to_string(),
1027            expires_in: None,
1028            refresh_token: None,
1029            scope: None,
1030            token_type: "test-token-type".to_string(),
1031        };
1032        server.expect(
1033            Expectation::matching(all_of![
1034                request::path("/token"),
1035                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1036                    check_request(req, None)
1037                }))
1038            ])
1039            .respond_with(json_encoded(response)),
1040        );
1041
1042        let authorized_user = serde_json::json!({
1043            "client_id": "test-client-id",
1044            "client_secret": "test-client-secret",
1045            "refresh_token": "test-refresh-token",
1046            "type": "authorized_user",
1047            "token_uri": "test-endpoint"
1048        });
1049
1050        let uc = Builder::new(authorized_user)
1051            .with_token_uri(server.url("/token").to_string())
1052            .build()?;
1053        let headers = uc.headers(Extensions::new()).await?;
1054        assert_eq!(
1055            get_token_from_headers(headers.clone()).unwrap(),
1056            "test-access-token"
1057        );
1058        assert_eq!(
1059            get_token_type_from_headers(headers).unwrap(),
1060            "test-token-type"
1061        );
1062
1063        Ok(())
1064    }
1065
1066    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1067    async fn credential_provider_with_scopes() -> TestResult {
1068        let server = Server::run();
1069        let response = Oauth2RefreshResponse {
1070            access_token: "test-access-token".to_string(),
1071            expires_in: None,
1072            refresh_token: None,
1073            scope: Some("scope1 scope2".to_string()),
1074            token_type: "test-token-type".to_string(),
1075        };
1076        server.expect(
1077            Expectation::matching(all_of![
1078                request::path("/token"),
1079                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1080                    check_request(req, Some("scope1 scope2".to_string()))
1081                }))
1082            ])
1083            .respond_with(json_encoded(response)),
1084        );
1085
1086        let authorized_user = serde_json::json!({
1087            "client_id": "test-client-id",
1088            "client_secret": "test-client-secret",
1089            "refresh_token": "test-refresh-token",
1090            "type": "authorized_user",
1091            "token_uri": "test-endpoint"
1092        });
1093
1094        let uc = Builder::new(authorized_user)
1095            .with_token_uri(server.url("/token").to_string())
1096            .with_scopes(vec!["scope1", "scope2"])
1097            .build()?;
1098        let headers = uc.headers(Extensions::new()).await?;
1099        assert_eq!(
1100            get_token_from_headers(headers.clone()).unwrap(),
1101            "test-access-token"
1102        );
1103        assert_eq!(
1104            get_token_type_from_headers(headers).unwrap(),
1105            "test-token-type"
1106        );
1107
1108        Ok(())
1109    }
1110
1111    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1112    async fn credential_provider_retryable_error() -> TestResult {
1113        let server = Server::run();
1114        server
1115            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1116
1117        let authorized_user = serde_json::json!({
1118            "client_id": "test-client-id",
1119            "client_secret": "test-client-secret",
1120            "refresh_token": "test-refresh-token",
1121            "type": "authorized_user",
1122            "token_uri": server.url("/token").to_string()
1123        });
1124
1125        let uc = Builder::new(authorized_user).build()?;
1126        let err = uc.headers(Extensions::new()).await.unwrap_err();
1127        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1128        assert!(original_err.is_transient());
1129
1130        let source = find_source_error::<reqwest::Error>(&err);
1131        assert!(
1132            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1133            "{err:?}"
1134        );
1135
1136        Ok(())
1137    }
1138
1139    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1140    async fn token_provider_nonretryable_error() -> TestResult {
1141        let server = Server::run();
1142        server
1143            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1144
1145        let authorized_user = serde_json::json!({
1146            "client_id": "test-client-id",
1147            "client_secret": "test-client-secret",
1148            "refresh_token": "test-refresh-token",
1149            "type": "authorized_user",
1150            "token_uri": server.url("/token").to_string()
1151        });
1152
1153        let uc = Builder::new(authorized_user).build()?;
1154        let err = uc.headers(Extensions::new()).await.unwrap_err();
1155        let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1156        assert!(!original_err.is_transient());
1157
1158        let source = find_source_error::<reqwest::Error>(&err);
1159        assert!(
1160            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1161            "{err:?}"
1162        );
1163
1164        Ok(())
1165    }
1166
1167    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1168    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1169        let server = Server::run();
1170        server.expect(
1171            Expectation::matching(request::path("/token"))
1172                .respond_with(json_encoded("bad json".to_string())),
1173        );
1174
1175        let authorized_user = serde_json::json!({
1176            "client_id": "test-client-id",
1177            "client_secret": "test-client-secret",
1178            "refresh_token": "test-refresh-token",
1179            "type": "authorized_user",
1180            "token_uri": server.url("/token").to_string()
1181        });
1182
1183        let uc = Builder::new(authorized_user).build()?;
1184        let e = uc.headers(Extensions::new()).await.err().unwrap();
1185        assert!(!e.is_transient(), "{e}");
1186
1187        Ok(())
1188    }
1189
1190    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1191    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1192        let authorized_user = serde_json::json!({
1193            "client_secret": "test-client-secret",
1194            "refresh_token": "test-refresh-token",
1195            "type": "authorized_user",
1196        });
1197
1198        let e = Builder::new(authorized_user).build().unwrap_err();
1199        assert!(e.is_parsing(), "{e}");
1200
1201        Ok(())
1202    }
1203}