google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//! * Customize the **retry behavior** when fetching access tokens.
39//!
40//! ## Example: Creating credentials from a JSON object
41//!
42//! ```
43//! # use google_cloud_auth::credentials::user_account::Builder;
44//! # use google_cloud_auth::credentials::Credentials;
45//! # use http::Extensions;
46//! # tokio_test::block_on(async {
47//! let authorized_user = serde_json::json!({
48//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
49//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
50//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
51//!     "type": "authorized_user",
52//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
53//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
54//! });
55//! let credentials: Credentials = Builder::new(authorized_user).build()?;
56//! let headers = credentials.headers(Extensions::new()).await?;
57//! println!("Headers: {headers:?}");
58//! # Ok::<(), anyhow::Error>(())
59//! # });
60//! ```
61//!
62//! ## Example: Creating credentials with custom retry behavior
63//!
64//! ```
65//! # use google_cloud_auth::credentials::user_account::Builder;
66//! # use google_cloud_auth::credentials::Credentials;
67//! # use http::Extensions;
68//! # use std::time::Duration;
69//! # tokio_test::block_on(async {
70//! use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
71//! use gax::exponential_backoff::ExponentialBackoff;
72//! let authorized_user = serde_json::json!({
73//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
74//!     "client_secret": "YOUR_CLIENT_SECRET",
75//!     "refresh_token": "YOUR_REFRESH_TOKEN",
76//!     "type": "authorized_user",
77//! });
78//! let backoff = ExponentialBackoff::default();
79//! let credentials: Credentials = Builder::new(authorized_user)
80//!     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
81//!     .with_backoff_policy(backoff)
82//!     .build()?;
83//! let headers = credentials.headers(Extensions::new()).await?;
84//! println!("Headers: {headers:?}");
85//! # Ok::<(), anyhow::Error>(())
86//! # });
87//! ```
88//!
89//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
90//! [Cloud Identity]: https://cloud.google.com/identity
91//! [Google Accounts]: https://myaccount.google.com/
92//! [Google Workspace]: https://workspace.google.com/
93//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
94//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
95//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
96
97use crate::build_errors::Error as BuilderError;
98use crate::credentials::dynamic::CredentialsProvider;
99use crate::credentials::{CacheableResource, Credentials};
100use crate::errors::{self, CredentialsError};
101use crate::headers_util::build_cacheable_headers;
102use crate::retry::Builder as RetryTokenProviderBuilder;
103use crate::token::{CachedTokenProvider, Token, TokenProvider};
104use crate::token_cache::TokenCache;
105use crate::{BuildResult, Result};
106use gax::backoff_policy::BackoffPolicyArg;
107use gax::retry_policy::RetryPolicyArg;
108use gax::retry_throttler::RetryThrottlerArg;
109use http::header::CONTENT_TYPE;
110use http::{Extensions, HeaderMap, HeaderValue};
111use reqwest::{Client, Method};
112use serde_json::Value;
113use std::sync::Arc;
114use tokio::time::{Duration, Instant};
115
116const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
117
118/// A builder for constructing `user_account` [Credentials] instance.
119///
120/// # Example
121/// ```
122/// # use google_cloud_auth::credentials::user_account::Builder;
123/// # tokio_test::block_on(async {
124/// let authorized_user = serde_json::json!({ /* add details here */ });
125/// let credentials = Builder::new(authorized_user).build();
126/// })
127/// ```
128pub struct Builder {
129    authorized_user: Value,
130    scopes: Option<Vec<String>>,
131    quota_project_id: Option<String>,
132    token_uri: Option<String>,
133    retry_builder: RetryTokenProviderBuilder,
134}
135
136impl Builder {
137    /// Creates a new builder using `authorized_user` JSON value.
138    ///
139    /// The `authorized_user` JSON is typically generated when a user
140    /// authenticates using the [application-default login] process.
141    ///
142    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
143    pub fn new(authorized_user: Value) -> Self {
144        Self {
145            authorized_user,
146            scopes: None,
147            quota_project_id: None,
148            token_uri: None,
149            retry_builder: RetryTokenProviderBuilder::default(),
150        }
151    }
152
153    /// Sets the URI for the token endpoint used to fetch access tokens.
154    ///
155    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
156    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
157    ///
158    /// # Example
159    /// ```
160    /// # use google_cloud_auth::credentials::user_account::Builder;
161    /// let authorized_user = serde_json::json!({ /* add details here */ });
162    /// let credentials = Builder::new(authorized_user)
163    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
164    ///     .build();
165    /// ```
166    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
167        self.token_uri = Some(token_uri.into());
168        self
169    }
170
171    /// Sets the [scopes] for these credentials.
172    ///
173    /// `scopes` define the *permissions being requested* for this specific access token
174    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
175    /// IAM permissions, on the other hand, define the *underlying capabilities*
176    /// the user account possesses within a system. For example, `storage.buckets.delete`.
177    /// When a token generated with specific scopes is used, the request must be permitted
178    /// by both the user account's underlying IAM permissions and the scopes requested
179    /// for the token. Therefore, scopes act as an additional restriction on what the token
180    /// can be used for.
181    ///
182    /// # Example
183    /// ```
184    /// # use google_cloud_auth::credentials::user_account::Builder;
185    /// let authorized_user = serde_json::json!({ /* add details here */ });
186    /// let credentials = Builder::new(authorized_user)
187    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
188    ///     .build();
189    /// ```
190    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
191    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
192    where
193        I: IntoIterator<Item = S>,
194        S: Into<String>,
195    {
196        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
197        self
198    }
199
200    /// Sets the [quota project] for these credentials.
201    ///
202    /// In some services, you can use an account in
203    /// one project for authentication and authorization, and charge
204    /// the usage to a different project. This requires that the
205    /// user has `serviceusage.services.use` permissions on the quota project.
206    ///
207    /// Any value set here overrides a `quota_project_id` value from the
208    /// input `authorized_user` JSON.
209    ///
210    /// # Example
211    /// ```
212    /// # use google_cloud_auth::credentials::user_account::Builder;
213    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
214    /// let credentials = Builder::new(authorized_user)
215    ///     .with_quota_project_id("my-project")
216    ///     .build();
217    /// ```
218    ///
219    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
220    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
221        self.quota_project_id = Some(quota_project_id.into());
222        self
223    }
224
225    /// Configure the retry policy for fetching tokens.
226    ///
227    /// The retry policy controls how to handle retries, and sets limits on
228    /// the number of attempts or the total time spent retrying.
229    ///
230    /// ```
231    /// # use google_cloud_auth::credentials::user_account::Builder;
232    /// # tokio_test::block_on(async {
233    /// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
234    /// let authorized_user = serde_json::json!({
235    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
236    ///     "client_secret": "YOUR_CLIENT_SECRET",
237    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
238    ///     "type": "authorized_user",
239    /// });
240    /// let credentials = Builder::new(authorized_user)
241    ///     .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
242    ///     .build();
243    /// # });
244    /// ```
245    pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
246        self.retry_builder = self.retry_builder.with_retry_policy(v.into());
247        self
248    }
249
250    /// Configure the retry backoff policy.
251    ///
252    /// The backoff policy controls how long to wait in between retry attempts.
253    ///
254    /// ```
255    /// # use google_cloud_auth::credentials::user_account::Builder;
256    /// # use std::time::Duration;
257    /// # tokio_test::block_on(async {
258    /// use gax::exponential_backoff::ExponentialBackoff;
259    /// let authorized_user = serde_json::json!({
260    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
261    ///     "client_secret": "YOUR_CLIENT_SECRET",
262    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
263    ///     "type": "authorized_user",
264    /// });
265    /// let credentials = Builder::new(authorized_user)
266    ///     .with_backoff_policy(ExponentialBackoff::default())
267    ///     .build();
268    /// # });
269    /// ```
270    pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
271        self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
272        self
273    }
274
275    /// Configure the retry throttler.
276    ///
277    /// Advanced applications may want to configure a retry throttler to
278    /// [Address Cascading Failures] and when [Handling Overload] conditions.
279    /// The authentication library throttles its retry loop, using a policy to
280    /// control the throttling algorithm. Use this method to fine tune or
281    /// customize the default retry throttler.
282    ///
283    /// [Handling Overload]: https://sre.google/sre-book/handling-overload/
284    /// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
285    ///
286    /// ```
287    /// # use google_cloud_auth::credentials::user_account::Builder;
288    /// # tokio_test::block_on(async {
289    /// use gax::retry_throttler::AdaptiveThrottler;
290    /// let authorized_user = serde_json::json!({
291    ///     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com",
292    ///     "client_secret": "YOUR_CLIENT_SECRET",
293    ///     "refresh_token": "YOUR_REFRESH_TOKEN",
294    ///     "type": "authorized_user",
295    /// });
296    /// let credentials = Builder::new(authorized_user)
297    ///     .with_retry_throttler(AdaptiveThrottler::default())
298    ///     .build();
299    /// # });
300    /// ```
301    pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
302        self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
303        self
304    }
305
306    /// Returns a [Credentials] instance with the configured settings.
307    ///
308    /// # Errors
309    ///
310    /// Returns a [CredentialsError] if the `authorized_user`
311    /// provided to [`Builder::new`] cannot be successfully deserialized into the
312    /// expected format. This typically happens if the JSON value is malformed or
313    /// missing required fields. For more information, on how to generate
314    /// `authorized_user` json, consult the relevant section in the
315    /// [application-default credentials] guide.
316    ///
317    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
318    pub fn build(self) -> BuildResult<Credentials> {
319        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
320            .map_err(BuilderError::parsing)?;
321        let endpoint = self
322            .token_uri
323            .or(authorized_user.token_uri)
324            .unwrap_or(OAUTH2_ENDPOINT.to_string());
325        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
326
327        let token_provider = UserTokenProvider {
328            client_id: authorized_user.client_id,
329            client_secret: authorized_user.client_secret,
330            refresh_token: authorized_user.refresh_token,
331            endpoint,
332            scopes: self.scopes.map(|scopes| scopes.join(" ")),
333        };
334
335        let token_provider = TokenCache::new(self.retry_builder.build(token_provider));
336
337        Ok(Credentials {
338            inner: Arc::new(UserCredentials {
339                token_provider,
340                quota_project_id,
341            }),
342        })
343    }
344}
345
346#[derive(PartialEq)]
347struct UserTokenProvider {
348    client_id: String,
349    client_secret: String,
350    refresh_token: String,
351    endpoint: String,
352    scopes: Option<String>,
353}
354
355impl std::fmt::Debug for UserTokenProvider {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        f.debug_struct("UserCredentials")
358            .field("client_id", &self.client_id)
359            .field("client_secret", &"[censored]")
360            .field("refresh_token", &"[censored]")
361            .field("endpoint", &self.endpoint)
362            .field("scopes", &self.scopes)
363            .finish()
364    }
365}
366
367#[async_trait::async_trait]
368impl TokenProvider for UserTokenProvider {
369    async fn token(&self) -> Result<Token> {
370        let client = Client::new();
371
372        // Make the request
373        let req = Oauth2RefreshRequest {
374            grant_type: RefreshGrantType::RefreshToken,
375            client_id: self.client_id.clone(),
376            client_secret: self.client_secret.clone(),
377            refresh_token: self.refresh_token.clone(),
378            scopes: self.scopes.clone(),
379        };
380        let header = HeaderValue::from_static("application/json");
381        let builder = client
382            .request(Method::POST, self.endpoint.as_str())
383            .header(CONTENT_TYPE, header)
384            .json(&req);
385        let resp = builder
386            .send()
387            .await
388            .map_err(|e| errors::from_http_error(e, MSG))?;
389
390        // Process the response
391        if !resp.status().is_success() {
392            let err = errors::from_http_response(resp, MSG).await;
393            return Err(err);
394        }
395        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
396            let retryable = !e.is_decode();
397            CredentialsError::from_source(retryable, e)
398        })?;
399        let token = Token {
400            token: response.access_token,
401            token_type: response.token_type,
402            expires_at: response
403                .expires_in
404                .map(|d| Instant::now() + Duration::from_secs(d)),
405            metadata: None,
406        };
407        Ok(token)
408    }
409}
410
411const MSG: &str = "failed to refresh user access token";
412
413/// Data model for a UserCredentials
414///
415/// See: https://cloud.google.com/docs/authentication#user-accounts
416#[derive(Debug)]
417pub(crate) struct UserCredentials<T>
418where
419    T: CachedTokenProvider,
420{
421    token_provider: T,
422    quota_project_id: Option<String>,
423}
424
425#[async_trait::async_trait]
426impl<T> CredentialsProvider for UserCredentials<T>
427where
428    T: CachedTokenProvider,
429{
430    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
431        let token = self.token_provider.token(extensions).await?;
432        build_cacheable_headers(&token, &self.quota_project_id)
433    }
434}
435
436#[derive(Debug, PartialEq, serde::Deserialize)]
437pub(crate) struct AuthorizedUser {
438    #[serde(rename = "type")]
439    cred_type: String,
440    client_id: String,
441    client_secret: String,
442    refresh_token: String,
443    #[serde(skip_serializing_if = "Option::is_none")]
444    token_uri: Option<String>,
445    #[serde(skip_serializing_if = "Option::is_none")]
446    quota_project_id: Option<String>,
447}
448
449#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
450enum RefreshGrantType {
451    #[serde(rename = "refresh_token")]
452    RefreshToken,
453}
454
455#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
456struct Oauth2RefreshRequest {
457    grant_type: RefreshGrantType,
458    client_id: String,
459    client_secret: String,
460    refresh_token: String,
461    scopes: Option<String>,
462}
463
464#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
465struct Oauth2RefreshResponse {
466    access_token: String,
467    #[serde(skip_serializing_if = "Option::is_none")]
468    scope: Option<String>,
469    #[serde(skip_serializing_if = "Option::is_none")]
470    expires_in: Option<u64>,
471    token_type: String,
472    #[serde(skip_serializing_if = "Option::is_none")]
473    refresh_token: Option<String>,
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use crate::credentials::tests::{
480        get_headers_from_cache, get_mock_auth_retry_policy, get_mock_backoff_policy,
481        get_mock_retry_throttler, get_token_from_headers, get_token_type_from_headers,
482    };
483    use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
484    use crate::token::tests::MockTokenProvider;
485    use http::StatusCode;
486    use http::header::AUTHORIZATION;
487    use httptest::matchers::{all_of, json_decoded, request};
488    use httptest::responders::{json_encoded, status_code};
489    use httptest::{Expectation, Server, cycle};
490    use std::error::Error;
491
492    type TestResult = anyhow::Result<()>;
493
494    fn authorized_user_json(token_uri: String) -> Value {
495        serde_json::json!({
496            "client_id": "test-client-id",
497            "client_secret": "test-client-secret",
498            "refresh_token": "test-refresh-token",
499            "type": "authorized_user",
500            "token_uri": token_uri,
501        })
502    }
503
504    #[tokio::test]
505    async fn test_user_account_retries_on_transient_failures() -> TestResult {
506        let mut server = Server::run();
507        server.expect(
508            Expectation::matching(request::path("/token"))
509                .times(3)
510                .respond_with(status_code(503)),
511        );
512
513        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
514            .with_retry_policy(get_mock_auth_retry_policy(3))
515            .with_backoff_policy(get_mock_backoff_policy())
516            .with_retry_throttler(get_mock_retry_throttler())
517            .build()?;
518
519        let err = credentials.headers(Extensions::new()).await.unwrap_err();
520        assert!(err.is_transient());
521        server.verify_and_clear();
522        Ok(())
523    }
524
525    #[tokio::test]
526    async fn test_user_account_does_not_retry_on_non_transient_failures() -> TestResult {
527        let mut server = Server::run();
528        server.expect(
529            Expectation::matching(request::path("/token"))
530                .times(1)
531                .respond_with(status_code(401)),
532        );
533
534        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
535            .with_retry_policy(get_mock_auth_retry_policy(1))
536            .with_backoff_policy(get_mock_backoff_policy())
537            .with_retry_throttler(get_mock_retry_throttler())
538            .build()?;
539
540        let err = credentials.headers(Extensions::new()).await.unwrap_err();
541        assert!(!err.is_transient());
542        server.verify_and_clear();
543        Ok(())
544    }
545
546    #[tokio::test]
547    async fn test_user_account_retries_for_success() -> TestResult {
548        let mut server = Server::run();
549        let response = Oauth2RefreshResponse {
550            access_token: "test-access-token".to_string(),
551            expires_in: Some(3600),
552            refresh_token: Some("test-refresh-token".to_string()),
553            scope: Some("scope1 scope2".to_string()),
554            token_type: "test-token-type".to_string(),
555        };
556
557        server.expect(
558            Expectation::matching(request::path("/token"))
559                .times(3)
560                .respond_with(cycle![
561                    status_code(503).body("try-again"),
562                    status_code(503).body("try-again"),
563                    status_code(200)
564                        .append_header("Content-Type", "application/json")
565                        .body(serde_json::to_string(&response).unwrap()),
566                ]),
567        );
568
569        let credentials = Builder::new(authorized_user_json(server.url("/token").to_string()))
570            .with_retry_policy(get_mock_auth_retry_policy(3))
571            .with_backoff_policy(get_mock_backoff_policy())
572            .with_retry_throttler(get_mock_retry_throttler())
573            .build()?;
574
575        let token = get_token_from_headers(credentials.headers(Extensions::new()).await.unwrap());
576        assert_eq!(token.unwrap(), "test-access-token");
577
578        server.verify_and_clear();
579        Ok(())
580    }
581
582    #[test]
583    fn debug_token_provider() {
584        let expected = UserTokenProvider {
585            client_id: "test-client-id".to_string(),
586            client_secret: "test-client-secret".to_string(),
587            refresh_token: "test-refresh-token".to_string(),
588            endpoint: OAUTH2_ENDPOINT.to_string(),
589            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
590        };
591        let fmt = format!("{expected:?}");
592        assert!(fmt.contains("test-client-id"), "{fmt}");
593        assert!(!fmt.contains("test-client-secret"), "{fmt}");
594        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
595        assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
596        assert!(
597            fmt.contains("https://www.googleapis.com/auth/pubsub"),
598            "{fmt}"
599        );
600    }
601
602    #[test]
603    fn authorized_user_full_from_json_success() {
604        let json = serde_json::json!({
605            "account": "",
606            "client_id": "test-client-id",
607            "client_secret": "test-client-secret",
608            "refresh_token": "test-refresh-token",
609            "type": "authorized_user",
610            "universe_domain": "googleapis.com",
611            "quota_project_id": "test-project",
612            "token_uri" : "test-token-uri",
613        });
614
615        let expected = AuthorizedUser {
616            cred_type: "authorized_user".to_string(),
617            client_id: "test-client-id".to_string(),
618            client_secret: "test-client-secret".to_string(),
619            refresh_token: "test-refresh-token".to_string(),
620            quota_project_id: Some("test-project".to_string()),
621            token_uri: Some("test-token-uri".to_string()),
622        };
623        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
624        assert_eq!(actual, expected);
625    }
626
627    #[test]
628    fn authorized_user_partial_from_json_success() {
629        let json = serde_json::json!({
630            "client_id": "test-client-id",
631            "client_secret": "test-client-secret",
632            "refresh_token": "test-refresh-token",
633            "type": "authorized_user",
634        });
635
636        let expected = AuthorizedUser {
637            cred_type: "authorized_user".to_string(),
638            client_id: "test-client-id".to_string(),
639            client_secret: "test-client-secret".to_string(),
640            refresh_token: "test-refresh-token".to_string(),
641            quota_project_id: None,
642            token_uri: None,
643        };
644        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
645        assert_eq!(actual, expected);
646    }
647
648    #[test]
649    fn authorized_user_from_json_parse_fail() {
650        let json_full = serde_json::json!({
651            "client_id": "test-client-id",
652            "client_secret": "test-client-secret",
653            "refresh_token": "test-refresh-token",
654            "type": "authorized_user",
655            "quota_project_id": "test-project"
656        });
657
658        for required_field in ["client_id", "client_secret", "refresh_token"] {
659            let mut json = json_full.clone();
660            // Remove a required field from the JSON
661            json[required_field].take();
662            serde_json::from_value::<AuthorizedUser>(json)
663                .err()
664                .unwrap();
665        }
666    }
667
668    #[tokio::test]
669    async fn default_universe_domain_success() {
670        let mock = TokenCache::new(MockTokenProvider::new());
671
672        let uc = UserCredentials {
673            token_provider: mock,
674            quota_project_id: None,
675        };
676        assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
677    }
678
679    #[tokio::test]
680    async fn headers_success() -> TestResult {
681        let token = Token {
682            token: "test-token".to_string(),
683            token_type: "Bearer".to_string(),
684            expires_at: None,
685            metadata: None,
686        };
687
688        let mut mock = MockTokenProvider::new();
689        mock.expect_token().times(1).return_once(|| Ok(token));
690
691        let uc = UserCredentials {
692            token_provider: TokenCache::new(mock),
693            quota_project_id: None,
694        };
695
696        let mut extensions = Extensions::new();
697        let cached_headers = uc.headers(extensions.clone()).await.unwrap();
698        let (headers, entity_tag) = match cached_headers {
699            CacheableResource::New { entity_tag, data } => (data, entity_tag),
700            CacheableResource::NotModified => unreachable!("expecting new headers"),
701        };
702        let token = headers.get(AUTHORIZATION).unwrap();
703
704        assert_eq!(headers.len(), 1, "{headers:?}");
705        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
706        assert!(token.is_sensitive());
707
708        extensions.insert(entity_tag);
709
710        let cached_headers = uc.headers(extensions).await?;
711
712        match cached_headers {
713            CacheableResource::New { .. } => unreachable!("expecting new headers"),
714            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
715        };
716        Ok(())
717    }
718
719    #[tokio::test]
720    async fn headers_failure() {
721        let mut mock = MockTokenProvider::new();
722        mock.expect_token()
723            .times(1)
724            .return_once(|| Err(errors::non_retryable_from_str("fail")));
725
726        let uc = UserCredentials {
727            token_provider: TokenCache::new(mock),
728            quota_project_id: None,
729        };
730        assert!(uc.headers(Extensions::new()).await.is_err());
731    }
732
733    #[tokio::test]
734    async fn headers_with_quota_project_success() -> TestResult {
735        let token = Token {
736            token: "test-token".to_string(),
737            token_type: "Bearer".to_string(),
738            expires_at: None,
739            metadata: None,
740        };
741
742        let mut mock = MockTokenProvider::new();
743        mock.expect_token().times(1).return_once(|| Ok(token));
744
745        let uc = UserCredentials {
746            token_provider: TokenCache::new(mock),
747            quota_project_id: Some("test-project".to_string()),
748        };
749
750        let headers = get_headers_from_cache(uc.headers(Extensions::new()).await.unwrap())?;
751        let token = headers.get(AUTHORIZATION).unwrap();
752        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
753
754        assert_eq!(headers.len(), 2, "{headers:?}");
755        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
756        assert!(token.is_sensitive());
757        assert_eq!(
758            quota_project_header,
759            HeaderValue::from_static("test-project")
760        );
761        assert!(!quota_project_header.is_sensitive());
762        Ok(())
763    }
764
765    #[test]
766    fn oauth2_request_serde() {
767        let request = Oauth2RefreshRequest {
768            grant_type: RefreshGrantType::RefreshToken,
769            client_id: "test-client-id".to_string(),
770            client_secret: "test-client-secret".to_string(),
771            refresh_token: "test-refresh-token".to_string(),
772            scopes: Some("scope1 scope2".to_string()),
773        };
774
775        let json = serde_json::to_value(&request).unwrap();
776        let expected = serde_json::json!({
777            "grant_type": "refresh_token",
778            "client_id": "test-client-id",
779            "client_secret": "test-client-secret",
780            "refresh_token": "test-refresh-token",
781            "scopes": "scope1 scope2",
782        });
783        assert_eq!(json, expected);
784        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
785        assert_eq!(request, roundtrip);
786    }
787
788    #[test]
789    fn oauth2_response_serde_full() {
790        let response = Oauth2RefreshResponse {
791            access_token: "test-access-token".to_string(),
792            scope: Some("scope1 scope2".to_string()),
793            expires_in: Some(3600),
794            token_type: "test-token-type".to_string(),
795            refresh_token: Some("test-refresh-token".to_string()),
796        };
797
798        let json = serde_json::to_value(&response).unwrap();
799        let expected = serde_json::json!({
800            "access_token": "test-access-token",
801            "scope": "scope1 scope2",
802            "expires_in": 3600,
803            "token_type": "test-token-type",
804            "refresh_token": "test-refresh-token"
805        });
806        assert_eq!(json, expected);
807        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
808        assert_eq!(response, roundtrip);
809    }
810
811    #[test]
812    fn oauth2_response_serde_partial() {
813        let response = Oauth2RefreshResponse {
814            access_token: "test-access-token".to_string(),
815            scope: None,
816            expires_in: None,
817            token_type: "test-token-type".to_string(),
818            refresh_token: None,
819        };
820
821        let json = serde_json::to_value(&response).unwrap();
822        let expected = serde_json::json!({
823            "access_token": "test-access-token",
824            "token_type": "test-token-type",
825        });
826        assert_eq!(json, expected);
827        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
828        assert_eq!(response, roundtrip);
829    }
830
831    fn check_request(request: &Oauth2RefreshRequest, expected_scopes: Option<String>) -> bool {
832        request.client_id == "test-client-id"
833            && request.client_secret == "test-client-secret"
834            && request.refresh_token == "test-refresh-token"
835            && request.grant_type == RefreshGrantType::RefreshToken
836            && request.scopes == expected_scopes
837    }
838
839    #[tokio::test(start_paused = true)]
840    async fn token_provider_full() -> TestResult {
841        let server = Server::run();
842        let response = Oauth2RefreshResponse {
843            access_token: "test-access-token".to_string(),
844            expires_in: Some(3600),
845            refresh_token: Some("test-refresh-token".to_string()),
846            scope: Some("scope1 scope2".to_string()),
847            token_type: "test-token-type".to_string(),
848        };
849        server.expect(
850            Expectation::matching(all_of![
851                request::path("/token"),
852                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
853                    check_request(req, Some("scope1 scope2".to_string()))
854                }))
855            ])
856            .respond_with(json_encoded(response)),
857        );
858
859        let tp = UserTokenProvider {
860            client_id: "test-client-id".to_string(),
861            client_secret: "test-client-secret".to_string(),
862            refresh_token: "test-refresh-token".to_string(),
863            endpoint: server.url("/token").to_string(),
864            scopes: Some("scope1 scope2".to_string()),
865        };
866        let now = Instant::now();
867        let token = tp.token().await?;
868        assert_eq!(token.token, "test-access-token");
869        assert_eq!(token.token_type, "test-token-type");
870        assert!(
871            token
872                .expires_at
873                .is_some_and(|d| d == now + Duration::from_secs(3600)),
874            "now: {:?}, expires_at: {:?}",
875            now,
876            token.expires_at
877        );
878
879        Ok(())
880    }
881
882    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
883    async fn credential_full_with_quota_project() -> TestResult {
884        let server = Server::run();
885        let response = Oauth2RefreshResponse {
886            access_token: "test-access-token".to_string(),
887            expires_in: Some(3600),
888            refresh_token: Some("test-refresh-token".to_string()),
889            scope: None,
890            token_type: "test-token-type".to_string(),
891        };
892        server.expect(
893            Expectation::matching(all_of![
894                request::path("/token"),
895                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
896                    check_request(req, None)
897                }))
898            ])
899            .respond_with(json_encoded(response)),
900        );
901
902        let authorized_user = serde_json::json!({
903            "client_id": "test-client-id",
904            "client_secret": "test-client-secret",
905            "refresh_token": "test-refresh-token",
906            "type": "authorized_user",
907            "token_uri": server.url("/token").to_string(),
908        });
909        let cred = Builder::new(authorized_user)
910            .with_quota_project_id("test-project")
911            .build()?;
912
913        let headers = get_headers_from_cache(cred.headers(Extensions::new()).await.unwrap())?;
914        let token = headers.get(AUTHORIZATION).unwrap();
915        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
916
917        assert_eq!(headers.len(), 2, "{headers:?}");
918        assert_eq!(
919            token,
920            HeaderValue::from_static("test-token-type test-access-token")
921        );
922        assert!(token.is_sensitive());
923        assert_eq!(
924            quota_project_header,
925            HeaderValue::from_static("test-project")
926        );
927        assert!(!quota_project_header.is_sensitive());
928
929        Ok(())
930    }
931
932    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
933    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
934        let mut server = Server::run();
935        let response = Oauth2RefreshResponse {
936            access_token: "test-access-token".to_string(),
937            expires_in: Some(3600),
938            refresh_token: Some("test-refresh-token".to_string()),
939            scope: Some("scope1 scope2".to_string()),
940            token_type: "test-token-type".to_string(),
941        };
942        server.expect(
943            Expectation::matching(all_of![
944                request::path("/token"),
945                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
946                    check_request(req, Some("scope1 scope2".to_string()))
947                }))
948            ])
949            .times(1)
950            .respond_with(json_encoded(response)),
951        );
952
953        let json = serde_json::json!({
954            "client_id": "test-client-id",
955            "client_secret": "test-client-secret",
956            "refresh_token": "test-refresh-token",
957            "type": "authorized_user",
958            "universe_domain": "googleapis.com",
959            "quota_project_id": "test-project",
960            "token_uri": server.url("/token").to_string(),
961        });
962
963        let cred = Builder::new(json)
964            .with_scopes(vec!["scope1", "scope2"])
965            .build()?;
966
967        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
968        assert_eq!(token.unwrap(), "test-access-token");
969
970        let token = get_token_from_headers(cred.headers(Extensions::new()).await?);
971        assert_eq!(token.unwrap(), "test-access-token");
972
973        server.verify_and_clear();
974
975        Ok(())
976    }
977
978    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
979    async fn credential_provider_partial() -> TestResult {
980        let server = Server::run();
981        let response = Oauth2RefreshResponse {
982            access_token: "test-access-token".to_string(),
983            expires_in: None,
984            refresh_token: None,
985            scope: None,
986            token_type: "test-token-type".to_string(),
987        };
988        server.expect(
989            Expectation::matching(all_of![
990                request::path("/token"),
991                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
992                    check_request(req, None)
993                }))
994            ])
995            .respond_with(json_encoded(response)),
996        );
997
998        let authorized_user = serde_json::json!({
999            "client_id": "test-client-id",
1000            "client_secret": "test-client-secret",
1001            "refresh_token": "test-refresh-token",
1002            "type": "authorized_user",
1003            "token_uri": server.url("/token").to_string()
1004        });
1005
1006        let uc = Builder::new(authorized_user).build()?;
1007        let headers = uc.headers(Extensions::new()).await?;
1008        assert_eq!(
1009            get_token_from_headers(headers.clone()).unwrap(),
1010            "test-access-token"
1011        );
1012        assert_eq!(
1013            get_token_type_from_headers(headers).unwrap(),
1014            "test-token-type"
1015        );
1016
1017        Ok(())
1018    }
1019
1020    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1021    async fn credential_provider_with_token_uri() -> TestResult {
1022        let server = Server::run();
1023        let response = Oauth2RefreshResponse {
1024            access_token: "test-access-token".to_string(),
1025            expires_in: None,
1026            refresh_token: None,
1027            scope: None,
1028            token_type: "test-token-type".to_string(),
1029        };
1030        server.expect(
1031            Expectation::matching(all_of![
1032                request::path("/token"),
1033                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1034                    check_request(req, None)
1035                }))
1036            ])
1037            .respond_with(json_encoded(response)),
1038        );
1039
1040        let authorized_user = serde_json::json!({
1041            "client_id": "test-client-id",
1042            "client_secret": "test-client-secret",
1043            "refresh_token": "test-refresh-token",
1044            "type": "authorized_user",
1045            "token_uri": "test-endpoint"
1046        });
1047
1048        let uc = Builder::new(authorized_user)
1049            .with_token_uri(server.url("/token").to_string())
1050            .build()?;
1051        let headers = uc.headers(Extensions::new()).await?;
1052        assert_eq!(
1053            get_token_from_headers(headers.clone()).unwrap(),
1054            "test-access-token"
1055        );
1056        assert_eq!(
1057            get_token_type_from_headers(headers).unwrap(),
1058            "test-token-type"
1059        );
1060
1061        Ok(())
1062    }
1063
1064    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1065    async fn credential_provider_with_scopes() -> TestResult {
1066        let server = Server::run();
1067        let response = Oauth2RefreshResponse {
1068            access_token: "test-access-token".to_string(),
1069            expires_in: None,
1070            refresh_token: None,
1071            scope: Some("scope1 scope2".to_string()),
1072            token_type: "test-token-type".to_string(),
1073        };
1074        server.expect(
1075            Expectation::matching(all_of![
1076                request::path("/token"),
1077                request::body(json_decoded(|req: &Oauth2RefreshRequest| {
1078                    check_request(req, Some("scope1 scope2".to_string()))
1079                }))
1080            ])
1081            .respond_with(json_encoded(response)),
1082        );
1083
1084        let authorized_user = serde_json::json!({
1085            "client_id": "test-client-id",
1086            "client_secret": "test-client-secret",
1087            "refresh_token": "test-refresh-token",
1088            "type": "authorized_user",
1089            "token_uri": "test-endpoint"
1090        });
1091
1092        let uc = Builder::new(authorized_user)
1093            .with_token_uri(server.url("/token").to_string())
1094            .with_scopes(vec!["scope1", "scope2"])
1095            .build()?;
1096        let headers = uc.headers(Extensions::new()).await?;
1097        assert_eq!(
1098            get_token_from_headers(headers.clone()).unwrap(),
1099            "test-access-token"
1100        );
1101        assert_eq!(
1102            get_token_type_from_headers(headers).unwrap(),
1103            "test-token-type"
1104        );
1105
1106        Ok(())
1107    }
1108
1109    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1110    async fn credential_provider_retryable_error() -> TestResult {
1111        let server = Server::run();
1112        server
1113            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(503)));
1114
1115        let authorized_user = serde_json::json!({
1116            "client_id": "test-client-id",
1117            "client_secret": "test-client-secret",
1118            "refresh_token": "test-refresh-token",
1119            "type": "authorized_user",
1120            "token_uri": server.url("/token").to_string()
1121        });
1122
1123        let uc = Builder::new(authorized_user).build()?;
1124        let err = uc.headers(Extensions::new()).await.unwrap_err();
1125        assert!(err.is_transient(), "{err}");
1126        let source = err
1127            .source()
1128            .and_then(|e| e.downcast_ref::<reqwest::Error>());
1129        assert!(
1130            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1131            "{err:?}"
1132        );
1133
1134        Ok(())
1135    }
1136
1137    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1138    async fn token_provider_nonretryable_error() -> TestResult {
1139        let server = Server::run();
1140        server
1141            .expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
1142
1143        let authorized_user = serde_json::json!({
1144            "client_id": "test-client-id",
1145            "client_secret": "test-client-secret",
1146            "refresh_token": "test-refresh-token",
1147            "type": "authorized_user",
1148            "token_uri": server.url("/token").to_string()
1149        });
1150
1151        let uc = Builder::new(authorized_user).build()?;
1152        let err = uc.headers(Extensions::new()).await.unwrap_err();
1153        assert!(!err.is_transient(), "{err:?}");
1154        let source = err
1155            .source()
1156            .and_then(|e| e.downcast_ref::<reqwest::Error>());
1157        assert!(
1158            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1159            "{err:?}"
1160        );
1161
1162        Ok(())
1163    }
1164
1165    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1166    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
1167        let server = Server::run();
1168        server.expect(
1169            Expectation::matching(request::path("/token"))
1170                .respond_with(json_encoded("bad json".to_string())),
1171        );
1172
1173        let authorized_user = serde_json::json!({
1174            "client_id": "test-client-id",
1175            "client_secret": "test-client-secret",
1176            "refresh_token": "test-refresh-token",
1177            "type": "authorized_user",
1178            "token_uri": server.url("/token").to_string()
1179        });
1180
1181        let uc = Builder::new(authorized_user).build()?;
1182        let e = uc.headers(Extensions::new()).await.err().unwrap();
1183        assert!(!e.is_transient(), "{e}");
1184
1185        Ok(())
1186    }
1187
1188    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1189    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
1190        let authorized_user = serde_json::json!({
1191            "client_secret": "test-client-secret",
1192            "refresh_token": "test-refresh-token",
1193            "type": "authorized_user",
1194        });
1195
1196        let e = Builder::new(authorized_user).build().unwrap_err();
1197        assert!(e.is_parsing(), "{e}");
1198
1199        Ok(())
1200    }
1201}