qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13    oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
14};
15use crate::configuration::{
16    error::DiscoveryError,
17    pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
18    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
19};
20#[cfg(feature = "tracing-config")]
21use crate::tracing_configuration::TracingConfiguration;
22#[cfg(feature = "tracing")]
23use urlpattern::UrlPatternMatchInput;
24
25pub use super::secret_string::ClientSecret;
26
27/// A single type containing an access token and an associated refresh token.
28#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[cfg_attr(feature = "python", pyo3::pyclass)]
30pub struct RefreshToken {
31    /// The token used to refresh the access token.
32    pub refresh_token: SecretRefreshToken,
33}
34
35impl RefreshToken {
36    /// Create a new [`RefreshToken`] with the given refresh token.
37    #[must_use]
38    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
39        Self { refresh_token }
40    }
41
42    /// Request and return a new access token from the given authorization server using this refresh token.
43    ///
44    /// # Errors
45    ///
46    /// See [`TokenError`]
47    pub async fn request_access_token(
48        &mut self,
49        auth_server: &AuthServer,
50    ) -> Result<SecretAccessToken, TokenError> {
51        if self.refresh_token.is_empty() {
52            return Err(TokenError::NoRefreshToken);
53        }
54
55        let client = default_http_client()?;
56        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
57            .await?
58            .token_endpoint;
59        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
60        let resp = client.post(token_url).form(&data).send().await?;
61
62        let RefreshTokenResponse {
63            access_token,
64            refresh_token,
65        } = resp.error_for_status()?.json().await?;
66
67        if let Some(refresh_token) = refresh_token {
68            self.refresh_token = refresh_token;
69        }
70        Ok(access_token)
71    }
72}
73
74#[derive(Deserialize, Debug, Serialize)]
75pub(super) struct ClientCredentialsResponse {
76    pub(super) access_token: SecretAccessToken,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
80#[cfg_attr(feature = "python", pyo3::pyclass)]
81/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
82pub struct ClientCredentials {
83    /// The client ID
84    pub client_id: String,
85    /// The client secret.
86    pub client_secret: ClientSecret,
87}
88
89impl ClientCredentials {
90    #[must_use]
91    /// Construct a new [`ClientCredentials`]
92    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
93        Self {
94            client_id: client_id.into(),
95            client_secret: client_secret.into(),
96        }
97    }
98
99    /// Get the client ID.
100    #[must_use]
101    pub fn client_id(&self) -> &str {
102        &self.client_id
103    }
104
105    /// Get the client secret.
106    #[must_use]
107    pub const fn client_secret(&self) -> &ClientSecret {
108        &self.client_secret
109    }
110
111    /// Request and return an access token from the given auth server using this set of client credentials.
112    ///
113    /// # Errors
114    ///
115    /// See [`TokenError`]
116    pub async fn request_access_token(
117        &self,
118        auth_server: &AuthServer,
119    ) -> Result<SecretAccessToken, TokenError> {
120        let request = ClientCredentialsRequest::new(None);
121        let client = default_http_client()?;
122
123        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
124            .await?
125            .token_endpoint;
126        let ready_to_send = client
127            .post(url)
128            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
129            .form(&request);
130        let response = ready_to_send.send().await?;
131
132        response.error_for_status_ref()?;
133
134        let ClientCredentialsResponse { access_token } = response.json().await?;
135        Ok(access_token)
136    }
137}
138
139#[derive(Clone, PartialEq, Eq, Deserialize)]
140#[expect(missing_debug_implementations, reason = "contains secret data")]
141#[cfg_attr(feature = "python", pyo3::pyclass)]
142/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
143pub struct PkceFlow {
144    /// The access token.
145    pub access_token: SecretAccessToken,
146    /// The refresh token, if available.
147    pub refresh_token: Option<RefreshToken>,
148}
149
150/// Errors that can occur when attempting to perform a PKCE login flow.
151#[derive(Debug, thiserror::Error)]
152pub enum PkceFlowError {
153    /// Error that occurred while performing the PKCE login flow.
154    #[error(transparent)]
155    PkceLogin(#[from] PkceLoginError),
156    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
157    #[error(transparent)]
158    Discovery(#[from] DiscoveryError),
159    /// Error that occurred while making http requests.
160    #[error(transparent)]
161    Request(#[from] reqwest::Error),
162}
163
164impl PkceFlow {
165    /// Starts a new PKCE login flow to acquire a new set of tokens.
166    ///
167    /// # Errors
168    ///
169    /// See [`PkceFlowError`]
170    pub async fn new_login_flow(
171        cancel_token: CancellationToken,
172        auth_server: &AuthServer,
173    ) -> Result<Self, PkceFlowError> {
174        let issuer = auth_server.issuer.clone();
175
176        let client = default_http_client()?;
177        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
178
179        let response = pkce_login(
180            cancel_token,
181            PkceLoginRequest {
182                client_id: auth_server.client_id.clone(),
183                redirect_port: None,
184                discovery,
185                scopes: auth_server.scopes.clone(),
186            },
187        )
188        .await?;
189
190        Ok(Self {
191            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
192            refresh_token: response
193                .refresh_token()
194                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
195        })
196    }
197
198    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
199    ///
200    /// # Errors
201    ///
202    /// See [`TokenError`]
203    pub async fn request_access_token(
204        &mut self,
205        auth_server: &AuthServer,
206    ) -> Result<SecretAccessToken, TokenError> {
207        if insecure_validate_token_exp(&self.access_token).is_ok() {
208            return Ok(self.access_token.clone());
209        }
210
211        if let Some(refresh_token) = &mut self.refresh_token {
212            let access_token = refresh_token.request_access_token(auth_server).await?;
213            self.access_token.clone_from(&access_token);
214            return Ok(access_token);
215        }
216
217        Err(TokenError::NoRefreshToken)
218    }
219}
220
221impl From<PkceFlow> for Credential {
222    fn from(value: PkceFlow) -> Self {
223        let mut token_payload = TokenPayload::default();
224        token_payload.access_token = Some(value.access_token);
225        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
226
227        Self {
228            token_payload: Some(token_payload),
229        }
230    }
231}
232
233#[derive(Clone)]
234#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
235/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
236/// needed to request said grant type.
237pub enum OAuthGrant {
238    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
239    RefreshToken(RefreshToken),
240    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
241    ClientCredentials(ClientCredentials),
242    /// Defers to a user provided function for access token requests.
243    ExternallyManaged(ExternallyManaged),
244    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
245    PkceFlow(PkceFlow),
246}
247
248impl From<ExternallyManaged> for OAuthGrant {
249    fn from(v: ExternallyManaged) -> Self {
250        Self::ExternallyManaged(v)
251    }
252}
253
254impl From<ClientCredentials> for OAuthGrant {
255    fn from(v: ClientCredentials) -> Self {
256        Self::ClientCredentials(v)
257    }
258}
259
260impl From<RefreshToken> for OAuthGrant {
261    fn from(v: RefreshToken) -> Self {
262        Self::RefreshToken(v)
263    }
264}
265
266impl From<PkceFlow> for OAuthGrant {
267    fn from(v: PkceFlow) -> Self {
268        Self::PkceFlow(v)
269    }
270}
271
272impl OAuthGrant {
273    /// Request a new access token from the given issuer using this grant type and payload.
274    async fn request_access_token(
275        &mut self,
276        auth_server: &AuthServer,
277    ) -> Result<SecretAccessToken, TokenError> {
278        match self {
279            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
280            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
281            Self::ExternallyManaged(tokens) => tokens
282                .request_access_token(auth_server)
283                .await
284                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
285            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
286        }
287    }
288}
289
290impl std::fmt::Debug for OAuthGrant {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        match self {
293            Self::RefreshToken(_) => f.write_str("RefreshToken"),
294            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
295            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
296            Self::PkceFlow(_) => f.write_str("PkceTokens"),
297        }
298    }
299}
300
301/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
302///
303/// This struct encapsulates the necessary information to request an access token
304/// from an authorization server, including the `OAuth2` grant type and any associated
305/// credentials or payload data.
306///
307/// # Fields
308///
309/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
310/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
311/// * `auth_server` - The authorization server responsible for issuing tokens.
312#[derive(Clone)]
313#[cfg_attr(feature = "python", pyo3::pyclass)]
314pub struct OAuthSession {
315    /// The grant type to use to request an access token.
316    payload: OAuthGrant,
317    /// The access token that is currently in use. None if no token has been requested yet.
318    access_token: Option<SecretAccessToken>,
319    /// The [`AuthServer`] that issues the tokens.
320    auth_server: AuthServer,
321}
322
323impl OAuthSession {
324    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
325    ///
326    /// Optionally include an `access_token`, if not included, then one can be requested
327    /// with [`Self::request_access_token`].
328    #[must_use]
329    pub const fn new(
330        payload: OAuthGrant,
331        auth_server: AuthServer,
332        access_token: Option<SecretAccessToken>,
333    ) -> Self {
334        Self {
335            payload,
336            access_token,
337            auth_server,
338        }
339    }
340
341    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
342    ///
343    /// Optionally include an `access_token`, if not included, then one can be requested
344    /// with [`Self::request_access_token`].
345    #[must_use]
346    pub const fn from_externally_managed(
347        tokens: ExternallyManaged,
348        auth_server: AuthServer,
349        access_token: Option<SecretAccessToken>,
350    ) -> Self {
351        Self::new(
352            OAuthGrant::ExternallyManaged(tokens),
353            auth_server,
354            access_token,
355        )
356    }
357
358    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
359    ///
360    /// Optionally include an `access_token`, if not included, then one can be requested
361    /// with [`Self::request_access_token`].
362    #[must_use]
363    pub const fn from_refresh_token(
364        tokens: RefreshToken,
365        auth_server: AuthServer,
366        access_token: Option<SecretAccessToken>,
367    ) -> Self {
368        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
369    }
370
371    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
372    ///
373    /// Optionally include an `access_token`, if not included, then one can be requested
374    /// with [`Self::request_access_token`].
375    #[must_use]
376    pub const fn from_client_credentials(
377        tokens: ClientCredentials,
378        auth_server: AuthServer,
379        access_token: Option<SecretAccessToken>,
380    ) -> Self {
381        Self::new(
382            OAuthGrant::ClientCredentials(tokens),
383            auth_server,
384            access_token,
385        )
386    }
387
388    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
389    ///
390    /// Optionally include an `access_token`, if not included, then one can be requested
391    /// with [`Self::request_access_token`].
392    #[must_use]
393    pub const fn from_pkce_flow(
394        flow: PkceFlow,
395        auth_server: AuthServer,
396        access_token: Option<SecretAccessToken>,
397    ) -> Self {
398        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
399    }
400
401    /// Get the current access token.
402    ///
403    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
404    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
405    ///
406    /// # Errors
407    ///
408    /// - [`TokenError::NoAccessToken`] if there is no access token
409    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
410        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
411    }
412
413    /// Get the payload used to request an access token.
414    #[must_use]
415    pub const fn payload(&self) -> &OAuthGrant {
416        &self.payload
417    }
418
419    /// Request and return an updated access token using these credentials.
420    ///
421    /// # Errors
422    ///
423    /// See [`TokenError`]
424    #[allow(clippy::missing_panics_doc)]
425    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
426        let access_token = self.payload.request_access_token(&self.auth_server).await?;
427        Ok(self.access_token.insert(access_token))
428    }
429
430    /// The [`AuthServer`] that issues the tokens.
431    #[must_use]
432    pub const fn auth_server(&self) -> &AuthServer {
433        &self.auth_server
434    }
435
436    /// Validate the access token, returning it if it is valid, or an error describing why it is
437    /// invalid.
438    ///
439    /// # Errors
440    ///
441    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
442    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
443    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
444        let access_token = self.access_token()?;
445        insecure_validate_token_exp(access_token)?;
446        Ok(access_token.clone())
447    }
448}
449
450/// Validates the access token's format and `exp` claim, but no other claims or
451/// signature. We do this only to determine if the token is expired and needs refreshing,
452/// there is no way to securely validate the token's signature on the client side.
453pub(crate) fn insecure_validate_token_exp(
454    access_token: &SecretAccessToken,
455) -> Result<(), TokenError> {
456    let placeholder_key = DecodingKey::from_secret(&[]);
457    let mut validation = Validation::new(Algorithm::RS256);
458    validation.validate_exp = true;
459    validation.leeway = 60;
460    validation.validate_aud = false;
461    validation.insecure_disable_signature_validation();
462
463    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
464        .map(|_| ())
465        .map_err(TokenError::InvalidAccessToken)
466}
467
468impl std::fmt::Debug for OAuthSession {
469    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470        let token_populated = if self.access_token.is_some() {
471            Some(())
472        } else {
473            None
474        };
475        f.debug_struct("OAuthSession")
476            .field("payload", &self.payload)
477            .field("access_token", &token_populated)
478            .field("auth_server", &self.auth_server)
479            .finish()
480    }
481}
482
483/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
484#[derive(Clone, Debug)]
485#[cfg_attr(feature = "python", pyo3::pyclass)]
486pub struct TokenDispatcher {
487    lock: Arc<RwLock<OAuthSession>>,
488    refreshing: Arc<Mutex<bool>>,
489    notify_refreshed: Arc<Notify>,
490}
491
492impl From<OAuthSession> for TokenDispatcher {
493    fn from(value: OAuthSession) -> Self {
494        Self {
495            lock: Arc::new(RwLock::new(value)),
496            refreshing: Arc::new(Mutex::new(false)),
497            notify_refreshed: Arc::new(Notify::new()),
498        }
499    }
500}
501
502impl TokenDispatcher {
503    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
504    /// dispatcher.
505    ///
506    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
507    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
508    ///
509    /// # Parameters
510    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
511    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
512    pub async fn use_tokens<F, O>(&self, f: F) -> O
513    where
514        F: FnOnce(&OAuthSession) -> O + Send,
515    {
516        let tokens = self.lock.read().await;
517        f(&tokens)
518    }
519
520    /// Get a copy of the current access token.
521    #[must_use]
522    pub async fn tokens(&self) -> OAuthSession {
523        self.use_tokens(Clone::clone).await
524    }
525
526    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
527    ///
528    /// # Errors
529    ///
530    /// See [`TokenError`]
531    pub async fn refresh(
532        &self,
533        source: &ConfigSource,
534        profile: &str,
535    ) -> Result<OAuthSession, TokenError> {
536        self.managed_refresh(Self::perform_refresh, source, profile)
537            .await
538    }
539
540    /// Validate the access token, returning it if it is valid, or an error describing why it is
541    /// invalid.
542    ///
543    /// # Errors
544    ///
545    /// - [`TokenError::NoAccessToken`] if there is no access token
546    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
547    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
548        self.use_tokens(OAuthSession::validate).await
549    }
550
551    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
552    /// ``refresh_fn``.
553    async fn managed_refresh<F, Fut>(
554        &self,
555        refresh_fn: F,
556        source: &ConfigSource,
557        profile: &str,
558    ) -> Result<OAuthSession, TokenError>
559    where
560        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
561        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
562    {
563        let mut is_refreshing = self.refreshing.lock().await;
564
565        if *is_refreshing {
566            drop(is_refreshing);
567            self.notify_refreshed.notified().await;
568            return Ok(self.tokens().await);
569        }
570
571        *is_refreshing = true;
572        drop(is_refreshing);
573
574        let oauth_session = refresh_fn(self.lock.clone()).await?;
575
576        // If the config source is a file, write the new access token to the file
577        if let ConfigSource::File {
578            settings_path: _,
579            secrets_path,
580        } = source
581        {
582            if !Secrets::is_read_only(secrets_path).await? {
583                // If the payload is a PkceFlow, write the fresh refresh token if available.
584                let refresh_token = match &oauth_session.payload {
585                    OAuthGrant::PkceFlow(payload) => {
586                        payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
587                    }
588                    _ => None,
589                };
590
591                let now = OffsetDateTime::now_utc();
592                Secrets::write_tokens(
593                    secrets_path,
594                    profile,
595                    refresh_token,
596                    oauth_session.access_token()?,
597                    now,
598                )
599                .await?;
600            }
601        }
602
603        *self.refreshing.lock().await = false;
604        self.notify_refreshed.notify_waiters();
605        Ok(oauth_session)
606    }
607
608    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
609    /// of the updated [`Credentials`]
610    ///
611    /// # Errors
612    ///
613    /// See [`TokenError`]
614    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
615        let mut credentials = lock.write().await;
616        credentials.request_access_token().await?;
617        Ok(credentials.clone())
618    }
619}
620
621pub(crate) type RefreshResult =
622    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
623
624/// A function that asynchronously refreshes a token.
625pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
626
627/// A struct that manages access tokens by utilizing a user-provided refresh function.
628///
629/// The [`ExternallyManaged`] struct allows users to define custom logic for
630/// fetching or refreshing access tokens.
631#[derive(Clone)]
632#[cfg_attr(feature = "python", pyo3::pyclass)]
633pub struct ExternallyManaged {
634    refresh_function: Arc<RefreshFunction>,
635}
636
637impl ExternallyManaged {
638    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
639    ///
640    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
641    /// they better fit your use case.
642    ///
643    /// # Arguments
644    ///
645    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
646    ///
647    /// # Example
648    ///
649    /// ```
650    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
651    /// use std::future::Future;
652    /// use std::pin::Pin;
653    /// use std::boxed::Box;
654    /// use std::error::Error;
655    ///
656    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
657    /// + Send + Sync>> {
658    ///     Ok("new_token_value".to_string())
659    /// }
660    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
661    /// ```
662    pub fn new(
663        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
664    ) -> Self {
665        Self {
666            refresh_function: Arc::new(Box::new(refresh_function)),
667        }
668    }
669
670    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
671    ///
672    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
673    /// the boxing and pinning of the future internally.
674    ///
675    /// # Arguments
676    ///
677    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
678    ///   produces a [`Result<String, TokenError>`].
679    ///
680    /// # Example
681    ///
682    /// ```
683    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
684    /// use tokio::runtime::Runtime;
685    /// use std::error::Error;
686    ///
687    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
688    /// + Send + Sync>> {
689    ///     Ok("new_token_value".to_string())
690    /// }
691    ///
692    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
693    ///
694    /// let rt = Runtime::new().unwrap();
695    /// rt.block_on(async {
696    ///     match token_manager.request_access_token(&AuthServer::default()).await {
697    ///         Ok(token) => println!("Token: {token:?}"),
698    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
699    ///     }
700    /// });
701    /// ```
702    pub fn from_async<F, Fut>(refresh_function: F) -> Self
703    where
704        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
705        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
706            + Send
707            + 'static,
708    {
709        Self {
710            refresh_function: Arc::new(Box::new(move |auth_server| {
711                Box::pin(refresh_function(auth_server))
712            })),
713        }
714    }
715
716    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
717    ///
718    /// The synchronous function is wrapped in an async block to fit the expected signature.
719    ///
720    /// # Arguments
721    ///
722    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
723    ///
724    /// # Example
725    ///
726    /// ```
727    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
728    /// use tokio::runtime::Runtime;
729    /// use std::error::Error;
730    ///
731    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
732    /// + Send + Sync>> {
733    ///     Ok("sync_token_value".to_string())
734    /// }
735    ///
736    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
737    ///
738    /// let rt = Runtime::new().unwrap();
739    /// rt.block_on(async {
740    ///     match token_manager.request_access_token(&AuthServer::default()).await {
741    ///         Ok(token) => println!("Token: {token:?}"),
742    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
743    ///     }
744    /// });
745    /// ```
746    pub fn from_sync(
747        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
748            + Send
749            + Sync
750            + 'static,
751    ) -> Self {
752        Self {
753            refresh_function: Arc::new(Box::new(move |auth_server| {
754                let result = refresh_function(auth_server);
755                Box::pin(async move { result })
756            })),
757        }
758    }
759
760    /// Request an updated access token using the provided refresh function.
761    ///
762    /// # Errors
763    ///
764    /// Errors are propagated from the refresh function.
765    pub async fn request_access_token(
766        &self,
767        auth_server: &AuthServer,
768    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
769        (self.refresh_function)(auth_server.clone())
770            .await
771            .map(SecretAccessToken::from)
772    }
773}
774
775impl std::fmt::Debug for ExternallyManaged {
776    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
777        f.debug_struct("ExternallyManaged")
778            .field(
779                "refresh_function",
780                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
781            )
782            .finish()
783    }
784}
785
786#[derive(Debug, Serialize, Deserialize)]
787pub(super) struct TokenRefreshRequest<'a> {
788    grant_type: &'static str,
789    client_id: &'a str,
790    refresh_token: &'a str,
791}
792
793impl<'a> TokenRefreshRequest<'a> {
794    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
795        Self {
796            grant_type: "refresh_token",
797            client_id,
798            refresh_token,
799        }
800    }
801}
802
803#[derive(Debug, Serialize, Deserialize)]
804pub(super) struct ClientCredentialsRequest {
805    grant_type: &'static str,
806    scope: Option<&'static str>,
807}
808
809impl ClientCredentialsRequest {
810    pub(super) const fn new(scope: Option<&'static str>) -> Self {
811        Self {
812            grant_type: "client_credentials",
813            scope,
814        }
815    }
816}
817
818#[derive(Deserialize, Debug, Serialize)]
819pub(super) struct RefreshTokenResponse {
820    pub(super) refresh_token: Option<SecretRefreshToken>,
821    pub(super) access_token: SecretAccessToken,
822}
823
824/// Get and refresh access tokens
825#[async_trait::async_trait]
826pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
827    /// The type to be returned in the event of a error during getting or
828    /// refreshing an access token
829    type Error;
830
831    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
832    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
833
834    /// Get the current access token, if any
835    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
836
837    /// Get a fresh access token
838    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
839
840    /// Get the base URL for requests
841    #[cfg(feature = "tracing")]
842    fn base_url(&self) -> &str;
843
844    /// Get the tracing configuration
845    #[cfg(feature = "tracing-config")]
846    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
847
848    /// Returns whether the given URL should be traced. Following
849    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
850    #[cfg(feature = "tracing")]
851    #[allow(clippy::needless_return)]
852    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
853        #[cfg(not(feature = "tracing-config"))]
854        {
855            let _ = url;
856            return true;
857        }
858
859        #[cfg(feature = "tracing-config")]
860        self.tracing_configuration()
861            .is_none_or(|config| config.is_enabled(url))
862    }
863}
864
865#[async_trait::async_trait]
866impl TokenRefresher for ClientConfiguration {
867    type Error = TokenError;
868
869    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
870        self.get_bearer_access_token().await
871    }
872
873    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
874        Ok(self.refresh().await?.access_token()?.clone())
875    }
876
877    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
878        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
879    }
880
881    #[cfg(feature = "tracing")]
882    fn base_url(&self) -> &str {
883        &self.grpc_api_url
884    }
885
886    #[cfg(feature = "tracing-config")]
887    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
888        self.tracing_configuration.as_ref()
889    }
890}
891
892/// Get a default http client.
893pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
894    reqwest::Client::builder()
895        .timeout(std::time::Duration::from_secs(10))
896        .build()
897}
898
899#[cfg(test)]
900mod test {
901    use std::time::Duration;
902
903    use super::*;
904    use httpmock::prelude::*;
905    use rstest::rstest;
906    use time::format_description::well_known::Rfc3339;
907    use tokio::time::Instant;
908    use toml_edit::DocumentMut;
909
910    #[tokio::test]
911    async fn test_tokens_blocked_during_refresh() {
912        let mock_server = MockServer::start_async().await;
913
914        let oidc_mock = mock_server
915            .mock_async(|when, then| {
916                when.method(GET).path("/.well-known/openid-configuration");
917                then.status(200)
918                    .json_body_obj(&oidc::Discovery::new_for_test(
919                        mock_server.base_url().parse().unwrap(),
920                    ));
921            })
922            .await;
923
924        let issuer_mock = mock_server
925            .mock_async(|when, then| {
926                when.method(POST).path("/v1/token");
927
928                then.status(200)
929                    .delay(Duration::from_secs(3))
930                    .json_body_obj(&RefreshTokenResponse {
931                        access_token: SecretAccessToken::from("new_access"),
932                        refresh_token: Some(SecretRefreshToken::from("new_refresh")),
933                    });
934            })
935            .await;
936
937        let original_tokens = OAuthSession::from_refresh_token(
938            RefreshToken::new(SecretRefreshToken::from("refresh")),
939            AuthServer {
940                client_id: "client_id".to_string(),
941                issuer: mock_server.base_url(),
942                scopes: None,
943            },
944            None,
945        );
946        let dispatcher: TokenDispatcher = original_tokens.clone().into();
947        let dispatcher_clone1 = dispatcher.clone();
948        let dispatcher_clone2 = dispatcher.clone();
949
950        let refresh_duration = Duration::from_secs(3);
951
952        let start_write = Instant::now();
953        let write_future = tokio::spawn(async move {
954            dispatcher_clone1
955                .refresh(&ConfigSource::Default, "")
956                .await
957                .unwrap()
958        });
959
960        let start_read = Instant::now();
961        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
962
963        let _ = write_future.await.unwrap();
964        let read_result = read_future.await.unwrap();
965
966        let write_duration = start_write.elapsed();
967        let read_duration = start_read.elapsed();
968
969        oidc_mock.assert_async().await;
970        issuer_mock.assert_async().await;
971
972        assert!(
973            write_duration >= refresh_duration,
974            "Write operation did not take enough time"
975        );
976        assert!(
977            read_duration >= refresh_duration,
978            "Read operation was not blocked by the write operation"
979        );
980        assert_eq!(
981            read_result.access_token.unwrap(),
982            SecretAccessToken::from("new_access")
983        );
984        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
985            assert_eq!(
986                payload.refresh_token,
987                SecretRefreshToken::from("new_refresh")
988            );
989        } else {
990            panic!(
991                "Expected RefreshToken payload, got {:?}",
992                read_result.payload
993            );
994        }
995    }
996
997    #[rstest]
998    fn test_qcs_secrets_readonly(
999        #[values(
1000            (Some("TRUE"), true),
1001            (Some("tRue"), true),
1002            (Some("true"), true),
1003            (Some("YES"), true),
1004            (Some("yEs"), true),
1005            (Some("yes"), true),
1006            (Some("1"), true),
1007            (Some("2"), false),
1008            (Some("other"), false),
1009            (Some(""), false),
1010            (None, false),
1011        )]
1012        read_only_values: (Option<&str>, bool),
1013        #[values(true, false)] read_only_perm: bool,
1014    ) {
1015        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1016        let expected_update = !env_is_read_only && !read_only_perm;
1017        figment::Jail::expect_with(|jail| {
1018            let profile_name = "test";
1019            let initial_access_token = "initial_access_token";
1020            let initial_refresh_token = "initial_refresh_token";
1021
1022            let initial_secrets_file_contents = format!(
1023                r#"
1024[credentials]
1025[credentials.{profile_name}]
1026[credentials.{profile_name}.token_payload]
1027access_token = "{initial_access_token}"
1028expires_in = 3600
1029id_token = "id_token"
1030refresh_token = "{initial_refresh_token}"
1031scope = "offline_access openid profile email"
1032token_type = "Bearer"
1033updated_at = "2024-01-01T00:00:00Z"
1034"#
1035            );
1036
1037            // Ignore any existing environment variables.
1038            jail.clear_env();
1039
1040            // Create a temporary secrets file
1041            let secrets_path = "secrets.toml";
1042            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1043                .expect("should create test secrets.toml");
1044
1045            if read_only_perm {
1046                let mut permissions = std::fs::metadata(secrets_path)
1047                    .expect("Should be able to get file metadata")
1048                    .permissions();
1049                permissions.set_readonly(true);
1050                std::fs::set_permissions(secrets_path, permissions)
1051                    .expect("Should be able to set file permissions");
1052            }
1053
1054            let rt = tokio::runtime::Runtime::new().unwrap();
1055            rt.block_on(async {
1056                let mock_server = MockServer::start_async().await;
1057
1058                let oidc_mock = mock_server
1059                    .mock_async(|when, then| {
1060                        when.method(GET).path("/.well-known/openid-configuration");
1061                        then.status(200)
1062                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1063                    })
1064                    .await;
1065
1066                // Set up the mock token endpoint
1067                let new_access_token = SecretAccessToken::from("new_access_token");
1068                let issuer_mock = mock_server
1069                    .mock_async(|when, then| {
1070                        when.method(POST).path("/v1/token");
1071                        then.status(200).json_body_obj(&RefreshTokenResponse {
1072                            access_token: new_access_token.clone(),
1073                            refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1074                        });
1075                    })
1076                    .await;
1077
1078                // Create tokens and dispatcher
1079                let original_tokens = OAuthSession::from_refresh_token(
1080                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1081                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1082                    Some(SecretAccessToken::from(initial_refresh_token)),
1083                );
1084                let dispatcher: TokenDispatcher = original_tokens.into();
1085
1086                // Test with QCS_SECRETS_READ_ONLY set first
1087                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1088                jail.set_env("QCS_PROFILE_NAME", "test");
1089                if let Some(read_only_env) = maybe_read_only_env {
1090                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1091                }
1092
1093                let before_refresh = OffsetDateTime::now_utc();
1094
1095                dispatcher
1096                    .refresh(
1097                        &ConfigSource::File {
1098                            settings_path: "".into(),
1099                            secrets_path: "secrets.toml".into(),
1100                        },
1101                        profile_name,
1102                    )
1103                    .await
1104                    .unwrap();
1105
1106                oidc_mock.assert_async().await;
1107                issuer_mock.assert_async().await;
1108
1109                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1110                let content = std::fs::read_to_string("secrets.toml").unwrap();
1111                if !expected_update {
1112                    assert!(
1113                        content.eq(initial_secrets_file_contents.as_str()),
1114                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1115                    );
1116                    return;
1117                }
1118
1119                // Verify the file was updated
1120                let mut toml = std::fs::read_to_string(secrets_path)
1121                    .unwrap()
1122                    .parse::<DocumentMut>()
1123                    .unwrap();
1124
1125                let token_payload = toml
1126                    .get_mut("credentials")
1127                    .and_then(|credentials| {
1128                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1129                    })
1130                    .expect("Should be able to get token_payload table");
1131
1132                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1133
1134                assert_eq!(
1135                    access_token,
1136                    Some(new_access_token)
1137                );
1138
1139                assert!(
1140                    OffsetDateTime::parse(
1141                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1142                        &Rfc3339
1143                    )
1144                    .unwrap()
1145                        > before_refresh
1146                );
1147
1148                let content = std::fs::read_to_string("secrets.toml").unwrap();
1149                assert!(
1150                content.contains("new_access_token"),
1151                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1152                );
1153            });
1154            Ok(())
1155        });
1156    }
1157
1158    #[test]
1159    fn test_auth_session_debug_fmt() {
1160        let session = OAuthSession {
1161            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1162                "hidden_id",
1163                "hidden_secret",
1164            )),
1165            access_token: Some(SecretAccessToken::from("token")),
1166            auth_server: AuthServer {
1167                client_id: "some_id".into(),
1168                issuer: "some_url".into(),
1169                scopes: None,
1170            },
1171        };
1172
1173        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1174    }
1175}