Skip to main content

qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16    ClientConfiguration, ConfigSource, TokenError, oidc, secrets::Secrets, settings::AuthServer,
17};
18use crate::configuration::{
19    error::DiscoveryError,
20    pkce::{PkceLoginError, PkceLoginRequest, pkce_login},
21    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30/// A single type containing an access token and an associated refresh token.
31#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38    /// The token used to refresh the access token.
39    pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43    /// Create a new [`RefreshToken`] with the given refresh token.
44    #[must_use]
45    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46        Self { refresh_token }
47    }
48
49    /// Request and return a new access token from the given authorization server using this refresh token.
50    ///
51    /// # Errors
52    ///
53    /// See [`TokenError`]
54    pub async fn request_access_token(
55        &mut self,
56        auth_server: &AuthServer,
57    ) -> Result<SecretAccessToken, TokenError> {
58        if self.refresh_token.is_empty() {
59            return Err(TokenError::NoRefreshToken);
60        }
61
62        let client = default_http_client()?;
63        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64            .await?
65            .token_endpoint;
66        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67        let resp = client.post(token_url).form(&data).send().await?;
68
69        let RefreshTokenResponse {
70            access_token,
71            refresh_token,
72        } = resp.error_for_status()?.json().await?;
73
74        if let Some(refresh_token) = refresh_token {
75            self.refresh_token = refresh_token;
76        }
77        Ok(access_token)
78    }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83    pub(super) access_token: SecretAccessToken,
84}
85
86/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
87#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90    feature = "python",
91    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94    /// The client ID
95    pub client_id: String,
96    /// The client secret.
97    pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101    #[must_use]
102    /// Construct a new [`ClientCredentials`]
103    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104        Self {
105            client_id: client_id.into(),
106            client_secret: client_secret.into(),
107        }
108    }
109
110    /// Get the client ID.
111    #[must_use]
112    pub fn client_id(&self) -> &str {
113        &self.client_id
114    }
115
116    /// Get the client secret.
117    #[must_use]
118    pub const fn client_secret(&self) -> &ClientSecret {
119        &self.client_secret
120    }
121
122    /// Request and return an access token from the given auth server using this set of client credentials.
123    ///
124    /// # Errors
125    ///
126    /// See [`TokenError`]
127    pub async fn request_access_token(
128        &self,
129        auth_server: &AuthServer,
130    ) -> Result<SecretAccessToken, TokenError> {
131        let request = ClientCredentialsRequest::new(None);
132        let client = default_http_client()?;
133
134        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135            .await?
136            .token_endpoint;
137        let ready_to_send = client
138            .post(url)
139            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140            .form(&request);
141        let response = ready_to_send.send().await?;
142
143        response.error_for_status_ref()?;
144
145        let ClientCredentialsResponse { access_token } = response.json().await?;
146        Ok(access_token)
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153    feature = "python",
154    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
157pub struct PkceFlow {
158    /// The access token.
159    pub access_token: SecretAccessToken,
160    /// The refresh token, if available.
161    pub refresh_token: Option<RefreshToken>,
162}
163
164/// Errors that can occur when attempting to perform a PKCE login flow.
165#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167    /// Error that occurred while performing the PKCE login flow.
168    #[error(transparent)]
169    PkceLogin(#[from] PkceLoginError),
170    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
171    #[error(transparent)]
172    Discovery(#[from] DiscoveryError),
173    /// Error that occurred while making http requests.
174    #[error(transparent)]
175    Request(#[from] qcs_dependencies_client::reqwest::Error),
176}
177
178impl PkceFlow {
179    /// Starts a new PKCE login flow to acquire a new set of tokens.
180    ///
181    /// # Errors
182    ///
183    /// See [`PkceFlowError`]
184    pub async fn new_login_flow(
185        cancel_token: CancellationToken,
186        auth_server: &AuthServer,
187    ) -> Result<Self, PkceFlowError> {
188        let issuer = auth_server.issuer.clone();
189
190        let client = default_http_client()?;
191        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193        let response = pkce_login(
194            cancel_token,
195            PkceLoginRequest {
196                client_id: auth_server.client_id.clone(),
197                redirect_port: None,
198                discovery,
199                scopes: auth_server.scopes.clone(),
200            },
201        )
202        .await?;
203
204        Ok(Self {
205            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206            refresh_token: response
207                .refresh_token()
208                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209        })
210    }
211
212    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
213    ///
214    /// # Errors
215    ///
216    /// See [`TokenError`]
217    pub async fn request_access_token(
218        &mut self,
219        auth_server: &AuthServer,
220    ) -> Result<SecretAccessToken, TokenError> {
221        if insecure_validate_token_exp(&self.access_token).is_ok() {
222            return Ok(self.access_token.clone());
223        }
224
225        if let Some(refresh_token) = &mut self.refresh_token {
226            let access_token = refresh_token.request_access_token(auth_server).await?;
227            self.access_token.clone_from(&access_token);
228            return Ok(access_token);
229        }
230
231        Err(TokenError::NoRefreshToken)
232    }
233}
234
235impl From<PkceFlow> for Credential {
236    fn from(value: PkceFlow) -> Self {
237        let mut token_payload = TokenPayload::default();
238        token_payload.access_token = Some(value.access_token);
239        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241        Self {
242            token_payload: Some(token_payload),
243        }
244    }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
250/// needed to request said grant type.
251pub enum OAuthGrant {
252    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
253    RefreshToken(RefreshToken),
254    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
255    ClientCredentials(ClientCredentials),
256    /// Defers to a user provided function for access token requests.
257    ExternallyManaged(ExternallyManaged),
258    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
259    PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263    fn from(v: ExternallyManaged) -> Self {
264        Self::ExternallyManaged(v)
265    }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269    fn from(v: ClientCredentials) -> Self {
270        Self::ClientCredentials(v)
271    }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275    fn from(v: RefreshToken) -> Self {
276        Self::RefreshToken(v)
277    }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281    fn from(v: PkceFlow) -> Self {
282        Self::PkceFlow(v)
283    }
284}
285
286impl OAuthGrant {
287    /// Request a new access token from the given issuer using this grant type and payload.
288    async fn request_access_token(
289        &mut self,
290        auth_server: &AuthServer,
291    ) -> Result<SecretAccessToken, TokenError> {
292        match self {
293            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295            Self::ExternallyManaged(tokens) => tokens
296                .request_access_token(auth_server)
297                .await
298                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300        }
301    }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        match self {
307            Self::RefreshToken(_) => f.write_str("RefreshToken"),
308            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310            Self::PkceFlow(_) => f.write_str("PkceTokens"),
311        }
312    }
313}
314
315/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
316///
317/// This struct encapsulates the necessary information to request an access token
318/// from an authorization server, including the `OAuth2` grant type and any associated
319/// credentials or payload data.
320///
321/// # Fields
322///
323/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
324/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
325/// * `auth_server` - The authorization server responsible for issuing tokens.
326#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329    feature = "python",
330    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333    /// The grant type to use to request an access token.
334    payload: OAuthGrant,
335    /// The access token that is currently in use. None if no token has been requested yet.
336    access_token: Option<SecretAccessToken>,
337    /// The [`AuthServer`] that issues the tokens.
338    auth_server: AuthServer,
339}
340
341impl OAuthSession {
342    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
343    ///
344    /// Optionally include an `access_token`, if not included, then one can be requested
345    /// with [`Self::request_access_token`].
346    #[must_use]
347    pub const fn new(
348        payload: OAuthGrant,
349        auth_server: AuthServer,
350        access_token: Option<SecretAccessToken>,
351    ) -> Self {
352        Self {
353            payload,
354            access_token,
355            auth_server,
356        }
357    }
358
359    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
360    ///
361    /// Optionally include an `access_token`, if not included, then one can be requested
362    /// with [`Self::request_access_token`].
363    #[must_use]
364    pub const fn from_externally_managed(
365        tokens: ExternallyManaged,
366        auth_server: AuthServer,
367        access_token: Option<SecretAccessToken>,
368    ) -> Self {
369        Self::new(
370            OAuthGrant::ExternallyManaged(tokens),
371            auth_server,
372            access_token,
373        )
374    }
375
376    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
377    ///
378    /// Optionally include an `access_token`, if not included, then one can be requested
379    /// with [`Self::request_access_token`].
380    #[must_use]
381    pub const fn from_refresh_token(
382        tokens: RefreshToken,
383        auth_server: AuthServer,
384        access_token: Option<SecretAccessToken>,
385    ) -> Self {
386        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387    }
388
389    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
390    ///
391    /// Optionally include an `access_token`, if not included, then one can be requested
392    /// with [`Self::request_access_token`].
393    #[must_use]
394    pub const fn from_client_credentials(
395        tokens: ClientCredentials,
396        auth_server: AuthServer,
397        access_token: Option<SecretAccessToken>,
398    ) -> Self {
399        Self::new(
400            OAuthGrant::ClientCredentials(tokens),
401            auth_server,
402            access_token,
403        )
404    }
405
406    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
407    ///
408    /// Optionally include an `access_token`, if not included, then one can be requested
409    /// with [`Self::request_access_token`].
410    #[must_use]
411    pub const fn from_pkce_flow(
412        flow: PkceFlow,
413        auth_server: AuthServer,
414        access_token: Option<SecretAccessToken>,
415    ) -> Self {
416        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417    }
418
419    /// Get the current access token.
420    ///
421    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
422    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
423    ///
424    /// # Errors
425    ///
426    /// - [`TokenError::NoAccessToken`] if there is no access token
427    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429    }
430
431    /// Get the payload used to request an access token.
432    #[must_use]
433    pub const fn payload(&self) -> &OAuthGrant {
434        &self.payload
435    }
436
437    /// Request and return an updated access token using these credentials.
438    ///
439    /// # Errors
440    ///
441    /// See [`TokenError`]
442    #[allow(clippy::missing_panics_doc)]
443    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444        let access_token = self.payload.request_access_token(&self.auth_server).await?;
445        Ok(self.access_token.insert(access_token))
446    }
447
448    /// The [`AuthServer`] that issues the tokens.
449    #[must_use]
450    pub const fn auth_server(&self) -> &AuthServer {
451        &self.auth_server
452    }
453
454    /// Validate the access token, returning it if it is valid, or an error describing why it is
455    /// invalid.
456    ///
457    /// # Errors
458    ///
459    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
460    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
461    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462        let access_token = self.access_token()?;
463        insecure_validate_token_exp(access_token)?;
464        Ok(access_token.clone())
465    }
466}
467
468/// Validates the access token's format and `exp` claim, but no other claims or
469/// signature. We do this only to determine if the token is expired and needs refreshing,
470/// there is no way to securely validate the token's signature on the client side.
471pub(crate) fn insecure_validate_token_exp(
472    access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474    let placeholder_key = DecodingKey::from_secret(&[]);
475    let mut validation = Validation::new(Algorithm::RS256);
476    validation.validate_exp = true;
477    validation.leeway = 60;
478    validation.validate_aud = false;
479    validation.insecure_disable_signature_validation();
480
481    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482        .map(|_| ())
483        .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        let token_populated = if self.access_token.is_some() {
489            Some(())
490        } else {
491            None
492        };
493        f.debug_struct("OAuthSession")
494            .field("payload", &self.payload)
495            .field("access_token", &token_populated)
496            .field("auth_server", &self.auth_server)
497            .finish()
498    }
499}
500
501/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
502#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505    feature = "python",
506    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
507)]
508pub struct TokenDispatcher {
509    lock: Arc<RwLock<OAuthSession>>,
510    refreshing: Arc<Mutex<bool>>,
511    notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515    fn from(value: OAuthSession) -> Self {
516        Self {
517            lock: Arc::new(RwLock::new(value)),
518            refreshing: Arc::new(Mutex::new(false)),
519            notify_refreshed: Arc::new(Notify::new()),
520        }
521    }
522}
523
524impl TokenDispatcher {
525    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
526    /// dispatcher.
527    ///
528    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
529    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
530    ///
531    /// # Parameters
532    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
533    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
534    pub async fn use_tokens<F, O>(&self, f: F) -> O
535    where
536        F: FnOnce(&OAuthSession) -> O + Send,
537    {
538        let tokens = self.lock.read().await;
539        f(&tokens)
540    }
541
542    /// Get a copy of the current access token.
543    #[must_use]
544    pub async fn tokens(&self) -> OAuthSession {
545        self.use_tokens(Clone::clone).await
546    }
547
548    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
549    ///
550    /// # Errors
551    ///
552    /// See [`TokenError`]
553    pub async fn refresh(
554        &self,
555        source: &ConfigSource,
556        profile: &str,
557    ) -> Result<OAuthSession, TokenError> {
558        self.managed_refresh(Self::perform_refresh, source, profile)
559            .await
560    }
561
562    /// Validate the access token, returning it if it is valid, or an error describing why it is
563    /// invalid.
564    ///
565    /// # Errors
566    ///
567    /// - [`TokenError::NoAccessToken`] if there is no access token
568    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
569    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570        self.use_tokens(OAuthSession::validate).await
571    }
572
573    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
574    /// ``refresh_fn``.
575    async fn managed_refresh<F, Fut>(
576        &self,
577        refresh_fn: F,
578        source: &ConfigSource,
579        profile: &str,
580    ) -> Result<OAuthSession, TokenError>
581    where
582        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584    {
585        let mut is_refreshing = self.refreshing.lock().await;
586
587        if *is_refreshing {
588            drop(is_refreshing);
589            self.notify_refreshed.notified().await;
590            return Ok(self.tokens().await);
591        }
592
593        *is_refreshing = true;
594        drop(is_refreshing);
595
596        let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598        // If the config source is a file, write the new access token to the file
599        let write_result = if let ConfigSource::File {
600            settings_path: _,
601            secrets_path,
602        } = source
603        {
604            match Secrets::is_read_only(secrets_path).await {
605                Ok(true) => Ok(()),
606                Ok(false) => {
607                    // Persist the fresh refresh token if the grant carries one, so that a rotated
608                    // refresh token isn't lost on the next load. Both the PKCE and refresh-token
609                    // grants can hold a refresh token that the auth server may have rotated.
610                    let refresh_token = match &oauth_session.payload {
611                        OAuthGrant::PkceFlow(payload) => {
612                            payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
613                        }
614                        OAuthGrant::RefreshToken(payload) => Some(&payload.refresh_token),
615                        OAuthGrant::ExternallyManaged(_) | OAuthGrant::ClientCredentials(_) => None,
616                    };
617
618                    let now = OffsetDateTime::now_utc();
619                    Secrets::write_tokens(
620                        secrets_path,
621                        profile,
622                        refresh_token,
623                        oauth_session.access_token()?,
624                        now,
625                    )
626                    .await
627                }
628                Err(e) => Err(e),
629            }
630        } else {
631            Ok(())
632        };
633
634        // Always clean up the refreshing lock, even if write failed
635        *self.refreshing.lock().await = false;
636        self.notify_refreshed.notify_waiters();
637
638        // If write failed, return error with the valid oauth_session
639        if let Err(error) = write_result {
640            return Err(TokenError::Write {
641                error,
642                oauth_session: Box::new(oauth_session),
643            });
644        }
645
646        Ok(oauth_session)
647    }
648
649    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
650    /// of the updated [`Credentials`]
651    ///
652    /// # Errors
653    ///
654    /// See [`TokenError`]
655    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
656        let mut credentials = lock.write().await;
657        credentials.request_access_token().await?;
658        Ok(credentials.clone())
659    }
660}
661
662pub(crate) type RefreshResult =
663    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
664
665/// A function that asynchronously refreshes a token.
666pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
667
668/// A struct that manages access tokens by utilizing a user-provided refresh function.
669///
670/// The [`ExternallyManaged`] struct allows users to define custom logic for
671/// fetching or refreshing access tokens.
672#[derive(Clone)]
673#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
674#[cfg_attr(
675    feature = "python",
676    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
677)]
678pub struct ExternallyManaged {
679    refresh_function: Arc<RefreshFunction>,
680}
681
682impl ExternallyManaged {
683    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
684    ///
685    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
686    /// they better fit your use case.
687    ///
688    /// # Arguments
689    ///
690    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
691    ///
692    /// # Example
693    ///
694    /// ```
695    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
696    /// use std::future::Future;
697    /// use std::pin::Pin;
698    /// use std::boxed::Box;
699    /// use std::error::Error;
700    ///
701    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
702    /// + Send + Sync>> {
703    ///     Ok("new_token_value".to_string())
704    /// }
705    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
706    /// ```
707    pub fn new(
708        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
709    ) -> Self {
710        Self {
711            refresh_function: Arc::new(Box::new(refresh_function)),
712        }
713    }
714
715    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
716    ///
717    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
718    /// the boxing and pinning of the future internally.
719    ///
720    /// # Arguments
721    ///
722    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
723    ///   produces a [`Result<String, TokenError>`].
724    ///
725    /// # Example
726    ///
727    /// ```
728    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
729    /// use tokio::runtime::Runtime;
730    /// use std::error::Error;
731    ///
732    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
733    /// + Send + Sync>> {
734    ///     Ok("new_token_value".to_string())
735    /// }
736    ///
737    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
738    ///
739    /// let rt = Runtime::new().unwrap();
740    /// rt.block_on(async {
741    ///     match token_manager.request_access_token(&AuthServer::default()).await {
742    ///         Ok(token) => println!("Token: {token:?}"),
743    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
744    ///     }
745    /// });
746    /// ```
747    pub fn from_async<F, Fut>(refresh_function: F) -> Self
748    where
749        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
750        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
751            + Send
752            + 'static,
753    {
754        Self {
755            refresh_function: Arc::new(Box::new(move |auth_server| {
756                Box::pin(refresh_function(auth_server))
757            })),
758        }
759    }
760
761    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
762    ///
763    /// The synchronous function is wrapped in an async block to fit the expected signature.
764    ///
765    /// # Arguments
766    ///
767    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
768    ///
769    /// # Example
770    ///
771    /// ```
772    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
773    /// use tokio::runtime::Runtime;
774    /// use std::error::Error;
775    ///
776    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
777    /// + Send + Sync>> {
778    ///     Ok("sync_token_value".to_string())
779    /// }
780    ///
781    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
782    ///
783    /// let rt = Runtime::new().unwrap();
784    /// rt.block_on(async {
785    ///     match token_manager.request_access_token(&AuthServer::default()).await {
786    ///         Ok(token) => println!("Token: {token:?}"),
787    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
788    ///     }
789    /// });
790    /// ```
791    pub fn from_sync(
792        refresh_function: impl Fn(
793            AuthServer,
794        ) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
795        + Send
796        + Sync
797        + 'static,
798    ) -> Self {
799        Self {
800            refresh_function: Arc::new(Box::new(move |auth_server| {
801                let result = refresh_function(auth_server);
802                Box::pin(async move { result })
803            })),
804        }
805    }
806
807    /// Request an updated access token using the provided refresh function.
808    ///
809    /// # Errors
810    ///
811    /// Errors are propagated from the refresh function.
812    pub async fn request_access_token(
813        &self,
814        auth_server: &AuthServer,
815    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
816        (self.refresh_function)(auth_server.clone())
817            .await
818            .map(SecretAccessToken::from)
819    }
820}
821
822impl std::fmt::Debug for ExternallyManaged {
823    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
824        f.debug_struct("ExternallyManaged")
825            .field(
826                "refresh_function",
827                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
828            )
829            .finish()
830    }
831}
832
833#[derive(Debug, Serialize, Deserialize)]
834pub(super) struct TokenRefreshRequest<'a> {
835    grant_type: &'static str,
836    client_id: &'a str,
837    refresh_token: &'a str,
838}
839
840impl<'a> TokenRefreshRequest<'a> {
841    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
842        Self {
843            grant_type: "refresh_token",
844            client_id,
845            refresh_token,
846        }
847    }
848}
849
850#[derive(Debug, Serialize, Deserialize)]
851pub(super) struct ClientCredentialsRequest {
852    grant_type: &'static str,
853    scope: Option<&'static str>,
854}
855
856impl ClientCredentialsRequest {
857    pub(super) const fn new(scope: Option<&'static str>) -> Self {
858        Self {
859            grant_type: "client_credentials",
860            scope,
861        }
862    }
863}
864
865#[derive(Deserialize, Debug, Serialize)]
866pub(super) struct RefreshTokenResponse {
867    pub(super) refresh_token: Option<SecretRefreshToken>,
868    pub(super) access_token: SecretAccessToken,
869}
870
871/// Get and refresh access tokens
872#[async_trait::async_trait]
873pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
874    /// The type to be returned in the event of a error during getting or
875    /// refreshing an access token
876    type Error;
877
878    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
879    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
880
881    /// Get the current access token, if any
882    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
883
884    /// Get a fresh access token
885    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
886
887    /// Get the base URL for requests
888    #[cfg(feature = "tracing")]
889    fn base_url(&self) -> &str;
890
891    /// Get the tracing configuration
892    #[cfg(feature = "tracing-config")]
893    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
894
895    /// Returns whether the given URL should be traced. Following
896    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
897    #[cfg(feature = "tracing")]
898    #[allow(clippy::needless_return)]
899    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
900        #[cfg(not(feature = "tracing-config"))]
901        {
902            let _ = url;
903            return true;
904        }
905
906        #[cfg(feature = "tracing-config")]
907        self.tracing_configuration()
908            .is_none_or(|config| config.is_enabled(url))
909    }
910}
911
912#[async_trait::async_trait]
913impl TokenRefresher for ClientConfiguration {
914    type Error = TokenError;
915
916    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
917        self.get_bearer_access_token().await
918    }
919
920    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
921        match self.refresh().await {
922            Ok(session) => Ok(session.access_token()?.clone()),
923            Err(TokenError::Write {
924                error,
925                oauth_session,
926            }) => {
927                // Token refresh succeeded but persistence failed. Extract and return the access token from the error.
928                #[cfg(feature = "tracing")]
929                tracing::warn!(
930                    "Token refresh succeeded but failed to persist: {}. Returning access token from error.",
931                    error
932                );
933                Ok(oauth_session.access_token()?.clone())
934            }
935            Err(e) => Err(e),
936        }
937    }
938
939    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
940        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
941    }
942
943    #[cfg(feature = "tracing")]
944    fn base_url(&self) -> &str {
945        &self.grpc_api_url
946    }
947
948    #[cfg(feature = "tracing-config")]
949    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
950        self.tracing_configuration.as_ref()
951    }
952}
953
954/// Get a default http client.
955pub(super) fn default_http_client()
956-> Result<qcs_dependencies_client::reqwest::Client, qcs_dependencies_client::reqwest::Error> {
957    qcs_dependencies_client::reqwest::Client::builder()
958        .timeout(std::time::Duration::from_secs(10))
959        .build()
960}
961
962#[cfg(test)]
963mod test {
964    #![allow(clippy::result_large_err, reason = "happens in figment tests")]
965
966    use std::time::Duration;
967
968    use super::*;
969    use httpmock::prelude::*;
970    use rstest::rstest;
971    use time::format_description::well_known::Rfc3339;
972    use tokio::time::Instant;
973    use toml_edit::DocumentMut;
974
975    #[tokio::test]
976    async fn test_tokens_blocked_during_refresh() {
977        let mock_server = MockServer::start_async().await;
978
979        let oidc_mock = mock_server
980            .mock_async(|when, then| {
981                when.method(GET).path("/.well-known/openid-configuration");
982                then.status(200)
983                    .json_body_obj(&oidc::Discovery::new_for_test(
984                        mock_server.base_url().parse().unwrap(),
985                    ));
986            })
987            .await;
988
989        let issuer_mock = mock_server
990            .mock_async(|when, then| {
991                when.method(POST).path("/v1/token");
992
993                then.status(200)
994                    .delay(Duration::from_secs(3))
995                    .json_body_obj(&RefreshTokenResponse {
996                        access_token: SecretAccessToken::from("new_access"),
997                        refresh_token: Some(SecretRefreshToken::from("new_refresh")),
998                    });
999            })
1000            .await;
1001
1002        let original_tokens = OAuthSession::from_refresh_token(
1003            RefreshToken::new(SecretRefreshToken::from("refresh")),
1004            AuthServer {
1005                client_id: "client_id".to_string(),
1006                issuer: mock_server.base_url(),
1007                scopes: None,
1008            },
1009            None,
1010        );
1011        let dispatcher: TokenDispatcher = original_tokens.clone().into();
1012        let dispatcher_clone1 = dispatcher.clone();
1013        let dispatcher_clone2 = dispatcher.clone();
1014
1015        let refresh_duration = Duration::from_secs(3);
1016
1017        let start_write = Instant::now();
1018        let write_future = tokio::spawn(async move {
1019            dispatcher_clone1
1020                .refresh(&ConfigSource::Default, "")
1021                .await
1022                .unwrap()
1023        });
1024
1025        let start_read = Instant::now();
1026        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
1027
1028        let _ = write_future.await.unwrap();
1029        let read_result = read_future.await.unwrap();
1030
1031        let write_duration = start_write.elapsed();
1032        let read_duration = start_read.elapsed();
1033
1034        oidc_mock.assert_async().await;
1035        issuer_mock.assert_async().await;
1036
1037        assert!(
1038            write_duration >= refresh_duration,
1039            "Write operation did not take enough time"
1040        );
1041        assert!(
1042            read_duration >= refresh_duration,
1043            "Read operation was not blocked by the write operation"
1044        );
1045        assert_eq!(
1046            read_result.access_token.unwrap(),
1047            SecretAccessToken::from("new_access")
1048        );
1049        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1050            assert_eq!(
1051                payload.refresh_token,
1052                SecretRefreshToken::from("new_refresh")
1053            );
1054        } else {
1055            panic!(
1056                "Expected RefreshToken payload, got {:?}",
1057                read_result.payload
1058            );
1059        }
1060    }
1061
1062    #[rstest]
1063    fn test_qcs_secrets_readonly(
1064        #[values(
1065            (Some("TRUE"), true),
1066            (Some("tRue"), true),
1067            (Some("true"), true),
1068            (Some("YES"), true),
1069            (Some("yEs"), true),
1070            (Some("yes"), true),
1071            (Some("1"), true),
1072            (Some("2"), false),
1073            (Some("other"), false),
1074            (Some(""), false),
1075            (None, false),
1076        )]
1077        read_only_values: (Option<&str>, bool),
1078        #[values(true, false)] read_only_perm: bool,
1079    ) {
1080        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1081        let expected_update = !env_is_read_only && !read_only_perm;
1082        figment::Jail::expect_with(|jail| {
1083            let profile_name = "test";
1084            let initial_access_token = "initial_access_token";
1085            let initial_refresh_token = "initial_refresh_token";
1086
1087            let initial_secrets_file_contents = format!(
1088                r#"
1089[credentials]
1090[credentials.{profile_name}]
1091[credentials.{profile_name}.token_payload]
1092access_token = "{initial_access_token}"
1093expires_in = 3600
1094id_token = "id_token"
1095refresh_token = "{initial_refresh_token}"
1096scope = "offline_access openid profile email"
1097token_type = "Bearer"
1098updated_at = "2024-01-01T00:00:00Z"
1099"#
1100            );
1101
1102            // Ignore any existing environment variables.
1103            jail.clear_env();
1104
1105            // Create a temporary secrets file
1106            let secrets_path = "secrets.toml";
1107            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1108                .expect("should create test secrets.toml");
1109
1110            if read_only_perm {
1111                let mut permissions = std::fs::metadata(secrets_path)
1112                    .expect("Should be able to get file metadata")
1113                    .permissions();
1114                permissions.set_readonly(true);
1115                std::fs::set_permissions(secrets_path, permissions)
1116                    .expect("Should be able to set file permissions");
1117            }
1118
1119            let rt = tokio::runtime::Runtime::new().unwrap();
1120            rt.block_on(async {
1121                let mock_server = MockServer::start_async().await;
1122
1123                let oidc_mock = mock_server
1124                    .mock_async(|when, then| {
1125                        when.method(GET).path("/.well-known/openid-configuration");
1126                        then.status(200)
1127                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1128                    })
1129                    .await;
1130
1131                // Set up the mock token endpoint
1132                let new_access_token = SecretAccessToken::from("new_access_token");
1133                let issuer_mock = mock_server
1134                    .mock_async(|when, then| {
1135                        when.method(POST).path("/v1/token");
1136                        then.status(200).json_body_obj(&RefreshTokenResponse {
1137                            access_token: new_access_token.clone(),
1138                            refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1139                        });
1140                    })
1141                    .await;
1142
1143                // Create tokens and dispatcher
1144                let original_tokens = OAuthSession::from_refresh_token(
1145                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1146                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1147                    Some(SecretAccessToken::from(initial_refresh_token)),
1148                );
1149                let dispatcher: TokenDispatcher = original_tokens.into();
1150
1151                // Test with QCS_SECRETS_READ_ONLY set first
1152                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1153                jail.set_env("QCS_PROFILE_NAME", "test");
1154                if let Some(read_only_env) = maybe_read_only_env {
1155                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1156                }
1157
1158                let before_refresh = OffsetDateTime::now_utc();
1159
1160                dispatcher
1161                    .refresh(
1162                        &ConfigSource::File {
1163                            settings_path: "".into(),
1164                            secrets_path: "secrets.toml".into(),
1165                        },
1166                        profile_name,
1167                    )
1168                    .await
1169                    .unwrap();
1170
1171                oidc_mock.assert_async().await;
1172                issuer_mock.assert_async().await;
1173
1174                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1175                let content = std::fs::read_to_string("secrets.toml").unwrap();
1176                if !expected_update {
1177                    assert!(
1178                        content.eq(initial_secrets_file_contents.as_str()),
1179                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1180                    );
1181                    return;
1182                }
1183
1184                // Verify the file was updated
1185                let mut toml = std::fs::read_to_string(secrets_path)
1186                    .unwrap()
1187                    .parse::<DocumentMut>()
1188                    .unwrap();
1189
1190                let token_payload = toml
1191                    .get_mut("credentials")
1192                    .and_then(|credentials| {
1193                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1194                    })
1195                    .expect("Should be able to get token_payload table");
1196
1197                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1198
1199                assert_eq!(
1200                    access_token,
1201                    Some(new_access_token)
1202                );
1203
1204                assert!(
1205                    OffsetDateTime::parse(
1206                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1207                        &Rfc3339
1208                    )
1209                    .unwrap()
1210                        > before_refresh
1211                );
1212
1213                let content = std::fs::read_to_string("secrets.toml").unwrap();
1214                assert!(
1215                content.contains("new_access_token"),
1216                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1217                );
1218            });
1219            Ok(())
1220        });
1221    }
1222
1223    /// When the auth server rotates the refresh token, a [`OAuthGrant::RefreshToken`] grant should
1224    /// persist the new refresh token to the secrets file (not just the access token).
1225    #[test]
1226    fn test_refresh_token_grant_persists_rotated_refresh_token() {
1227        let initial_refresh_token = "initial_refresh_token";
1228        let rotated_refresh_token = "rotated_refresh_token";
1229        let new_access_token = "new_access_token";
1230
1231        figment::Jail::expect_with(|jail| {
1232            jail.clear_env();
1233
1234            let secrets_path = "secrets.toml";
1235            let initial_secrets_file_contents = format!(
1236                r#"
1237[credentials]
1238[credentials.test]
1239[credentials.test.token_payload]
1240access_token = "initial_access_token"
1241refresh_token = "{initial_refresh_token}"
1242updated_at = "2024-01-01T00:00:00Z"
1243"#
1244            );
1245            jail.create_file(secrets_path, &initial_secrets_file_contents)
1246                .expect("should create test secrets.toml");
1247
1248            let rt = tokio::runtime::Runtime::new().unwrap();
1249            rt.block_on(async {
1250                let mock_server = MockServer::start_async().await;
1251                let oidc_mock = mock_server
1252                    .mock_async(|when, then| {
1253                        when.method(GET).path("/.well-known/openid-configuration");
1254                        then.status(200)
1255                            .json_body_obj(&oidc::Discovery::new_for_test(
1256                                mock_server.base_url().parse().unwrap(),
1257                            ));
1258                    })
1259                    .await;
1260                let issuer_mock = mock_server
1261                    .mock_async(|when, then| {
1262                        when.method(POST).path("/v1/token");
1263                        then.status(200).json_body_obj(&RefreshTokenResponse {
1264                            access_token: SecretAccessToken::from(new_access_token),
1265                            refresh_token: Some(SecretRefreshToken::from(rotated_refresh_token)),
1266                        });
1267                    })
1268                    .await;
1269
1270                let dispatcher: TokenDispatcher = OAuthSession::from_refresh_token(
1271                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1272                    AuthServer {
1273                        client_id: "client_id".to_string(),
1274                        issuer: mock_server.base_url(),
1275                        scopes: None,
1276                    },
1277                    Some(SecretAccessToken::from("initial_access_token")),
1278                )
1279                .into();
1280
1281                dispatcher
1282                    .refresh(
1283                        &ConfigSource::File {
1284                            settings_path: "".into(),
1285                            secrets_path: secrets_path.into(),
1286                        },
1287                        "test",
1288                    )
1289                    .await
1290                    .expect("refresh should succeed");
1291
1292                oidc_mock.assert_async().await;
1293                issuer_mock.assert_async().await;
1294            });
1295
1296            // The rotated refresh token (and the new access token) should be persisted.
1297            let payload = Secrets::load_from_path(&secrets_path.into())
1298                .expect("should load secrets")
1299                .credentials
1300                .remove("test")
1301                .expect("should have test credentials")
1302                .token_payload
1303                .expect("should have token payload");
1304            assert_eq!(
1305                payload.refresh_token.unwrap(),
1306                SecretRefreshToken::from(rotated_refresh_token),
1307                "rotated refresh token should be persisted to the secrets file"
1308            );
1309            assert_eq!(
1310                payload.access_token.unwrap(),
1311                SecretAccessToken::from(new_access_token),
1312                "new access token should be persisted to the secrets file"
1313            );
1314
1315            Ok(())
1316        });
1317    }
1318
1319    #[test]
1320    fn test_auth_session_debug_fmt() {
1321        let session = OAuthSession {
1322            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1323                "hidden_id",
1324                "hidden_secret",
1325            )),
1326            access_token: Some(SecretAccessToken::from("token")),
1327            auth_server: AuthServer {
1328                client_id: "some_id".into(),
1329                issuer: "some_url".into(),
1330                scopes: None,
1331            },
1332        };
1333
1334        assert_eq!(
1335            "OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }",
1336            &format!("{session:?}")
1337        );
1338    }
1339}