qcs_api_client_common/configuration/
tokens.rs

1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation};
5use serde::{Deserialize, Serialize};
6use time::OffsetDateTime;
7use tokio::sync::{Mutex, Notify, RwLock};
8
9use super::{
10    secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
11    QCS_AUDIENCE,
12};
13#[cfg(feature = "tracing-config")]
14use crate::tracing_configuration::TracingConfiguration;
15#[cfg(feature = "tracing")]
16use urlpattern::UrlPatternMatchInput;
17
18/// A single type containing an access token and an associated refresh token.
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
20#[cfg_attr(feature = "python", pyo3::pyclass)]
21pub struct RefreshToken {
22    /// The token used to refresh the access token.
23    pub refresh_token: String,
24}
25
26impl RefreshToken {
27    /// Create a new [`RefreshToken`] with the given refresh token.
28    #[must_use]
29    pub const fn new(refresh_token: String) -> Self {
30        Self { refresh_token }
31    }
32
33    /// Request and return a new access token from the given authorization server using this refresh token.
34    ///
35    /// # Errors
36    ///
37    /// See [`TokenError`]
38    pub async fn request_access_token(
39        &mut self,
40        auth_server: &AuthServer,
41    ) -> Result<String, TokenError> {
42        if self.refresh_token.is_empty() {
43            return Err(TokenError::NoRefreshToken);
44        }
45        let token_url = format!("{}/v1/token", auth_server.issuer());
46        let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
47        let resp = reqwest::Client::builder()
48            .timeout(std::time::Duration::from_secs(10))
49            .build()?
50            .post(token_url)
51            .form(&data)
52            .send()
53            .await?;
54
55        let response_data: RefreshTokenResponse = resp.error_for_status()?.json().await?;
56        self.refresh_token = response_data.refresh_token;
57        Ok(response_data.access_token)
58    }
59}
60
61#[derive(Deserialize, Debug, Serialize)]
62pub(super) struct ClientCredentialsResponse {
63    pub(super) access_token: String,
64}
65
66#[derive(Clone, PartialEq, Eq, Deserialize)]
67#[expect(missing_debug_implementations, reason = "contains secret data")]
68#[cfg_attr(feature = "python", pyo3::pyclass)]
69/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
70pub struct ClientCredentials {
71    /// The client ID
72    pub client_id: String,
73    /// The client secret.
74    pub client_secret: String,
75}
76
77impl ClientCredentials {
78    #[must_use]
79    /// Construct a new [`ClientCredentials`]
80    pub const fn new(client_id: String, client_secret: String) -> Self {
81        Self {
82            client_id,
83            client_secret,
84        }
85    }
86
87    /// Get the client ID.
88    #[must_use]
89    pub fn client_id(&self) -> &str {
90        &self.client_id
91    }
92
93    /// Get the client secret.
94    #[must_use]
95    pub fn client_secret(&self) -> &str {
96        &self.client_secret
97    }
98
99    /// Request and return an access token from the given auth server using this set of client credentials.
100    ///
101    /// # Errors
102    ///
103    /// See [`TokenError`]
104    pub async fn request_access_token(
105        &self,
106        auth_server: &AuthServer,
107    ) -> Result<String, TokenError> {
108        let request = ClientCredentialsRequest::new(None);
109        let url = format!("{}/v1/token", auth_server.issuer());
110
111        let client = reqwest::Client::builder()
112            .timeout(std::time::Duration::from_secs(10))
113            .build()?;
114
115        let ready_to_send = client
116            .post(url)
117            .basic_auth(auth_server.client_id(), Some(&self.client_secret))
118            .form(&request);
119        let response = ready_to_send.send().await?;
120
121        response.error_for_status_ref()?;
122
123        let response_body: ClientCredentialsResponse = response.json().await?;
124
125        Ok(response_body.access_token)
126    }
127}
128
129#[derive(Clone)]
130#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
131/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
132/// needed to request said grant type.
133pub enum OAuthGrant {
134    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
135    RefreshToken(RefreshToken),
136    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
137    ClientCredentials(ClientCredentials),
138    /// Defers to a user provided function for access token requests.
139    ExternallyManaged(ExternallyManaged),
140}
141
142impl From<ExternallyManaged> for OAuthGrant {
143    fn from(v: ExternallyManaged) -> Self {
144        Self::ExternallyManaged(v)
145    }
146}
147
148impl From<ClientCredentials> for OAuthGrant {
149    fn from(v: ClientCredentials) -> Self {
150        Self::ClientCredentials(v)
151    }
152}
153
154impl From<RefreshToken> for OAuthGrant {
155    fn from(v: RefreshToken) -> Self {
156        Self::RefreshToken(v)
157    }
158}
159
160impl OAuthGrant {
161    /// Request a new access token from the given issuer using this grant type and payload.
162    async fn request_access_token(
163        &mut self,
164        auth_server: &AuthServer,
165    ) -> Result<String, TokenError> {
166        match self {
167            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
168            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
169            Self::ExternallyManaged(tokens) => tokens
170                .request_access_token(auth_server)
171                .await
172                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
173        }
174    }
175}
176
177impl std::fmt::Debug for OAuthGrant {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::RefreshToken(_) => f.write_str("RefreshToken"),
181            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
182            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
183        }
184    }
185}
186
187/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
188///
189/// This struct encapsulates the necessary information to request an access token
190/// from an authorization server, including the `OAuth2` grant type and any associated
191/// credentials or payload data.
192///
193/// # Fields
194///
195/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
196/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
197/// * `auth_server` - The authorization server responsible for issuing tokens.
198#[derive(Clone)]
199#[cfg_attr(feature = "python", pyo3::pyclass)]
200pub struct OAuthSession {
201    /// The grant type to use to request an access token.
202    payload: OAuthGrant,
203    /// The access token that is currently in use. None if no token has been requested yet.
204    access_token: Option<String>,
205    /// The [`AuthServer`] that issues the tokens.
206    auth_server: AuthServer,
207}
208
209impl OAuthSession {
210    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
211    ///
212    /// Optionally include an `access_token`, if not included, then one can be requested
213    /// with [`Self::request_access_token`].
214    #[must_use]
215    pub const fn new(
216        payload: OAuthGrant,
217        auth_server: AuthServer,
218        access_token: Option<String>,
219    ) -> Self {
220        Self {
221            payload,
222            access_token,
223            auth_server,
224        }
225    }
226
227    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
228    ///
229    /// Optionally include an `access_token`, if not included, then one can be requested
230    /// with [`Self::request_access_token`].
231    #[must_use]
232    pub const fn from_externally_managed(
233        tokens: ExternallyManaged,
234        auth_server: AuthServer,
235        access_token: Option<String>,
236    ) -> Self {
237        Self::new(
238            OAuthGrant::ExternallyManaged(tokens),
239            auth_server,
240            access_token,
241        )
242    }
243
244    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
245    ///
246    /// Optionally include an `access_token`, if not included, then one can be requested
247    /// with [`Self::request_access_token`].
248    #[must_use]
249    pub const fn from_refresh_token(
250        tokens: RefreshToken,
251        auth_server: AuthServer,
252        access_token: Option<String>,
253    ) -> Self {
254        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
255    }
256
257    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
258    ///
259    /// Optionally include an `access_token`, if not included, then one can be requested
260    /// with [`Self::request_access_token`].
261    #[must_use]
262    pub const fn from_client_credentials(
263        tokens: ClientCredentials,
264        auth_server: AuthServer,
265        access_token: Option<String>,
266    ) -> Self {
267        Self::new(
268            OAuthGrant::ClientCredentials(tokens),
269            auth_server,
270            access_token,
271        )
272    }
273
274    /// Get the current access token.
275    ///
276    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
277    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
278    ///
279    /// # Errors
280    ///
281    /// - [`TokenError::NoAccessToken`] if there is no access token
282    pub fn access_token(&self) -> Result<&str, TokenError> {
283        self.access_token.as_ref().map_or_else(
284            || Err(TokenError::NoAccessToken),
285            |token| Ok(token.as_str()),
286        )
287    }
288
289    /// Get the payload used to request an access token.
290    #[must_use]
291    pub const fn payload(&self) -> &OAuthGrant {
292        &self.payload
293    }
294
295    /// Request and return an updated access token using these credentials.
296    ///
297    /// # Errors
298    ///
299    /// See [`TokenError`]
300    #[allow(clippy::missing_panics_doc)]
301    pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
302        let access_token = self.payload.request_access_token(&self.auth_server).await?;
303        self.access_token = Some(access_token);
304        Ok(self
305            .access_token
306            .as_ref()
307            .expect("This value is set in the previous line, so it cannot be None"))
308    }
309
310    /// The [`AuthServer`] that issues the tokens.
311    #[must_use]
312    pub const fn auth_server(&self) -> &AuthServer {
313        &self.auth_server
314    }
315
316    /// Validate the access token, returning it if it is valid, or an error describing why it is
317    /// invalid.
318    ///
319    /// # Errors
320    ///
321    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
322    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
323    pub fn validate(&self) -> Result<String, TokenError> {
324        self.access_token().map_or_else(
325            |_| Err(TokenError::NoAccessToken),
326            |access_token| {
327                let placeholder_key = DecodingKey::from_secret(&[]);
328                let mut validation = Validation::new(Algorithm::RS256);
329                validation.validate_exp = true;
330                validation.leeway = 60;
331                validation.set_audience(&[QCS_AUDIENCE]);
332                validation.insecure_disable_signature_validation();
333                jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
334                    .map(|_| access_token.to_string())
335                    .map_err(TokenError::InvalidAccessToken)
336            },
337        )
338    }
339}
340
341impl std::fmt::Debug for OAuthSession {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        let token_populated = if self.access_token.is_some() {
344            Some(())
345        } else {
346            None
347        };
348        f.debug_struct("OAuthSession")
349            .field("payload", &self.payload)
350            .field("access_token", &token_populated)
351            .field("auth_server", &self.auth_server)
352            .finish()
353    }
354}
355
356/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
357#[derive(Clone, Debug)]
358#[cfg_attr(feature = "python", pyo3::pyclass)]
359pub struct TokenDispatcher {
360    lock: Arc<RwLock<OAuthSession>>,
361    refreshing: Arc<Mutex<bool>>,
362    notify_refreshed: Arc<Notify>,
363}
364
365impl From<OAuthSession> for TokenDispatcher {
366    fn from(value: OAuthSession) -> Self {
367        Self {
368            lock: Arc::new(RwLock::new(value)),
369            refreshing: Arc::new(Mutex::new(false)),
370            notify_refreshed: Arc::new(Notify::new()),
371        }
372    }
373}
374
375impl TokenDispatcher {
376    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
377    /// dispatcher.
378    ///
379    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
380    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
381    ///
382    /// # Parameters
383    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
384    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
385    pub async fn use_tokens<F, O>(&self, f: F) -> O
386    where
387        F: FnOnce(&OAuthSession) -> O + Send,
388    {
389        let tokens = self.lock.read().await;
390        f(&tokens)
391    }
392
393    /// Get a copy of the current access token.
394    #[must_use]
395    pub async fn tokens(&self) -> OAuthSession {
396        self.use_tokens(Clone::clone).await
397    }
398
399    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
400    ///
401    /// # Errors
402    ///
403    /// See [`TokenError`]
404    pub async fn refresh(
405        &self,
406        source: &ConfigSource,
407        profile: &str,
408    ) -> Result<OAuthSession, TokenError> {
409        self.managed_refresh(Self::perform_refresh, source, profile)
410            .await
411    }
412
413    /// Validate the access token, returning it if it is valid, or an error describing why it is
414    /// invalid.
415    ///
416    /// # Errors
417    ///
418    /// - [`TokenError::NoAccessToken`] if there is no access token
419    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
420    pub async fn validate(&self) -> Result<String, TokenError> {
421        self.use_tokens(OAuthSession::validate).await
422    }
423
424    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
425    /// ``refresh_fn``.
426    async fn managed_refresh<F, Fut>(
427        &self,
428        refresh_fn: F,
429        source: &ConfigSource,
430        profile: &str,
431    ) -> Result<OAuthSession, TokenError>
432    where
433        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
434        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
435    {
436        let mut is_refreshing = self.refreshing.lock().await;
437
438        if *is_refreshing {
439            drop(is_refreshing);
440            self.notify_refreshed.notified().await;
441            return Ok(self.tokens().await);
442        }
443
444        *is_refreshing = true;
445        drop(is_refreshing);
446
447        let oauth_session = refresh_fn(self.lock.clone()).await?;
448
449        // If the config source is a file, write the new access token to the file
450        if let ConfigSource::File {
451            settings_path: _,
452            secrets_path,
453        } = source
454        {
455            if !Secrets::is_read_only(secrets_path).await? {
456                let now = OffsetDateTime::now_utc();
457                Secrets::write_access_token(
458                    secrets_path,
459                    profile,
460                    oauth_session.access_token()?,
461                    now,
462                )
463                .await?;
464            }
465        }
466
467        *self.refreshing.lock().await = false;
468        self.notify_refreshed.notify_waiters();
469        Ok(oauth_session)
470    }
471
472    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
473    /// of the updated [`Credentials`]
474    ///
475    /// # Errors
476    ///
477    /// See [`TokenError`]
478    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
479        let mut credentials = lock.write().await;
480        credentials.request_access_token().await?;
481        Ok(credentials.clone())
482    }
483}
484
485pub(crate) type RefreshResult =
486    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
487
488/// A function that asynchronously refreshes a token.
489pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
490
491/// A struct that manages access tokens by utilizing a user-provided refresh function.
492///
493/// The [`ExternallyManaged`] struct allows users to define custom logic for
494/// fetching or refreshing access tokens.
495#[derive(Clone)]
496#[cfg_attr(feature = "python", pyo3::pyclass)]
497pub struct ExternallyManaged {
498    refresh_function: Arc<RefreshFunction>,
499}
500
501impl ExternallyManaged {
502    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
503    ///
504    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
505    /// they better fit your use case.
506    ///
507    /// # Arguments
508    ///
509    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
510    ///
511    /// # Example
512    ///
513    /// ```
514    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
515    /// use std::future::Future;
516    /// use std::pin::Pin;
517    /// use std::boxed::Box;
518    /// use std::error::Error;
519    ///
520    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
521    /// + Send + Sync>> {
522    ///     Ok("new_token_value".to_string())
523    /// }
524    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
525    /// ```
526    pub fn new(
527        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
528    ) -> Self {
529        Self {
530            refresh_function: Arc::new(Box::new(refresh_function)),
531        }
532    }
533
534    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
535    ///
536    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
537    /// the boxing and pinning of the future internally.
538    ///
539    /// # Arguments
540    ///
541    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
542    ///   produces a [`Result<String, TokenError>`].
543    ///
544    /// # Example
545    ///
546    /// ```
547    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
548    /// use tokio::runtime::Runtime;
549    /// use std::error::Error;
550    ///
551    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
552    /// + Send + Sync>> {
553    ///     Ok("new_token_value".to_string())
554    /// }
555    ///
556    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
557    ///
558    /// let rt = Runtime::new().unwrap();
559    /// rt.block_on(async {
560    ///     match token_manager.request_access_token(&AuthServer::default()).await {
561    ///         Ok(token) => println!("Token: {}", token),
562    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
563    ///     }
564    /// });
565    /// ```
566    pub fn from_async<F, Fut>(refresh_function: F) -> Self
567    where
568        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
569        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
570            + Send
571            + 'static,
572    {
573        Self {
574            refresh_function: Arc::new(Box::new(move |auth_server| {
575                Box::pin(refresh_function(auth_server))
576            })),
577        }
578    }
579
580    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
581    ///
582    /// The synchronous function is wrapped in an async block to fit the expected signature.
583    ///
584    /// # Arguments
585    ///
586    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
587    ///
588    /// # Example
589    ///
590    /// ```
591    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
592    /// use tokio::runtime::Runtime;
593    /// use std::error::Error;
594    ///
595    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
596    /// + Send + Sync>> {
597    ///     Ok("sync_token_value".to_string())
598    /// }
599    ///
600    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
601    ///
602    /// let rt = Runtime::new().unwrap();
603    /// rt.block_on(async {
604    ///     match token_manager.request_access_token(&AuthServer::default()).await {
605    ///         Ok(token) => println!("Token: {}", token),
606    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
607    ///     }
608    /// });
609    /// ```
610    pub fn from_sync(
611        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
612            + Send
613            + Sync
614            + 'static,
615    ) -> Self {
616        Self {
617            refresh_function: Arc::new(Box::new(move |auth_server| {
618                let result = refresh_function(auth_server);
619                Box::pin(async move { result })
620            })),
621        }
622    }
623
624    /// Request an updated access token using the provided refresh function.
625    ///
626    /// # Errors
627    ///
628    /// Errors are propagated from the refresh function.
629    pub async fn request_access_token(
630        &self,
631        auth_server: &AuthServer,
632    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
633        (self.refresh_function)(auth_server.clone()).await
634    }
635}
636
637impl std::fmt::Debug for ExternallyManaged {
638    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639        f.debug_struct("ExternallyManaged")
640            .field(
641                "refresh_function",
642                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
643            )
644            .finish()
645    }
646}
647
648#[derive(Debug, Serialize, Deserialize)]
649pub(super) struct TokenRefreshRequest<'a> {
650    grant_type: &'static str,
651    client_id: &'a str,
652    refresh_token: &'a str,
653}
654
655impl<'a> TokenRefreshRequest<'a> {
656    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
657        Self {
658            grant_type: "refresh_token",
659            client_id,
660            refresh_token,
661        }
662    }
663}
664
665#[derive(Debug, Serialize, Deserialize)]
666pub(super) struct ClientCredentialsRequest {
667    grant_type: &'static str,
668    scope: Option<&'static str>,
669}
670
671impl ClientCredentialsRequest {
672    pub(super) const fn new(scope: Option<&'static str>) -> Self {
673        Self {
674            grant_type: "client_credentials",
675            scope,
676        }
677    }
678}
679
680#[derive(Deserialize, Debug, Serialize)]
681pub(super) struct RefreshTokenResponse {
682    pub(super) refresh_token: String,
683    pub(super) access_token: String,
684}
685
686/// Get and refresh access tokens
687#[async_trait::async_trait]
688pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
689    /// The type to be returned in the event of a error during getting or
690    /// refreshing an access token
691    type Error;
692
693    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
694    async fn validated_access_token(&self) -> Result<String, Self::Error>;
695
696    /// Get the current access token, if any
697    async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
698
699    /// Get a fresh access token
700    async fn refresh_access_token(&self) -> Result<String, Self::Error>;
701
702    /// Get the base URL for requests
703    #[cfg(feature = "tracing")]
704    fn base_url(&self) -> &str;
705
706    /// Get the tracing configuration
707    #[cfg(feature = "tracing-config")]
708    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
709
710    /// Returns whether the given URL should be traced. Following
711    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
712    #[cfg(feature = "tracing")]
713    #[allow(clippy::needless_return)]
714    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
715        #[cfg(not(feature = "tracing-config"))]
716        {
717            let _ = url;
718            return true;
719        }
720
721        #[cfg(feature = "tracing-config")]
722        self.tracing_configuration()
723            .is_none_or(|config| config.is_enabled(url))
724    }
725}
726
727#[async_trait::async_trait]
728impl TokenRefresher for ClientConfiguration {
729    type Error = TokenError;
730
731    async fn validated_access_token(&self) -> Result<String, Self::Error> {
732        self.get_bearer_access_token().await
733    }
734
735    async fn refresh_access_token(&self) -> Result<String, Self::Error> {
736        Ok(self.refresh().await?.access_token()?.to_string())
737    }
738
739    async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
740        Ok(Some(
741            self.oauth_session().await?.access_token()?.to_string(),
742        ))
743    }
744
745    #[cfg(feature = "tracing")]
746    fn base_url(&self) -> &str {
747        &self.grpc_api_url
748    }
749
750    #[cfg(feature = "tracing-config")]
751    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
752        self.tracing_configuration.as_ref()
753    }
754}
755
756#[cfg(test)]
757mod test {
758    use std::time::Duration;
759
760    use super::*;
761    use httpmock::prelude::*;
762    use rstest::rstest;
763    use time::format_description::well_known::Rfc3339;
764    use tokio::time::Instant;
765    use toml_edit::DocumentMut;
766
767    #[tokio::test]
768    async fn test_tokens_blocked_during_refresh() {
769        let mock_server = MockServer::start_async().await;
770
771        let issuer_mock = mock_server
772            .mock_async(|when, then| {
773                when.method(POST).path("/v1/token");
774
775                then.status(200)
776                    .delay(Duration::from_secs(3))
777                    .json_body_obj(&RefreshTokenResponse {
778                        access_token: "new_access".to_string(),
779                        refresh_token: "new_refresh".to_string(),
780                    });
781            })
782            .await;
783
784        let original_tokens = OAuthSession::from_refresh_token(
785            RefreshToken::new("refresh".to_string()),
786            AuthServer::new("client_id".to_string(), mock_server.base_url()),
787            None,
788        );
789        let dispatcher: TokenDispatcher = original_tokens.clone().into();
790        let dispatcher_clone1 = dispatcher.clone();
791        let dispatcher_clone2 = dispatcher.clone();
792
793        let refresh_duration = Duration::from_secs(3);
794
795        let start_write = Instant::now();
796        let write_future = tokio::spawn(async move {
797            dispatcher_clone1
798                .refresh(&ConfigSource::Default, "")
799                .await
800                .unwrap()
801        });
802
803        let start_read = Instant::now();
804        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
805
806        let _ = write_future.await.unwrap();
807        let read_result = read_future.await.unwrap();
808
809        let write_duration = start_write.elapsed();
810        let read_duration = start_read.elapsed();
811
812        issuer_mock.assert_async().await;
813
814        assert!(
815            write_duration >= refresh_duration,
816            "Write operation did not take enough time"
817        );
818        assert!(
819            read_duration >= refresh_duration,
820            "Read operation was not blocked by the write operation"
821        );
822        assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
823        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
824            assert_eq!(&payload.refresh_token, "new_refresh");
825        } else {
826            panic!(
827                "Expected RefreshToken payload, got {:?}",
828                read_result.payload
829            );
830        }
831    }
832
833    #[rstest]
834    fn test_qcs_secrets_readonly(
835        #[values(
836            (Some("TRUE"), true),
837            (Some("tRue"), true),
838            (Some("true"), true),
839            (Some("YES"), true),
840            (Some("yEs"), true),
841            (Some("yes"), true),
842            (Some("1"), true),
843            (Some("2"), false),
844            (Some("other"), false),
845            (Some(""), false),
846            (None, false),
847        )]
848        read_only_values: (Option<&str>, bool),
849        #[values(true, false)] read_only_perm: bool,
850    ) {
851        let (maybe_read_only_env, env_is_read_only) = read_only_values;
852        let expected_update = !env_is_read_only && !read_only_perm;
853        figment::Jail::expect_with(|jail| {
854            let profile_name = "test";
855            let initial_access_token = "initial_access_token";
856            let initial_refresh_token = "initial_refresh_token";
857
858            let initial_secrets_file_contents = format!(
859                r#"
860[credentials]
861[credentials.{profile_name}]
862[credentials.{profile_name}.token_payload]
863access_token = "{initial_access_token}"
864expires_in = 3600
865id_token = "id_token"
866refresh_token = "{initial_refresh_token}"
867scope = "offline_access openid profile email"
868token_type = "Bearer"
869updated_at = "2024-01-01T00:00:00Z"
870"#
871            );
872
873            // Ignore any existing environment variables.
874            jail.clear_env();
875
876            // Create a temporary secrets file
877            let secrets_path = "secrets.toml";
878            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
879                .expect("should create test secrets.toml");
880
881            if read_only_perm {
882                let mut permissions = std::fs::metadata(secrets_path)
883                    .expect("Should be able to get file metadata")
884                    .permissions();
885                permissions.set_readonly(true);
886                std::fs::set_permissions(secrets_path, permissions)
887                    .expect("Should be able to set file permissions");
888            }
889
890            let rt = tokio::runtime::Runtime::new().unwrap();
891            rt.block_on(async {
892                let mock_server = MockServer::start_async().await;
893
894                // Set up the mock token endpoint
895                let new_access_token = "new_access_token";
896                let issuer_mock = mock_server
897                    .mock_async(|when, then| {
898                        when.method(POST).path("/v1/token");
899                        then.status(200).json_body_obj(&RefreshTokenResponse {
900                            access_token: new_access_token.to_string(),
901                            refresh_token: initial_refresh_token.to_string(),
902                        });
903                    })
904                    .await;
905
906                // Create tokens and dispatcher
907                let original_tokens = OAuthSession::from_refresh_token(
908                    RefreshToken::new(initial_refresh_token.to_string()),
909                    AuthServer::new("client_id".to_string(), mock_server.base_url()),
910                    Some(initial_refresh_token.to_string()),
911                );
912                let dispatcher: TokenDispatcher = original_tokens.into();
913
914                // Test with QCS_SECRETS_READ_ONLY set first
915                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
916                jail.set_env("QCS_PROFILE_NAME", "test");
917                if let Some(read_only_env) = maybe_read_only_env {
918                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
919                }
920
921                let before_refresh = OffsetDateTime::now_utc();
922
923                dispatcher
924                    .refresh(
925                        &ConfigSource::File {
926                            settings_path: "".into(),
927                            secrets_path: "secrets.toml".into(),
928                        },
929                        profile_name,
930                    )
931                    .await
932                    .unwrap();
933
934                issuer_mock.assert_async().await;
935
936                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
937                let content = std::fs::read_to_string("secrets.toml").unwrap();
938                if !expected_update {
939                    assert!(
940                        content.eq(initial_secrets_file_contents.as_str()),
941                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
942                    );
943                    return;
944                }
945
946                // Verify the file was updated
947                let mut toml = std::fs::read_to_string(secrets_path)
948                    .unwrap()
949                    .parse::<DocumentMut>()
950                    .unwrap();
951
952                let token_payload = toml
953                    .get_mut("credentials")
954                    .and_then(|credentials| {
955                        credentials.get_mut(profile_name)?.get_mut("token_payload")
956                    })
957                    .expect("Should be able to get token_payload table");
958
959                assert_eq!(
960                    token_payload.get("access_token").unwrap().as_str().unwrap(),
961                    new_access_token
962                );
963
964                assert!(
965                    OffsetDateTime::parse(
966                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
967                        &Rfc3339
968                    )
969                    .unwrap()
970                        > before_refresh
971                );
972
973                let content = std::fs::read_to_string("secrets.toml").unwrap();
974                assert!(
975                content.contains("new_access_token"),
976                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
977                );
978            });
979            Ok(())
980        });
981    }
982
983    #[test]
984    fn test_auth_session_debug_fmt() {
985        let session = OAuthSession {
986            payload: OAuthGrant::ClientCredentials(ClientCredentials {
987                client_id: "hidden_id".into(),
988                client_secret: "hidden_secret".into(),
989            }),
990            access_token: Some("token".into()),
991            auth_server: AuthServer::new("some_id".into(), "some_url".into()),
992        };
993
994        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
995    }
996}