Skip to main content

qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16    oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
17};
18use crate::configuration::{
19    error::DiscoveryError,
20    pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
21    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30/// A single type containing an access token and an associated refresh token.
31#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38    /// The token used to refresh the access token.
39    pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43    /// Create a new [`RefreshToken`] with the given refresh token.
44    #[must_use]
45    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46        Self { refresh_token }
47    }
48
49    /// Request and return a new access token from the given authorization server using this refresh token.
50    ///
51    /// # Errors
52    ///
53    /// See [`TokenError`]
54    pub async fn request_access_token(
55        &mut self,
56        auth_server: &AuthServer,
57    ) -> Result<SecretAccessToken, TokenError> {
58        if self.refresh_token.is_empty() {
59            return Err(TokenError::NoRefreshToken);
60        }
61
62        let client = default_http_client()?;
63        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64            .await?
65            .token_endpoint;
66        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67        let resp = client.post(token_url).form(&data).send().await?;
68
69        let RefreshTokenResponse {
70            access_token,
71            refresh_token,
72        } = resp.error_for_status()?.json().await?;
73
74        if let Some(refresh_token) = refresh_token {
75            self.refresh_token = refresh_token;
76        }
77        Ok(access_token)
78    }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83    pub(super) access_token: SecretAccessToken,
84}
85
86/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
87#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90    feature = "python",
91    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94    /// The client ID
95    pub client_id: String,
96    /// The client secret.
97    pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101    #[must_use]
102    /// Construct a new [`ClientCredentials`]
103    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104        Self {
105            client_id: client_id.into(),
106            client_secret: client_secret.into(),
107        }
108    }
109
110    /// Get the client ID.
111    #[must_use]
112    pub fn client_id(&self) -> &str {
113        &self.client_id
114    }
115
116    /// Get the client secret.
117    #[must_use]
118    pub const fn client_secret(&self) -> &ClientSecret {
119        &self.client_secret
120    }
121
122    /// Request and return an access token from the given auth server using this set of client credentials.
123    ///
124    /// # Errors
125    ///
126    /// See [`TokenError`]
127    pub async fn request_access_token(
128        &self,
129        auth_server: &AuthServer,
130    ) -> Result<SecretAccessToken, TokenError> {
131        let request = ClientCredentialsRequest::new(None);
132        let client = default_http_client()?;
133
134        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135            .await?
136            .token_endpoint;
137        let ready_to_send = client
138            .post(url)
139            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140            .form(&request);
141        let response = ready_to_send.send().await?;
142
143        response.error_for_status_ref()?;
144
145        let ClientCredentialsResponse { access_token } = response.json().await?;
146        Ok(access_token)
147    }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153    feature = "python",
154    pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
157pub struct PkceFlow {
158    /// The access token.
159    pub access_token: SecretAccessToken,
160    /// The refresh token, if available.
161    pub refresh_token: Option<RefreshToken>,
162}
163
164/// Errors that can occur when attempting to perform a PKCE login flow.
165#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167    /// Error that occurred while performing the PKCE login flow.
168    #[error(transparent)]
169    PkceLogin(#[from] PkceLoginError),
170    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
171    #[error(transparent)]
172    Discovery(#[from] DiscoveryError),
173    /// Error that occurred while making http requests.
174    #[error(transparent)]
175    Request(#[from] reqwest::Error),
176}
177
178impl PkceFlow {
179    /// Starts a new PKCE login flow to acquire a new set of tokens.
180    ///
181    /// # Errors
182    ///
183    /// See [`PkceFlowError`]
184    pub async fn new_login_flow(
185        cancel_token: CancellationToken,
186        auth_server: &AuthServer,
187    ) -> Result<Self, PkceFlowError> {
188        let issuer = auth_server.issuer.clone();
189
190        let client = default_http_client()?;
191        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193        let response = pkce_login(
194            cancel_token,
195            PkceLoginRequest {
196                client_id: auth_server.client_id.clone(),
197                redirect_port: None,
198                discovery,
199                scopes: auth_server.scopes.clone(),
200            },
201        )
202        .await?;
203
204        Ok(Self {
205            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206            refresh_token: response
207                .refresh_token()
208                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209        })
210    }
211
212    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
213    ///
214    /// # Errors
215    ///
216    /// See [`TokenError`]
217    pub async fn request_access_token(
218        &mut self,
219        auth_server: &AuthServer,
220    ) -> Result<SecretAccessToken, TokenError> {
221        if insecure_validate_token_exp(&self.access_token).is_ok() {
222            return Ok(self.access_token.clone());
223        }
224
225        if let Some(refresh_token) = &mut self.refresh_token {
226            let access_token = refresh_token.request_access_token(auth_server).await?;
227            self.access_token.clone_from(&access_token);
228            return Ok(access_token);
229        }
230
231        Err(TokenError::NoRefreshToken)
232    }
233}
234
235impl From<PkceFlow> for Credential {
236    fn from(value: PkceFlow) -> Self {
237        let mut token_payload = TokenPayload::default();
238        token_payload.access_token = Some(value.access_token);
239        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241        Self {
242            token_payload: Some(token_payload),
243        }
244    }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
250/// needed to request said grant type.
251pub enum OAuthGrant {
252    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
253    RefreshToken(RefreshToken),
254    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
255    ClientCredentials(ClientCredentials),
256    /// Defers to a user provided function for access token requests.
257    ExternallyManaged(ExternallyManaged),
258    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
259    PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263    fn from(v: ExternallyManaged) -> Self {
264        Self::ExternallyManaged(v)
265    }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269    fn from(v: ClientCredentials) -> Self {
270        Self::ClientCredentials(v)
271    }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275    fn from(v: RefreshToken) -> Self {
276        Self::RefreshToken(v)
277    }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281    fn from(v: PkceFlow) -> Self {
282        Self::PkceFlow(v)
283    }
284}
285
286impl OAuthGrant {
287    /// Request a new access token from the given issuer using this grant type and payload.
288    async fn request_access_token(
289        &mut self,
290        auth_server: &AuthServer,
291    ) -> Result<SecretAccessToken, TokenError> {
292        match self {
293            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295            Self::ExternallyManaged(tokens) => tokens
296                .request_access_token(auth_server)
297                .await
298                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300        }
301    }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        match self {
307            Self::RefreshToken(_) => f.write_str("RefreshToken"),
308            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310            Self::PkceFlow(_) => f.write_str("PkceTokens"),
311        }
312    }
313}
314
315/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
316///
317/// This struct encapsulates the necessary information to request an access token
318/// from an authorization server, including the `OAuth2` grant type and any associated
319/// credentials or payload data.
320///
321/// # Fields
322///
323/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
324/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
325/// * `auth_server` - The authorization server responsible for issuing tokens.
326#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329    feature = "python",
330    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333    /// The grant type to use to request an access token.
334    payload: OAuthGrant,
335    /// The access token that is currently in use. None if no token has been requested yet.
336    access_token: Option<SecretAccessToken>,
337    /// The [`AuthServer`] that issues the tokens.
338    auth_server: AuthServer,
339}
340
341impl OAuthSession {
342    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
343    ///
344    /// Optionally include an `access_token`, if not included, then one can be requested
345    /// with [`Self::request_access_token`].
346    #[must_use]
347    pub const fn new(
348        payload: OAuthGrant,
349        auth_server: AuthServer,
350        access_token: Option<SecretAccessToken>,
351    ) -> Self {
352        Self {
353            payload,
354            access_token,
355            auth_server,
356        }
357    }
358
359    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
360    ///
361    /// Optionally include an `access_token`, if not included, then one can be requested
362    /// with [`Self::request_access_token`].
363    #[must_use]
364    pub const fn from_externally_managed(
365        tokens: ExternallyManaged,
366        auth_server: AuthServer,
367        access_token: Option<SecretAccessToken>,
368    ) -> Self {
369        Self::new(
370            OAuthGrant::ExternallyManaged(tokens),
371            auth_server,
372            access_token,
373        )
374    }
375
376    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
377    ///
378    /// Optionally include an `access_token`, if not included, then one can be requested
379    /// with [`Self::request_access_token`].
380    #[must_use]
381    pub const fn from_refresh_token(
382        tokens: RefreshToken,
383        auth_server: AuthServer,
384        access_token: Option<SecretAccessToken>,
385    ) -> Self {
386        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387    }
388
389    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
390    ///
391    /// Optionally include an `access_token`, if not included, then one can be requested
392    /// with [`Self::request_access_token`].
393    #[must_use]
394    pub const fn from_client_credentials(
395        tokens: ClientCredentials,
396        auth_server: AuthServer,
397        access_token: Option<SecretAccessToken>,
398    ) -> Self {
399        Self::new(
400            OAuthGrant::ClientCredentials(tokens),
401            auth_server,
402            access_token,
403        )
404    }
405
406    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
407    ///
408    /// Optionally include an `access_token`, if not included, then one can be requested
409    /// with [`Self::request_access_token`].
410    #[must_use]
411    pub const fn from_pkce_flow(
412        flow: PkceFlow,
413        auth_server: AuthServer,
414        access_token: Option<SecretAccessToken>,
415    ) -> Self {
416        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417    }
418
419    /// Get the current access token.
420    ///
421    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
422    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
423    ///
424    /// # Errors
425    ///
426    /// - [`TokenError::NoAccessToken`] if there is no access token
427    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429    }
430
431    /// Get the payload used to request an access token.
432    #[must_use]
433    pub const fn payload(&self) -> &OAuthGrant {
434        &self.payload
435    }
436
437    /// Request and return an updated access token using these credentials.
438    ///
439    /// # Errors
440    ///
441    /// See [`TokenError`]
442    #[allow(clippy::missing_panics_doc)]
443    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444        let access_token = self.payload.request_access_token(&self.auth_server).await?;
445        Ok(self.access_token.insert(access_token))
446    }
447
448    /// The [`AuthServer`] that issues the tokens.
449    #[must_use]
450    pub const fn auth_server(&self) -> &AuthServer {
451        &self.auth_server
452    }
453
454    /// Validate the access token, returning it if it is valid, or an error describing why it is
455    /// invalid.
456    ///
457    /// # Errors
458    ///
459    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
460    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
461    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462        let access_token = self.access_token()?;
463        insecure_validate_token_exp(access_token)?;
464        Ok(access_token.clone())
465    }
466}
467
468/// Validates the access token's format and `exp` claim, but no other claims or
469/// signature. We do this only to determine if the token is expired and needs refreshing,
470/// there is no way to securely validate the token's signature on the client side.
471pub(crate) fn insecure_validate_token_exp(
472    access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474    let placeholder_key = DecodingKey::from_secret(&[]);
475    let mut validation = Validation::new(Algorithm::RS256);
476    validation.validate_exp = true;
477    validation.leeway = 60;
478    validation.validate_aud = false;
479    validation.insecure_disable_signature_validation();
480
481    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482        .map(|_| ())
483        .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        let token_populated = if self.access_token.is_some() {
489            Some(())
490        } else {
491            None
492        };
493        f.debug_struct("OAuthSession")
494            .field("payload", &self.payload)
495            .field("access_token", &token_populated)
496            .field("auth_server", &self.auth_server)
497            .finish()
498    }
499}
500
501/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
502#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505    feature = "python",
506    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
507)]
508pub struct TokenDispatcher {
509    lock: Arc<RwLock<OAuthSession>>,
510    refreshing: Arc<Mutex<bool>>,
511    notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515    fn from(value: OAuthSession) -> Self {
516        Self {
517            lock: Arc::new(RwLock::new(value)),
518            refreshing: Arc::new(Mutex::new(false)),
519            notify_refreshed: Arc::new(Notify::new()),
520        }
521    }
522}
523
524impl TokenDispatcher {
525    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
526    /// dispatcher.
527    ///
528    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
529    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
530    ///
531    /// # Parameters
532    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
533    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
534    pub async fn use_tokens<F, O>(&self, f: F) -> O
535    where
536        F: FnOnce(&OAuthSession) -> O + Send,
537    {
538        let tokens = self.lock.read().await;
539        f(&tokens)
540    }
541
542    /// Get a copy of the current access token.
543    #[must_use]
544    pub async fn tokens(&self) -> OAuthSession {
545        self.use_tokens(Clone::clone).await
546    }
547
548    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
549    ///
550    /// # Errors
551    ///
552    /// See [`TokenError`]
553    pub async fn refresh(
554        &self,
555        source: &ConfigSource,
556        profile: &str,
557    ) -> Result<OAuthSession, TokenError> {
558        self.managed_refresh(Self::perform_refresh, source, profile)
559            .await
560    }
561
562    /// Validate the access token, returning it if it is valid, or an error describing why it is
563    /// invalid.
564    ///
565    /// # Errors
566    ///
567    /// - [`TokenError::NoAccessToken`] if there is no access token
568    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
569    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570        self.use_tokens(OAuthSession::validate).await
571    }
572
573    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
574    /// ``refresh_fn``.
575    async fn managed_refresh<F, Fut>(
576        &self,
577        refresh_fn: F,
578        source: &ConfigSource,
579        profile: &str,
580    ) -> Result<OAuthSession, TokenError>
581    where
582        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584    {
585        let mut is_refreshing = self.refreshing.lock().await;
586
587        if *is_refreshing {
588            drop(is_refreshing);
589            self.notify_refreshed.notified().await;
590            return Ok(self.tokens().await);
591        }
592
593        *is_refreshing = true;
594        drop(is_refreshing);
595
596        let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598        // If the config source is a file, write the new access token to the file
599        let write_result = if let ConfigSource::File {
600            settings_path: _,
601            secrets_path,
602        } = source
603        {
604            match Secrets::is_read_only(secrets_path).await {
605                Ok(true) => Ok(()),
606                Ok(false) => {
607                    // If the payload is a PkceFlow, write the fresh refresh token if available.
608                    let refresh_token = match &oauth_session.payload {
609                        OAuthGrant::PkceFlow(payload) => {
610                            payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
611                        }
612                        _ => None,
613                    };
614
615                    let now = OffsetDateTime::now_utc();
616                    Secrets::write_tokens(
617                        secrets_path,
618                        profile,
619                        refresh_token,
620                        oauth_session.access_token()?,
621                        now,
622                    )
623                    .await
624                }
625                Err(e) => Err(e),
626            }
627        } else {
628            Ok(())
629        };
630
631        // Always clean up the refreshing lock, even if write failed
632        *self.refreshing.lock().await = false;
633        self.notify_refreshed.notify_waiters();
634
635        // If write failed, return error with the valid oauth_session
636        if let Err(error) = write_result {
637            return Err(TokenError::Write {
638                error,
639                oauth_session: Box::new(oauth_session),
640            });
641        }
642
643        Ok(oauth_session)
644    }
645
646    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
647    /// of the updated [`Credentials`]
648    ///
649    /// # Errors
650    ///
651    /// See [`TokenError`]
652    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
653        let mut credentials = lock.write().await;
654        credentials.request_access_token().await?;
655        Ok(credentials.clone())
656    }
657}
658
659pub(crate) type RefreshResult =
660    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
661
662/// A function that asynchronously refreshes a token.
663pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
664
665/// A struct that manages access tokens by utilizing a user-provided refresh function.
666///
667/// The [`ExternallyManaged`] struct allows users to define custom logic for
668/// fetching or refreshing access tokens.
669#[derive(Clone)]
670#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
671#[cfg_attr(
672    feature = "python",
673    pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
674)]
675pub struct ExternallyManaged {
676    refresh_function: Arc<RefreshFunction>,
677}
678
679impl ExternallyManaged {
680    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
681    ///
682    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
683    /// they better fit your use case.
684    ///
685    /// # Arguments
686    ///
687    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
688    ///
689    /// # Example
690    ///
691    /// ```
692    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
693    /// use std::future::Future;
694    /// use std::pin::Pin;
695    /// use std::boxed::Box;
696    /// use std::error::Error;
697    ///
698    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
699    /// + Send + Sync>> {
700    ///     Ok("new_token_value".to_string())
701    /// }
702    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
703    /// ```
704    pub fn new(
705        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
706    ) -> Self {
707        Self {
708            refresh_function: Arc::new(Box::new(refresh_function)),
709        }
710    }
711
712    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
713    ///
714    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
715    /// the boxing and pinning of the future internally.
716    ///
717    /// # Arguments
718    ///
719    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
720    ///   produces a [`Result<String, TokenError>`].
721    ///
722    /// # Example
723    ///
724    /// ```
725    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
726    /// use tokio::runtime::Runtime;
727    /// use std::error::Error;
728    ///
729    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
730    /// + Send + Sync>> {
731    ///     Ok("new_token_value".to_string())
732    /// }
733    ///
734    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
735    ///
736    /// let rt = Runtime::new().unwrap();
737    /// rt.block_on(async {
738    ///     match token_manager.request_access_token(&AuthServer::default()).await {
739    ///         Ok(token) => println!("Token: {token:?}"),
740    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
741    ///     }
742    /// });
743    /// ```
744    pub fn from_async<F, Fut>(refresh_function: F) -> Self
745    where
746        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
747        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
748            + Send
749            + 'static,
750    {
751        Self {
752            refresh_function: Arc::new(Box::new(move |auth_server| {
753                Box::pin(refresh_function(auth_server))
754            })),
755        }
756    }
757
758    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
759    ///
760    /// The synchronous function is wrapped in an async block to fit the expected signature.
761    ///
762    /// # Arguments
763    ///
764    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
765    ///
766    /// # Example
767    ///
768    /// ```
769    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
770    /// use tokio::runtime::Runtime;
771    /// use std::error::Error;
772    ///
773    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
774    /// + Send + Sync>> {
775    ///     Ok("sync_token_value".to_string())
776    /// }
777    ///
778    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
779    ///
780    /// let rt = Runtime::new().unwrap();
781    /// rt.block_on(async {
782    ///     match token_manager.request_access_token(&AuthServer::default()).await {
783    ///         Ok(token) => println!("Token: {token:?}"),
784    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
785    ///     }
786    /// });
787    /// ```
788    pub fn from_sync(
789        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
790            + Send
791            + Sync
792            + 'static,
793    ) -> Self {
794        Self {
795            refresh_function: Arc::new(Box::new(move |auth_server| {
796                let result = refresh_function(auth_server);
797                Box::pin(async move { result })
798            })),
799        }
800    }
801
802    /// Request an updated access token using the provided refresh function.
803    ///
804    /// # Errors
805    ///
806    /// Errors are propagated from the refresh function.
807    pub async fn request_access_token(
808        &self,
809        auth_server: &AuthServer,
810    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
811        (self.refresh_function)(auth_server.clone())
812            .await
813            .map(SecretAccessToken::from)
814    }
815}
816
817impl std::fmt::Debug for ExternallyManaged {
818    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
819        f.debug_struct("ExternallyManaged")
820            .field(
821                "refresh_function",
822                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
823            )
824            .finish()
825    }
826}
827
828#[derive(Debug, Serialize, Deserialize)]
829pub(super) struct TokenRefreshRequest<'a> {
830    grant_type: &'static str,
831    client_id: &'a str,
832    refresh_token: &'a str,
833}
834
835impl<'a> TokenRefreshRequest<'a> {
836    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
837        Self {
838            grant_type: "refresh_token",
839            client_id,
840            refresh_token,
841        }
842    }
843}
844
845#[derive(Debug, Serialize, Deserialize)]
846pub(super) struct ClientCredentialsRequest {
847    grant_type: &'static str,
848    scope: Option<&'static str>,
849}
850
851impl ClientCredentialsRequest {
852    pub(super) const fn new(scope: Option<&'static str>) -> Self {
853        Self {
854            grant_type: "client_credentials",
855            scope,
856        }
857    }
858}
859
860#[derive(Deserialize, Debug, Serialize)]
861pub(super) struct RefreshTokenResponse {
862    pub(super) refresh_token: Option<SecretRefreshToken>,
863    pub(super) access_token: SecretAccessToken,
864}
865
866/// Get and refresh access tokens
867#[async_trait::async_trait]
868pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
869    /// The type to be returned in the event of a error during getting or
870    /// refreshing an access token
871    type Error;
872
873    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
874    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
875
876    /// Get the current access token, if any
877    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
878
879    /// Get a fresh access token
880    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
881
882    /// Get the base URL for requests
883    #[cfg(feature = "tracing")]
884    fn base_url(&self) -> &str;
885
886    /// Get the tracing configuration
887    #[cfg(feature = "tracing-config")]
888    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
889
890    /// Returns whether the given URL should be traced. Following
891    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
892    #[cfg(feature = "tracing")]
893    #[allow(clippy::needless_return)]
894    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
895        #[cfg(not(feature = "tracing-config"))]
896        {
897            let _ = url;
898            return true;
899        }
900
901        #[cfg(feature = "tracing-config")]
902        self.tracing_configuration()
903            .is_none_or(|config| config.is_enabled(url))
904    }
905}
906
907#[async_trait::async_trait]
908impl TokenRefresher for ClientConfiguration {
909    type Error = TokenError;
910
911    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
912        self.get_bearer_access_token().await
913    }
914
915    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
916        match self.refresh().await {
917            Ok(session) => Ok(session.access_token()?.clone()),
918            Err(TokenError::Write {
919                error,
920                oauth_session,
921            }) => {
922                // Token refresh succeeded but persistence failed. Extract and return the access token from the error.
923                #[cfg(feature = "tracing")]
924                tracing::warn!(
925                    "Token refresh succeeded but failed to persist: {}. Returning access token from error.",
926                    error
927                );
928                Ok(oauth_session.access_token()?.clone())
929            }
930            Err(e) => Err(e),
931        }
932    }
933
934    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
935        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
936    }
937
938    #[cfg(feature = "tracing")]
939    fn base_url(&self) -> &str {
940        &self.grpc_api_url
941    }
942
943    #[cfg(feature = "tracing-config")]
944    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
945        self.tracing_configuration.as_ref()
946    }
947}
948
949/// Get a default http client.
950pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
951    reqwest::Client::builder()
952        .timeout(std::time::Duration::from_secs(10))
953        .build()
954}
955
956#[cfg(test)]
957mod test {
958    #![allow(clippy::result_large_err, reason = "happens in figment tests")]
959
960    use std::time::Duration;
961
962    use super::*;
963    use httpmock::prelude::*;
964    use rstest::rstest;
965    use time::format_description::well_known::Rfc3339;
966    use tokio::time::Instant;
967    use toml_edit::DocumentMut;
968
969    #[tokio::test]
970    async fn test_tokens_blocked_during_refresh() {
971        let mock_server = MockServer::start_async().await;
972
973        let oidc_mock = mock_server
974            .mock_async(|when, then| {
975                when.method(GET).path("/.well-known/openid-configuration");
976                then.status(200)
977                    .json_body_obj(&oidc::Discovery::new_for_test(
978                        mock_server.base_url().parse().unwrap(),
979                    ));
980            })
981            .await;
982
983        let issuer_mock = mock_server
984            .mock_async(|when, then| {
985                when.method(POST).path("/v1/token");
986
987                then.status(200)
988                    .delay(Duration::from_secs(3))
989                    .json_body_obj(&RefreshTokenResponse {
990                        access_token: SecretAccessToken::from("new_access"),
991                        refresh_token: Some(SecretRefreshToken::from("new_refresh")),
992                    });
993            })
994            .await;
995
996        let original_tokens = OAuthSession::from_refresh_token(
997            RefreshToken::new(SecretRefreshToken::from("refresh")),
998            AuthServer {
999                client_id: "client_id".to_string(),
1000                issuer: mock_server.base_url(),
1001                scopes: None,
1002            },
1003            None,
1004        );
1005        let dispatcher: TokenDispatcher = original_tokens.clone().into();
1006        let dispatcher_clone1 = dispatcher.clone();
1007        let dispatcher_clone2 = dispatcher.clone();
1008
1009        let refresh_duration = Duration::from_secs(3);
1010
1011        let start_write = Instant::now();
1012        let write_future = tokio::spawn(async move {
1013            dispatcher_clone1
1014                .refresh(&ConfigSource::Default, "")
1015                .await
1016                .unwrap()
1017        });
1018
1019        let start_read = Instant::now();
1020        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
1021
1022        let _ = write_future.await.unwrap();
1023        let read_result = read_future.await.unwrap();
1024
1025        let write_duration = start_write.elapsed();
1026        let read_duration = start_read.elapsed();
1027
1028        oidc_mock.assert_async().await;
1029        issuer_mock.assert_async().await;
1030
1031        assert!(
1032            write_duration >= refresh_duration,
1033            "Write operation did not take enough time"
1034        );
1035        assert!(
1036            read_duration >= refresh_duration,
1037            "Read operation was not blocked by the write operation"
1038        );
1039        assert_eq!(
1040            read_result.access_token.unwrap(),
1041            SecretAccessToken::from("new_access")
1042        );
1043        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1044            assert_eq!(
1045                payload.refresh_token,
1046                SecretRefreshToken::from("new_refresh")
1047            );
1048        } else {
1049            panic!(
1050                "Expected RefreshToken payload, got {:?}",
1051                read_result.payload
1052            );
1053        }
1054    }
1055
1056    #[rstest]
1057    fn test_qcs_secrets_readonly(
1058        #[values(
1059            (Some("TRUE"), true),
1060            (Some("tRue"), true),
1061            (Some("true"), true),
1062            (Some("YES"), true),
1063            (Some("yEs"), true),
1064            (Some("yes"), true),
1065            (Some("1"), true),
1066            (Some("2"), false),
1067            (Some("other"), false),
1068            (Some(""), false),
1069            (None, false),
1070        )]
1071        read_only_values: (Option<&str>, bool),
1072        #[values(true, false)] read_only_perm: bool,
1073    ) {
1074        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1075        let expected_update = !env_is_read_only && !read_only_perm;
1076        figment::Jail::expect_with(|jail| {
1077            let profile_name = "test";
1078            let initial_access_token = "initial_access_token";
1079            let initial_refresh_token = "initial_refresh_token";
1080
1081            let initial_secrets_file_contents = format!(
1082                r#"
1083[credentials]
1084[credentials.{profile_name}]
1085[credentials.{profile_name}.token_payload]
1086access_token = "{initial_access_token}"
1087expires_in = 3600
1088id_token = "id_token"
1089refresh_token = "{initial_refresh_token}"
1090scope = "offline_access openid profile email"
1091token_type = "Bearer"
1092updated_at = "2024-01-01T00:00:00Z"
1093"#
1094            );
1095
1096            // Ignore any existing environment variables.
1097            jail.clear_env();
1098
1099            // Create a temporary secrets file
1100            let secrets_path = "secrets.toml";
1101            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1102                .expect("should create test secrets.toml");
1103
1104            if read_only_perm {
1105                let mut permissions = std::fs::metadata(secrets_path)
1106                    .expect("Should be able to get file metadata")
1107                    .permissions();
1108                permissions.set_readonly(true);
1109                std::fs::set_permissions(secrets_path, permissions)
1110                    .expect("Should be able to set file permissions");
1111            }
1112
1113            let rt = tokio::runtime::Runtime::new().unwrap();
1114            rt.block_on(async {
1115                let mock_server = MockServer::start_async().await;
1116
1117                let oidc_mock = mock_server
1118                    .mock_async(|when, then| {
1119                        when.method(GET).path("/.well-known/openid-configuration");
1120                        then.status(200)
1121                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1122                    })
1123                    .await;
1124
1125                // Set up the mock token endpoint
1126                let new_access_token = SecretAccessToken::from("new_access_token");
1127                let issuer_mock = mock_server
1128                    .mock_async(|when, then| {
1129                        when.method(POST).path("/v1/token");
1130                        then.status(200).json_body_obj(&RefreshTokenResponse {
1131                            access_token: new_access_token.clone(),
1132                            refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1133                        });
1134                    })
1135                    .await;
1136
1137                // Create tokens and dispatcher
1138                let original_tokens = OAuthSession::from_refresh_token(
1139                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1140                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1141                    Some(SecretAccessToken::from(initial_refresh_token)),
1142                );
1143                let dispatcher: TokenDispatcher = original_tokens.into();
1144
1145                // Test with QCS_SECRETS_READ_ONLY set first
1146                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1147                jail.set_env("QCS_PROFILE_NAME", "test");
1148                if let Some(read_only_env) = maybe_read_only_env {
1149                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1150                }
1151
1152                let before_refresh = OffsetDateTime::now_utc();
1153
1154                dispatcher
1155                    .refresh(
1156                        &ConfigSource::File {
1157                            settings_path: "".into(),
1158                            secrets_path: "secrets.toml".into(),
1159                        },
1160                        profile_name,
1161                    )
1162                    .await
1163                    .unwrap();
1164
1165                oidc_mock.assert_async().await;
1166                issuer_mock.assert_async().await;
1167
1168                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1169                let content = std::fs::read_to_string("secrets.toml").unwrap();
1170                if !expected_update {
1171                    assert!(
1172                        content.eq(initial_secrets_file_contents.as_str()),
1173                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1174                    );
1175                    return;
1176                }
1177
1178                // Verify the file was updated
1179                let mut toml = std::fs::read_to_string(secrets_path)
1180                    .unwrap()
1181                    .parse::<DocumentMut>()
1182                    .unwrap();
1183
1184                let token_payload = toml
1185                    .get_mut("credentials")
1186                    .and_then(|credentials| {
1187                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1188                    })
1189                    .expect("Should be able to get token_payload table");
1190
1191                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1192
1193                assert_eq!(
1194                    access_token,
1195                    Some(new_access_token)
1196                );
1197
1198                assert!(
1199                    OffsetDateTime::parse(
1200                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1201                        &Rfc3339
1202                    )
1203                    .unwrap()
1204                        > before_refresh
1205                );
1206
1207                let content = std::fs::read_to_string("secrets.toml").unwrap();
1208                assert!(
1209                content.contains("new_access_token"),
1210                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1211                );
1212            });
1213            Ok(())
1214        });
1215    }
1216
1217    #[test]
1218    fn test_auth_session_debug_fmt() {
1219        let session = OAuthSession {
1220            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1221                "hidden_id",
1222                "hidden_secret",
1223            )),
1224            access_token: Some(SecretAccessToken::from("token")),
1225            auth_server: AuthServer {
1226                client_id: "some_id".into(),
1227                issuer: "some_url".into(),
1228                scopes: None,
1229            },
1230        };
1231
1232        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1233    }
1234}