qcs_api_client_common/configuration/
tokens.rs

1//! Models and utilities for managing `OAuth2` sessions.
2use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13    oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
14};
15use crate::configuration::{
16    error::DiscoveryError,
17    pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
18    secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
19};
20#[cfg(feature = "tracing-config")]
21use crate::tracing_configuration::TracingConfiguration;
22#[cfg(feature = "tracing")]
23use urlpattern::UrlPatternMatchInput;
24
25pub use super::secret_string::ClientSecret;
26
27/// A single type containing an access token and an associated refresh token.
28#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[cfg_attr(feature = "python", pyo3::pyclass)]
30pub struct RefreshToken {
31    /// The token used to refresh the access token.
32    pub refresh_token: SecretRefreshToken,
33}
34
35impl RefreshToken {
36    /// Create a new [`RefreshToken`] with the given refresh token.
37    #[must_use]
38    pub const fn new(refresh_token: SecretRefreshToken) -> Self {
39        Self { refresh_token }
40    }
41
42    /// Request and return a new access token from the given authorization server using this refresh token.
43    ///
44    /// # Errors
45    ///
46    /// See [`TokenError`]
47    pub async fn request_access_token(
48        &mut self,
49        auth_server: &AuthServer,
50    ) -> Result<SecretAccessToken, TokenError> {
51        if self.refresh_token.is_empty() {
52            return Err(TokenError::NoRefreshToken);
53        }
54
55        let client = default_http_client()?;
56        let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
57            .await?
58            .token_endpoint;
59        let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
60        let resp = client.post(token_url).form(&data).send().await?;
61
62        let RefreshTokenResponse {
63            access_token,
64            refresh_token,
65        } = resp.error_for_status()?.json().await?;
66
67        self.refresh_token = refresh_token;
68        Ok(access_token)
69    }
70}
71
72#[derive(Deserialize, Debug, Serialize)]
73pub(super) struct ClientCredentialsResponse {
74    pub(super) access_token: SecretAccessToken,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
78#[cfg_attr(feature = "python", pyo3::pyclass)]
79/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
80pub struct ClientCredentials {
81    /// The client ID
82    pub client_id: String,
83    /// The client secret.
84    pub client_secret: ClientSecret,
85}
86
87impl ClientCredentials {
88    #[must_use]
89    /// Construct a new [`ClientCredentials`]
90    pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
91        Self {
92            client_id: client_id.into(),
93            client_secret: client_secret.into(),
94        }
95    }
96
97    /// Get the client ID.
98    #[must_use]
99    pub fn client_id(&self) -> &str {
100        &self.client_id
101    }
102
103    /// Get the client secret.
104    #[must_use]
105    pub const fn client_secret(&self) -> &ClientSecret {
106        &self.client_secret
107    }
108
109    /// Request and return an access token from the given auth server using this set of client credentials.
110    ///
111    /// # Errors
112    ///
113    /// See [`TokenError`]
114    pub async fn request_access_token(
115        &self,
116        auth_server: &AuthServer,
117    ) -> Result<SecretAccessToken, TokenError> {
118        let request = ClientCredentialsRequest::new(None);
119        let client = default_http_client()?;
120
121        let url = oidc::fetch_discovery(&client, &auth_server.issuer)
122            .await?
123            .token_endpoint;
124        let ready_to_send = client
125            .post(url)
126            .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
127            .form(&request);
128        let response = ready_to_send.send().await?;
129
130        response.error_for_status_ref()?;
131
132        let ClientCredentialsResponse { access_token } = response.json().await?;
133        Ok(access_token)
134    }
135}
136
137#[derive(Clone, PartialEq, Eq, Deserialize)]
138#[expect(missing_debug_implementations, reason = "contains secret data")]
139#[cfg_attr(feature = "python", pyo3::pyclass)]
140/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
141pub struct PkceFlow {
142    /// The access token.
143    pub access_token: SecretAccessToken,
144    /// The refresh token, if available.
145    pub refresh_token: Option<RefreshToken>,
146}
147
148/// Errors that can occur when attempting to perform a PKCE login flow.
149#[derive(Debug, thiserror::Error)]
150pub enum PkceFlowError {
151    /// Error that occurred while performing the PKCE login flow.
152    #[error(transparent)]
153    PkceLogin(#[from] PkceLoginError),
154    /// Error that occurred while fetching the discovery document from the `OAuth2` issuer.
155    #[error(transparent)]
156    Discovery(#[from] DiscoveryError),
157    /// Error that occurred while making http requests.
158    #[error(transparent)]
159    Request(#[from] reqwest::Error),
160}
161
162impl PkceFlow {
163    /// Starts a new PKCE login flow to acquire a new set of tokens.
164    ///
165    /// # Errors
166    ///
167    /// See [`PkceFlowError`]
168    pub async fn new_login_flow(
169        cancel_token: CancellationToken,
170        auth_server: &AuthServer,
171    ) -> Result<Self, PkceFlowError> {
172        let issuer = auth_server.issuer.clone();
173
174        let client = default_http_client()?;
175        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
176
177        let response = pkce_login(
178            cancel_token,
179            PkceLoginRequest {
180                client_id: auth_server.client_id.clone(),
181                redirect_port: None,
182                discovery,
183            },
184        )
185        .await?;
186
187        Ok(Self {
188            access_token: SecretAccessToken::from(response.access_token().secret().clone()),
189            refresh_token: response
190                .refresh_token()
191                .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
192        })
193    }
194
195    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
196    ///
197    /// # Errors
198    ///
199    /// See [`TokenError`]
200    pub async fn request_access_token(
201        &mut self,
202        auth_server: &AuthServer,
203    ) -> Result<SecretAccessToken, TokenError> {
204        if insecure_validate_token_exp(&self.access_token).is_ok() {
205            return Ok(self.access_token.clone());
206        }
207
208        if let Some(refresh_token) = &mut self.refresh_token {
209            let access_token = refresh_token.request_access_token(auth_server).await?;
210            self.access_token.clone_from(&access_token);
211            return Ok(access_token);
212        }
213
214        Err(TokenError::NoRefreshToken)
215    }
216}
217
218impl From<PkceFlow> for Credential {
219    fn from(value: PkceFlow) -> Self {
220        let mut token_payload = TokenPayload::default();
221        token_payload.access_token = Some(value.access_token);
222        token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
223
224        Self {
225            token_payload: Some(token_payload),
226        }
227    }
228}
229
230#[derive(Clone)]
231#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
232/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
233/// needed to request said grant type.
234pub enum OAuthGrant {
235    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
236    RefreshToken(RefreshToken),
237    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
238    ClientCredentials(ClientCredentials),
239    /// Defers to a user provided function for access token requests.
240    ExternallyManaged(ExternallyManaged),
241    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
242    PkceFlow(PkceFlow),
243}
244
245impl From<ExternallyManaged> for OAuthGrant {
246    fn from(v: ExternallyManaged) -> Self {
247        Self::ExternallyManaged(v)
248    }
249}
250
251impl From<ClientCredentials> for OAuthGrant {
252    fn from(v: ClientCredentials) -> Self {
253        Self::ClientCredentials(v)
254    }
255}
256
257impl From<RefreshToken> for OAuthGrant {
258    fn from(v: RefreshToken) -> Self {
259        Self::RefreshToken(v)
260    }
261}
262
263impl From<PkceFlow> for OAuthGrant {
264    fn from(v: PkceFlow) -> Self {
265        Self::PkceFlow(v)
266    }
267}
268
269impl OAuthGrant {
270    /// Request a new access token from the given issuer using this grant type and payload.
271    async fn request_access_token(
272        &mut self,
273        auth_server: &AuthServer,
274    ) -> Result<SecretAccessToken, TokenError> {
275        match self {
276            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
277            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
278            Self::ExternallyManaged(tokens) => tokens
279                .request_access_token(auth_server)
280                .await
281                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
282            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
283        }
284    }
285}
286
287impl std::fmt::Debug for OAuthGrant {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        match self {
290            Self::RefreshToken(_) => f.write_str("RefreshToken"),
291            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
292            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
293            Self::PkceFlow(_) => f.write_str("PkceTokens"),
294        }
295    }
296}
297
298/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
299///
300/// This struct encapsulates the necessary information to request an access token
301/// from an authorization server, including the `OAuth2` grant type and any associated
302/// credentials or payload data.
303///
304/// # Fields
305///
306/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
307/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
308/// * `auth_server` - The authorization server responsible for issuing tokens.
309#[derive(Clone)]
310#[cfg_attr(feature = "python", pyo3::pyclass)]
311pub struct OAuthSession {
312    /// The grant type to use to request an access token.
313    payload: OAuthGrant,
314    /// The access token that is currently in use. None if no token has been requested yet.
315    access_token: Option<SecretAccessToken>,
316    /// The [`AuthServer`] that issues the tokens.
317    auth_server: AuthServer,
318}
319
320impl OAuthSession {
321    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
322    ///
323    /// Optionally include an `access_token`, if not included, then one can be requested
324    /// with [`Self::request_access_token`].
325    #[must_use]
326    pub const fn new(
327        payload: OAuthGrant,
328        auth_server: AuthServer,
329        access_token: Option<SecretAccessToken>,
330    ) -> Self {
331        Self {
332            payload,
333            access_token,
334            auth_server,
335        }
336    }
337
338    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
339    ///
340    /// Optionally include an `access_token`, if not included, then one can be requested
341    /// with [`Self::request_access_token`].
342    #[must_use]
343    pub const fn from_externally_managed(
344        tokens: ExternallyManaged,
345        auth_server: AuthServer,
346        access_token: Option<SecretAccessToken>,
347    ) -> Self {
348        Self::new(
349            OAuthGrant::ExternallyManaged(tokens),
350            auth_server,
351            access_token,
352        )
353    }
354
355    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
356    ///
357    /// Optionally include an `access_token`, if not included, then one can be requested
358    /// with [`Self::request_access_token`].
359    #[must_use]
360    pub const fn from_refresh_token(
361        tokens: RefreshToken,
362        auth_server: AuthServer,
363        access_token: Option<SecretAccessToken>,
364    ) -> Self {
365        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
366    }
367
368    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
369    ///
370    /// Optionally include an `access_token`, if not included, then one can be requested
371    /// with [`Self::request_access_token`].
372    #[must_use]
373    pub const fn from_client_credentials(
374        tokens: ClientCredentials,
375        auth_server: AuthServer,
376        access_token: Option<SecretAccessToken>,
377    ) -> Self {
378        Self::new(
379            OAuthGrant::ClientCredentials(tokens),
380            auth_server,
381            access_token,
382        )
383    }
384
385    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
386    ///
387    /// Optionally include an `access_token`, if not included, then one can be requested
388    /// with [`Self::request_access_token`].
389    #[must_use]
390    pub const fn from_pkce_flow(
391        flow: PkceFlow,
392        auth_server: AuthServer,
393        access_token: Option<SecretAccessToken>,
394    ) -> Self {
395        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
396    }
397
398    /// Get the current access token.
399    ///
400    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
401    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
402    ///
403    /// # Errors
404    ///
405    /// - [`TokenError::NoAccessToken`] if there is no access token
406    pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
407        self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
408    }
409
410    /// Get the payload used to request an access token.
411    #[must_use]
412    pub const fn payload(&self) -> &OAuthGrant {
413        &self.payload
414    }
415
416    /// Request and return an updated access token using these credentials.
417    ///
418    /// # Errors
419    ///
420    /// See [`TokenError`]
421    #[allow(clippy::missing_panics_doc)]
422    pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
423        let access_token = self.payload.request_access_token(&self.auth_server).await?;
424        Ok(self.access_token.insert(access_token))
425    }
426
427    /// The [`AuthServer`] that issues the tokens.
428    #[must_use]
429    pub const fn auth_server(&self) -> &AuthServer {
430        &self.auth_server
431    }
432
433    /// Validate the access token, returning it if it is valid, or an error describing why it is
434    /// invalid.
435    ///
436    /// # Errors
437    ///
438    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
439    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
440    pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
441        let access_token = self.access_token()?;
442        insecure_validate_token_exp(access_token)?;
443        Ok(access_token.clone())
444    }
445}
446
447/// Validates the access token's format and `exp` claim, but no other claims or
448/// signature. We do this only to determine if the token is expired and needs refreshing,
449/// there is no way to securely validate the token's signature on the client side.
450pub(crate) fn insecure_validate_token_exp(
451    access_token: &SecretAccessToken,
452) -> Result<(), TokenError> {
453    let placeholder_key = DecodingKey::from_secret(&[]);
454    let mut validation = Validation::new(Algorithm::RS256);
455    validation.validate_exp = true;
456    validation.leeway = 60;
457    validation.validate_aud = false;
458    validation.insecure_disable_signature_validation();
459
460    jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
461        .map(|_| ())
462        .map_err(TokenError::InvalidAccessToken)
463}
464
465impl std::fmt::Debug for OAuthSession {
466    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467        let token_populated = if self.access_token.is_some() {
468            Some(())
469        } else {
470            None
471        };
472        f.debug_struct("OAuthSession")
473            .field("payload", &self.payload)
474            .field("access_token", &token_populated)
475            .field("auth_server", &self.auth_server)
476            .finish()
477    }
478}
479
480/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
481#[derive(Clone, Debug)]
482#[cfg_attr(feature = "python", pyo3::pyclass)]
483pub struct TokenDispatcher {
484    lock: Arc<RwLock<OAuthSession>>,
485    refreshing: Arc<Mutex<bool>>,
486    notify_refreshed: Arc<Notify>,
487}
488
489impl From<OAuthSession> for TokenDispatcher {
490    fn from(value: OAuthSession) -> Self {
491        Self {
492            lock: Arc::new(RwLock::new(value)),
493            refreshing: Arc::new(Mutex::new(false)),
494            notify_refreshed: Arc::new(Notify::new()),
495        }
496    }
497}
498
499impl TokenDispatcher {
500    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
501    /// dispatcher.
502    ///
503    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
504    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
505    ///
506    /// # Parameters
507    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
508    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
509    pub async fn use_tokens<F, O>(&self, f: F) -> O
510    where
511        F: FnOnce(&OAuthSession) -> O + Send,
512    {
513        let tokens = self.lock.read().await;
514        f(&tokens)
515    }
516
517    /// Get a copy of the current access token.
518    #[must_use]
519    pub async fn tokens(&self) -> OAuthSession {
520        self.use_tokens(Clone::clone).await
521    }
522
523    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
524    ///
525    /// # Errors
526    ///
527    /// See [`TokenError`]
528    pub async fn refresh(
529        &self,
530        source: &ConfigSource,
531        profile: &str,
532    ) -> Result<OAuthSession, TokenError> {
533        self.managed_refresh(Self::perform_refresh, source, profile)
534            .await
535    }
536
537    /// Validate the access token, returning it if it is valid, or an error describing why it is
538    /// invalid.
539    ///
540    /// # Errors
541    ///
542    /// - [`TokenError::NoAccessToken`] if there is no access token
543    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
544    pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
545        self.use_tokens(OAuthSession::validate).await
546    }
547
548    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
549    /// ``refresh_fn``.
550    async fn managed_refresh<F, Fut>(
551        &self,
552        refresh_fn: F,
553        source: &ConfigSource,
554        profile: &str,
555    ) -> Result<OAuthSession, TokenError>
556    where
557        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
558        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
559    {
560        let mut is_refreshing = self.refreshing.lock().await;
561
562        if *is_refreshing {
563            drop(is_refreshing);
564            self.notify_refreshed.notified().await;
565            return Ok(self.tokens().await);
566        }
567
568        *is_refreshing = true;
569        drop(is_refreshing);
570
571        let oauth_session = refresh_fn(self.lock.clone()).await?;
572
573        // If the config source is a file, write the new access token to the file
574        if let ConfigSource::File {
575            settings_path: _,
576            secrets_path,
577        } = source
578        {
579            if !Secrets::is_read_only(secrets_path).await? {
580                // If the payload is a PkceFlow, write the fresh refresh token if available.
581                let refresh_token = match &oauth_session.payload {
582                    OAuthGrant::PkceFlow(payload) => {
583                        payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
584                    }
585                    _ => None,
586                };
587
588                let now = OffsetDateTime::now_utc();
589                Secrets::write_tokens(
590                    secrets_path,
591                    profile,
592                    refresh_token,
593                    oauth_session.access_token()?,
594                    now,
595                )
596                .await?;
597            }
598        }
599
600        *self.refreshing.lock().await = false;
601        self.notify_refreshed.notify_waiters();
602        Ok(oauth_session)
603    }
604
605    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
606    /// of the updated [`Credentials`]
607    ///
608    /// # Errors
609    ///
610    /// See [`TokenError`]
611    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
612        let mut credentials = lock.write().await;
613        credentials.request_access_token().await?;
614        Ok(credentials.clone())
615    }
616}
617
618pub(crate) type RefreshResult =
619    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
620
621/// A function that asynchronously refreshes a token.
622pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
623
624/// A struct that manages access tokens by utilizing a user-provided refresh function.
625///
626/// The [`ExternallyManaged`] struct allows users to define custom logic for
627/// fetching or refreshing access tokens.
628#[derive(Clone)]
629#[cfg_attr(feature = "python", pyo3::pyclass)]
630pub struct ExternallyManaged {
631    refresh_function: Arc<RefreshFunction>,
632}
633
634impl ExternallyManaged {
635    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
636    ///
637    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
638    /// they better fit your use case.
639    ///
640    /// # Arguments
641    ///
642    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
643    ///
644    /// # Example
645    ///
646    /// ```
647    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
648    /// use std::future::Future;
649    /// use std::pin::Pin;
650    /// use std::boxed::Box;
651    /// use std::error::Error;
652    ///
653    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
654    /// + Send + Sync>> {
655    ///     Ok("new_token_value".to_string())
656    /// }
657    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
658    /// ```
659    pub fn new(
660        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
661    ) -> Self {
662        Self {
663            refresh_function: Arc::new(Box::new(refresh_function)),
664        }
665    }
666
667    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
668    ///
669    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
670    /// the boxing and pinning of the future internally.
671    ///
672    /// # Arguments
673    ///
674    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
675    ///   produces a [`Result<String, TokenError>`].
676    ///
677    /// # Example
678    ///
679    /// ```
680    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
681    /// use tokio::runtime::Runtime;
682    /// use std::error::Error;
683    ///
684    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
685    /// + Send + Sync>> {
686    ///     Ok("new_token_value".to_string())
687    /// }
688    ///
689    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
690    ///
691    /// let rt = Runtime::new().unwrap();
692    /// rt.block_on(async {
693    ///     match token_manager.request_access_token(&AuthServer::default()).await {
694    ///         Ok(token) => println!("Token: {token:?}"),
695    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
696    ///     }
697    /// });
698    /// ```
699    pub fn from_async<F, Fut>(refresh_function: F) -> Self
700    where
701        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
702        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
703            + Send
704            + 'static,
705    {
706        Self {
707            refresh_function: Arc::new(Box::new(move |auth_server| {
708                Box::pin(refresh_function(auth_server))
709            })),
710        }
711    }
712
713    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
714    ///
715    /// The synchronous function is wrapped in an async block to fit the expected signature.
716    ///
717    /// # Arguments
718    ///
719    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
720    ///
721    /// # Example
722    ///
723    /// ```
724    /// use qcs_api_client_common::configuration::{settings::AuthServer, tokens::ExternallyManaged, TokenError};
725    /// use tokio::runtime::Runtime;
726    /// use std::error::Error;
727    ///
728    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
729    /// + Send + Sync>> {
730    ///     Ok("sync_token_value".to_string())
731    /// }
732    ///
733    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
734    ///
735    /// let rt = Runtime::new().unwrap();
736    /// rt.block_on(async {
737    ///     match token_manager.request_access_token(&AuthServer::default()).await {
738    ///         Ok(token) => println!("Token: {token:?}"),
739    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
740    ///     }
741    /// });
742    /// ```
743    pub fn from_sync(
744        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
745            + Send
746            + Sync
747            + 'static,
748    ) -> Self {
749        Self {
750            refresh_function: Arc::new(Box::new(move |auth_server| {
751                let result = refresh_function(auth_server);
752                Box::pin(async move { result })
753            })),
754        }
755    }
756
757    /// Request an updated access token using the provided refresh function.
758    ///
759    /// # Errors
760    ///
761    /// Errors are propagated from the refresh function.
762    pub async fn request_access_token(
763        &self,
764        auth_server: &AuthServer,
765    ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
766        (self.refresh_function)(auth_server.clone())
767            .await
768            .map(SecretAccessToken::from)
769    }
770}
771
772impl std::fmt::Debug for ExternallyManaged {
773    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
774        f.debug_struct("ExternallyManaged")
775            .field(
776                "refresh_function",
777                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
778            )
779            .finish()
780    }
781}
782
783#[derive(Debug, Serialize, Deserialize)]
784pub(super) struct TokenRefreshRequest<'a> {
785    grant_type: &'static str,
786    client_id: &'a str,
787    refresh_token: &'a str,
788}
789
790impl<'a> TokenRefreshRequest<'a> {
791    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
792        Self {
793            grant_type: "refresh_token",
794            client_id,
795            refresh_token,
796        }
797    }
798}
799
800#[derive(Debug, Serialize, Deserialize)]
801pub(super) struct ClientCredentialsRequest {
802    grant_type: &'static str,
803    scope: Option<&'static str>,
804}
805
806impl ClientCredentialsRequest {
807    pub(super) const fn new(scope: Option<&'static str>) -> Self {
808        Self {
809            grant_type: "client_credentials",
810            scope,
811        }
812    }
813}
814
815#[derive(Deserialize, Debug, Serialize)]
816pub(super) struct RefreshTokenResponse {
817    pub(super) refresh_token: SecretRefreshToken,
818    pub(super) access_token: SecretAccessToken,
819}
820
821/// Get and refresh access tokens
822#[async_trait::async_trait]
823pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
824    /// The type to be returned in the event of a error during getting or
825    /// refreshing an access token
826    type Error;
827
828    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
829    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
830
831    /// Get the current access token, if any
832    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
833
834    /// Get a fresh access token
835    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
836
837    /// Get the base URL for requests
838    #[cfg(feature = "tracing")]
839    fn base_url(&self) -> &str;
840
841    /// Get the tracing configuration
842    #[cfg(feature = "tracing-config")]
843    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
844
845    /// Returns whether the given URL should be traced. Following
846    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
847    #[cfg(feature = "tracing")]
848    #[allow(clippy::needless_return)]
849    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
850        #[cfg(not(feature = "tracing-config"))]
851        {
852            let _ = url;
853            return true;
854        }
855
856        #[cfg(feature = "tracing-config")]
857        self.tracing_configuration()
858            .is_none_or(|config| config.is_enabled(url))
859    }
860}
861
862#[async_trait::async_trait]
863impl TokenRefresher for ClientConfiguration {
864    type Error = TokenError;
865
866    async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
867        self.get_bearer_access_token().await
868    }
869
870    async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
871        Ok(self.refresh().await?.access_token()?.clone())
872    }
873
874    async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
875        Ok(Some(self.oauth_session().await?.access_token()?.clone()))
876    }
877
878    #[cfg(feature = "tracing")]
879    fn base_url(&self) -> &str {
880        &self.grpc_api_url
881    }
882
883    #[cfg(feature = "tracing-config")]
884    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
885        self.tracing_configuration.as_ref()
886    }
887}
888
889/// Get a default http client.
890fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
891    reqwest::Client::builder()
892        .timeout(std::time::Duration::from_secs(10))
893        .build()
894}
895
896#[cfg(test)]
897mod test {
898    use std::time::Duration;
899
900    use super::*;
901    use httpmock::prelude::*;
902    use rstest::rstest;
903    use time::format_description::well_known::Rfc3339;
904    use tokio::time::Instant;
905    use toml_edit::DocumentMut;
906
907    #[tokio::test]
908    async fn test_tokens_blocked_during_refresh() {
909        let mock_server = MockServer::start_async().await;
910
911        let oidc_mock = mock_server
912            .mock_async(|when, then| {
913                when.method(GET).path("/.well-known/openid-configuration");
914                then.status(200)
915                    .json_body_obj(&oidc::Discovery::new_for_test(
916                        mock_server.base_url().parse().unwrap(),
917                    ));
918            })
919            .await;
920
921        let issuer_mock = mock_server
922            .mock_async(|when, then| {
923                when.method(POST).path("/v1/token");
924
925                then.status(200)
926                    .delay(Duration::from_secs(3))
927                    .json_body_obj(&RefreshTokenResponse {
928                        access_token: SecretAccessToken::from("new_access"),
929                        refresh_token: SecretRefreshToken::from("new_refresh"),
930                    });
931            })
932            .await;
933
934        let original_tokens = OAuthSession::from_refresh_token(
935            RefreshToken::new(SecretRefreshToken::from("refresh")),
936            AuthServer {
937                client_id: "client_id".to_string(),
938                issuer: mock_server.base_url(),
939            },
940            None,
941        );
942        let dispatcher: TokenDispatcher = original_tokens.clone().into();
943        let dispatcher_clone1 = dispatcher.clone();
944        let dispatcher_clone2 = dispatcher.clone();
945
946        let refresh_duration = Duration::from_secs(3);
947
948        let start_write = Instant::now();
949        let write_future = tokio::spawn(async move {
950            dispatcher_clone1
951                .refresh(&ConfigSource::Default, "")
952                .await
953                .unwrap()
954        });
955
956        let start_read = Instant::now();
957        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
958
959        let _ = write_future.await.unwrap();
960        let read_result = read_future.await.unwrap();
961
962        let write_duration = start_write.elapsed();
963        let read_duration = start_read.elapsed();
964
965        oidc_mock.assert_async().await;
966        issuer_mock.assert_async().await;
967
968        assert!(
969            write_duration >= refresh_duration,
970            "Write operation did not take enough time"
971        );
972        assert!(
973            read_duration >= refresh_duration,
974            "Read operation was not blocked by the write operation"
975        );
976        assert_eq!(
977            read_result.access_token.unwrap(),
978            SecretAccessToken::from("new_access")
979        );
980        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
981            assert_eq!(
982                payload.refresh_token,
983                SecretRefreshToken::from("new_refresh")
984            );
985        } else {
986            panic!(
987                "Expected RefreshToken payload, got {:?}",
988                read_result.payload
989            );
990        }
991    }
992
993    #[rstest]
994    fn test_qcs_secrets_readonly(
995        #[values(
996            (Some("TRUE"), true),
997            (Some("tRue"), true),
998            (Some("true"), true),
999            (Some("YES"), true),
1000            (Some("yEs"), true),
1001            (Some("yes"), true),
1002            (Some("1"), true),
1003            (Some("2"), false),
1004            (Some("other"), false),
1005            (Some(""), false),
1006            (None, false),
1007        )]
1008        read_only_values: (Option<&str>, bool),
1009        #[values(true, false)] read_only_perm: bool,
1010    ) {
1011        let (maybe_read_only_env, env_is_read_only) = read_only_values;
1012        let expected_update = !env_is_read_only && !read_only_perm;
1013        figment::Jail::expect_with(|jail| {
1014            let profile_name = "test";
1015            let initial_access_token = "initial_access_token";
1016            let initial_refresh_token = "initial_refresh_token";
1017
1018            let initial_secrets_file_contents = format!(
1019                r#"
1020[credentials]
1021[credentials.{profile_name}]
1022[credentials.{profile_name}.token_payload]
1023access_token = "{initial_access_token}"
1024expires_in = 3600
1025id_token = "id_token"
1026refresh_token = "{initial_refresh_token}"
1027scope = "offline_access openid profile email"
1028token_type = "Bearer"
1029updated_at = "2024-01-01T00:00:00Z"
1030"#
1031            );
1032
1033            // Ignore any existing environment variables.
1034            jail.clear_env();
1035
1036            // Create a temporary secrets file
1037            let secrets_path = "secrets.toml";
1038            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1039                .expect("should create test secrets.toml");
1040
1041            if read_only_perm {
1042                let mut permissions = std::fs::metadata(secrets_path)
1043                    .expect("Should be able to get file metadata")
1044                    .permissions();
1045                permissions.set_readonly(true);
1046                std::fs::set_permissions(secrets_path, permissions)
1047                    .expect("Should be able to set file permissions");
1048            }
1049
1050            let rt = tokio::runtime::Runtime::new().unwrap();
1051            rt.block_on(async {
1052                let mock_server = MockServer::start_async().await;
1053
1054                let oidc_mock = mock_server
1055                    .mock_async(|when, then| {
1056                        when.method(GET).path("/.well-known/openid-configuration");
1057                        then.status(200)
1058                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1059                    })
1060                    .await;
1061
1062                // Set up the mock token endpoint
1063                let new_access_token = SecretAccessToken::from("new_access_token");
1064                let issuer_mock = mock_server
1065                    .mock_async(|when, then| {
1066                        when.method(POST).path("/v1/token");
1067                        then.status(200).json_body_obj(&RefreshTokenResponse {
1068                            access_token: new_access_token.clone(),
1069                            refresh_token: SecretRefreshToken::from(initial_refresh_token),
1070                        });
1071                    })
1072                    .await;
1073
1074                // Create tokens and dispatcher
1075                let original_tokens = OAuthSession::from_refresh_token(
1076                    RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1077                    AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url() },
1078                    Some(SecretAccessToken::from(initial_refresh_token)),
1079                );
1080                let dispatcher: TokenDispatcher = original_tokens.into();
1081
1082                // Test with QCS_SECRETS_READ_ONLY set first
1083                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1084                jail.set_env("QCS_PROFILE_NAME", "test");
1085                if let Some(read_only_env) = maybe_read_only_env {
1086                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1087                }
1088
1089                let before_refresh = OffsetDateTime::now_utc();
1090
1091                dispatcher
1092                    .refresh(
1093                        &ConfigSource::File {
1094                            settings_path: "".into(),
1095                            secrets_path: "secrets.toml".into(),
1096                        },
1097                        profile_name,
1098                    )
1099                    .await
1100                    .unwrap();
1101
1102                oidc_mock.assert_async().await;
1103                issuer_mock.assert_async().await;
1104
1105                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1106                let content = std::fs::read_to_string("secrets.toml").unwrap();
1107                if !expected_update {
1108                    assert!(
1109                        content.eq(initial_secrets_file_contents.as_str()),
1110                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1111                    );
1112                    return;
1113                }
1114
1115                // Verify the file was updated
1116                let mut toml = std::fs::read_to_string(secrets_path)
1117                    .unwrap()
1118                    .parse::<DocumentMut>()
1119                    .unwrap();
1120
1121                let token_payload = toml
1122                    .get_mut("credentials")
1123                    .and_then(|credentials| {
1124                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1125                    })
1126                    .expect("Should be able to get token_payload table");
1127
1128                let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1129
1130                assert_eq!(
1131                    access_token,
1132                    Some(new_access_token)
1133                );
1134
1135                assert!(
1136                    OffsetDateTime::parse(
1137                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1138                        &Rfc3339
1139                    )
1140                    .unwrap()
1141                        > before_refresh
1142                );
1143
1144                let content = std::fs::read_to_string("secrets.toml").unwrap();
1145                assert!(
1146                content.contains("new_access_token"),
1147                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1148                );
1149            });
1150            Ok(())
1151        });
1152    }
1153
1154    #[test]
1155    fn test_auth_session_debug_fmt() {
1156        let session = OAuthSession {
1157            payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1158                "hidden_id",
1159                "hidden_secret",
1160            )),
1161            access_token: Some(SecretAccessToken::from("token")),
1162            auth_server: AuthServer {
1163                client_id: "some_id".into(),
1164                issuer: "some_url".into(),
1165            },
1166        };
1167
1168        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
1169    }
1170}