google_cloud_auth/credentials/
user_account.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [User Account] Credentials type.
16//!
17//! User accounts represent a developer, administrator, or any other person who
18//! interacts with Google APIs and services. User accounts are managed as
19//! [Google Accounts], either via [Google Workspace] or [Cloud Identity].
20//!
21//! This module provides [Credentials] derived from user account
22//! information, specifically utilizing an OAuth 2.0 refresh token.
23//!
24//! This module is designed for refresh tokens obtained via the standard
25//! [Authorization Code grant]. Acquiring the initial refresh token (e.g., through
26//! user consent) is outside the scope of this library.
27//! See [RFC 6749 Section 4.1] for flow details.
28//!
29//! The Google Cloud client libraries for Rust will typically find and use these
30//! credentials automatically if a credentials file exists in the
31//! standard ADC search paths. This file is often created by running:
32//! `gcloud auth application-default login`. You might instantiate these credentials
33//! directly using the [`Builder`] if you need to:
34//! * Load credentials from a non-standard location or source.
35//! * Override the OAuth 2.0 **scopes** being requested for the access token.
36//! * Override the **quota project ID** for billing and quota management.
37//! * Override the **token URI** used to fetch access tokens.
38//!
39//! Example usage:
40//!
41//! ```
42//! # use google_cloud_auth::credentials::user_account::Builder;
43//! # use google_cloud_auth::credentials::Credentials;
44//! # use google_cloud_auth::errors::CredentialsError;
45//! # use http::Extensions;
46//! # tokio_test::block_on(async {
47//! let authorized_user = serde_json::json!({
48//!     "client_id": "YOUR_CLIENT_ID.apps.googleusercontent.com", // Replace with your actual Client ID
49//!     "client_secret": "YOUR_CLIENT_SECRET", // Replace with your actual Client Secret - LOAD SECURELY!
50//!     "refresh_token": "YOUR_REFRESH_TOKEN", // Replace with the user's refresh token - LOAD SECURELY!
51//!     "type": "authorized_user",
52//!     // "quota_project_id": "your-billing-project-id", // Optional: Set if needed
53//!     // "token_uri" : "test-token-uri", // Optional: Set if needed
54//! });
55//! let credentials: Credentials = Builder::new(authorized_user).build()?;
56//! let headers = credentials.headers(Extensions::new()).await?;
57//! println!("Headers: {headers:?}");
58//! # Ok::<(), CredentialsError>(())
59//! # });
60//! ```
61//!
62//! [Authorization Code grant]: https://tools.ietf.org/html/rfc6749#section-1.3.1
63//! [Cloud Identity]: https://cloud.google.com/identity
64//! [Google Accounts]: https://myaccount.google.com/
65//! [Google Workspace]: https://workspace.google.com/
66//! [RFC 6749 Section 4.1]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.1
67//! [User Account]: https://cloud.google.com/docs/authentication#user-accounts
68//! [Workforce Identity Federation]: https://cloud.google.com/iam/docs/workforce-identity-federation
69
70use crate::credentials::dynamic::CredentialsProvider;
71use crate::credentials::{Credentials, Result};
72use crate::errors::{self, CredentialsError, is_retryable};
73use crate::headers_util::build_bearer_headers;
74use crate::token::{CachedTokenProvider, Token, TokenProvider};
75use crate::token_cache::TokenCache;
76use http::header::CONTENT_TYPE;
77use http::{Extensions, HeaderMap, HeaderValue};
78use reqwest::{Client, Method};
79use serde_json::Value;
80use std::sync::Arc;
81use tokio::time::{Duration, Instant};
82
83const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
84
85/// A builder for constructing `user_account` [Credentials] instance.
86///
87/// # Example
88/// ```
89/// # use google_cloud_auth::credentials::user_account::Builder;
90/// # tokio_test::block_on(async {
91/// let authorized_user = serde_json::json!({ /* add details here */ });
92/// let credentials = Builder::new(authorized_user).build();
93/// })
94/// ```
95pub struct Builder {
96    authorized_user: Value,
97    scopes: Option<Vec<String>>,
98    quota_project_id: Option<String>,
99    token_uri: Option<String>,
100}
101
102impl Builder {
103    /// Creates a new builder using `authorized_user` JSON value.
104    ///
105    /// The `authorized_user` JSON is typically generated when a user
106    /// authenticates using the [application-default login] process.
107    ///
108    /// [application-default login]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login
109    pub fn new(authorized_user: Value) -> Self {
110        Self {
111            authorized_user,
112            scopes: None,
113            quota_project_id: None,
114            token_uri: None,
115        }
116    }
117
118    /// Sets the URI for the token endpoint used to fetch access tokens.
119    ///
120    /// Any value provided here overrides a `token_uri` value from the input `authorized_user` JSON.
121    /// Defaults to `https://oauth2.googleapis.com/token` if not specified here or in the `authorized_user` JSON.
122    ///
123    /// # Example
124    /// ```
125    /// # use google_cloud_auth::credentials::user_account::Builder;
126    /// let authorized_user = serde_json::json!({ /* add details here */ });
127    /// let credentials = Builder::new(authorized_user)
128    ///     .with_token_uri("https://oauth2-FOOBAR.p.googleapis.com")
129    ///     .build();
130    /// ```
131    pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
132        self.token_uri = Some(token_uri.into());
133        self
134    }
135
136    /// Sets the [scopes] for these credentials.
137    ///
138    /// `scopes` define the *permissions being requested* for this specific access token
139    /// when interacting with a service. For example, `https://www.googleapis.com/auth/devstorage.read_write`.
140    /// IAM permissions, on the other hand, define the *underlying capabilities*
141    /// the user account possesses within a system. For example, `storage.buckets.delete`.
142    /// When a token generated with specific scopes is used, the request must be permitted
143    /// by both the user account's underlying IAM permissions and the scopes requested
144    /// for the token. Therefore, scopes act as an additional restriction on what the token
145    /// can be used for.
146    ///
147    /// # Example
148    /// ```
149    /// # use google_cloud_auth::credentials::user_account::Builder;
150    /// let authorized_user = serde_json::json!({ /* add details here */ });
151    /// let credentials = Builder::new(authorized_user)
152    ///     .with_scopes(["https://www.googleapis.com/auth/pubsub"])
153    ///     .build();
154    /// ```
155    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
156    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
157    where
158        I: IntoIterator<Item = S>,
159        S: Into<String>,
160    {
161        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
162        self
163    }
164
165    /// Sets the [quota project] for these credentials.
166    ///
167    /// In some services, you can use an account in
168    /// one project for authentication and authorization, and charge
169    /// the usage to a different project. This requires that the
170    /// user has `serviceusage.services.use` permissions on the quota project.
171    ///
172    /// Any value set here overrides a `quota_project_id` value from the
173    /// input `authorized_user` JSON.
174    ///
175    /// # Example
176    /// ```
177    /// # use google_cloud_auth::credentials::user_account::Builder;
178    /// let authorized_user = serde_json::json!("{ /* add details here */ }");
179    /// let credentials = Builder::new(authorized_user)
180    ///     .with_quota_project_id("my-project")
181    ///     .build();
182    /// ```
183    ///
184    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
185    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
186        self.quota_project_id = Some(quota_project_id.into());
187        self
188    }
189
190    /// Returns a [Credentials] instance with the configured settings.
191    ///
192    /// # Errors
193    ///
194    /// Returns a [CredentialsError] if the `authorized_user`
195    /// provided to [`Builder::new`] cannot be successfully deserialized into the
196    /// expected format. This typically happens if the JSON value is malformed or
197    /// missing required fields. For more information, on how to generate
198    /// `authorized_user` json, consult the relevant section in the
199    /// [application-default credentials] guide.
200    ///
201    /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
202    pub fn build(self) -> Result<Credentials> {
203        let authorized_user = serde_json::from_value::<AuthorizedUser>(self.authorized_user)
204            .map_err(errors::non_retryable)?;
205        let endpoint = self
206            .token_uri
207            .or(authorized_user.token_uri)
208            .unwrap_or(OAUTH2_ENDPOINT.to_string());
209        let quota_project_id = self.quota_project_id.or(authorized_user.quota_project_id);
210
211        let token_provider = UserTokenProvider {
212            client_id: authorized_user.client_id,
213            client_secret: authorized_user.client_secret,
214            refresh_token: authorized_user.refresh_token,
215            endpoint,
216            scopes: self.scopes.map(|scopes| scopes.join(" ")),
217        };
218        let token_provider = TokenCache::new(token_provider);
219
220        Ok(Credentials {
221            inner: Arc::new(UserCredentials {
222                token_provider,
223                quota_project_id,
224            }),
225        })
226    }
227}
228
229#[derive(PartialEq)]
230struct UserTokenProvider {
231    client_id: String,
232    client_secret: String,
233    refresh_token: String,
234    endpoint: String,
235    scopes: Option<String>,
236}
237
238impl std::fmt::Debug for UserTokenProvider {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        f.debug_struct("UserCredentials")
241            .field("client_id", &self.client_id)
242            .field("client_secret", &"[censored]")
243            .field("refresh_token", &"[censored]")
244            .field("endpoint", &self.endpoint)
245            .field("scopes", &self.scopes)
246            .finish()
247    }
248}
249
250#[async_trait::async_trait]
251impl TokenProvider for UserTokenProvider {
252    async fn token(&self) -> Result<Token> {
253        let client = Client::new();
254
255        // Make the request
256        let req = Oauth2RefreshRequest {
257            grant_type: RefreshGrantType::RefreshToken,
258            client_id: self.client_id.clone(),
259            client_secret: self.client_secret.clone(),
260            refresh_token: self.refresh_token.clone(),
261            scopes: self.scopes.clone(),
262        };
263        let header = HeaderValue::from_static("application/json");
264        let builder = client
265            .request(Method::POST, self.endpoint.as_str())
266            .header(CONTENT_TYPE, header)
267            .json(&req);
268        let resp = builder.send().await.map_err(errors::retryable)?;
269
270        // Process the response
271        if !resp.status().is_success() {
272            let status = resp.status();
273            let body = resp
274                .text()
275                .await
276                .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
277            return Err(CredentialsError::from_str(
278                is_retryable(status),
279                format!("Failed to fetch token. {body}"),
280            ));
281        }
282        let response = resp.json::<Oauth2RefreshResponse>().await.map_err(|e| {
283            let retryable = !e.is_decode();
284            CredentialsError::new(retryable, e)
285        })?;
286        let token = Token {
287            token: response.access_token,
288            token_type: response.token_type,
289            expires_at: response
290                .expires_in
291                .map(|d| Instant::now() + Duration::from_secs(d)),
292            metadata: None,
293        };
294        Ok(token)
295    }
296}
297
298/// Data model for a UserCredentials
299///
300/// See: https://cloud.google.com/docs/authentication#user-accounts
301#[derive(Debug)]
302pub(crate) struct UserCredentials<T>
303where
304    T: CachedTokenProvider,
305{
306    token_provider: T,
307    quota_project_id: Option<String>,
308}
309
310#[async_trait::async_trait]
311impl<T> CredentialsProvider for UserCredentials<T>
312where
313    T: CachedTokenProvider,
314{
315    async fn headers(&self, extensions: Extensions) -> Result<HeaderMap> {
316        let token = self.token_provider.token(extensions).await?;
317        build_bearer_headers(&token, &self.quota_project_id)
318    }
319}
320
321#[derive(Debug, PartialEq, serde::Deserialize)]
322pub(crate) struct AuthorizedUser {
323    #[serde(rename = "type")]
324    cred_type: String,
325    client_id: String,
326    client_secret: String,
327    refresh_token: String,
328    #[serde(skip_serializing_if = "Option::is_none")]
329    token_uri: Option<String>,
330    #[serde(skip_serializing_if = "Option::is_none")]
331    quota_project_id: Option<String>,
332}
333
334#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
335enum RefreshGrantType {
336    #[serde(rename = "refresh_token")]
337    RefreshToken,
338}
339
340#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
341struct Oauth2RefreshRequest {
342    grant_type: RefreshGrantType,
343    client_id: String,
344    client_secret: String,
345    refresh_token: String,
346    scopes: Option<String>,
347}
348
349#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
350struct Oauth2RefreshResponse {
351    access_token: String,
352    #[serde(skip_serializing_if = "Option::is_none")]
353    scope: Option<String>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    expires_in: Option<u64>,
356    token_type: String,
357    #[serde(skip_serializing_if = "Option::is_none")]
358    refresh_token: Option<String>,
359}
360
361#[cfg(test)]
362mod test {
363    use super::*;
364    use crate::credentials::test::{get_token_from_headers, get_token_type_from_headers};
365    use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY};
366    use crate::token::test::MockTokenProvider;
367    use axum::extract::Json;
368    use http::StatusCode;
369    use http::header::AUTHORIZATION;
370    use std::error::Error;
371    use std::sync::Mutex;
372    use tokio::task::JoinHandle;
373
374    type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
375
376    #[test]
377    fn debug_token_provider() {
378        let expected = UserTokenProvider {
379            client_id: "test-client-id".to_string(),
380            client_secret: "test-client-secret".to_string(),
381            refresh_token: "test-refresh-token".to_string(),
382            endpoint: OAUTH2_ENDPOINT.to_string(),
383            scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()),
384        };
385        let fmt = format!("{expected:?}");
386        assert!(fmt.contains("test-client-id"), "{fmt}");
387        assert!(!fmt.contains("test-client-secret"), "{fmt}");
388        assert!(!fmt.contains("test-refresh-token"), "{fmt}");
389        assert!(fmt.contains(OAUTH2_ENDPOINT), "{fmt}");
390        assert!(
391            fmt.contains("https://www.googleapis.com/auth/pubsub"),
392            "{fmt}"
393        );
394    }
395
396    #[test]
397    fn authorized_user_full_from_json_success() {
398        let json = serde_json::json!({
399            "account": "",
400            "client_id": "test-client-id",
401            "client_secret": "test-client-secret",
402            "refresh_token": "test-refresh-token",
403            "type": "authorized_user",
404            "universe_domain": "googleapis.com",
405            "quota_project_id": "test-project",
406            "token_uri" : "test-token-uri",
407        });
408
409        let expected = AuthorizedUser {
410            cred_type: "authorized_user".to_string(),
411            client_id: "test-client-id".to_string(),
412            client_secret: "test-client-secret".to_string(),
413            refresh_token: "test-refresh-token".to_string(),
414            quota_project_id: Some("test-project".to_string()),
415            token_uri: Some("test-token-uri".to_string()),
416        };
417        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
418        assert_eq!(actual, expected);
419    }
420
421    #[test]
422    fn authorized_user_partial_from_json_success() {
423        let json = serde_json::json!({
424            "client_id": "test-client-id",
425            "client_secret": "test-client-secret",
426            "refresh_token": "test-refresh-token",
427            "type": "authorized_user",
428        });
429
430        let expected = AuthorizedUser {
431            cred_type: "authorized_user".to_string(),
432            client_id: "test-client-id".to_string(),
433            client_secret: "test-client-secret".to_string(),
434            refresh_token: "test-refresh-token".to_string(),
435            quota_project_id: None,
436            token_uri: None,
437        };
438        let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
439        assert_eq!(actual, expected);
440    }
441
442    #[test]
443    fn authorized_user_from_json_parse_fail() {
444        let json_full = serde_json::json!({
445            "client_id": "test-client-id",
446            "client_secret": "test-client-secret",
447            "refresh_token": "test-refresh-token",
448            "type": "authorized_user",
449            "quota_project_id": "test-project"
450        });
451
452        for required_field in ["client_id", "client_secret", "refresh_token"] {
453            let mut json = json_full.clone();
454            // Remove a required field from the JSON
455            json[required_field].take();
456            serde_json::from_value::<AuthorizedUser>(json)
457                .err()
458                .unwrap();
459        }
460    }
461
462    #[tokio::test]
463    async fn default_universe_domain_success() {
464        let mock = TokenCache::new(MockTokenProvider::new());
465
466        let uc = UserCredentials {
467            token_provider: mock,
468            quota_project_id: None,
469        };
470        assert_eq!(uc.universe_domain().await.unwrap(), DEFAULT_UNIVERSE_DOMAIN);
471    }
472
473    #[tokio::test]
474    async fn headers_success() {
475        let token = Token {
476            token: "test-token".to_string(),
477            token_type: "Bearer".to_string(),
478            expires_at: None,
479            metadata: None,
480        };
481
482        let mut mock = MockTokenProvider::new();
483        mock.expect_token().times(1).return_once(|| Ok(token));
484
485        let uc = UserCredentials {
486            token_provider: TokenCache::new(mock),
487            quota_project_id: None,
488        };
489
490        let headers = uc.headers(Extensions::new()).await.unwrap();
491        let token = headers.get(AUTHORIZATION).unwrap();
492
493        assert_eq!(headers.len(), 1, "{headers:?}");
494        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
495        assert!(token.is_sensitive());
496    }
497
498    #[tokio::test]
499    async fn headers_failure() {
500        let mut mock = MockTokenProvider::new();
501        mock.expect_token()
502            .times(1)
503            .return_once(|| Err(errors::non_retryable_from_str("fail")));
504
505        let uc = UserCredentials {
506            token_provider: TokenCache::new(mock),
507            quota_project_id: None,
508        };
509        assert!(uc.headers(Extensions::new()).await.is_err());
510    }
511
512    #[tokio::test]
513    async fn headers_with_quota_project_success() {
514        let token = Token {
515            token: "test-token".to_string(),
516            token_type: "Bearer".to_string(),
517            expires_at: None,
518            metadata: None,
519        };
520
521        let mut mock = MockTokenProvider::new();
522        mock.expect_token().times(1).return_once(|| Ok(token));
523
524        let uc = UserCredentials {
525            token_provider: TokenCache::new(mock),
526            quota_project_id: Some("test-project".to_string()),
527        };
528
529        let headers = uc.headers(Extensions::new()).await.unwrap();
530        let token = headers.get(AUTHORIZATION).unwrap();
531        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
532
533        assert_eq!(headers.len(), 2, "{headers:?}");
534        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
535        assert!(token.is_sensitive());
536        assert_eq!(
537            quota_project_header,
538            HeaderValue::from_static("test-project")
539        );
540        assert!(!quota_project_header.is_sensitive());
541    }
542
543    #[test]
544    fn oauth2_request_serde() {
545        let request = Oauth2RefreshRequest {
546            grant_type: RefreshGrantType::RefreshToken,
547            client_id: "test-client-id".to_string(),
548            client_secret: "test-client-secret".to_string(),
549            refresh_token: "test-refresh-token".to_string(),
550            scopes: Some("scope1 scope2".to_string()),
551        };
552
553        let json = serde_json::to_value(&request).unwrap();
554        let expected = serde_json::json!({
555            "grant_type": "refresh_token",
556            "client_id": "test-client-id",
557            "client_secret": "test-client-secret",
558            "refresh_token": "test-refresh-token",
559            "scopes": "scope1 scope2",
560        });
561        assert_eq!(json, expected);
562        let roundtrip = serde_json::from_value::<Oauth2RefreshRequest>(json).unwrap();
563        assert_eq!(request, roundtrip);
564    }
565
566    #[test]
567    fn oauth2_response_serde_full() {
568        let response = Oauth2RefreshResponse {
569            access_token: "test-access-token".to_string(),
570            scope: Some("scope1 scope2".to_string()),
571            expires_in: Some(3600),
572            token_type: "test-token-type".to_string(),
573            refresh_token: Some("test-refresh-token".to_string()),
574        };
575
576        let json = serde_json::to_value(&response).unwrap();
577        let expected = serde_json::json!({
578            "access_token": "test-access-token",
579            "scope": "scope1 scope2",
580            "expires_in": 3600,
581            "token_type": "test-token-type",
582            "refresh_token": "test-refresh-token"
583        });
584        assert_eq!(json, expected);
585        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
586        assert_eq!(response, roundtrip);
587    }
588
589    #[test]
590    fn oauth2_response_serde_partial() {
591        let response = Oauth2RefreshResponse {
592            access_token: "test-access-token".to_string(),
593            scope: None,
594            expires_in: None,
595            token_type: "test-token-type".to_string(),
596            refresh_token: None,
597        };
598
599        let json = serde_json::to_value(&response).unwrap();
600        let expected = serde_json::json!({
601            "access_token": "test-access-token",
602            "token_type": "test-token-type",
603        });
604        assert_eq!(json, expected);
605        let roundtrip = serde_json::from_value::<Oauth2RefreshResponse>(json).unwrap();
606        assert_eq!(response, roundtrip);
607    }
608
609    // Starts a server running locally. Returns an (endpoint, handler) pair.
610    async fn start(
611        response_code: StatusCode,
612        response_body: Value,
613        call_count: Arc<Mutex<i32>>,
614    ) -> (String, JoinHandle<()>) {
615        let code = response_code;
616        let body = response_body.clone();
617        let handler = move |req| async move { handle_token_factory(code, body, call_count)(req) };
618        let app = axum::Router::new().route("/token", axum::routing::post(handler));
619        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
620        let addr = listener.local_addr().unwrap();
621        let server = tokio::spawn(async {
622            axum::serve(listener, app).await.unwrap();
623        });
624
625        (
626            format!("http://{}:{}/token", addr.ip(), addr.port()),
627            server,
628        )
629    }
630
631    // Creates a handler that
632    // - verifies fields in an Oauth2RefreshRequest
633    // - returns a pre-canned HTTP response
634    fn handle_token_factory(
635        response_code: StatusCode,
636        response_body: Value,
637        call_count: Arc<std::sync::Mutex<i32>>,
638    ) -> impl Fn(Json<Oauth2RefreshRequest>) -> (StatusCode, String) {
639        move |request: Json<Oauth2RefreshRequest>| -> (StatusCode, String) {
640            let mut count = call_count.lock().unwrap();
641            *count += 1;
642            assert_eq!(request.client_id, "test-client-id");
643            assert_eq!(request.client_secret, "test-client-secret");
644            assert_eq!(request.refresh_token, "test-refresh-token");
645            assert_eq!(request.grant_type, RefreshGrantType::RefreshToken);
646            assert_eq!(
647                request.scopes,
648                response_body["scope"].as_str().map(|s| s.to_owned())
649            );
650
651            (response_code, response_body.to_string())
652        }
653    }
654
655    #[tokio::test(start_paused = true)]
656    async fn token_provider_full() -> TestResult {
657        let response = Oauth2RefreshResponse {
658            access_token: "test-access-token".to_string(),
659            expires_in: Some(3600),
660            refresh_token: Some("test-refresh-token".to_string()),
661            scope: Some("scope1 scope2".to_string()),
662            token_type: "test-token-type".to_string(),
663        };
664        let response_body = serde_json::to_value(&response).unwrap();
665        let (endpoint, _server) =
666            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
667        println!("endpoint = {endpoint}");
668
669        let tp = UserTokenProvider {
670            client_id: "test-client-id".to_string(),
671            client_secret: "test-client-secret".to_string(),
672            refresh_token: "test-refresh-token".to_string(),
673            endpoint,
674            scopes: Some("scope1 scope2".to_string()),
675        };
676        let now = Instant::now();
677        let token = tp.token().await?;
678        assert_eq!(token.token, "test-access-token");
679        assert_eq!(token.token_type, "test-token-type");
680        assert!(
681            token
682                .expires_at
683                .is_some_and(|d| d == now + Duration::from_secs(3600)),
684            "now: {:?}, expires_at: {:?}",
685            now,
686            token.expires_at
687        );
688
689        Ok(())
690    }
691
692    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
693    async fn credential_full_with_quota_project() -> TestResult {
694        let response = Oauth2RefreshResponse {
695            access_token: "test-access-token".to_string(),
696            expires_in: Some(3600),
697            refresh_token: Some("test-refresh-token".to_string()),
698            scope: None,
699            token_type: "test-token-type".to_string(),
700        };
701        let response_body = serde_json::to_value(&response).unwrap();
702        let (endpoint, _server) =
703            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
704        println!("endpoint = {endpoint}");
705
706        let authorized_user = serde_json::json!({
707            "client_id": "test-client-id",
708            "client_secret": "test-client-secret",
709            "refresh_token": "test-refresh-token",
710            "type": "authorized_user",
711            "token_uri": endpoint,
712        });
713        let cred = Builder::new(authorized_user)
714            .with_quota_project_id("test-project")
715            .build()?;
716
717        let headers = cred.headers(Extensions::new()).await.unwrap();
718        let token = headers.get(AUTHORIZATION).unwrap();
719        let quota_project_header = headers.get(QUOTA_PROJECT_KEY).unwrap();
720
721        assert_eq!(headers.len(), 2, "{headers:?}");
722        assert_eq!(
723            token,
724            HeaderValue::from_static("test-token-type test-access-token")
725        );
726        assert!(token.is_sensitive());
727        assert_eq!(
728            quota_project_header,
729            HeaderValue::from_static("test-project")
730        );
731        assert!(!quota_project_header.is_sensitive());
732
733        Ok(())
734    }
735
736    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
737    async fn creds_from_json_custom_uri_with_caching() -> TestResult {
738        let response = Oauth2RefreshResponse {
739            access_token: "test-access-token".to_string(),
740            expires_in: Some(3600),
741            refresh_token: Some("test-refresh-token".to_string()),
742            scope: Some("scope1 scope2".to_string()),
743            token_type: "test-token-type".to_string(),
744        };
745        let response_body = serde_json::to_value(&response).unwrap();
746        let call_count = Arc::new(Mutex::new(0));
747        let (endpoint, _server) = start(StatusCode::OK, response_body, call_count.clone()).await;
748        println!("endpoint = {endpoint}");
749
750        let json = serde_json::json!({
751            "client_id": "test-client-id",
752            "client_secret": "test-client-secret",
753            "refresh_token": "test-refresh-token",
754            "type": "authorized_user",
755            "universe_domain": "googleapis.com",
756            "quota_project_id": "test-project",
757            "token_uri": endpoint,
758        });
759
760        let cred = Builder::new(json)
761            .with_scopes(vec!["scope1", "scope2"])
762            .build()?;
763
764        let token = get_token_from_headers(&cred.headers(Extensions::new()).await?);
765        assert_eq!(token.unwrap(), "test-access-token");
766
767        let token = get_token_from_headers(&cred.headers(Extensions::new()).await?);
768        assert_eq!(token.unwrap(), "test-access-token");
769
770        // Test that the inner token provider was called only
771        // once even though token was called twice.
772        assert_eq!(*call_count.lock().unwrap(), 1);
773
774        Ok(())
775    }
776
777    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
778    async fn credential_provider_partial() -> TestResult {
779        let response = Oauth2RefreshResponse {
780            access_token: "test-access-token".to_string(),
781            expires_in: None,
782            refresh_token: None,
783            scope: None,
784            token_type: "test-token-type".to_string(),
785        };
786        let response_body = serde_json::to_value(&response).unwrap();
787        let (endpoint, _server) =
788            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
789        println!("endpoint = {endpoint}");
790
791        let authorized_user = serde_json::json!({
792            "client_id": "test-client-id",
793            "client_secret": "test-client-secret",
794            "refresh_token": "test-refresh-token",
795            "type": "authorized_user",
796            "token_uri": endpoint});
797
798        let uc = Builder::new(authorized_user).build()?;
799        let headers = uc.headers(Extensions::new()).await?;
800        assert_eq!(
801            get_token_from_headers(&headers).unwrap(),
802            "test-access-token"
803        );
804        assert_eq!(
805            get_token_type_from_headers(&headers).unwrap(),
806            "test-token-type"
807        );
808
809        Ok(())
810    }
811
812    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
813    async fn credential_provider_with_token_uri() -> TestResult {
814        let response = Oauth2RefreshResponse {
815            access_token: "test-access-token".to_string(),
816            expires_in: None,
817            refresh_token: None,
818            scope: None,
819            token_type: "test-token-type".to_string(),
820        };
821        let response_body = serde_json::to_value(&response).unwrap();
822        let (endpoint, _server) =
823            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
824        println!("endpoint = {endpoint}");
825
826        let authorized_user = serde_json::json!({
827            "client_id": "test-client-id",
828            "client_secret": "test-client-secret",
829            "refresh_token": "test-refresh-token",
830            "type": "authorized_user",
831            "token_uri": "test-endpoint"});
832
833        let uc = Builder::new(authorized_user)
834            .with_token_uri(endpoint)
835            .build()?;
836        let headers = uc.headers(Extensions::new()).await?;
837        assert_eq!(
838            get_token_from_headers(&headers).unwrap(),
839            "test-access-token"
840        );
841        assert_eq!(
842            get_token_type_from_headers(&headers).unwrap(),
843            "test-token-type"
844        );
845
846        Ok(())
847    }
848
849    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
850    async fn credential_provider_with_scopes() -> TestResult {
851        let response = Oauth2RefreshResponse {
852            access_token: "test-access-token".to_string(),
853            expires_in: None,
854            refresh_token: None,
855            scope: Some("scope1 scope2".to_string()),
856            token_type: "test-token-type".to_string(),
857        };
858        let response_body = serde_json::to_value(&response).unwrap();
859        let (endpoint, _server) =
860            start(StatusCode::OK, response_body, Arc::new(Mutex::new(0))).await;
861        println!("endpoint = {endpoint}");
862
863        let authorized_user = serde_json::json!({
864            "client_id": "test-client-id",
865            "client_secret": "test-client-secret",
866            "refresh_token": "test-refresh-token",
867            "type": "authorized_user",
868            "token_uri": "test-endpoint"});
869
870        let uc = Builder::new(authorized_user)
871            .with_token_uri(endpoint)
872            .with_scopes(vec!["scope1", "scope2"])
873            .build()?;
874        let headers = uc.headers(Extensions::new()).await?;
875        assert_eq!(
876            get_token_from_headers(&headers).unwrap(),
877            "test-access-token"
878        );
879        assert_eq!(
880            get_token_type_from_headers(&headers).unwrap(),
881            "test-token-type"
882        );
883
884        Ok(())
885    }
886
887    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
888    async fn credential_provider_retryable_error() -> TestResult {
889        let (endpoint, _server) = start(
890            StatusCode::SERVICE_UNAVAILABLE,
891            serde_json::to_value("try again".to_string())?,
892            Arc::new(Mutex::new(0)),
893        )
894        .await;
895        println!("endpoint = {endpoint}");
896
897        let authorized_user = serde_json::json!({
898            "client_id": "test-client-id",
899            "client_secret": "test-client-secret",
900            "refresh_token": "test-refresh-token",
901            "type": "authorized_user",
902            "token_uri": endpoint});
903
904        let uc = Builder::new(authorized_user).build()?;
905        let e = uc.headers(Extensions::new()).await.err().unwrap();
906        assert!(e.is_retryable(), "{e}");
907        assert!(e.source().unwrap().to_string().contains("try again"), "{e}");
908
909        Ok(())
910    }
911
912    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
913    async fn token_provider_nonretryable_error() -> TestResult {
914        let (endpoint, _server) = start(
915            StatusCode::UNAUTHORIZED,
916            serde_json::to_value("epic fail".to_string())?,
917            Arc::new(Mutex::new(0)),
918        )
919        .await;
920        println!("endpoint = {endpoint}");
921
922        let authorized_user = serde_json::json!({
923            "client_id": "test-client-id",
924            "client_secret": "test-client-secret",
925            "refresh_token": "test-refresh-token",
926            "type": "authorized_user",
927            "token_uri": endpoint});
928
929        let uc = Builder::new(authorized_user).build()?;
930        let e = uc.headers(Extensions::new()).await.err().unwrap();
931        assert!(!e.is_retryable(), "{e}");
932        assert!(e.source().unwrap().to_string().contains("epic fail"), "{e}");
933
934        Ok(())
935    }
936
937    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
938    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
939        let (endpoint, _server) = start(
940            StatusCode::OK,
941            serde_json::to_value("bad json".to_string())?,
942            Arc::new(Mutex::new(0)),
943        )
944        .await;
945        println!("endpoint = {endpoint}");
946
947        let authorized_user = serde_json::json!({
948            "client_id": "test-client-id",
949            "client_secret": "test-client-secret",
950            "refresh_token": "test-refresh-token",
951            "type": "authorized_user",
952            "token_uri": endpoint});
953
954        let uc = Builder::new(authorized_user).build()?;
955        let e = uc.headers(Extensions::new()).await.err().unwrap();
956        assert!(!e.is_retryable(), "{e}");
957
958        Ok(())
959    }
960
961    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
962    async fn builder_malformed_authorized_json_nonretryable() -> TestResult {
963        let authorized_user = serde_json::json!({
964        "client_secret": "test-client-secret",
965        "refresh_token": "test-refresh-token",
966        "type": "authorized_user",
967        });
968
969        let e = Builder::new(authorized_user).build().unwrap_err();
970        assert!(!e.is_retryable(), "{e}");
971
972        Ok(())
973    }
974}