Skip to main content

qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16    oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
17};
18use crate::configuration::{
19    error::DiscoveryError,
20    pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
21    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30/// A single type containing an access token and an associated refresh token.
31#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38    /// The token used to refresh the access token.
39    pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43    /// Create a new [`RefreshToken`] with the given refresh token.
44    #[must_use]
45    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46        Self { refresh_token }
47    }
48
49    /// Request and return a new access token from the given authorization server using this refresh token.
50    ///
51    /// # Errors
52    ///
53    /// See [`TokenError`]
54    pub async fn request_access_token(
55        &mut self,
56        auth_server: &AuthServer,
57    ) -> Result<SecretAccessToken, TokenError> {
58        if self.refresh_token.is_empty() {
59            return Err(TokenError::NoRefreshToken);
60        }
61
62        let client = default_http_client()?;
63        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64            .await?
65            .token_endpoint;
66        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67        let resp = client.post(token_url).form(&data).send().await?;
68
69        let RefreshTokenResponse {
70            access_token,
71            refresh_token,
72        } = resp.error_for_status()?.json().await?;
73
74        if let Some(refresh_token) = refresh_token {
75            self.refresh_token = refresh_token;
76        }
77        Ok(access_token)
78    }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83    pub(super) access_token: SecretAccessToken,
84}
85
86/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
87#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90    feature = "python",
91    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94    /// The client ID
95    pub client_id: String,
96    /// The client secret.
97    pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101    #[must_use]
102    /// Construct a new [`ClientCredentials`]
103    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104        Self {
105            client_id: client_id.into(),
106            client_secret: client_secret.into(),
107        }
108    }
109
110    /// Get the client ID.
111    #[must_use]
112    pub fn client_id(&self) -> &str {
113        &self.client_id
114    }
115
116    /// Get the client secret.
117    #[must_use]
118    pub const fn client_secret(&self) -> &ClientSecret {
119        &self.client_secret
120    }
121
122    /// Request and return an access token from the given auth server using this set of client credentials.
123    ///
124    /// # Errors
125    ///
126    /// See [`TokenError`]
127    pub async fn request_access_token(
128        &self,
129        auth_server: &AuthServer,
130    ) -> Result<SecretAccessToken, TokenError> {
131        let request = ClientCredentialsRequest::new(None);
132        let client = default_http_client()?;
133
134        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135            .await?
136            .token_endpoint;
137        let ready_to_send = client
138            .post(url)
139            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140            .form(&request);
141        let response = ready_to_send.send().await?;
142
143        response.error_for_status_ref()?;
144
145        let ClientCredentialsResponse { access_token } = response.json().await?;
146        Ok(access_token)
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153    feature = "python",
154    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
157pub struct PkceFlow {
158    /// The access token.
159    pub access_token: SecretAccessToken,
160    /// The refresh token, if available.
161    pub refresh_token: Option<RefreshToken>,
162}
163
164/// Errors that can occur when attempting to perform a PKCE login flow.
165#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167    /// Error that occurred while performing the PKCE login flow.
168    #[error(transparent)]
169    PkceLogin(#[from] PkceLoginError),
170    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
171    #[error(transparent)]
172    Discovery(#[from] DiscoveryError),
173    /// Error that occurred while making http requests.
174    #[error(transparent)]
175    Request(#[from] reqwest::Error),
176}
177
178impl PkceFlow {
179    /// Starts a new PKCE login flow to acquire a new set of tokens.
180    ///
181    /// # Errors
182    ///
183    /// See [`PkceFlowError`]
184    pub async fn new_login_flow(
185        cancel_token: CancellationToken,
186        auth_server: &AuthServer,
187    ) -> Result<Self, PkceFlowError> {
188        let issuer = auth_server.issuer.clone();
189
190        let client = default_http_client()?;
191        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193        let response = pkce_login(
194            cancel_token,
195            PkceLoginRequest {
196                client_id: auth_server.client_id.clone(),
197                redirect_port: None,
198                discovery,
199                scopes: auth_server.scopes.clone(),
200            },
201        )
202        .await?;
203
204        Ok(Self {
205            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206            refresh_token: response
207                .refresh_token()
208                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209        })
210    }
211
212    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
213    ///
214    /// # Errors
215    ///
216    /// See [`TokenError`]
217    pub async fn request_access_token(
218        &mut self,
219        auth_server: &AuthServer,
220    ) -> Result<SecretAccessToken, TokenError> {
221        if insecure_validate_token_exp(&self.access_token).is_ok() {
222            return Ok(self.access_token.clone());
223        }
224
225        if let Some(refresh_token) = &mut self.refresh_token {
226            let access_token = refresh_token.request_access_token(auth_server).await?;
227            self.access_token.clone_from(&access_token);
228            return Ok(access_token);
229        }
230
231        Err(TokenError::NoRefreshToken)
232    }
233}
234
235impl From<PkceFlow> for Credential {
236    fn from(value: PkceFlow) -> Self {
237        let mut token_payload = TokenPayload::default();
238        token_payload.access_token = Some(value.access_token);
239        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241        Self {
242            token_payload: Some(token_payload),
243        }
244    }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
250/// needed to request said grant type.
251pub enum OAuthGrant {
252    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
253    RefreshToken(RefreshToken),
254    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
255    ClientCredentials(ClientCredentials),
256    /// Defers to a user provided function for access token requests.
257    ExternallyManaged(ExternallyManaged),
258    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
259    PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263    fn from(v: ExternallyManaged) -> Self {
264        Self::ExternallyManaged(v)
265    }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269    fn from(v: ClientCredentials) -> Self {
270        Self::ClientCredentials(v)
271    }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275    fn from(v: RefreshToken) -> Self {
276        Self::RefreshToken(v)
277    }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281    fn from(v: PkceFlow) -> Self {
282        Self::PkceFlow(v)
283    }
284}
285
286impl OAuthGrant {
287    /// Request a new access token from the given issuer using this grant type and payload.
288    async fn request_access_token(
289        &mut self,
290        auth_server: &AuthServer,
291    ) -> Result<SecretAccessToken, TokenError> {
292        match self {
293            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295            Self::ExternallyManaged(tokens) => tokens
296                .request_access_token(auth_server)
297                .await
298                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300        }
301    }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        match self {
307            Self::RefreshToken(_) => f.write_str("RefreshToken"),
308            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310            Self::PkceFlow(_) => f.write_str("PkceTokens"),
311        }
312    }
313}
314
315/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
316///
317/// This struct encapsulates the necessary information to request an access token
318/// from an authorization server, including the `OAuth2` grant type and any associated
319/// credentials or payload data.
320///
321/// # Fields
322///
323/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
324/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
325/// * `auth_server` - The authorization server responsible for issuing tokens.
326#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329    feature = "python",
330    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333    /// The grant type to use to request an access token.
334    payload: OAuthGrant,
335    /// The access token that is currently in use. None if no token has been requested yet.
336    access_token: Option<SecretAccessToken>,
337    /// The [`AuthServer`] that issues the tokens.
338    auth_server: AuthServer,
339}
340
341impl OAuthSession {
342    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
343    ///
344    /// Optionally include an `access_token`, if not included, then one can be requested
345    /// with [`Self::request_access_token`].
346    #[must_use]
347    pub const fn new(
348        payload: OAuthGrant,
349        auth_server: AuthServer,
350        access_token: Option<SecretAccessToken>,
351    ) -> Self {
352        Self {
353            payload,
354            access_token,
355            auth_server,
356        }
357    }
358
359    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
360    ///
361    /// Optionally include an `access_token`, if not included, then one can be requested
362    /// with [`Self::request_access_token`].
363    #[must_use]
364    pub const fn from_externally_managed(
365        tokens: ExternallyManaged,
366        auth_server: AuthServer,
367        access_token: Option<SecretAccessToken>,
368    ) -> Self {
369        Self::new(
370            OAuthGrant::ExternallyManaged(tokens),
371            auth_server,
372            access_token,
373        )
374    }
375
376    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
377    ///
378    /// Optionally include an `access_token`, if not included, then one can be requested
379    /// with [`Self::request_access_token`].
380    #[must_use]
381    pub const fn from_refresh_token(
382        tokens: RefreshToken,
383        auth_server: AuthServer,
384        access_token: Option<SecretAccessToken>,
385    ) -> Self {
386        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387    }
388
389    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
390    ///
391    /// Optionally include an `access_token`, if not included, then one can be requested
392    /// with [`Self::request_access_token`].
393    #[must_use]
394    pub const fn from_client_credentials(
395        tokens: ClientCredentials,
396        auth_server: AuthServer,
397        access_token: Option<SecretAccessToken>,
398    ) -> Self {
399        Self::new(
400            OAuthGrant::ClientCredentials(tokens),
401            auth_server,
402            access_token,
403        )
404    }
405
406    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
407    ///
408    /// Optionally include an `access_token`, if not included, then one can be requested
409    /// with [`Self::request_access_token`].
410    #[must_use]
411    pub const fn from_pkce_flow(
412        flow: PkceFlow,
413        auth_server: AuthServer,
414        access_token: Option<SecretAccessToken>,
415    ) -> Self {
416        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417    }
418
419    /// Get the current access token.
420    ///
421    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
422    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
423    ///
424    /// # Errors
425    ///
426    /// - [`TokenError::NoAccessToken`] if there is no access token
427    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429    }
430
431    /// Get the payload used to request an access token.
432    #[must_use]
433    pub const fn payload(&self) -> &OAuthGrant {
434        &self.payload
435    }
436
437    /// Request and return an updated access token using these credentials.
438    ///
439    /// # Errors
440    ///
441    /// See [`TokenError`]
442    #[allow(clippy::missing_panics_doc)]
443    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444        let access_token = self.payload.request_access_token(&self.auth_server).await?;
445        Ok(self.access_token.insert(access_token))
446    }
447
448    /// The [`AuthServer`] that issues the tokens.
449    #[must_use]
450    pub const fn auth_server(&self) -> &AuthServer {
451        &self.auth_server
452    }
453
454    /// Validate the access token, returning it if it is valid, or an error describing why it is
455    /// invalid.
456    ///
457    /// # Errors
458    ///
459    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
460    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
461    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462        let access_token = self.access_token()?;
463        insecure_validate_token_exp(access_token)?;
464        Ok(access_token.clone())
465    }
466}
467
468/// Validates the access token's format and `exp` claim, but no other claims or
469/// signature. We do this only to determine if the token is expired and needs refreshing,
470/// there is no way to securely validate the token's signature on the client side.
471pub(crate) fn insecure_validate_token_exp(
472    access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474    let placeholder_key = DecodingKey::from_secret(&[]);
475    let mut validation = Validation::new(Algorithm::RS256);
476    validation.validate_exp = true;
477    validation.leeway = 60;
478    validation.validate_aud = false;
479    validation.insecure_disable_signature_validation();
480
481    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482        .map(|_| ())
483        .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        let token_populated = if self.access_token.is_some() {
489            Some(())
490        } else {
491            None
492        };
493        f.debug_struct("OAuthSession")
494            .field("payload", &self.payload)
495            .field("access_token", &token_populated)
496            .field("auth_server", &self.auth_server)
497            .finish()
498    }
499}
500
501/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
502#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505    feature = "python",
506    pyo3::pyclass(module = "qcs_api_client_common.configuration")
507)]
508pub struct TokenDispatcher {
509    lock: Arc<RwLock<OAuthSession>>,
510    refreshing: Arc<Mutex<bool>>,
511    notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515    fn from(value: OAuthSession) -> Self {
516        Self {
517            lock: Arc::new(RwLock::new(value)),
518            refreshing: Arc::new(Mutex::new(false)),
519            notify_refreshed: Arc::new(Notify::new()),
520        }
521    }
522}
523
524impl TokenDispatcher {
525    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
526    /// dispatcher.
527    ///
528    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
529    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
530    ///
531    /// # Parameters
532    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
533    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
534    pub async fn use_tokens<F, O>(&self, f: F) -> O
535    where
536        F: FnOnce(&OAuthSession) -> O + Send,
537    {
538        let tokens = self.lock.read().await;
539        f(&tokens)
540    }
541
542    /// Get a copy of the current access token.
543    #[must_use]
544    pub async fn tokens(&self) -> OAuthSession {
545        self.use_tokens(Clone::clone).await
546    }
547
548    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
549    ///
550    /// # Errors
551    ///
552    /// See [`TokenError`]
553    pub async fn refresh(
554        &self,
555        source: &ConfigSource,
556        profile: &str,
557    ) -> Result<OAuthSession, TokenError> {
558        self.managed_refresh(Self::perform_refresh, source, profile)
559            .await
560    }
561
562    /// Validate the access token, returning it if it is valid, or an error describing why it is
563    /// invalid.
564    ///
565    /// # Errors
566    ///
567    /// - [`TokenError::NoAccessToken`] if there is no access token
568    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
569    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570        self.use_tokens(OAuthSession::validate).await
571    }
572
573    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
574    /// ``refresh_fn``.
575    async fn managed_refresh<F, Fut>(
576        &self,
577        refresh_fn: F,
578        source: &ConfigSource,
579        profile: &str,
580    ) -> Result<OAuthSession, TokenError>
581    where
582        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584    {
585        let mut is_refreshing = self.refreshing.lock().await;
586
587        if *is_refreshing {
588            drop(is_refreshing);
589            self.notify_refreshed.notified().await;
590            return Ok(self.tokens().await);
591        }
592
593        *is_refreshing = true;
594        drop(is_refreshing);
595
596        let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598        // If the config source is a file, write the new access token to the file
599        if let ConfigSource::File {
600            settings_path: _,
601            secrets_path,
602        } = source
603        {
604            if !Secrets::is_read_only(secrets_path).await? {
605                // If the payload is a PkceFlow, write the fresh refresh token if available.
606                let refresh_token = match &oauth_session.payload {
607                    OAuthGrant::PkceFlow(payload) => {
608                        payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
609                    }
610                    _ => None,
611                };
612
613                let now = OffsetDateTime::now_utc();
614                Secrets::write_tokens(
615                    secrets_path,
616                    profile,
617                    refresh_token,
618                    oauth_session.access_token()?,
619                    now,
620                )
621                .await?;
622            }
623        }
624
625        *self.refreshing.lock().await = false;
626        self.notify_refreshed.notify_waiters();
627        Ok(oauth_session)
628    }
629
630    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
631    /// of the updated [`Credentials`]
632    ///
633    /// # Errors
634    ///
635    /// See [`TokenError`]
636    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
637        let mut credentials = lock.write().await;
638        credentials.request_access_token().await?;
639        Ok(credentials.clone())
640    }
641}
642
643pub(crate) type RefreshResult =
644    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
645
646/// A function that asynchronously refreshes a token.
647pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
648
649/// A struct that manages access tokens by utilizing a user-provided refresh function.
650///
651/// The [`ExternallyManaged`] struct allows users to define custom logic for
652/// fetching or refreshing access tokens.
653#[derive(Clone)]
654#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
655#[cfg_attr(
656    feature = "python",
657    pyo3::pyclass(module = "qcs_api_client_common.configuration")
658)]
659pub struct ExternallyManaged {
660    refresh_function: Arc<RefreshFunction>,
661}
662
663impl ExternallyManaged {
664    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
665    ///
666    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
667    /// they better fit your use case.
668    ///
669    /// # Arguments
670    ///
671    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
672    ///
673    /// # Example
674    ///
675    /// ```
676    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
677    /// use std::future::Future;
678    /// use std::pin::Pin;
679    /// use std::boxed::Box;
680    /// use std::error::Error;
681    ///
682    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
683    /// + Send + Sync>> {
684    ///     Ok("new_token_value".to_string())
685    /// }
686    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
687    /// ```
688    pub fn new(
689        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
690    ) -> Self {
691        Self {
692            refresh_function: Arc::new(Box::new(refresh_function)),
693        }
694    }
695
696    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
697    ///
698    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
699    /// the boxing and pinning of the future internally.
700    ///
701    /// # Arguments
702    ///
703    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
704    ///   produces a [`Result<String, TokenError>`].
705    ///
706    /// # Example
707    ///
708    /// ```
709    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
710    /// use tokio::runtime::Runtime;
711    /// use std::error::Error;
712    ///
713    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
714    /// + Send + Sync>> {
715    ///     Ok("new_token_value".to_string())
716    /// }
717    ///
718    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
719    ///
720    /// let rt = Runtime::new().unwrap();
721    /// rt.block_on(async {
722    ///     match token_manager.request_access_token(&AuthServer::default()).await {
723    ///         Ok(token) => println!("Token: {token:?}"),
724    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
725    ///     }
726    /// });
727    /// ```
728    pub fn from_async<F, Fut>(refresh_function: F) -> Self
729    where
730        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
731        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
732            + Send
733            + 'static,
734    {
735        Self {
736            refresh_function: Arc::new(Box::new(move |auth_server| {
737                Box::pin(refresh_function(auth_server))
738            })),
739        }
740    }
741
742    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
743    ///
744    /// The synchronous function is wrapped in an async block to fit the expected signature.
745    ///
746    /// # Arguments
747    ///
748    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
749    ///
750    /// # Example
751    ///
752    /// ```
753    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
754    /// use tokio::runtime::Runtime;
755    /// use std::error::Error;
756    ///
757    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
758    /// + Send + Sync>> {
759    ///     Ok("sync_token_value".to_string())
760    /// }
761    ///
762    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
763    ///
764    /// let rt = Runtime::new().unwrap();
765    /// rt.block_on(async {
766    ///     match token_manager.request_access_token(&AuthServer::default()).await {
767    ///         Ok(token) => println!("Token: {token:?}"),
768    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
769    ///     }
770    /// });
771    /// ```
772    pub fn from_sync(
773        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
774            + Send
775            + Sync
776            + 'static,
777    ) -> Self {
778        Self {
779            refresh_function: Arc::new(Box::new(move |auth_server| {
780                let result = refresh_function(auth_server);
781                Box::pin(async move { result })
782            })),
783        }
784    }
785
786    /// Request an updated access token using the provided refresh function.
787    ///
788    /// # Errors
789    ///
790    /// Errors are propagated from the refresh function.
791    pub async fn request_access_token(
792        &self,
793        auth_server: &AuthServer,
794    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
795        (self.refresh_function)(auth_server.clone())
796            .await
797            .map(SecretAccessToken::from)
798    }
799}
800
801impl std::fmt::Debug for ExternallyManaged {
802    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
803        f.debug_struct("ExternallyManaged")
804            .field(
805                "refresh_function",
806                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
807            )
808            .finish()
809    }
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813pub(super) struct TokenRefreshRequest<'a> {
814    grant_type: &'static str,
815    client_id: &'a str,
816    refresh_token: &'a str,
817}
818
819impl<'a> TokenRefreshRequest<'a> {
820    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
821        Self {
822            grant_type: "refresh_token",
823            client_id,
824            refresh_token,
825        }
826    }
827}
828
829#[derive(Debug, Serialize, Deserialize)]
830pub(super) struct ClientCredentialsRequest {
831    grant_type: &'static str,
832    scope: Option<&'static str>,
833}
834
835impl ClientCredentialsRequest {
836    pub(super) const fn new(scope: Option<&'static str>) -> Self {
837        Self {
838            grant_type: "client_credentials",
839            scope,
840        }
841    }
842}
843
844#[derive(Deserialize, Debug, Serialize)]
845pub(super) struct RefreshTokenResponse {
846    pub(super) refresh_token: Option<SecretRefreshToken>,
847    pub(super) access_token: SecretAccessToken,
848}
849
850/// Get and refresh access tokens
851#[async_trait::async_trait]
852pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
853    /// The type to be returned in the event of a error during getting or
854    /// refreshing an access token
855    type Error;
856
857    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
858    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
859
860    /// Get the current access token, if any
861    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
862
863    /// Get a fresh access token
864    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
865
866    /// Get the base URL for requests
867    #[cfg(feature = "tracing")]
868    fn base_url(&self) -> &str;
869
870    /// Get the tracing configuration
871    #[cfg(feature = "tracing-config")]
872    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
873
874    /// Returns whether the given URL should be traced. Following
875    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
876    #[cfg(feature = "tracing")]
877    #[allow(clippy::needless_return)]
878    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
879        #[cfg(not(feature = "tracing-config"))]
880        {
881            let _ = url;
882            return true;
883        }
884
885        #[cfg(feature = "tracing-config")]
886        self.tracing_configuration()
887            .is_none_or(|config| config.is_enabled(url))
888    }
889}
890
891#[async_trait::async_trait]
892impl TokenRefresher for ClientConfiguration {
893    type Error = TokenError;
894
895    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
896        self.get_bearer_access_token().await
897    }
898
899    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
900        Ok(self.refresh().await?.access_token()?.clone())
901    }
902
903    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
904        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
905    }
906
907    #[cfg(feature = "tracing")]
908    fn base_url(&self) -> &str {
909        &self.grpc_api_url
910    }
911
912    #[cfg(feature = "tracing-config")]
913    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
914        self.tracing_configuration.as_ref()
915    }
916}
917
918/// Get a default http client.
919pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
920    reqwest::Client::builder()
921        .timeout(std::time::Duration::from_secs(10))
922        .build()
923}
924
925#[cfg(test)]
926mod test {
927    use std::time::Duration;
928
929    use super::*;
930    use httpmock::prelude::*;
931    use rstest::rstest;
932    use time::format_description::well_known::Rfc3339;
933    use tokio::time::Instant;
934    use toml_edit::DocumentMut;
935
936    #[tokio::test]
937    async fn test_tokens_blocked_during_refresh() {
938        let mock_server = MockServer::start_async().await;
939
940        let oidc_mock = mock_server
941            .mock_async(|when, then| {
942                when.method(GET).path("/.well-known/openid-configuration");
943                then.status(200)
944                    .json_body_obj(&oidc::Discovery::new_for_test(
945                        mock_server.base_url().parse().unwrap(),
946                    ));
947            })
948            .await;
949
950        let issuer_mock = mock_server
951            .mock_async(|when, then| {
952                when.method(POST).path("/v1/token");
953
954                then.status(200)
955                    .delay(Duration::from_secs(3))
956                    .json_body_obj(&RefreshTokenResponse {
957                        access_token: SecretAccessToken::from("new_access"),
958                        refresh_token: Some(SecretRefreshToken::from("new_refresh")),
959                    });
960            })
961            .await;
962
963        let original_tokens = OAuthSession::from_refresh_token(
964            RefreshToken::new(SecretRefreshToken::from("refresh")),
965            AuthServer {
966                client_id: "client_id".to_string(),
967                issuer: mock_server.base_url(),
968                scopes: None,
969            },
970            None,
971        );
972        let dispatcher: TokenDispatcher = original_tokens.clone().into();
973        let dispatcher_clone1 = dispatcher.clone();
974        let dispatcher_clone2 = dispatcher.clone();
975
976        let refresh_duration = Duration::from_secs(3);
977
978        let start_write = Instant::now();
979        let write_future = tokio::spawn(async move {
980            dispatcher_clone1
981                .refresh(&ConfigSource::Default, "")
982                .await
983                .unwrap()
984        });
985
986        let start_read = Instant::now();
987        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
988
989        let _ = write_future.await.unwrap();
990        let read_result = read_future.await.unwrap();
991
992        let write_duration = start_write.elapsed();
993        let read_duration = start_read.elapsed();
994
995        oidc_mock.assert_async().await;
996        issuer_mock.assert_async().await;
997
998        assert!(
999            write_duration >= refresh_duration,
1000            "Write operation did not take enough time"
1001        );
1002        assert!(
1003            read_duration >= refresh_duration,
1004            "Read operation was not blocked by the write operation"
1005        );
1006        assert_eq!(
1007            read_result.access_token.unwrap(),
1008            SecretAccessToken::from("new_access")
1009        );
1010        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1011            assert_eq!(
1012                payload.refresh_token,
1013                SecretRefreshToken::from("new_refresh")
1014            );
1015        } else {
1016            panic!(
1017                "Expected RefreshToken payload, got {:?}",
1018                read_result.payload
1019            );
1020        }
1021    }
1022
1023    #[rstest]
1024    fn test_qcs_secrets_readonly(
1025        #[values(
1026            (Some("TRUE"), true),
1027            (Some("tRue"), true),
1028            (Some("true"), true),
1029            (Some("YES"), true),
1030            (Some("yEs"), true),
1031            (Some("yes"), true),
1032            (Some("1"), true),
1033            (Some("2"), false),
1034            (Some("other"), false),
1035            (Some(""), false),
1036            (None, false),
1037        )]
1038        read_only_values: (Option<&str>, bool),
1039        #[values(true, false)] read_only_perm: bool,
1040    ) {
1041        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1042        let expected_update = !env_is_read_only && !read_only_perm;
1043        figment::Jail::expect_with(|jail| {
1044            let profile_name = "test";
1045            let initial_access_token = "initial_access_token";
1046            let initial_refresh_token = "initial_refresh_token";
1047
1048            let initial_secrets_file_contents = format!(
1049                r#"
1050[credentials]
1051[credentials.{profile_name}]
1052[credentials.{profile_name}.token_payload]
1053access_token = "{initial_access_token}"
1054expires_in = 3600
1055id_token = "id_token"
1056refresh_token = "{initial_refresh_token}"
1057scope = "offline_access openid profile email"
1058token_type = "Bearer"
1059updated_at = "2024-01-01T00:00:00Z"
1060"#
1061            );
1062
1063            // Ignore any existing environment variables.
1064            jail.clear_env();
1065
1066            // Create a temporary secrets file
1067            let secrets_path = "secrets.toml";
1068            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1069                .expect("should create test secrets.toml");
1070
1071            if read_only_perm {
1072                let mut permissions = std::fs::metadata(secrets_path)
1073                    .expect("Should be able to get file metadata")
1074                    .permissions();
1075                permissions.set_readonly(true);
1076                std::fs::set_permissions(secrets_path, permissions)
1077                    .expect("Should be able to set file permissions");
1078            }
1079
1080            let rt = tokio::runtime::Runtime::new().unwrap();
1081            rt.block_on(async {
1082                let mock_server = MockServer::start_async().await;
1083
1084                let oidc_mock = mock_server
1085                    .mock_async(|when, then| {
1086                        when.method(GET).path("/.well-known/openid-configuration");
1087                        then.status(200)
1088                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1089                    })
1090                    .await;
1091
1092                // Set up the mock token endpoint
1093                let new_access_token = SecretAccessToken::from("new_access_token");
1094                let issuer_mock = mock_server
1095                    .mock_async(|when, then| {
1096                        when.method(POST).path("/v1/token");
1097                        then.status(200).json_body_obj(&RefreshTokenResponse {
1098                            access_token: new_access_token.clone(),
1099                            refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1100                        });
1101                    })
1102                    .await;
1103
1104                // Create tokens and dispatcher
1105                let original_tokens = OAuthSession::from_refresh_token(
1106                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1107                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1108                    Some(SecretAccessToken::from(initial_refresh_token)),
1109                );
1110                let dispatcher: TokenDispatcher = original_tokens.into();
1111
1112                // Test with QCS_SECRETS_READ_ONLY set first
1113                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1114                jail.set_env("QCS_PROFILE_NAME", "test");
1115                if let Some(read_only_env) = maybe_read_only_env {
1116                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1117                }
1118
1119                let before_refresh = OffsetDateTime::now_utc();
1120
1121                dispatcher
1122                    .refresh(
1123                        &ConfigSource::File {
1124                            settings_path: "".into(),
1125                            secrets_path: "secrets.toml".into(),
1126                        },
1127                        profile_name,
1128                    )
1129                    .await
1130                    .unwrap();
1131
1132                oidc_mock.assert_async().await;
1133                issuer_mock.assert_async().await;
1134
1135                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1136                let content = std::fs::read_to_string("secrets.toml").unwrap();
1137                if !expected_update {
1138                    assert!(
1139                        content.eq(initial_secrets_file_contents.as_str()),
1140                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1141                    );
1142                    return;
1143                }
1144
1145                // Verify the file was updated
1146                let mut toml = std::fs::read_to_string(secrets_path)
1147                    .unwrap()
1148                    .parse::<DocumentMut>()
1149                    .unwrap();
1150
1151                let token_payload = toml
1152                    .get_mut("credentials")
1153                    .and_then(|credentials| {
1154                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1155                    })
1156                    .expect("Should be able to get token_payload table");
1157
1158                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1159
1160                assert_eq!(
1161                    access_token,
1162                    Some(new_access_token)
1163                );
1164
1165                assert!(
1166                    OffsetDateTime::parse(
1167                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1168                        &Rfc3339
1169                    )
1170                    .unwrap()
1171                        > before_refresh
1172                );
1173
1174                let content = std::fs::read_to_string("secrets.toml").unwrap();
1175                assert!(
1176                content.contains("new_access_token"),
1177                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1178                );
1179            });
1180            Ok(())
1181        });
1182    }
1183
1184    #[test]
1185    fn test_auth_session_debug_fmt() {
1186        let session = OAuthSession {
1187            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1188                "hidden_id",
1189                "hidden_secret",
1190            )),
1191            access_token: Some(SecretAccessToken::from("token")),
1192            auth_server: AuthServer {
1193                client_id: "some_id".into(),
1194                issuer: "some_url".into(),
1195                scopes: None,
1196            },
1197        };
1198
1199        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1200    }
1201}