Skip to main content

qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16    ClientConfiguration, ConfigSource, TokenError, oidc, secrets::Secrets, settings::AuthServer,
17};
18use crate::configuration::{
19    error::DiscoveryError,
20    pkce::{PkceLoginError, PkceLoginRequest, pkce_login},
21    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30/// A single type containing an access token and an associated refresh token.
31#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38    /// The token used to refresh the access token.
39    pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43    /// Create a new [`RefreshToken`] with the given refresh token.
44    #[must_use]
45    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46        Self { refresh_token }
47    }
48
49    /// Request and return a new access token from the given authorization server using this refresh token.
50    ///
51    /// # Errors
52    ///
53    /// See [`TokenError`]
54    pub async fn request_access_token(
55        &mut self,
56        auth_server: &AuthServer,
57    ) -> Result<SecretAccessToken, TokenError> {
58        if self.refresh_token.is_empty() {
59            return Err(TokenError::NoRefreshToken);
60        }
61
62        let client = default_http_client()?;
63        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64            .await?
65            .token_endpoint;
66        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67        let resp = client.post(token_url).form(&data).send().await?;
68
69        let RefreshTokenResponse {
70            access_token,
71            refresh_token,
72        } = resp.error_for_status()?.json().await?;
73
74        if let Some(refresh_token) = refresh_token {
75            self.refresh_token = refresh_token;
76        }
77        Ok(access_token)
78    }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83    pub(super) access_token: SecretAccessToken,
84}
85
86/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
87#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90    feature = "python",
91    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94    /// The client ID
95    pub client_id: String,
96    /// The client secret.
97    pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101    #[must_use]
102    /// Construct a new [`ClientCredentials`]
103    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104        Self {
105            client_id: client_id.into(),
106            client_secret: client_secret.into(),
107        }
108    }
109
110    /// Get the client ID.
111    #[must_use]
112    pub fn client_id(&self) -> &str {
113        &self.client_id
114    }
115
116    /// Get the client secret.
117    #[must_use]
118    pub const fn client_secret(&self) -> &ClientSecret {
119        &self.client_secret
120    }
121
122    /// Request and return an access token from the given auth server using this set of client credentials.
123    ///
124    /// # Errors
125    ///
126    /// See [`TokenError`]
127    pub async fn request_access_token(
128        &self,
129        auth_server: &AuthServer,
130    ) -> Result<SecretAccessToken, TokenError> {
131        let request = ClientCredentialsRequest::new(None);
132        let client = default_http_client()?;
133
134        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135            .await?
136            .token_endpoint;
137        let ready_to_send = client
138            .post(url)
139            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140            .form(&request);
141        let response = ready_to_send.send().await?;
142
143        response.error_for_status_ref()?;
144
145        let ClientCredentialsResponse { access_token } = response.json().await?;
146        Ok(access_token)
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153    feature = "python",
154    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
157pub struct PkceFlow {
158    /// The access token.
159    pub access_token: SecretAccessToken,
160    /// The refresh token, if available.
161    pub refresh_token: Option<RefreshToken>,
162}
163
164/// Errors that can occur when attempting to perform a PKCE login flow.
165#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167    /// Error that occurred while performing the PKCE login flow.
168    #[error(transparent)]
169    PkceLogin(#[from] PkceLoginError),
170    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
171    #[error(transparent)]
172    Discovery(#[from] DiscoveryError),
173    /// Error that occurred while making http requests.
174    #[error(transparent)]
175    Request(#[from] qcs_dependencies_client::reqwest::Error),
176}
177
178impl PkceFlow {
179    /// Starts a new PKCE login flow to acquire a new set of tokens.
180    ///
181    /// # Errors
182    ///
183    /// See [`PkceFlowError`]
184    pub async fn new_login_flow(
185        cancel_token: CancellationToken,
186        auth_server: &AuthServer,
187    ) -> Result<Self, PkceFlowError> {
188        let issuer = auth_server.issuer.clone();
189
190        let client = default_http_client()?;
191        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193        let response = pkce_login(
194            cancel_token,
195            PkceLoginRequest {
196                client_id: auth_server.client_id.clone(),
197                redirect_port: None,
198                discovery,
199                scopes: auth_server.scopes.clone(),
200            },
201        )
202        .await?;
203
204        Ok(Self {
205            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206            refresh_token: response
207                .refresh_token()
208                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209        })
210    }
211
212    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
213    ///
214    /// # Errors
215    ///
216    /// See [`TokenError`]
217    pub async fn request_access_token(
218        &mut self,
219        auth_server: &AuthServer,
220    ) -> Result<SecretAccessToken, TokenError> {
221        if insecure_validate_token_exp(&self.access_token).is_ok() {
222            return Ok(self.access_token.clone());
223        }
224
225        if let Some(refresh_token) = &mut self.refresh_token {
226            let access_token = refresh_token.request_access_token(auth_server).await?;
227            self.access_token.clone_from(&access_token);
228            return Ok(access_token);
229        }
230
231        Err(TokenError::NoRefreshToken)
232    }
233}
234
235impl From<PkceFlow> for Credential {
236    fn from(value: PkceFlow) -> Self {
237        let mut token_payload = TokenPayload::default();
238        token_payload.access_token = Some(value.access_token);
239        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241        Self {
242            token_payload: Some(token_payload),
243        }
244    }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
250/// needed to request said grant type.
251pub enum OAuthGrant {
252    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
253    RefreshToken(RefreshToken),
254    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
255    ClientCredentials(ClientCredentials),
256    /// Defers to a user provided function for access token requests.
257    ExternallyManaged(ExternallyManaged),
258    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
259    PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263    fn from(v: ExternallyManaged) -> Self {
264        Self::ExternallyManaged(v)
265    }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269    fn from(v: ClientCredentials) -> Self {
270        Self::ClientCredentials(v)
271    }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275    fn from(v: RefreshToken) -> Self {
276        Self::RefreshToken(v)
277    }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281    fn from(v: PkceFlow) -> Self {
282        Self::PkceFlow(v)
283    }
284}
285
286impl OAuthGrant {
287    /// Request a new access token from the given issuer using this grant type and payload.
288    async fn request_access_token(
289        &mut self,
290        auth_server: &AuthServer,
291    ) -> Result<SecretAccessToken, TokenError> {
292        match self {
293            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295            Self::ExternallyManaged(tokens) => tokens
296                .request_access_token(auth_server)
297                .await
298                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300        }
301    }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        match self {
307            Self::RefreshToken(_) => f.write_str("RefreshToken"),
308            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310            Self::PkceFlow(_) => f.write_str("PkceTokens"),
311        }
312    }
313}
314
315/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
316///
317/// This struct encapsulates the necessary information to request an access token
318/// from an authorization server, including the `OAuth2` grant type and any associated
319/// credentials or payload data.
320///
321/// # Fields
322///
323/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
324/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
325/// * `auth_server` - The authorization server responsible for issuing tokens.
326#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329    feature = "python",
330    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333    /// The grant type to use to request an access token.
334    payload: OAuthGrant,
335    /// The access token that is currently in use. None if no token has been requested yet.
336    access_token: Option<SecretAccessToken>,
337    /// The [`AuthServer`] that issues the tokens.
338    auth_server: AuthServer,
339}
340
341impl OAuthSession {
342    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
343    ///
344    /// Optionally include an `access_token`, if not included, then one can be requested
345    /// with [`Self::request_access_token`].
346    #[must_use]
347    pub const fn new(
348        payload: OAuthGrant,
349        auth_server: AuthServer,
350        access_token: Option<SecretAccessToken>,
351    ) -> Self {
352        Self {
353            payload,
354            access_token,
355            auth_server,
356        }
357    }
358
359    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
360    ///
361    /// Optionally include an `access_token`, if not included, then one can be requested
362    /// with [`Self::request_access_token`].
363    #[must_use]
364    pub const fn from_externally_managed(
365        tokens: ExternallyManaged,
366        auth_server: AuthServer,
367        access_token: Option<SecretAccessToken>,
368    ) -> Self {
369        Self::new(
370            OAuthGrant::ExternallyManaged(tokens),
371            auth_server,
372            access_token,
373        )
374    }
375
376    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
377    ///
378    /// Optionally include an `access_token`, if not included, then one can be requested
379    /// with [`Self::request_access_token`].
380    #[must_use]
381    pub const fn from_refresh_token(
382        tokens: RefreshToken,
383        auth_server: AuthServer,
384        access_token: Option<SecretAccessToken>,
385    ) -> Self {
386        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387    }
388
389    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
390    ///
391    /// Optionally include an `access_token`, if not included, then one can be requested
392    /// with [`Self::request_access_token`].
393    #[must_use]
394    pub const fn from_client_credentials(
395        tokens: ClientCredentials,
396        auth_server: AuthServer,
397        access_token: Option<SecretAccessToken>,
398    ) -> Self {
399        Self::new(
400            OAuthGrant::ClientCredentials(tokens),
401            auth_server,
402            access_token,
403        )
404    }
405
406    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
407    ///
408    /// Optionally include an `access_token`, if not included, then one can be requested
409    /// with [`Self::request_access_token`].
410    #[must_use]
411    pub const fn from_pkce_flow(
412        flow: PkceFlow,
413        auth_server: AuthServer,
414        access_token: Option<SecretAccessToken>,
415    ) -> Self {
416        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417    }
418
419    /// Get the current access token.
420    ///
421    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
422    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
423    ///
424    /// # Errors
425    ///
426    /// - [`TokenError::NoAccessToken`] if there is no access token
427    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429    }
430
431    /// Get the payload used to request an access token.
432    #[must_use]
433    pub const fn payload(&self) -> &OAuthGrant {
434        &self.payload
435    }
436
437    /// Request and return an updated access token using these credentials.
438    ///
439    /// # Errors
440    ///
441    /// See [`TokenError`]
442    #[allow(clippy::missing_panics_doc)]
443    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444        let access_token = self.payload.request_access_token(&self.auth_server).await?;
445        Ok(self.access_token.insert(access_token))
446    }
447
448    /// The [`AuthServer`] that issues the tokens.
449    #[must_use]
450    pub const fn auth_server(&self) -> &AuthServer {
451        &self.auth_server
452    }
453
454    /// Validate the access token, returning it if it is valid, or an error describing why it is
455    /// invalid.
456    ///
457    /// # Errors
458    ///
459    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
460    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
461    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462        let access_token = self.access_token()?;
463        insecure_validate_token_exp(access_token)?;
464        Ok(access_token.clone())
465    }
466}
467
468/// Validates the access token's format and `exp` claim, but no other claims or
469/// signature. We do this only to determine if the token is expired and needs refreshing,
470/// there is no way to securely validate the token's signature on the client side.
471pub(crate) fn insecure_validate_token_exp(
472    access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474    let placeholder_key = DecodingKey::from_secret(&[]);
475    let mut validation = Validation::new(Algorithm::RS256);
476    validation.validate_exp = true;
477    validation.leeway = 60;
478    validation.validate_aud = false;
479    validation.insecure_disable_signature_validation();
480
481    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482        .map(|_| ())
483        .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        let token_populated = if self.access_token.is_some() {
489            Some(())
490        } else {
491            None
492        };
493        f.debug_struct("OAuthSession")
494            .field("payload", &self.payload)
495            .field("access_token", &token_populated)
496            .field("auth_server", &self.auth_server)
497            .finish()
498    }
499}
500
501/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
502#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505    feature = "python",
506    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
507)]
508pub struct TokenDispatcher {
509    lock: Arc<RwLock<OAuthSession>>,
510    refreshing: Arc<Mutex<bool>>,
511    notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515    fn from(value: OAuthSession) -> Self {
516        Self {
517            lock: Arc::new(RwLock::new(value)),
518            refreshing: Arc::new(Mutex::new(false)),
519            notify_refreshed: Arc::new(Notify::new()),
520        }
521    }
522}
523
524impl TokenDispatcher {
525    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
526    /// dispatcher.
527    ///
528    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
529    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
530    ///
531    /// # Parameters
532    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
533    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
534    pub async fn use_tokens<F, O>(&self, f: F) -> O
535    where
536        F: FnOnce(&OAuthSession) -> O + Send,
537    {
538        let tokens = self.lock.read().await;
539        f(&tokens)
540    }
541
542    /// Get a copy of the current access token.
543    #[must_use]
544    pub async fn tokens(&self) -> OAuthSession {
545        self.use_tokens(Clone::clone).await
546    }
547
548    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
549    ///
550    /// # Errors
551    ///
552    /// See [`TokenError`]
553    pub async fn refresh(
554        &self,
555        source: &ConfigSource,
556        profile: &str,
557    ) -> Result<OAuthSession, TokenError> {
558        self.managed_refresh(Self::perform_refresh, source, profile)
559            .await
560    }
561
562    /// Validate the access token, returning it if it is valid, or an error describing why it is
563    /// invalid.
564    ///
565    /// # Errors
566    ///
567    /// - [`TokenError::NoAccessToken`] if there is no access token
568    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
569    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570        self.use_tokens(OAuthSession::validate).await
571    }
572
573    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
574    /// ``refresh_fn``.
575    async fn managed_refresh<F, Fut>(
576        &self,
577        refresh_fn: F,
578        source: &ConfigSource,
579        profile: &str,
580    ) -> Result<OAuthSession, TokenError>
581    where
582        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584    {
585        let mut is_refreshing = self.refreshing.lock().await;
586
587        if *is_refreshing {
588            drop(is_refreshing);
589            self.notify_refreshed.notified().await;
590            return Ok(self.tokens().await);
591        }
592
593        *is_refreshing = true;
594        drop(is_refreshing);
595
596        let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598        // If the config source is a file, write the new access token to the file
599        let write_result = if let ConfigSource::File {
600            settings_path: _,
601            secrets_path,
602        } = source
603        {
604            match Secrets::is_read_only(secrets_path).await {
605                Ok(true) => Ok(()),
606                Ok(false) => {
607                    // If the payload is a PkceFlow, write the fresh refresh token if available.
608                    let refresh_token = match &oauth_session.payload {
609                        OAuthGrant::PkceFlow(payload) => {
610                            payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
611                        }
612                        _ => None,
613                    };
614
615                    let now = OffsetDateTime::now_utc();
616                    Secrets::write_tokens(
617                        secrets_path,
618                        profile,
619                        refresh_token,
620                        oauth_session.access_token()?,
621                        now,
622                    )
623                    .await
624                }
625                Err(e) => Err(e),
626            }
627        } else {
628            Ok(())
629        };
630
631        // Always clean up the refreshing lock, even if write failed
632        *self.refreshing.lock().await = false;
633        self.notify_refreshed.notify_waiters();
634
635        // If write failed, return error with the valid oauth_session
636        if let Err(error) = write_result {
637            return Err(TokenError::Write {
638                error,
639                oauth_session: Box::new(oauth_session),
640            });
641        }
642
643        Ok(oauth_session)
644    }
645
646    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
647    /// of the updated [`Credentials`]
648    ///
649    /// # Errors
650    ///
651    /// See [`TokenError`]
652    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
653        let mut credentials = lock.write().await;
654        credentials.request_access_token().await?;
655        Ok(credentials.clone())
656    }
657}
658
659pub(crate) type RefreshResult =
660    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
661
662/// A function that asynchronously refreshes a token.
663pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
664
665/// A struct that manages access tokens by utilizing a user-provided refresh function.
666///
667/// The [`ExternallyManaged`] struct allows users to define custom logic for
668/// fetching or refreshing access tokens.
669#[derive(Clone)]
670#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
671#[cfg_attr(
672    feature = "python",
673    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
674)]
675pub struct ExternallyManaged {
676    refresh_function: Arc<RefreshFunction>,
677}
678
679impl ExternallyManaged {
680    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
681    ///
682    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
683    /// they better fit your use case.
684    ///
685    /// # Arguments
686    ///
687    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
688    ///
689    /// # Example
690    ///
691    /// ```
692    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
693    /// use std::future::Future;
694    /// use std::pin::Pin;
695    /// use std::boxed::Box;
696    /// use std::error::Error;
697    ///
698    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
699    /// + Send + Sync>> {
700    ///     Ok("new_token_value".to_string())
701    /// }
702    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
703    /// ```
704    pub fn new(
705        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
706    ) -> Self {
707        Self {
708            refresh_function: Arc::new(Box::new(refresh_function)),
709        }
710    }
711
712    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
713    ///
714    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
715    /// the boxing and pinning of the future internally.
716    ///
717    /// # Arguments
718    ///
719    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
720    ///   produces a [`Result<String, TokenError>`].
721    ///
722    /// # Example
723    ///
724    /// ```
725    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
726    /// use tokio::runtime::Runtime;
727    /// use std::error::Error;
728    ///
729    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
730    /// + Send + Sync>> {
731    ///     Ok("new_token_value".to_string())
732    /// }
733    ///
734    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
735    ///
736    /// let rt = Runtime::new().unwrap();
737    /// rt.block_on(async {
738    ///     match token_manager.request_access_token(&AuthServer::default()).await {
739    ///         Ok(token) => println!("Token: {token:?}"),
740    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
741    ///     }
742    /// });
743    /// ```
744    pub fn from_async<F, Fut>(refresh_function: F) -> Self
745    where
746        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
747        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
748            + Send
749            + 'static,
750    {
751        Self {
752            refresh_function: Arc::new(Box::new(move |auth_server| {
753                Box::pin(refresh_function(auth_server))
754            })),
755        }
756    }
757
758    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
759    ///
760    /// The synchronous function is wrapped in an async block to fit the expected signature.
761    ///
762    /// # Arguments
763    ///
764    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
765    ///
766    /// # Example
767    ///
768    /// ```
769    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
770    /// use tokio::runtime::Runtime;
771    /// use std::error::Error;
772    ///
773    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
774    /// + Send + Sync>> {
775    ///     Ok("sync_token_value".to_string())
776    /// }
777    ///
778    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
779    ///
780    /// let rt = Runtime::new().unwrap();
781    /// rt.block_on(async {
782    ///     match token_manager.request_access_token(&AuthServer::default()).await {
783    ///         Ok(token) => println!("Token: {token:?}"),
784    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
785    ///     }
786    /// });
787    /// ```
788    pub fn from_sync(
789        refresh_function: impl Fn(
790            AuthServer,
791        ) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
792        + Send
793        + Sync
794        + 'static,
795    ) -> Self {
796        Self {
797            refresh_function: Arc::new(Box::new(move |auth_server| {
798                let result = refresh_function(auth_server);
799                Box::pin(async move { result })
800            })),
801        }
802    }
803
804    /// Request an updated access token using the provided refresh function.
805    ///
806    /// # Errors
807    ///
808    /// Errors are propagated from the refresh function.
809    pub async fn request_access_token(
810        &self,
811        auth_server: &AuthServer,
812    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
813        (self.refresh_function)(auth_server.clone())
814            .await
815            .map(SecretAccessToken::from)
816    }
817}
818
819impl std::fmt::Debug for ExternallyManaged {
820    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
821        f.debug_struct("ExternallyManaged")
822            .field(
823                "refresh_function",
824                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
825            )
826            .finish()
827    }
828}
829
830#[derive(Debug, Serialize, Deserialize)]
831pub(super) struct TokenRefreshRequest<'a> {
832    grant_type: &'static str,
833    client_id: &'a str,
834    refresh_token: &'a str,
835}
836
837impl<'a> TokenRefreshRequest<'a> {
838    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
839        Self {
840            grant_type: "refresh_token",
841            client_id,
842            refresh_token,
843        }
844    }
845}
846
847#[derive(Debug, Serialize, Deserialize)]
848pub(super) struct ClientCredentialsRequest {
849    grant_type: &'static str,
850    scope: Option<&'static str>,
851}
852
853impl ClientCredentialsRequest {
854    pub(super) const fn new(scope: Option<&'static str>) -> Self {
855        Self {
856            grant_type: "client_credentials",
857            scope,
858        }
859    }
860}
861
862#[derive(Deserialize, Debug, Serialize)]
863pub(super) struct RefreshTokenResponse {
864    pub(super) refresh_token: Option<SecretRefreshToken>,
865    pub(super) access_token: SecretAccessToken,
866}
867
868/// Get and refresh access tokens
869#[async_trait::async_trait]
870pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
871    /// The type to be returned in the event of a error during getting or
872    /// refreshing an access token
873    type Error;
874
875    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
876    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
877
878    /// Get the current access token, if any
879    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
880
881    /// Get a fresh access token
882    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
883
884    /// Get the base URL for requests
885    #[cfg(feature = "tracing")]
886    fn base_url(&self) -> &str;
887
888    /// Get the tracing configuration
889    #[cfg(feature = "tracing-config")]
890    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
891
892    /// Returns whether the given URL should be traced. Following
893    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
894    #[cfg(feature = "tracing")]
895    #[allow(clippy::needless_return)]
896    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
897        #[cfg(not(feature = "tracing-config"))]
898        {
899            let _ = url;
900            return true;
901        }
902
903        #[cfg(feature = "tracing-config")]
904        self.tracing_configuration()
905            .is_none_or(|config| config.is_enabled(url))
906    }
907}
908
909#[async_trait::async_trait]
910impl TokenRefresher for ClientConfiguration {
911    type Error = TokenError;
912
913    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
914        self.get_bearer_access_token().await
915    }
916
917    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
918        match self.refresh().await {
919            Ok(session) => Ok(session.access_token()?.clone()),
920            Err(TokenError::Write {
921                error,
922                oauth_session,
923            }) => {
924                // Token refresh succeeded but persistence failed. Extract and return the access token from the error.
925                #[cfg(feature = "tracing")]
926                tracing::warn!(
927                    "Token refresh succeeded but failed to persist: {}. Returning access token from error.",
928                    error
929                );
930                Ok(oauth_session.access_token()?.clone())
931            }
932            Err(e) => Err(e),
933        }
934    }
935
936    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
937        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
938    }
939
940    #[cfg(feature = "tracing")]
941    fn base_url(&self) -> &str {
942        &self.grpc_api_url
943    }
944
945    #[cfg(feature = "tracing-config")]
946    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
947        self.tracing_configuration.as_ref()
948    }
949}
950
951/// Get a default http client.
952pub(super) fn default_http_client()
953-> Result<qcs_dependencies_client::reqwest::Client, qcs_dependencies_client::reqwest::Error> {
954    qcs_dependencies_client::reqwest::Client::builder()
955        .timeout(std::time::Duration::from_secs(10))
956        .build()
957}
958
959#[cfg(test)]
960mod test {
961    #![allow(clippy::result_large_err, reason = "happens in figment tests")]
962
963    use std::time::Duration;
964
965    use super::*;
966    use httpmock::prelude::*;
967    use rstest::rstest;
968    use time::format_description::well_known::Rfc3339;
969    use tokio::time::Instant;
970    use toml_edit::DocumentMut;
971
972    #[tokio::test]
973    async fn test_tokens_blocked_during_refresh() {
974        let mock_server = MockServer::start_async().await;
975
976        let oidc_mock = mock_server
977            .mock_async(|when, then| {
978                when.method(GET).path("/.well-known/openid-configuration");
979                then.status(200)
980                    .json_body_obj(&oidc::Discovery::new_for_test(
981                        mock_server.base_url().parse().unwrap(),
982                    ));
983            })
984            .await;
985
986        let issuer_mock = mock_server
987            .mock_async(|when, then| {
988                when.method(POST).path("/v1/token");
989
990                then.status(200)
991                    .delay(Duration::from_secs(3))
992                    .json_body_obj(&RefreshTokenResponse {
993                        access_token: SecretAccessToken::from("new_access"),
994                        refresh_token: Some(SecretRefreshToken::from("new_refresh")),
995                    });
996            })
997            .await;
998
999        let original_tokens = OAuthSession::from_refresh_token(
1000            RefreshToken::new(SecretRefreshToken::from("refresh")),
1001            AuthServer {
1002                client_id: "client_id".to_string(),
1003                issuer: mock_server.base_url(),
1004                scopes: None,
1005            },
1006            None,
1007        );
1008        let dispatcher: TokenDispatcher = original_tokens.clone().into();
1009        let dispatcher_clone1 = dispatcher.clone();
1010        let dispatcher_clone2 = dispatcher.clone();
1011
1012        let refresh_duration = Duration::from_secs(3);
1013
1014        let start_write = Instant::now();
1015        let write_future = tokio::spawn(async move {
1016            dispatcher_clone1
1017                .refresh(&ConfigSource::Default, "")
1018                .await
1019                .unwrap()
1020        });
1021
1022        let start_read = Instant::now();
1023        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
1024
1025        let _ = write_future.await.unwrap();
1026        let read_result = read_future.await.unwrap();
1027
1028        let write_duration = start_write.elapsed();
1029        let read_duration = start_read.elapsed();
1030
1031        oidc_mock.assert_async().await;
1032        issuer_mock.assert_async().await;
1033
1034        assert!(
1035            write_duration >= refresh_duration,
1036            "Write operation did not take enough time"
1037        );
1038        assert!(
1039            read_duration >= refresh_duration,
1040            "Read operation was not blocked by the write operation"
1041        );
1042        assert_eq!(
1043            read_result.access_token.unwrap(),
1044            SecretAccessToken::from("new_access")
1045        );
1046        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1047            assert_eq!(
1048                payload.refresh_token,
1049                SecretRefreshToken::from("new_refresh")
1050            );
1051        } else {
1052            panic!(
1053                "Expected RefreshToken payload, got {:?}",
1054                read_result.payload
1055            );
1056        }
1057    }
1058
1059    #[rstest]
1060    fn test_qcs_secrets_readonly(
1061        #[values(
1062            (Some("TRUE"), true),
1063            (Some("tRue"), true),
1064            (Some("true"), true),
1065            (Some("YES"), true),
1066            (Some("yEs"), true),
1067            (Some("yes"), true),
1068            (Some("1"), true),
1069            (Some("2"), false),
1070            (Some("other"), false),
1071            (Some(""), false),
1072            (None, false),
1073        )]
1074        read_only_values: (Option<&str>, bool),
1075        #[values(true, false)] read_only_perm: bool,
1076    ) {
1077        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1078        let expected_update = !env_is_read_only && !read_only_perm;
1079        figment::Jail::expect_with(|jail| {
1080            let profile_name = "test";
1081            let initial_access_token = "initial_access_token";
1082            let initial_refresh_token = "initial_refresh_token";
1083
1084            let initial_secrets_file_contents = format!(
1085                r#"
1086[credentials]
1087[credentials.{profile_name}]
1088[credentials.{profile_name}.token_payload]
1089access_token = "{initial_access_token}"
1090expires_in = 3600
1091id_token = "id_token"
1092refresh_token = "{initial_refresh_token}"
1093scope = "offline_access openid profile email"
1094token_type = "Bearer"
1095updated_at = "2024-01-01T00:00:00Z"
1096"#
1097            );
1098
1099            // Ignore any existing environment variables.
1100            jail.clear_env();
1101
1102            // Create a temporary secrets file
1103            let secrets_path = "secrets.toml";
1104            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1105                .expect("should create test secrets.toml");
1106
1107            if read_only_perm {
1108                let mut permissions = std::fs::metadata(secrets_path)
1109                    .expect("Should be able to get file metadata")
1110                    .permissions();
1111                permissions.set_readonly(true);
1112                std::fs::set_permissions(secrets_path, permissions)
1113                    .expect("Should be able to set file permissions");
1114            }
1115
1116            let rt = tokio::runtime::Runtime::new().unwrap();
1117            rt.block_on(async {
1118                let mock_server = MockServer::start_async().await;
1119
1120                let oidc_mock = mock_server
1121                    .mock_async(|when, then| {
1122                        when.method(GET).path("/.well-known/openid-configuration");
1123                        then.status(200)
1124                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1125                    })
1126                    .await;
1127
1128                // Set up the mock token endpoint
1129                let new_access_token = SecretAccessToken::from("new_access_token");
1130                let issuer_mock = mock_server
1131                    .mock_async(|when, then| {
1132                        when.method(POST).path("/v1/token");
1133                        then.status(200).json_body_obj(&RefreshTokenResponse {
1134                            access_token: new_access_token.clone(),
1135                            refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1136                        });
1137                    })
1138                    .await;
1139
1140                // Create tokens and dispatcher
1141                let original_tokens = OAuthSession::from_refresh_token(
1142                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1143                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1144                    Some(SecretAccessToken::from(initial_refresh_token)),
1145                );
1146                let dispatcher: TokenDispatcher = original_tokens.into();
1147
1148                // Test with QCS_SECRETS_READ_ONLY set first
1149                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1150                jail.set_env("QCS_PROFILE_NAME", "test");
1151                if let Some(read_only_env) = maybe_read_only_env {
1152                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1153                }
1154
1155                let before_refresh = OffsetDateTime::now_utc();
1156
1157                dispatcher
1158                    .refresh(
1159                        &ConfigSource::File {
1160                            settings_path: "".into(),
1161                            secrets_path: "secrets.toml".into(),
1162                        },
1163                        profile_name,
1164                    )
1165                    .await
1166                    .unwrap();
1167
1168                oidc_mock.assert_async().await;
1169                issuer_mock.assert_async().await;
1170
1171                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1172                let content = std::fs::read_to_string("secrets.toml").unwrap();
1173                if !expected_update {
1174                    assert!(
1175                        content.eq(initial_secrets_file_contents.as_str()),
1176                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1177                    );
1178                    return;
1179                }
1180
1181                // Verify the file was updated
1182                let mut toml = std::fs::read_to_string(secrets_path)
1183                    .unwrap()
1184                    .parse::<DocumentMut>()
1185                    .unwrap();
1186
1187                let token_payload = toml
1188                    .get_mut("credentials")
1189                    .and_then(|credentials| {
1190                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1191                    })
1192                    .expect("Should be able to get token_payload table");
1193
1194                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1195
1196                assert_eq!(
1197                    access_token,
1198                    Some(new_access_token)
1199                );
1200
1201                assert!(
1202                    OffsetDateTime::parse(
1203                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1204                        &Rfc3339
1205                    )
1206                    .unwrap()
1207                        > before_refresh
1208                );
1209
1210                let content = std::fs::read_to_string("secrets.toml").unwrap();
1211                assert!(
1212                content.contains("new_access_token"),
1213                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1214                );
1215            });
1216            Ok(())
1217        });
1218    }
1219
1220    #[test]
1221    fn test_auth_session_debug_fmt() {
1222        let session = OAuthSession {
1223            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1224                "hidden_id",
1225                "hidden_secret",
1226            )),
1227            access_token: Some(SecretAccessToken::from("token")),
1228            auth_server: AuthServer {
1229                client_id: "some_id".into(),
1230                issuer: "some_url".into(),
1231                scopes: None,
1232            },
1233        };
1234
1235        assert_eq!(
1236            "OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }",
1237            &format!("{session:?}")
1238        );
1239    }
1240}