qcs_api_client_common/configuration/
tokens.rs

1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation};
5use oauth2::TokenResponse;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8use tokio::sync::{Mutex, Notify, RwLock};
9use tokio_util::sync::CancellationToken;
10
11use super::{
12    oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
13};
14use crate::configuration::{
15    error::DiscoveryError,
16    pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
17};
18#[cfg(feature = "tracing-config")]
19use crate::tracing_configuration::TracingConfiguration;
20#[cfg(feature = "tracing")]
21use urlpattern::UrlPatternMatchInput;
22
23/// A single type containing an access token and an associated refresh token.
24#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
25#[cfg_attr(feature = "python", pyo3::pyclass)]
26pub struct RefreshToken {
27    /// The token used to refresh the access token.
28    pub refresh_token: String,
29}
30
31impl RefreshToken {
32    /// Create a new [`RefreshToken`] with the given refresh token.
33    #[must_use]
34    pub const fn new(refresh_token: String) -> Self {
35        Self { refresh_token }
36    }
37
38    /// Request and return a new access token from the given authorization server using this refresh token.
39    ///
40    /// # Errors
41    ///
42    /// See [`TokenError`]
43    pub async fn request_access_token(
44        &mut self,
45        auth_server: &AuthServer,
46    ) -> Result<String, TokenError> {
47        if self.refresh_token.is_empty() {
48            return Err(TokenError::NoRefreshToken);
49        }
50
51        let client = default_http_client()?;
52        let token_url = oidc::fetch_discovery(&client, auth_server.issuer())
53            .await?
54            .token_endpoint;
55        let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
56        let resp = client.post(token_url).form(&data).send().await?;
57
58        let response_data: RefreshTokenResponse = resp.error_for_status()?.json().await?;
59        self.refresh_token = response_data.refresh_token;
60        Ok(response_data.access_token)
61    }
62}
63
64#[derive(Deserialize, Debug, Serialize)]
65pub(super) struct ClientCredentialsResponse {
66    pub(super) access_token: String,
67}
68
69#[derive(Clone, PartialEq, Eq, Deserialize)]
70#[expect(missing_debug_implementations, reason = "contains secret data")]
71#[cfg_attr(feature = "python", pyo3::pyclass)]
72/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
73pub struct ClientCredentials {
74    /// The client ID
75    pub client_id: String,
76    /// The client secret.
77    pub client_secret: String,
78}
79
80impl ClientCredentials {
81    #[must_use]
82    /// Construct a new [`ClientCredentials`]
83    pub const fn new(client_id: String, client_secret: String) -> Self {
84        Self {
85            client_id,
86            client_secret,
87        }
88    }
89
90    /// Get the client ID.
91    #[must_use]
92    pub fn client_id(&self) -> &str {
93        &self.client_id
94    }
95
96    /// Get the client secret.
97    #[must_use]
98    pub fn client_secret(&self) -> &str {
99        &self.client_secret
100    }
101
102    /// Request and return an access token from the given auth server using this set of client credentials.
103    ///
104    /// # Errors
105    ///
106    /// See [`TokenError`]
107    pub async fn request_access_token(
108        &self,
109        auth_server: &AuthServer,
110    ) -> Result<String, TokenError> {
111        let request = ClientCredentialsRequest::new(None);
112        let client = default_http_client()?;
113
114        let url = oidc::fetch_discovery(&client, auth_server.issuer())
115            .await?
116            .token_endpoint;
117        let ready_to_send = client
118            .post(url)
119            .basic_auth(auth_server.client_id(), Some(&self.client_secret))
120            .form(&request);
121        let response = ready_to_send.send().await?;
122
123        response.error_for_status_ref()?;
124
125        let response_body: ClientCredentialsResponse = response.json().await?;
126
127        Ok(response_body.access_token)
128    }
129}
130
131#[derive(Clone, PartialEq, Eq, Deserialize)]
132#[expect(missing_debug_implementations, reason = "contains secret data")]
133#[cfg_attr(feature = "python", pyo3::pyclass)]
134/// The Access (Bearer) and refresh (if available) tokens from a PKCE login.
135pub struct PkceFlow {
136    /// The access token.
137    pub access_token: String,
138    /// The refresh token, if available.
139    pub refresh_token: Option<RefreshToken>,
140}
141
142/// Errors that can occur when attempting to perform a PKCE login flow.
143#[derive(Debug, thiserror::Error)]
144pub enum PkceFlowError {
145    #[error(transparent)]
146    PkceLogin(#[from] PkceLoginError),
147    #[error(transparent)]
148    Discovery(#[from] DiscoveryError),
149    #[error(transparent)]
150    Request(#[from] reqwest::Error),
151}
152
153impl PkceFlow {
154    /// Starts a new PKCE login flow to acquire a new set of tokens.
155    ///
156    /// # Errors
157    ///
158    /// See [`PkceFlowError`]
159    pub async fn new_login_flow(
160        cancel_token: CancellationToken,
161        auth_server: &AuthServer,
162    ) -> Result<Self, PkceFlowError> {
163        let issuer = auth_server.issuer().to_string();
164
165        let client = default_http_client()?;
166        let discovery = oidc::fetch_discovery(&client, &issuer).await?;
167
168        let response = pkce_login(
169            cancel_token,
170            PkceLoginRequest {
171                client_id: auth_server.client_id().to_string(),
172                redirect_port: None,
173                discovery,
174            },
175        )
176        .await?;
177
178        Ok(Self {
179            access_token: response.access_token().secret().to_string(),
180            refresh_token: response
181                .refresh_token()
182                .map(|rt| RefreshToken::new(rt.secret().to_string())),
183        })
184    }
185
186    /// Returns the access token if it is valid, otherwise requests a new access token using the refresh token if available.
187    ///
188    /// # Errors
189    ///
190    /// See [`TokenError`]
191    pub async fn request_access_token(
192        &mut self,
193        auth_server: &AuthServer,
194    ) -> Result<String, TokenError> {
195        if let Ok(access_token) = insecure_validate_token_exp(&self.access_token) {
196            return Ok(access_token);
197        }
198
199        if let Some(refresh_token) = &mut self.refresh_token {
200            let access_token = refresh_token.request_access_token(auth_server).await?;
201            self.access_token.clone_from(&access_token);
202            return Ok(access_token);
203        }
204
205        Err(TokenError::NoRefreshToken)
206    }
207}
208
209#[derive(Clone)]
210#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
211/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
212/// needed to request said grant type.
213pub enum OAuthGrant {
214    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
215    RefreshToken(RefreshToken),
216    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
217    ClientCredentials(ClientCredentials),
218    /// Defers to a user provided function for access token requests.
219    ExternallyManaged(ExternallyManaged),
220    /// The tokens returned by the PKCE login that are an [Authorization Code grant type](https://oauth.net/2/pkce/).
221    PkceFlow(PkceFlow),
222}
223
224impl From<ExternallyManaged> for OAuthGrant {
225    fn from(v: ExternallyManaged) -> Self {
226        Self::ExternallyManaged(v)
227    }
228}
229
230impl From<ClientCredentials> for OAuthGrant {
231    fn from(v: ClientCredentials) -> Self {
232        Self::ClientCredentials(v)
233    }
234}
235
236impl From<RefreshToken> for OAuthGrant {
237    fn from(v: RefreshToken) -> Self {
238        Self::RefreshToken(v)
239    }
240}
241
242impl From<PkceFlow> for OAuthGrant {
243    fn from(v: PkceFlow) -> Self {
244        Self::PkceFlow(v)
245    }
246}
247
248impl OAuthGrant {
249    /// Request a new access token from the given issuer using this grant type and payload.
250    async fn request_access_token(
251        &mut self,
252        auth_server: &AuthServer,
253    ) -> Result<String, TokenError> {
254        match self {
255            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
256            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
257            Self::ExternallyManaged(tokens) => tokens
258                .request_access_token(auth_server)
259                .await
260                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
261            Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
262        }
263    }
264}
265
266impl std::fmt::Debug for OAuthGrant {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        match self {
269            Self::RefreshToken(_) => f.write_str("RefreshToken"),
270            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
271            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
272            Self::PkceFlow(_) => f.write_str("PkceTokens"),
273        }
274    }
275}
276
277/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
278///
279/// This struct encapsulates the necessary information to request an access token
280/// from an authorization server, including the `OAuth2` grant type and any associated
281/// credentials or payload data.
282///
283/// # Fields
284///
285/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
286/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
287/// * `auth_server` - The authorization server responsible for issuing tokens.
288#[derive(Clone)]
289#[cfg_attr(feature = "python", pyo3::pyclass)]
290pub struct OAuthSession {
291    /// The grant type to use to request an access token.
292    payload: OAuthGrant,
293    /// The access token that is currently in use. None if no token has been requested yet.
294    access_token: Option<String>,
295    /// The [`AuthServer`] that issues the tokens.
296    auth_server: AuthServer,
297}
298
299impl OAuthSession {
300    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
301    ///
302    /// Optionally include an `access_token`, if not included, then one can be requested
303    /// with [`Self::request_access_token`].
304    #[must_use]
305    pub const fn new(
306        payload: OAuthGrant,
307        auth_server: AuthServer,
308        access_token: Option<String>,
309    ) -> Self {
310        Self {
311            payload,
312            access_token,
313            auth_server,
314        }
315    }
316
317    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
318    ///
319    /// Optionally include an `access_token`, if not included, then one can be requested
320    /// with [`Self::request_access_token`].
321    #[must_use]
322    pub const fn from_externally_managed(
323        tokens: ExternallyManaged,
324        auth_server: AuthServer,
325        access_token: Option<String>,
326    ) -> Self {
327        Self::new(
328            OAuthGrant::ExternallyManaged(tokens),
329            auth_server,
330            access_token,
331        )
332    }
333
334    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
335    ///
336    /// Optionally include an `access_token`, if not included, then one can be requested
337    /// with [`Self::request_access_token`].
338    #[must_use]
339    pub const fn from_refresh_token(
340        tokens: RefreshToken,
341        auth_server: AuthServer,
342        access_token: Option<String>,
343    ) -> Self {
344        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
345    }
346
347    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
348    ///
349    /// Optionally include an `access_token`, if not included, then one can be requested
350    /// with [`Self::request_access_token`].
351    #[must_use]
352    pub const fn from_client_credentials(
353        tokens: ClientCredentials,
354        auth_server: AuthServer,
355        access_token: Option<String>,
356    ) -> Self {
357        Self::new(
358            OAuthGrant::ClientCredentials(tokens),
359            auth_server,
360            access_token,
361        )
362    }
363
364    /// Initialize a new set of [`Credentials`] using [`PkceFlow`].
365    ///
366    /// Optionally include an `access_token`, if not included, then one can be requested
367    /// with [`Self::request_access_token`].
368    #[must_use]
369    pub const fn from_pkce_flow(
370        flow: PkceFlow,
371        auth_server: AuthServer,
372        access_token: Option<String>,
373    ) -> Self {
374        Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
375    }
376
377    /// Get the current access token.
378    ///
379    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
380    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
381    ///
382    /// # Errors
383    ///
384    /// - [`TokenError::NoAccessToken`] if there is no access token
385    pub fn access_token(&self) -> Result<&str, TokenError> {
386        self.access_token.as_ref().map_or_else(
387            || Err(TokenError::NoAccessToken),
388            |token| Ok(token.as_str()),
389        )
390    }
391
392    /// Get the payload used to request an access token.
393    #[must_use]
394    pub const fn payload(&self) -> &OAuthGrant {
395        &self.payload
396    }
397
398    /// Request and return an updated access token using these credentials.
399    ///
400    /// # Errors
401    ///
402    /// See [`TokenError`]
403    #[allow(clippy::missing_panics_doc)]
404    pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
405        let access_token = self.payload.request_access_token(&self.auth_server).await?;
406        self.access_token = Some(access_token);
407        Ok(self
408            .access_token
409            .as_ref()
410            .expect("This value is set in the previous line, so it cannot be None"))
411    }
412
413    /// The [`AuthServer`] that issues the tokens.
414    #[must_use]
415    pub const fn auth_server(&self) -> &AuthServer {
416        &self.auth_server
417    }
418
419    /// Validate the access token, returning it if it is valid, or an error describing why it is
420    /// invalid.
421    ///
422    /// # Errors
423    ///
424    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
425    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
426    pub fn validate(&self) -> Result<String, TokenError> {
427        self.access_token().map_or_else(
428            |_| Err(TokenError::NoAccessToken),
429            insecure_validate_token_exp,
430        )
431    }
432}
433
434/// Validates the access token's format and `exp` claim, but no other claims or
435/// signature. We do this only to determine if the token is expired and needs refreshing,
436/// there is no way to securely validate the token's signature on the client side.
437pub(crate) fn insecure_validate_token_exp(access_token: &str) -> Result<String, TokenError> {
438    let placeholder_key = DecodingKey::from_secret(&[]);
439    let mut validation = Validation::new(Algorithm::RS256);
440    validation.validate_exp = true;
441    validation.leeway = 60;
442    validation.validate_aud = false;
443    validation.insecure_disable_signature_validation();
444    jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
445        .map(|_| access_token.to_string())
446        .map_err(TokenError::InvalidAccessToken)
447}
448
449impl std::fmt::Debug for OAuthSession {
450    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451        let token_populated = if self.access_token.is_some() {
452            Some(())
453        } else {
454            None
455        };
456        f.debug_struct("OAuthSession")
457            .field("payload", &self.payload)
458            .field("access_token", &token_populated)
459            .field("auth_server", &self.auth_server)
460            .finish()
461    }
462}
463
464/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
465#[derive(Clone, Debug)]
466#[cfg_attr(feature = "python", pyo3::pyclass)]
467pub struct TokenDispatcher {
468    lock: Arc<RwLock<OAuthSession>>,
469    refreshing: Arc<Mutex<bool>>,
470    notify_refreshed: Arc<Notify>,
471}
472
473impl From<OAuthSession> for TokenDispatcher {
474    fn from(value: OAuthSession) -> Self {
475        Self {
476            lock: Arc::new(RwLock::new(value)),
477            refreshing: Arc::new(Mutex::new(false)),
478            notify_refreshed: Arc::new(Notify::new()),
479        }
480    }
481}
482
483impl TokenDispatcher {
484    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
485    /// dispatcher.
486    ///
487    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
488    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
489    ///
490    /// # Parameters
491    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
492    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
493    pub async fn use_tokens<F, O>(&self, f: F) -> O
494    where
495        F: FnOnce(&OAuthSession) -> O + Send,
496    {
497        let tokens = self.lock.read().await;
498        f(&tokens)
499    }
500
501    /// Get a copy of the current access token.
502    #[must_use]
503    pub async fn tokens(&self) -> OAuthSession {
504        self.use_tokens(Clone::clone).await
505    }
506
507    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
508    ///
509    /// # Errors
510    ///
511    /// See [`TokenError`]
512    pub async fn refresh(
513        &self,
514        source: &ConfigSource,
515        profile: &str,
516    ) -> Result<OAuthSession, TokenError> {
517        self.managed_refresh(Self::perform_refresh, source, profile)
518            .await
519    }
520
521    /// Validate the access token, returning it if it is valid, or an error describing why it is
522    /// invalid.
523    ///
524    /// # Errors
525    ///
526    /// - [`TokenError::NoAccessToken`] if there is no access token
527    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
528    pub async fn validate(&self) -> Result<String, TokenError> {
529        self.use_tokens(OAuthSession::validate).await
530    }
531
532    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
533    /// ``refresh_fn``.
534    async fn managed_refresh<F, Fut>(
535        &self,
536        refresh_fn: F,
537        source: &ConfigSource,
538        profile: &str,
539    ) -> Result<OAuthSession, TokenError>
540    where
541        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
542        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
543    {
544        let mut is_refreshing = self.refreshing.lock().await;
545
546        if *is_refreshing {
547            drop(is_refreshing);
548            self.notify_refreshed.notified().await;
549            return Ok(self.tokens().await);
550        }
551
552        *is_refreshing = true;
553        drop(is_refreshing);
554
555        let oauth_session = refresh_fn(self.lock.clone()).await?;
556
557        // If the config source is a file, write the new access token to the file
558        if let ConfigSource::File {
559            settings_path: _,
560            secrets_path,
561        } = source
562        {
563            if !Secrets::is_read_only(secrets_path).await? {
564                // If the payload is a PkceFlow, write the fresh refresh token if available.
565                let refresh_token = match &oauth_session.payload {
566                    OAuthGrant::PkceFlow(payload) => payload
567                        .refresh_token
568                        .as_ref()
569                        .map(|rt| rt.refresh_token.as_str()),
570                    _ => None,
571                };
572
573                let now = OffsetDateTime::now_utc();
574                Secrets::write_tokens(
575                    secrets_path,
576                    profile,
577                    refresh_token,
578                    oauth_session.access_token()?,
579                    now,
580                )
581                .await?;
582            }
583        }
584
585        *self.refreshing.lock().await = false;
586        self.notify_refreshed.notify_waiters();
587        Ok(oauth_session)
588    }
589
590    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
591    /// of the updated [`Credentials`]
592    ///
593    /// # Errors
594    ///
595    /// See [`TokenError`]
596    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
597        let mut credentials = lock.write().await;
598        credentials.request_access_token().await?;
599        Ok(credentials.clone())
600    }
601}
602
603pub(crate) type RefreshResult =
604    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
605
606/// A function that asynchronously refreshes a token.
607pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
608
609/// A struct that manages access tokens by utilizing a user-provided refresh function.
610///
611/// The [`ExternallyManaged`] struct allows users to define custom logic for
612/// fetching or refreshing access tokens.
613#[derive(Clone)]
614#[cfg_attr(feature = "python", pyo3::pyclass)]
615pub struct ExternallyManaged {
616    refresh_function: Arc<RefreshFunction>,
617}
618
619impl ExternallyManaged {
620    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
621    ///
622    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
623    /// they better fit your use case.
624    ///
625    /// # Arguments
626    ///
627    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
628    ///
629    /// # Example
630    ///
631    /// ```
632    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
633    /// use std::future::Future;
634    /// use std::pin::Pin;
635    /// use std::boxed::Box;
636    /// use std::error::Error;
637    ///
638    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
639    /// + Send + Sync>> {
640    ///     Ok("new_token_value".to_string())
641    /// }
642    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
643    /// ```
644    pub fn new(
645        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
646    ) -> Self {
647        Self {
648            refresh_function: Arc::new(Box::new(refresh_function)),
649        }
650    }
651
652    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
653    ///
654    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
655    /// the boxing and pinning of the future internally.
656    ///
657    /// # Arguments
658    ///
659    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
660    ///   produces a [`Result<String, TokenError>`].
661    ///
662    /// # Example
663    ///
664    /// ```
665    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
666    /// use tokio::runtime::Runtime;
667    /// use std::error::Error;
668    ///
669    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
670    /// + Send + Sync>> {
671    ///     Ok("new_token_value".to_string())
672    /// }
673    ///
674    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
675    ///
676    /// let rt = Runtime::new().unwrap();
677    /// rt.block_on(async {
678    ///     match token_manager.request_access_token(&AuthServer::default()).await {
679    ///         Ok(token) => println!("Token: {}", token),
680    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
681    ///     }
682    /// });
683    /// ```
684    pub fn from_async<F, Fut>(refresh_function: F) -> Self
685    where
686        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
687        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
688            + Send
689            + 'static,
690    {
691        Self {
692            refresh_function: Arc::new(Box::new(move |auth_server| {
693                Box::pin(refresh_function(auth_server))
694            })),
695        }
696    }
697
698    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
699    ///
700    /// The synchronous function is wrapped in an async block to fit the expected signature.
701    ///
702    /// # Arguments
703    ///
704    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
705    ///
706    /// # Example
707    ///
708    /// ```
709    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
710    /// use tokio::runtime::Runtime;
711    /// use std::error::Error;
712    ///
713    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
714    /// + Send + Sync>> {
715    ///     Ok("sync_token_value".to_string())
716    /// }
717    ///
718    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
719    ///
720    /// let rt = Runtime::new().unwrap();
721    /// rt.block_on(async {
722    ///     match token_manager.request_access_token(&AuthServer::default()).await {
723    ///         Ok(token) => println!("Token: {}", token),
724    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
725    ///     }
726    /// });
727    /// ```
728    pub fn from_sync(
729        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
730            + Send
731            + Sync
732            + 'static,
733    ) -> Self {
734        Self {
735            refresh_function: Arc::new(Box::new(move |auth_server| {
736                let result = refresh_function(auth_server);
737                Box::pin(async move { result })
738            })),
739        }
740    }
741
742    /// Request an updated access token using the provided refresh function.
743    ///
744    /// # Errors
745    ///
746    /// Errors are propagated from the refresh function.
747    pub async fn request_access_token(
748        &self,
749        auth_server: &AuthServer,
750    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
751        (self.refresh_function)(auth_server.clone()).await
752    }
753}
754
755impl std::fmt::Debug for ExternallyManaged {
756    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
757        f.debug_struct("ExternallyManaged")
758            .field(
759                "refresh_function",
760                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
761            )
762            .finish()
763    }
764}
765
766#[derive(Debug, Serialize, Deserialize)]
767pub(super) struct TokenRefreshRequest<'a> {
768    grant_type: &'static str,
769    client_id: &'a str,
770    refresh_token: &'a str,
771}
772
773impl<'a> TokenRefreshRequest<'a> {
774    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
775        Self {
776            grant_type: "refresh_token",
777            client_id,
778            refresh_token,
779        }
780    }
781}
782
783#[derive(Debug, Serialize, Deserialize)]
784pub(super) struct ClientCredentialsRequest {
785    grant_type: &'static str,
786    scope: Option<&'static str>,
787}
788
789impl ClientCredentialsRequest {
790    pub(super) const fn new(scope: Option<&'static str>) -> Self {
791        Self {
792            grant_type: "client_credentials",
793            scope,
794        }
795    }
796}
797
798#[derive(Deserialize, Debug, Serialize)]
799pub(super) struct RefreshTokenResponse {
800    pub(super) refresh_token: String,
801    pub(super) access_token: String,
802}
803
804/// Get and refresh access tokens
805#[async_trait::async_trait]
806pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
807    /// The type to be returned in the event of a error during getting or
808    /// refreshing an access token
809    type Error;
810
811    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
812    async fn validated_access_token(&self) -> Result<String, Self::Error>;
813
814    /// Get the current access token, if any
815    async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
816
817    /// Get a fresh access token
818    async fn refresh_access_token(&self) -> Result<String, Self::Error>;
819
820    /// Get the base URL for requests
821    #[cfg(feature = "tracing")]
822    fn base_url(&self) -> &str;
823
824    /// Get the tracing configuration
825    #[cfg(feature = "tracing-config")]
826    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
827
828    /// Returns whether the given URL should be traced. Following
829    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
830    #[cfg(feature = "tracing")]
831    #[allow(clippy::needless_return)]
832    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
833        #[cfg(not(feature = "tracing-config"))]
834        {
835            let _ = url;
836            return true;
837        }
838
839        #[cfg(feature = "tracing-config")]
840        self.tracing_configuration()
841            .is_none_or(|config| config.is_enabled(url))
842    }
843}
844
845#[async_trait::async_trait]
846impl TokenRefresher for ClientConfiguration {
847    type Error = TokenError;
848
849    async fn validated_access_token(&self) -> Result<String, Self::Error> {
850        self.get_bearer_access_token().await
851    }
852
853    async fn refresh_access_token(&self) -> Result<String, Self::Error> {
854        Ok(self.refresh().await?.access_token()?.to_string())
855    }
856
857    async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
858        Ok(Some(
859            self.oauth_session().await?.access_token()?.to_string(),
860        ))
861    }
862
863    #[cfg(feature = "tracing")]
864    fn base_url(&self) -> &str {
865        &self.grpc_api_url
866    }
867
868    #[cfg(feature = "tracing-config")]
869    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
870        self.tracing_configuration.as_ref()
871    }
872}
873
874/// Get a default http client.
875fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
876    reqwest::Client::builder()
877        .timeout(std::time::Duration::from_secs(10))
878        .build()
879}
880
881#[cfg(test)]
882mod test {
883    use std::time::Duration;
884
885    use super::*;
886    use httpmock::prelude::*;
887    use rstest::rstest;
888    use time::format_description::well_known::Rfc3339;
889    use tokio::time::Instant;
890    use toml_edit::DocumentMut;
891
892    #[tokio::test]
893    async fn test_tokens_blocked_during_refresh() {
894        let mock_server = MockServer::start_async().await;
895
896        let oidc_mock = mock_server
897            .mock_async(|when, then| {
898                when.method(GET).path("/.well-known/openid-configuration");
899                then.status(200)
900                    .json_body_obj(&oidc::Discovery::new_for_test(
901                        mock_server.base_url().parse().unwrap(),
902                    ));
903            })
904            .await;
905
906        let issuer_mock = mock_server
907            .mock_async(|when, then| {
908                when.method(POST).path("/v1/token");
909
910                then.status(200)
911                    .delay(Duration::from_secs(3))
912                    .json_body_obj(&RefreshTokenResponse {
913                        access_token: "new_access".to_string(),
914                        refresh_token: "new_refresh".to_string(),
915                    });
916            })
917            .await;
918
919        let original_tokens = OAuthSession::from_refresh_token(
920            RefreshToken::new("refresh".to_string()),
921            AuthServer::new("client_id".to_string(), mock_server.base_url()),
922            None,
923        );
924        let dispatcher: TokenDispatcher = original_tokens.clone().into();
925        let dispatcher_clone1 = dispatcher.clone();
926        let dispatcher_clone2 = dispatcher.clone();
927
928        let refresh_duration = Duration::from_secs(3);
929
930        let start_write = Instant::now();
931        let write_future = tokio::spawn(async move {
932            dispatcher_clone1
933                .refresh(&ConfigSource::Default, "")
934                .await
935                .unwrap()
936        });
937
938        let start_read = Instant::now();
939        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
940
941        let _ = write_future.await.unwrap();
942        let read_result = read_future.await.unwrap();
943
944        let write_duration = start_write.elapsed();
945        let read_duration = start_read.elapsed();
946
947        oidc_mock.assert_async().await;
948        issuer_mock.assert_async().await;
949
950        assert!(
951            write_duration >= refresh_duration,
952            "Write operation did not take enough time"
953        );
954        assert!(
955            read_duration >= refresh_duration,
956            "Read operation was not blocked by the write operation"
957        );
958        assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
959        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
960            assert_eq!(&payload.refresh_token, "new_refresh");
961        } else {
962            panic!(
963                "Expected RefreshToken payload, got {:?}",
964                read_result.payload
965            );
966        }
967    }
968
969    #[rstest]
970    fn test_qcs_secrets_readonly(
971        #[values(
972            (Some("TRUE"), true),
973            (Some("tRue"), true),
974            (Some("true"), true),
975            (Some("YES"), true),
976            (Some("yEs"), true),
977            (Some("yes"), true),
978            (Some("1"), true),
979            (Some("2"), false),
980            (Some("other"), false),
981            (Some(""), false),
982            (None, false),
983        )]
984        read_only_values: (Option<&str>, bool),
985        #[values(true, false)] read_only_perm: bool,
986    ) {
987        let (maybe_read_only_env, env_is_read_only) = read_only_values;
988        let expected_update = !env_is_read_only && !read_only_perm;
989        figment::Jail::expect_with(|jail| {
990            let profile_name = "test";
991            let initial_access_token = "initial_access_token";
992            let initial_refresh_token = "initial_refresh_token";
993
994            let initial_secrets_file_contents = format!(
995                r#"
996[credentials]
997[credentials.{profile_name}]
998[credentials.{profile_name}.token_payload]
999access_token = "{initial_access_token}"
1000expires_in = 3600
1001id_token = "id_token"
1002refresh_token = "{initial_refresh_token}"
1003scope = "offline_access openid profile email"
1004token_type = "Bearer"
1005updated_at = "2024-01-01T00:00:00Z"
1006"#
1007            );
1008
1009            // Ignore any existing environment variables.
1010            jail.clear_env();
1011
1012            // Create a temporary secrets file
1013            let secrets_path = "secrets.toml";
1014            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1015                .expect("should create test secrets.toml");
1016
1017            if read_only_perm {
1018                let mut permissions = std::fs::metadata(secrets_path)
1019                    .expect("Should be able to get file metadata")
1020                    .permissions();
1021                permissions.set_readonly(true);
1022                std::fs::set_permissions(secrets_path, permissions)
1023                    .expect("Should be able to set file permissions");
1024            }
1025
1026            let rt = tokio::runtime::Runtime::new().unwrap();
1027            rt.block_on(async {
1028                let mock_server = MockServer::start_async().await;
1029
1030                let oidc_mock = mock_server
1031                    .mock_async(|when, then| {
1032                        when.method(GET).path("/.well-known/openid-configuration");
1033                        then.status(200)
1034                            .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1035                    })
1036                    .await;
1037
1038                // Set up the mock token endpoint
1039                let new_access_token = "new_access_token";
1040                let issuer_mock = mock_server
1041                    .mock_async(|when, then| {
1042                        when.method(POST).path("/v1/token");
1043                        then.status(200).json_body_obj(&RefreshTokenResponse {
1044                            access_token: new_access_token.to_string(),
1045                            refresh_token: initial_refresh_token.to_string(),
1046                        });
1047                    })
1048                    .await;
1049
1050                // Create tokens and dispatcher
1051                let original_tokens = OAuthSession::from_refresh_token(
1052                    RefreshToken::new(initial_refresh_token.to_string()),
1053                    AuthServer::new("client_id".to_string(), mock_server.base_url()),
1054                    Some(initial_refresh_token.to_string()),
1055                );
1056                let dispatcher: TokenDispatcher = original_tokens.into();
1057
1058                // Test with QCS_SECRETS_READ_ONLY set first
1059                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1060                jail.set_env("QCS_PROFILE_NAME", "test");
1061                if let Some(read_only_env) = maybe_read_only_env {
1062                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1063                }
1064
1065                let before_refresh = OffsetDateTime::now_utc();
1066
1067                dispatcher
1068                    .refresh(
1069                        &ConfigSource::File {
1070                            settings_path: "".into(),
1071                            secrets_path: "secrets.toml".into(),
1072                        },
1073                        profile_name,
1074                    )
1075                    .await
1076                    .unwrap();
1077
1078                oidc_mock.assert_async().await;
1079                issuer_mock.assert_async().await;
1080
1081                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
1082                let content = std::fs::read_to_string("secrets.toml").unwrap();
1083                if !expected_update {
1084                    assert!(
1085                        content.eq(initial_secrets_file_contents.as_str()),
1086                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1087                    );
1088                    return;
1089                }
1090
1091                // Verify the file was updated
1092                let mut toml = std::fs::read_to_string(secrets_path)
1093                    .unwrap()
1094                    .parse::<DocumentMut>()
1095                    .unwrap();
1096
1097                let token_payload = toml
1098                    .get_mut("credentials")
1099                    .and_then(|credentials| {
1100                        credentials.get_mut(profile_name)?.get_mut("token_payload")
1101                    })
1102                    .expect("Should be able to get token_payload table");
1103
1104                assert_eq!(
1105                    token_payload.get("access_token").unwrap().as_str().unwrap(),
1106                    new_access_token
1107                );
1108
1109                assert!(
1110                    OffsetDateTime::parse(
1111                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
1112                        &Rfc3339
1113                    )
1114                    .unwrap()
1115                        > before_refresh
1116                );
1117
1118                let content = std::fs::read_to_string("secrets.toml").unwrap();
1119                assert!(
1120                content.contains("new_access_token"),
1121                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1122                );
1123            });
1124            Ok(())
1125        });
1126    }
1127
1128    #[test]
1129    fn test_auth_session_debug_fmt() {
1130        let session = OAuthSession {
1131            payload: OAuthGrant::ClientCredentials(ClientCredentials {
1132                client_id: "hidden_id".into(),
1133                client_secret: "hidden_secret".into(),
1134            }),
1135            access_token: Some("token".into()),
1136            auth_server: AuthServer::new("some_id".into(), "some_url".into()),
1137        };
1138
1139        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
1140    }
1141}