Skip to main content

qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16    oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
17};
18use crate::configuration::{
19    error::DiscoveryError,
20    pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
21    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30/// A single type containing an access token and an associated refresh token.
31#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38    /// The token used to refresh the access token.
39    pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43    /// Create a new [`RefreshToken`] with the given refresh token.
44    #[must_use]
45    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46        Self { refresh_token }
47    }
48
49    /// Request and return a new access token from the given authorization server using this refresh token.
50    ///
51    /// # Errors
52    ///
53    /// See [`TokenError`]
54    pub async fn request_access_token(
55        &mut self,
56        auth_server: &AuthServer,
57    ) -> Result<SecretAccessToken, TokenError> {
58        if self.refresh_token.is_empty() {
59            return Err(TokenError::NoRefreshToken);
60        }
61
62        let client = default_http_client()?;
63        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64            .await?
65            .token_endpoint;
66        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67        let resp = client.post(token_url).form(&data).send().await?;
68
69        let RefreshTokenResponse {
70            access_token,
71            refresh_token,
72        } = resp.error_for_status()?.json().await?;
73
74        if let Some(refresh_token) = refresh_token {
75            self.refresh_token = refresh_token;
76        }
77        Ok(access_token)
78    }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83    pub(super) access_token: SecretAccessToken,
84}
85
86/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
87#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90    feature = "python",
91    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94    /// The client ID
95    pub client_id: String,
96    /// The client secret.
97    pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101    #[must_use]
102    /// Construct a new [`ClientCredentials`]
103    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104        Self {
105            client_id: client_id.into(),
106            client_secret: client_secret.into(),
107        }
108    }
109
110    /// Get the client ID.
111    #[must_use]
112    pub fn client_id(&self) -> &str {
113        &self.client_id
114    }
115
116    /// Get the client secret.
117    #[must_use]
118    pub const fn client_secret(&self) -> &ClientSecret {
119        &self.client_secret
120    }
121
122    /// Request and return an access token from the given auth server using this set of client credentials.
123    ///
124    /// # Errors
125    ///
126    /// See [`TokenError`]
127    pub async fn request_access_token(
128        &self,
129        auth_server: &AuthServer,
130    ) -> Result<SecretAccessToken, TokenError> {
131        let request = ClientCredentialsRequest::new(None);
132        let client = default_http_client()?;
133
134        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135            .await?
136            .token_endpoint;
137        let ready_to_send = client
138            .post(url)
139            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140            .form(&request);
141        let response = ready_to_send.send().await?;
142
143        response.error_for_status_ref()?;
144
145        let ClientCredentialsResponse { access_token } = response.json().await?;
146        Ok(access_token)
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153    feature = "python",
154    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
157pub struct PkceFlow {
158    /// The access token.
159    pub access_token: SecretAccessToken,
160    /// The refresh token, if available.
161    pub refresh_token: Option<RefreshToken>,
162}
163
164/// Errors that can occur when attempting to perform a PKCE login flow.
165#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167    /// Error that occurred while performing the PKCE login flow.
168    #[error(transparent)]
169    PkceLogin(#[from] PkceLoginError),
170    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
171    #[error(transparent)]
172    Discovery(#[from] DiscoveryError),
173    /// Error that occurred while making http requests.
174    #[error(transparent)]
175    Request(#[from] reqwest::Error),
176}
177
178impl PkceFlow {
179    /// Starts a new PKCE login flow to acquire a new set of tokens.
180    ///
181    /// # Errors
182    ///
183    /// See [`PkceFlowError`]
184    pub async fn new_login_flow(
185        cancel_token: CancellationToken,
186        auth_server: &AuthServer,
187    ) -> Result<Self, PkceFlowError> {
188        let issuer = auth_server.issuer.clone();
189
190        let client = default_http_client()?;
191        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193        let response = pkce_login(
194            cancel_token,
195            PkceLoginRequest {
196                client_id: auth_server.client_id.clone(),
197                redirect_port: None,
198                discovery,
199                scopes: auth_server.scopes.clone(),
200            },
201        )
202        .await?;
203
204        Ok(Self {
205            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206            refresh_token: response
207                .refresh_token()
208                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209        })
210    }
211
212    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
213    ///
214    /// # Errors
215    ///
216    /// See [`TokenError`]
217    pub async fn request_access_token(
218        &mut self,
219        auth_server: &AuthServer,
220    ) -> Result<SecretAccessToken, TokenError> {
221        if insecure_validate_token_exp(&self.access_token).is_ok() {
222            return Ok(self.access_token.clone());
223        }
224
225        if let Some(refresh_token) = &mut self.refresh_token {
226            let access_token = refresh_token.request_access_token(auth_server).await?;
227            self.access_token.clone_from(&access_token);
228            return Ok(access_token);
229        }
230
231        Err(TokenError::NoRefreshToken)
232    }
233}
234
235impl From<PkceFlow> for Credential {
236    fn from(value: PkceFlow) -> Self {
237        let mut token_payload = TokenPayload::default();
238        token_payload.access_token = Some(value.access_token);
239        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241        Self {
242            token_payload: Some(token_payload),
243        }
244    }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
250/// needed to request said grant type.
251pub enum OAuthGrant {
252    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
253    RefreshToken(RefreshToken),
254    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
255    ClientCredentials(ClientCredentials),
256    /// Defers to a user provided function for access token requests.
257    ExternallyManaged(ExternallyManaged),
258    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
259    PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263    fn from(v: ExternallyManaged) -> Self {
264        Self::ExternallyManaged(v)
265    }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269    fn from(v: ClientCredentials) -> Self {
270        Self::ClientCredentials(v)
271    }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275    fn from(v: RefreshToken) -> Self {
276        Self::RefreshToken(v)
277    }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281    fn from(v: PkceFlow) -> Self {
282        Self::PkceFlow(v)
283    }
284}
285
286impl OAuthGrant {
287    /// Request a new access token from the given issuer using this grant type and payload.
288    async fn request_access_token(
289        &mut self,
290        auth_server: &AuthServer,
291    ) -> Result<SecretAccessToken, TokenError> {
292        match self {
293            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295            Self::ExternallyManaged(tokens) => tokens
296                .request_access_token(auth_server)
297                .await
298                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300        }
301    }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        match self {
307            Self::RefreshToken(_) => f.write_str("RefreshToken"),
308            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310            Self::PkceFlow(_) => f.write_str("PkceTokens"),
311        }
312    }
313}
314
315/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
316///
317/// This struct encapsulates the necessary information to request an access token
318/// from an authorization server, including the `OAuth2` grant type and any associated
319/// credentials or payload data.
320///
321/// # Fields
322///
323/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
324/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
325/// * `auth_server` - The authorization server responsible for issuing tokens.
326#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329    feature = "python",
330    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333    /// The grant type to use to request an access token.
334    payload: OAuthGrant,
335    /// The access token that is currently in use. None if no token has been requested yet.
336    access_token: Option<SecretAccessToken>,
337    /// The [`AuthServer`] that issues the tokens.
338    auth_server: AuthServer,
339}
340
341impl OAuthSession {
342    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
343    ///
344    /// Optionally include an `access_token`, if not included, then one can be requested
345    /// with [`Self::request_access_token`].
346    #[must_use]
347    pub const fn new(
348        payload: OAuthGrant,
349        auth_server: AuthServer,
350        access_token: Option<SecretAccessToken>,
351    ) -> Self {
352        Self {
353            payload,
354            access_token,
355            auth_server,
356        }
357    }
358
359    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
360    ///
361    /// Optionally include an `access_token`, if not included, then one can be requested
362    /// with [`Self::request_access_token`].
363    #[must_use]
364    pub const fn from_externally_managed(
365        tokens: ExternallyManaged,
366        auth_server: AuthServer,
367        access_token: Option<SecretAccessToken>,
368    ) -> Self {
369        Self::new(
370            OAuthGrant::ExternallyManaged(tokens),
371            auth_server,
372            access_token,
373        )
374    }
375
376    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
377    ///
378    /// Optionally include an `access_token`, if not included, then one can be requested
379    /// with [`Self::request_access_token`].
380    #[must_use]
381    pub const fn from_refresh_token(
382        tokens: RefreshToken,
383        auth_server: AuthServer,
384        access_token: Option<SecretAccessToken>,
385    ) -> Self {
386        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387    }
388
389    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
390    ///
391    /// Optionally include an `access_token`, if not included, then one can be requested
392    /// with [`Self::request_access_token`].
393    #[must_use]
394    pub const fn from_client_credentials(
395        tokens: ClientCredentials,
396        auth_server: AuthServer,
397        access_token: Option<SecretAccessToken>,
398    ) -> Self {
399        Self::new(
400            OAuthGrant::ClientCredentials(tokens),
401            auth_server,
402            access_token,
403        )
404    }
405
406    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
407    ///
408    /// Optionally include an `access_token`, if not included, then one can be requested
409    /// with [`Self::request_access_token`].
410    #[must_use]
411    pub const fn from_pkce_flow(
412        flow: PkceFlow,
413        auth_server: AuthServer,
414        access_token: Option<SecretAccessToken>,
415    ) -> Self {
416        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417    }
418
419    /// Get the current access token.
420    ///
421    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
422    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
423    ///
424    /// # Errors
425    ///
426    /// - [`TokenError::NoAccessToken`] if there is no access token
427    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429    }
430
431    /// Get the payload used to request an access token.
432    #[must_use]
433    pub const fn payload(&self) -> &OAuthGrant {
434        &self.payload
435    }
436
437    /// Request and return an updated access token using these credentials.
438    ///
439    /// # Errors
440    ///
441    /// See [`TokenError`]
442    #[allow(clippy::missing_panics_doc)]
443    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444        let access_token = self.payload.request_access_token(&self.auth_server).await?;
445        Ok(self.access_token.insert(access_token))
446    }
447
448    /// The [`AuthServer`] that issues the tokens.
449    #[must_use]
450    pub const fn auth_server(&self) -> &AuthServer {
451        &self.auth_server
452    }
453
454    /// Validate the access token, returning it if it is valid, or an error describing why it is
455    /// invalid.
456    ///
457    /// # Errors
458    ///
459    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
460    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
461    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462        let access_token = self.access_token()?;
463        insecure_validate_token_exp(access_token)?;
464        Ok(access_token.clone())
465    }
466}
467
468/// Validates the access token's format and `exp` claim, but no other claims or
469/// signature. We do this only to determine if the token is expired and needs refreshing,
470/// there is no way to securely validate the token's signature on the client side.
471pub(crate) fn insecure_validate_token_exp(
472    access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474    let placeholder_key = DecodingKey::from_secret(&[]);
475    let mut validation = Validation::new(Algorithm::RS256);
476    validation.validate_exp = true;
477    validation.leeway = 60;
478    validation.validate_aud = false;
479    validation.insecure_disable_signature_validation();
480
481    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482        .map(|_| ())
483        .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        let token_populated = if self.access_token.is_some() {
489            Some(())
490        } else {
491            None
492        };
493        f.debug_struct("OAuthSession")
494            .field("payload", &self.payload)
495            .field("access_token", &token_populated)
496            .field("auth_server", &self.auth_server)
497            .finish()
498    }
499}
500
501/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
502#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505    feature = "python",
506    pyo3::pyclass(module = "qcs_api_client_common.configuration")
507)]
508pub struct TokenDispatcher {
509    lock: Arc<RwLock<OAuthSession>>,
510    refreshing: Arc<Mutex<bool>>,
511    notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515    fn from(value: OAuthSession) -> Self {
516        Self {
517            lock: Arc::new(RwLock::new(value)),
518            refreshing: Arc::new(Mutex::new(false)),
519            notify_refreshed: Arc::new(Notify::new()),
520        }
521    }
522}
523
524impl TokenDispatcher {
525    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
526    /// dispatcher.
527    ///
528    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
529    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
530    ///
531    /// # Parameters
532    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
533    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
534    pub async fn use_tokens<F, O>(&self, f: F) -> O
535    where
536        F: FnOnce(&OAuthSession) -> O + Send,
537    {
538        let tokens = self.lock.read().await;
539        f(&tokens)
540    }
541
542    /// Get a copy of the current access token.
543    #[must_use]
544    pub async fn tokens(&self) -> OAuthSession {
545        self.use_tokens(Clone::clone).await
546    }
547
548    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
549    ///
550    /// # Errors
551    ///
552    /// See [`TokenError`]
553    pub async fn refresh(
554        &self,
555        source: &ConfigSource,
556        profile: &str,
557    ) -> Result<OAuthSession, TokenError> {
558        self.managed_refresh(Self::perform_refresh, source, profile)
559            .await
560    }
561
562    /// Validate the access token, returning it if it is valid, or an error describing why it is
563    /// invalid.
564    ///
565    /// # Errors
566    ///
567    /// - [`TokenError::NoAccessToken`] if there is no access token
568    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
569    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570        self.use_tokens(OAuthSession::validate).await
571    }
572
573    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
574    /// ``refresh_fn``.
575    async fn managed_refresh<F, Fut>(
576        &self,
577        refresh_fn: F,
578        source: &ConfigSource,
579        profile: &str,
580    ) -> Result<OAuthSession, TokenError>
581    where
582        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584    {
585        let mut is_refreshing = self.refreshing.lock().await;
586
587        if *is_refreshing {
588            drop(is_refreshing);
589            self.notify_refreshed.notified().await;
590            return Ok(self.tokens().await);
591        }
592
593        *is_refreshing = true;
594        drop(is_refreshing);
595
596        let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598        // If the config source is a file, write the new access token to the file
599        if let ConfigSource::File {
600            settings_path: _,
601            secrets_path,
602        } = source
603        {
604            if !Secrets::is_read_only(secrets_path).await? {
605                // If the payload is a PkceFlow, write the fresh refresh token if available.
606                let refresh_token = match &oauth_session.payload {
607                    OAuthGrant::PkceFlow(payload) => {
608                        payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
609                    }
610                    _ => None,
611                };
612
613                let now = OffsetDateTime::now_utc();
614                Secrets::write_tokens(
615                    secrets_path,
616                    profile,
617                    refresh_token,
618                    oauth_session.access_token()?,
619                    now,
620                )
621                .await?;
622            }
623        }
624
625        *self.refreshing.lock().await = false;
626        self.notify_refreshed.notify_waiters();
627        Ok(oauth_session)
628    }
629
630    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
631    /// of the updated [`Credentials`]
632    ///
633    /// # Errors
634    ///
635    /// See [`TokenError`]
636    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
637        let mut credentials = lock.write().await;
638        credentials.request_access_token().await?;
639        Ok(credentials.clone())
640    }
641}
642
643pub(crate) type RefreshResult =
644    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
645
646/// A function that asynchronously refreshes a token.
647pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
648
649/// A struct that manages access tokens by utilizing a user-provided refresh function.
650///
651/// The [`ExternallyManaged`] struct allows users to define custom logic for
652/// fetching or refreshing access tokens.
653#[derive(Clone)]
654#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
655#[cfg_attr(
656    feature = "python",
657    pyo3::pyclass(module = "qcs_api_client_common.configuration")
658)]
659pub struct ExternallyManaged {
660    refresh_function: Arc<RefreshFunction>,
661}
662
663impl ExternallyManaged {
664    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
665    ///
666    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
667    /// they better fit your use case.
668    ///
669    /// # Arguments
670    ///
671    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
672    ///
673    /// # Example
674    ///
675    /// ```
676    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
677    /// use std::future::Future;
678    /// use std::pin::Pin;
679    /// use std::boxed::Box;
680    /// use std::error::Error;
681    ///
682    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
683    /// + Send + Sync>> {
684    ///     Ok("new_token_value".to_string())
685    /// }
686    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
687    /// ```
688    pub fn new(
689        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
690    ) -> Self {
691        Self {
692            refresh_function: Arc::new(Box::new(refresh_function)),
693        }
694    }
695
696    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
697    ///
698    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
699    /// the boxing and pinning of the future internally.
700    ///
701    /// # Arguments
702    ///
703    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
704    ///   produces a [`Result<String, TokenError>`].
705    ///
706    /// # Example
707    ///
708    /// ```
709    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
710    /// use tokio::runtime::Runtime;
711    /// use std::error::Error;
712    ///
713    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
714    /// + Send + Sync>> {
715    ///     Ok("new_token_value".to_string())
716    /// }
717    ///
718    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
719    ///
720    /// let rt = Runtime::new().unwrap();
721    /// rt.block_on(async {
722    ///     match token_manager.request_access_token(&AuthServer::default()).await {
723    ///         Ok(token) => println!("Token: {token:?}"),
724    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
725    ///     }
726    /// });
727    /// ```
728    pub fn from_async<F, Fut>(refresh_function: F) -> Self
729    where
730        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
731        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
732            + Send
733            + 'static,
734    {
735        Self {
736            refresh_function: Arc::new(Box::new(move |auth_server| {
737                Box::pin(refresh_function(auth_server))
738            })),
739        }
740    }
741
742    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
743    ///
744    /// The synchronous function is wrapped in an async block to fit the expected signature.
745    ///
746    /// # Arguments
747    ///
748    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
749    ///
750    /// # Example
751    ///
752    /// ```
753    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
754    /// use tokio::runtime::Runtime;
755    /// use std::error::Error;
756    ///
757    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
758    /// + Send + Sync>> {
759    ///     Ok("sync_token_value".to_string())
760    /// }
761    ///
762    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
763    ///
764    /// let rt = Runtime::new().unwrap();
765    /// rt.block_on(async {
766    ///     match token_manager.request_access_token(&AuthServer::default()).await {
767    ///         Ok(token) => println!("Token: {token:?}"),
768    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
769    ///     }
770    /// });
771    /// ```
772    pub fn from_sync(
773        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
774            + Send
775            + Sync
776            + 'static,
777    ) -> Self {
778        Self {
779            refresh_function: Arc::new(Box::new(move |auth_server| {
780                let result = refresh_function(auth_server);
781                Box::pin(async move { result })
782            })),
783        }
784    }
785
786    /// Request an updated access token using the provided refresh function.
787    ///
788    /// # Errors
789    ///
790    /// Errors are propagated from the refresh function.
791    pub async fn request_access_token(
792        &self,
793        auth_server: &AuthServer,
794    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
795        (self.refresh_function)(auth_server.clone())
796            .await
797            .map(SecretAccessToken::from)
798    }
799}
800
801impl std::fmt::Debug for ExternallyManaged {
802    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
803        f.debug_struct("ExternallyManaged")
804            .field(
805                "refresh_function",
806                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
807            )
808            .finish()
809    }
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813pub(super) struct TokenRefreshRequest<'a> {
814    grant_type: &'static str,
815    client_id: &'a str,
816    refresh_token: &'a str,
817}
818
819impl<'a> TokenRefreshRequest<'a> {
820    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
821        Self {
822            grant_type: "refresh_token",
823            client_id,
824            refresh_token,
825        }
826    }
827}
828
829#[derive(Debug, Serialize, Deserialize)]
830pub(super) struct ClientCredentialsRequest {
831    grant_type: &'static str,
832    scope: Option<&'static str>,
833}
834
835impl ClientCredentialsRequest {
836    pub(super) const fn new(scope: Option<&'static str>) -> Self {
837        Self {
838            grant_type: "client_credentials",
839            scope,
840        }
841    }
842}
843
844#[derive(Deserialize, Debug, Serialize)]
845pub(super) struct RefreshTokenResponse {
846    pub(super) refresh_token: Option<SecretRefreshToken>,
847    pub(super) access_token: SecretAccessToken,
848}
849
850/// Get and refresh access tokens
851#[async_trait::async_trait]
852pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
853    /// The type to be returned in the event of a error during getting or
854    /// refreshing an access token
855    type Error;
856
857    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
858    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
859
860    /// Get the current access token, if any
861    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
862
863    /// Get a fresh access token
864    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
865
866    /// Get the base URL for requests
867    #[cfg(feature = "tracing")]
868    fn base_url(&self) -> &str;
869
870    /// Get the tracing configuration
871    #[cfg(feature = "tracing-config")]
872    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
873
874    /// Returns whether the given URL should be traced. Following
875    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
876    #[cfg(feature = "tracing")]
877    #[allow(clippy::needless_return)]
878    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
879        #[cfg(not(feature = "tracing-config"))]
880        {
881            let _ = url;
882            return true;
883        }
884
885        #[cfg(feature = "tracing-config")]
886        self.tracing_configuration()
887            .is_none_or(|config| config.is_enabled(url))
888    }
889}
890
891#[async_trait::async_trait]
892impl TokenRefresher for ClientConfiguration {
893    type Error = TokenError;
894
895    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
896        self.get_bearer_access_token().await
897    }
898
899    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
900        Ok(self.refresh().await?.access_token()?.clone())
901    }
902
903    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
904        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
905    }
906
907    #[cfg(feature = "tracing")]
908    fn base_url(&self) -> &str {
909        &self.grpc_api_url
910    }
911
912    #[cfg(feature = "tracing-config")]
913    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
914        self.tracing_configuration.as_ref()
915    }
916}
917
918/// Get a default http client.
919pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
920    reqwest::Client::builder()
921        .timeout(std::time::Duration::from_secs(10))
922        .build()
923}
924
925#[cfg(test)]
926mod test {
927    #![allow(clippy::result_large_err, reason = "happens in figment tests")]
928
929    use std::time::Duration;
930
931    use super::*;
932    use httpmock::prelude::*;
933    use rstest::rstest;
934    use time::format_description::well_known::Rfc3339;
935    use tokio::time::Instant;
936    use toml_edit::DocumentMut;
937
938    #[tokio::test]
939    async fn test_tokens_blocked_during_refresh() {
940        let mock_server = MockServer::start_async().await;
941
942        let oidc_mock = mock_server
943            .mock_async(|when, then| {
944                when.method(GET).path("/.well-known/openid-configuration");
945                then.status(200)
946                    .json_body_obj(&oidc::Discovery::new_for_test(
947                        mock_server.base_url().parse().unwrap(),
948                    ));
949            })
950            .await;
951
952        let issuer_mock = mock_server
953            .mock_async(|when, then| {
954                when.method(POST).path("/v1/token");
955
956                then.status(200)
957                    .delay(Duration::from_secs(3))
958                    .json_body_obj(&RefreshTokenResponse {
959                        access_token: SecretAccessToken::from("new_access"),
960                        refresh_token: Some(SecretRefreshToken::from("new_refresh")),
961                    });
962            })
963            .await;
964
965        let original_tokens = OAuthSession::from_refresh_token(
966            RefreshToken::new(SecretRefreshToken::from("refresh")),
967            AuthServer {
968                client_id: "client_id".to_string(),
969                issuer: mock_server.base_url(),
970                scopes: None,
971            },
972            None,
973        );
974        let dispatcher: TokenDispatcher = original_tokens.clone().into();
975        let dispatcher_clone1 = dispatcher.clone();
976        let dispatcher_clone2 = dispatcher.clone();
977
978        let refresh_duration = Duration::from_secs(3);
979
980        let start_write = Instant::now();
981        let write_future = tokio::spawn(async move {
982            dispatcher_clone1
983                .refresh(&ConfigSource::Default, "")
984                .await
985                .unwrap()
986        });
987
988        let start_read = Instant::now();
989        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
990
991        let _ = write_future.await.unwrap();
992        let read_result = read_future.await.unwrap();
993
994        let write_duration = start_write.elapsed();
995        let read_duration = start_read.elapsed();
996
997        oidc_mock.assert_async().await;
998        issuer_mock.assert_async().await;
999
1000        assert!(
1001            write_duration >= refresh_duration,
1002            "Write operation did not take enough time"
1003        );
1004        assert!(
1005            read_duration >= refresh_duration,
1006            "Read operation was not blocked by the write operation"
1007        );
1008        assert_eq!(
1009            read_result.access_token.unwrap(),
1010            SecretAccessToken::from("new_access")
1011        );
1012        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1013            assert_eq!(
1014                payload.refresh_token,
1015                SecretRefreshToken::from("new_refresh")
1016            );
1017        } else {
1018            panic!(
1019                "Expected RefreshToken payload, got {:?}",
1020                read_result.payload
1021            );
1022        }
1023    }
1024
1025    #[rstest]
1026    fn test_qcs_secrets_readonly(
1027        #[values(
1028            (Some("TRUE"), true),
1029            (Some("tRue"), true),
1030            (Some("true"), true),
1031            (Some("YES"), true),
1032            (Some("yEs"), true),
1033            (Some("yes"), true),
1034            (Some("1"), true),
1035            (Some("2"), false),
1036            (Some("other"), false),
1037            (Some(""), false),
1038            (None, false),
1039        )]
1040        read_only_values: (Option<&str>, bool),
1041        #[values(true, false)] read_only_perm: bool,
1042    ) {
1043        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1044        let expected_update = !env_is_read_only && !read_only_perm;
1045        figment::Jail::expect_with(|jail| {
1046            let profile_name = "test";
1047            let initial_access_token = "initial_access_token";
1048            let initial_refresh_token = "initial_refresh_token";
1049
1050            let initial_secrets_file_contents = format!(
1051                r#"
1052[credentials]
1053[credentials.{profile_name}]
1054[credentials.{profile_name}.token_payload]
1055access_token = "{initial_access_token}"
1056expires_in = 3600
1057id_token = "id_token"
1058refresh_token = "{initial_refresh_token}"
1059scope = "offline_access openid profile email"
1060token_type = "Bearer"
1061updated_at = "2024-01-01T00:00:00Z"
1062"#
1063            );
1064
1065            // Ignore any existing environment variables.
1066            jail.clear_env();
1067
1068            // Create a temporary secrets file
1069            let secrets_path = "secrets.toml";
1070            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1071                .expect("should create test secrets.toml");
1072
1073            if read_only_perm {
1074                let mut permissions = std::fs::metadata(secrets_path)
1075                    .expect("Should be able to get file metadata")
1076                    .permissions();
1077                permissions.set_readonly(true);
1078                std::fs::set_permissions(secrets_path, permissions)
1079                    .expect("Should be able to set file permissions");
1080            }
1081
1082            let rt = tokio::runtime::Runtime::new().unwrap();
1083            rt.block_on(async {
1084                let mock_server = MockServer::start_async().await;
1085
1086                let oidc_mock = mock_server
1087                    .mock_async(|when, then| {
1088                        when.method(GET).path("/.well-known/openid-configuration");
1089                        then.status(200)
1090                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1091                    })
1092                    .await;
1093
1094                // Set up the mock token endpoint
1095                let new_access_token = SecretAccessToken::from("new_access_token");
1096                let issuer_mock = mock_server
1097                    .mock_async(|when, then| {
1098                        when.method(POST).path("/v1/token");
1099                        then.status(200).json_body_obj(&RefreshTokenResponse {
1100                            access_token: new_access_token.clone(),
1101                            refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1102                        });
1103                    })
1104                    .await;
1105
1106                // Create tokens and dispatcher
1107                let original_tokens = OAuthSession::from_refresh_token(
1108                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1109                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1110                    Some(SecretAccessToken::from(initial_refresh_token)),
1111                );
1112                let dispatcher: TokenDispatcher = original_tokens.into();
1113
1114                // Test with QCS_SECRETS_READ_ONLY set first
1115                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1116                jail.set_env("QCS_PROFILE_NAME", "test");
1117                if let Some(read_only_env) = maybe_read_only_env {
1118                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1119                }
1120
1121                let before_refresh = OffsetDateTime::now_utc();
1122
1123                dispatcher
1124                    .refresh(
1125                        &ConfigSource::File {
1126                            settings_path: "".into(),
1127                            secrets_path: "secrets.toml".into(),
1128                        },
1129                        profile_name,
1130                    )
1131                    .await
1132                    .unwrap();
1133
1134                oidc_mock.assert_async().await;
1135                issuer_mock.assert_async().await;
1136
1137                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1138                let content = std::fs::read_to_string("secrets.toml").unwrap();
1139                if !expected_update {
1140                    assert!(
1141                        content.eq(initial_secrets_file_contents.as_str()),
1142                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1143                    );
1144                    return;
1145                }
1146
1147                // Verify the file was updated
1148                let mut toml = std::fs::read_to_string(secrets_path)
1149                    .unwrap()
1150                    .parse::<DocumentMut>()
1151                    .unwrap();
1152
1153                let token_payload = toml
1154                    .get_mut("credentials")
1155                    .and_then(|credentials| {
1156                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1157                    })
1158                    .expect("Should be able to get token_payload table");
1159
1160                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1161
1162                assert_eq!(
1163                    access_token,
1164                    Some(new_access_token)
1165                );
1166
1167                assert!(
1168                    OffsetDateTime::parse(
1169                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1170                        &Rfc3339
1171                    )
1172                    .unwrap()
1173                        > before_refresh
1174                );
1175
1176                let content = std::fs::read_to_string("secrets.toml").unwrap();
1177                assert!(
1178                content.contains("new_access_token"),
1179                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1180                );
1181            });
1182            Ok(())
1183        });
1184    }
1185
1186    #[test]
1187    fn test_auth_session_debug_fmt() {
1188        let session = OAuthSession {
1189            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1190                "hidden_id",
1191                "hidden_secret",
1192            )),
1193            access_token: Some(SecretAccessToken::from("token")),
1194            auth_server: AuthServer {
1195                client_id: "some_id".into(),
1196                issuer: "some_url".into(),
1197                scopes: None,
1198            },
1199        };
1200
1201        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1202    }
1203}