qcs_api_client_common/configuration/
tokens.rs

1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use http::{header::CONTENT_TYPE, HeaderMap, HeaderValue};
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8use tokio::sync::{Mutex, Notify, RwLock};
9
10use super::{
11    secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
12    QCS_AUDIENCE,
13};
14#[cfg(feature = "tracing-config")]
15use crate::tracing_configuration::TracingConfiguration;
16#[cfg(feature = "tracing")]
17use urlpattern::UrlPatternMatchInput;
18
19/// A single type containing an access token and an associated refresh token.
20#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[cfg_attr(feature = "python", pyo3::pyclass)]
22pub struct RefreshToken {
23    /// The token used to refresh the access token.
24    pub refresh_token: String,
25}
26
27impl RefreshToken {
28    /// Create a new [`RefreshToken`] with the given refresh token.
29    #[must_use]
30    pub const fn new(refresh_token: String) -> Self {
31        Self { refresh_token }
32    }
33
34    /// Request and return a new access token from the given authorization server using this refresh token.
35    ///
36    /// # Errors
37    ///
38    /// See [`TokenError`]
39    pub async fn request_access_token(
40        &mut self,
41        auth_server: &AuthServer,
42    ) -> Result<String, TokenError> {
43        if self.refresh_token.is_empty() {
44            return Err(TokenError::NoRefreshToken);
45        }
46        let token_url = format!("{}/v1/token", auth_server.issuer());
47        let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
48        let resp = reqwest::Client::builder()
49            .timeout(std::time::Duration::from_secs(10))
50            .build()?
51            .post(token_url)
52            .form(&data)
53            .send()
54            .await?;
55
56        let response_data: TokenResponse = resp.error_for_status()?.json().await?;
57        self.refresh_token = response_data.refresh_token;
58        Ok(response_data.access_token)
59    }
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
63#[cfg_attr(feature = "python", pyo3::pyclass)]
64/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
65pub struct ClientCredentials {
66    /// The client ID
67    pub client_id: String,
68    /// The client secret.
69    pub client_secret: String,
70}
71
72impl ClientCredentials {
73    #[must_use]
74    /// Construct a new [`ClientCredentials`]
75    pub const fn new(client_id: String, client_secret: String) -> Self {
76        Self {
77            client_id,
78            client_secret,
79        }
80    }
81
82    /// Get the client ID.
83    #[must_use]
84    pub fn client_id(&self) -> &str {
85        &self.client_id
86    }
87
88    /// Get the client secret.
89    #[must_use]
90    pub fn client_secret(&self) -> &str {
91        &self.client_secret
92    }
93
94    /// Request and return an access token from the given auth server using this set of client credentials.
95    ///
96    /// # Errors
97    ///
98    /// See [`TokenError`]
99    pub async fn request_access_token(
100        &self,
101        auth_server: &AuthServer,
102    ) -> Result<String, TokenError> {
103        let request = ClientCredentialsRequest::new(&self.client_id, &self.client_secret);
104        let url = format!("{}/v1/token", auth_server.issuer());
105
106        // let credentials = format!("{}:{}", self.auth_server.client_id(), self.client_secret);
107        // let encoded_credentials = base64::encode(credentials);
108        // let authorization_value = format!("Basic {}", encoded_credentials);
109        let mut headers = HeaderMap::new();
110        // headers.insert(AUTHORIZATION, HeaderValue::from_str(&authorization_value)?);
111        headers.insert(
112            CONTENT_TYPE,
113            HeaderValue::from_static("application/x-www-form-urlencoded"),
114        );
115
116        let client = reqwest::Client::builder()
117            .timeout(std::time::Duration::from_secs(10))
118            .build()?;
119
120        let response = client
121            .post(url)
122            .headers(headers)
123            .form(&request)
124            .send()
125            .await?;
126
127        response.error_for_status_ref()?;
128
129        let response_body: TokenResponse = response.json().await?;
130
131        Ok(response_body.access_token)
132    }
133}
134
135#[derive(Clone, Debug)]
136#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
137/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
138/// needed to request said grant type.
139pub enum OAuthGrant {
140    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
141    RefreshToken(RefreshToken),
142    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
143    ClientCredentials(ClientCredentials),
144    /// Defers to a user provided function for access token requests.
145    ExternallyManaged(ExternallyManaged),
146}
147
148impl From<ExternallyManaged> for OAuthGrant {
149    fn from(v: ExternallyManaged) -> Self {
150        Self::ExternallyManaged(v)
151    }
152}
153
154impl From<ClientCredentials> for OAuthGrant {
155    fn from(v: ClientCredentials) -> Self {
156        Self::ClientCredentials(v)
157    }
158}
159
160impl From<RefreshToken> for OAuthGrant {
161    fn from(v: RefreshToken) -> Self {
162        Self::RefreshToken(v)
163    }
164}
165
166impl OAuthGrant {
167    /// Request a new access token from the given issuer using this grant type and payload.
168    async fn request_access_token(
169        &mut self,
170        auth_server: &AuthServer,
171    ) -> Result<String, TokenError> {
172        match self {
173            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
174            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
175            Self::ExternallyManaged(tokens) => tokens
176                .request_access_token(auth_server)
177                .await
178                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
179        }
180    }
181}
182
183/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
184///
185/// This struct encapsulates the necessary information to request an access token
186/// from an authorization server, including the `OAuth2` grant type and any associated
187/// credentials or payload data.
188///
189/// # Fields
190///
191/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
192/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
193/// * `auth_server` - The authorization server responsible for issuing tokens.
194#[derive(Clone, Debug)]
195#[cfg_attr(feature = "python", pyo3::pyclass)]
196pub struct OAuthSession {
197    /// The grant type to use to request an access token.
198    payload: OAuthGrant,
199    /// The access token that is currently in use. None if no token has been requested yet.
200    access_token: Option<String>,
201    /// The [`AuthServer`] that issues the tokens.
202    auth_server: AuthServer,
203}
204
205impl OAuthSession {
206    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
207    ///
208    /// Optionally include an `access_token`, if not included, then one can be requested
209    /// with [`Self::request_access_token`].
210    #[must_use]
211    pub const fn new(
212        payload: OAuthGrant,
213        auth_server: AuthServer,
214        access_token: Option<String>,
215    ) -> Self {
216        Self {
217            payload,
218            access_token,
219            auth_server,
220        }
221    }
222
223    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
224    ///
225    /// Optionally include an `access_token`, if not included, then one can be requested
226    /// with [`Self::request_access_token`].
227    #[must_use]
228    pub const fn from_externally_managed(
229        tokens: ExternallyManaged,
230        auth_server: AuthServer,
231        access_token: Option<String>,
232    ) -> Self {
233        Self::new(
234            OAuthGrant::ExternallyManaged(tokens),
235            auth_server,
236            access_token,
237        )
238    }
239
240    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
241    ///
242    /// Optionally include an `access_token`, if not included, then one can be requested
243    /// with [`Self::request_access_token`].
244    #[must_use]
245    pub const fn from_refresh_token(
246        tokens: RefreshToken,
247        auth_server: AuthServer,
248        access_token: Option<String>,
249    ) -> Self {
250        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
251    }
252
253    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
254    ///
255    /// Optionally include an `access_token`, if not included, then one can be requested
256    /// with [`Self::request_access_token`].
257    #[must_use]
258    pub const fn from_client_credentials(
259        tokens: ClientCredentials,
260        auth_server: AuthServer,
261        access_token: Option<String>,
262    ) -> Self {
263        Self::new(
264            OAuthGrant::ClientCredentials(tokens),
265            auth_server,
266            access_token,
267        )
268    }
269
270    /// Get the current access token.
271    ///
272    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
273    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
274    ///
275    /// # Errors
276    ///
277    /// - [`TokenError::NoAccessToken`] if there is no access token
278    pub fn access_token(&self) -> Result<&str, TokenError> {
279        self.access_token.as_ref().map_or_else(
280            || Err(TokenError::NoAccessToken),
281            |token| Ok(token.as_str()),
282        )
283    }
284
285    /// Get the payload used to request an access token.
286    #[must_use]
287    pub const fn payload(&self) -> &OAuthGrant {
288        &self.payload
289    }
290
291    /// Request and return an updated access token using these credentials.
292    ///
293    /// # Errors
294    ///
295    /// See [`TokenError`]
296    #[allow(clippy::missing_panics_doc)]
297    pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
298        let access_token = self.payload.request_access_token(&self.auth_server).await?;
299        self.access_token = Some(access_token);
300        Ok(self
301            .access_token
302            .as_ref()
303            .expect("This value is set in the previous line, so it cannot be None"))
304    }
305
306    /// The [`AuthServer`] that issues the tokens.
307    #[must_use]
308    pub const fn auth_server(&self) -> &AuthServer {
309        &self.auth_server
310    }
311
312    /// Validate the access token, returning it if it is valid, or an error describing why it is
313    /// invalid.
314    ///
315    /// # Errors
316    ///
317    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
318    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
319    pub fn validate(&self) -> Result<String, TokenError> {
320        self.access_token().map_or_else(
321            |_| Err(TokenError::NoAccessToken),
322            |access_token| {
323                let placeholder_key = DecodingKey::from_secret(&[]);
324                let mut validation = Validation::new(Algorithm::RS256);
325                validation.validate_exp = true;
326                validation.leeway = 60;
327                validation.set_audience(&[QCS_AUDIENCE]);
328                validation.insecure_disable_signature_validation();
329                jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
330                    .map(|_| access_token.to_string())
331                    .map_err(TokenError::InvalidAccessToken)
332            },
333        )
334    }
335}
336
337/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
338#[derive(Clone, Debug)]
339#[cfg_attr(feature = "python", pyo3::pyclass)]
340pub struct TokenDispatcher {
341    lock: Arc<RwLock<OAuthSession>>,
342    refreshing: Arc<Mutex<bool>>,
343    notify_refreshed: Arc<Notify>,
344}
345
346impl From<OAuthSession> for TokenDispatcher {
347    fn from(value: OAuthSession) -> Self {
348        Self {
349            lock: Arc::new(RwLock::new(value)),
350            refreshing: Arc::new(Mutex::new(false)),
351            notify_refreshed: Arc::new(Notify::new()),
352        }
353    }
354}
355
356impl TokenDispatcher {
357    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
358    /// dispatcher.
359    ///
360    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
361    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
362    ///
363    /// # Parameters
364    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
365    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
366    pub async fn use_tokens<F, O>(&self, f: F) -> O
367    where
368        F: FnOnce(&OAuthSession) -> O + Send,
369    {
370        let tokens = self.lock.read().await;
371        f(&tokens)
372    }
373
374    /// Get a copy of the current access token.
375    #[must_use]
376    pub async fn tokens(&self) -> OAuthSession {
377        self.use_tokens(Clone::clone).await
378    }
379
380    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
381    ///
382    /// # Errors
383    ///
384    /// See [`TokenError`]
385    pub async fn refresh(
386        &self,
387        source: &ConfigSource,
388        profile: &str,
389    ) -> Result<OAuthSession, TokenError> {
390        self.managed_refresh(Self::perform_refresh, source, profile)
391            .await
392    }
393
394    /// Validate the access token, returning it if it is valid, or an error describing why it is
395    /// invalid.
396    ///
397    /// # Errors
398    ///
399    /// - [`TokenError::NoAccessToken`] if there is no access token
400    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
401    pub async fn validate(&self) -> Result<String, TokenError> {
402        self.use_tokens(OAuthSession::validate).await
403    }
404
405    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
406    /// ``refresh_fn``.
407    async fn managed_refresh<F, Fut>(
408        &self,
409        refresh_fn: F,
410        source: &ConfigSource,
411        profile: &str,
412    ) -> Result<OAuthSession, TokenError>
413    where
414        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
415        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
416    {
417        let mut is_refreshing = self.refreshing.lock().await;
418
419        if *is_refreshing {
420            drop(is_refreshing);
421            self.notify_refreshed.notified().await;
422            return Ok(self.tokens().await);
423        }
424
425        *is_refreshing = true;
426        drop(is_refreshing);
427
428        let oauth_session = refresh_fn(self.lock.clone()).await?;
429
430        // If the config source is a file, write the new access token to the file
431        if let ConfigSource::File {
432            settings_path: _,
433            secrets_path,
434        } = source
435        {
436            if !Secrets::is_read_only(secrets_path).await? {
437                let now = OffsetDateTime::now_utc();
438                Secrets::write_access_token(
439                    secrets_path,
440                    profile,
441                    oauth_session.access_token()?,
442                    now,
443                )
444                .await?;
445            }
446        }
447
448        *self.refreshing.lock().await = false;
449        self.notify_refreshed.notify_waiters();
450        Ok(oauth_session)
451    }
452
453    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
454    /// of the updated [`Credentials`]
455    ///
456    /// # Errors
457    ///
458    /// See [`TokenError`]
459    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
460        let mut credentials = lock.write().await;
461        credentials.request_access_token().await?;
462        Ok(credentials.clone())
463    }
464}
465
466pub(crate) type RefreshResult =
467    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
468
469/// A function that asynchronously refreshes a token.
470pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
471
472/// A struct that manages access tokens by utilizing a user-provided refresh function.
473///
474/// The [`ExternallyManaged`] struct allows users to define custom logic for
475/// fetching or refreshing access tokens.
476#[derive(Clone)]
477#[cfg_attr(feature = "python", pyo3::pyclass)]
478pub struct ExternallyManaged {
479    refresh_function: Arc<RefreshFunction>,
480}
481
482impl ExternallyManaged {
483    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
484    ///
485    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
486    /// they better fit your use case.
487    ///
488    /// # Arguments
489    ///
490    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
491    ///
492    /// # Example
493    ///
494    /// ```
495    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
496    /// use std::future::Future;
497    /// use std::pin::Pin;
498    /// use std::boxed::Box;
499    /// use std::error::Error;
500    ///
501    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
502    /// + Send + Sync>> {
503    ///     Ok("new_token_value".to_string())
504    /// }
505    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
506    /// ```
507    pub fn new(
508        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
509    ) -> Self {
510        Self {
511            refresh_function: Arc::new(Box::new(refresh_function)),
512        }
513    }
514
515    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
516    ///
517    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
518    /// the boxing and pinning of the future internally.
519    ///
520    /// # Arguments
521    ///
522    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
523    ///   produces a [`Result<String, TokenError>`].
524    ///
525    /// # Example
526    ///
527    /// ```
528    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
529    /// use tokio::runtime::Runtime;
530    /// use std::error::Error;
531    ///
532    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
533    /// + Send + Sync>> {
534    ///     Ok("new_token_value".to_string())
535    /// }
536    ///
537    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
538    ///
539    /// let rt = Runtime::new().unwrap();
540    /// rt.block_on(async {
541    ///     match token_manager.request_access_token(&AuthServer::default()).await {
542    ///         Ok(token) => println!("Token: {}", token),
543    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
544    ///     }
545    /// });
546    /// ```
547    pub fn from_async<F, Fut>(refresh_function: F) -> Self
548    where
549        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
550        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
551            + Send
552            + 'static,
553    {
554        Self {
555            refresh_function: Arc::new(Box::new(move |auth_server| {
556                Box::pin(refresh_function(auth_server))
557            })),
558        }
559    }
560
561    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
562    ///
563    /// The synchronous function is wrapped in an async block to fit the expected signature.
564    ///
565    /// # Arguments
566    ///
567    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
568    ///
569    /// # Example
570    ///
571    /// ```
572    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
573    /// use tokio::runtime::Runtime;
574    /// use std::error::Error;
575    ///
576    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
577    /// + Send + Sync>> {
578    ///     Ok("sync_token_value".to_string())
579    /// }
580    ///
581    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
582    ///
583    /// let rt = Runtime::new().unwrap();
584    /// rt.block_on(async {
585    ///     match token_manager.request_access_token(&AuthServer::default()).await {
586    ///         Ok(token) => println!("Token: {}", token),
587    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
588    ///     }
589    /// });
590    /// ```
591    pub fn from_sync(
592        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
593            + Send
594            + Sync
595            + 'static,
596    ) -> Self {
597        Self {
598            refresh_function: Arc::new(Box::new(move |auth_server| {
599                let result = refresh_function(auth_server);
600                Box::pin(async move { result })
601            })),
602        }
603    }
604
605    /// Request an updated access token using the provided refresh function.
606    ///
607    /// # Errors
608    ///
609    /// Errors are propagated from the refresh function.
610    pub async fn request_access_token(
611        &self,
612        auth_server: &AuthServer,
613    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
614        (self.refresh_function)(auth_server.clone()).await
615    }
616}
617
618impl std::fmt::Debug for ExternallyManaged {
619    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620        f.debug_struct("ExternallyManaged")
621            .field(
622                "refresh_function",
623                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
624            )
625            .finish()
626    }
627}
628
629#[derive(Debug, Serialize, Deserialize)]
630pub(super) struct TokenRefreshRequest<'a> {
631    grant_type: &'static str,
632    client_id: &'a str,
633    refresh_token: &'a str,
634}
635
636impl<'a> TokenRefreshRequest<'a> {
637    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
638        Self {
639            grant_type: "refresh_token",
640            client_id,
641            refresh_token,
642        }
643    }
644}
645
646#[derive(Debug, Serialize, Deserialize)]
647pub(super) struct ClientCredentialsRequest<'a> {
648    grant_type: &'static str,
649    client_id: &'a str,
650    client_secret: &'a str,
651}
652
653impl<'a> ClientCredentialsRequest<'a> {
654    pub(super) const fn new(client_id: &'a str, client_secret: &'a str) -> Self {
655        Self {
656            grant_type: "client_credentials",
657            client_id,
658            client_secret,
659        }
660    }
661}
662
663#[derive(Deserialize, Debug, Serialize)]
664pub(super) struct TokenResponse {
665    pub(super) refresh_token: String,
666    pub(super) access_token: String,
667}
668
669/// Get and refresh access tokens
670#[async_trait::async_trait]
671pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
672    /// The type to be returned in the event of a error during getting or
673    /// refreshing an access token
674    type Error;
675
676    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
677    async fn validated_access_token(&self) -> Result<String, Self::Error>;
678
679    /// Get the current access token, if any
680    async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
681
682    /// Get a fresh access token
683    async fn refresh_access_token(&self) -> Result<String, Self::Error>;
684
685    /// Get the base URL for requests
686    #[cfg(feature = "tracing")]
687    fn base_url(&self) -> &str;
688
689    /// Get the tracing configuration
690    #[cfg(feature = "tracing-config")]
691    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
692
693    /// Returns whether the given URL should be traced. Following
694    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
695    #[cfg(feature = "tracing")]
696    #[allow(clippy::needless_return)]
697    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
698        #[cfg(not(feature = "tracing-config"))]
699        {
700            let _ = url;
701            return true;
702        }
703
704        #[cfg(feature = "tracing-config")]
705        self.tracing_configuration()
706            .is_none_or(|config| config.is_enabled(url))
707    }
708}
709
710#[async_trait::async_trait]
711impl TokenRefresher for ClientConfiguration {
712    type Error = TokenError;
713
714    async fn validated_access_token(&self) -> Result<String, Self::Error> {
715        self.get_bearer_access_token().await
716    }
717
718    async fn refresh_access_token(&self) -> Result<String, Self::Error> {
719        Ok(self.refresh().await?.access_token()?.to_string())
720    }
721
722    async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
723        Ok(Some(
724            self.oauth_session().await?.access_token()?.to_string(),
725        ))
726    }
727
728    #[cfg(feature = "tracing")]
729    fn base_url(&self) -> &str {
730        &self.grpc_api_url
731    }
732
733    #[cfg(feature = "tracing-config")]
734    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
735        self.tracing_configuration.as_ref()
736    }
737}
738
739#[cfg(test)]
740mod test {
741    use std::time::Duration;
742
743    use super::*;
744    use httpmock::prelude::*;
745    use rstest::rstest;
746    use time::format_description::well_known::Rfc3339;
747    use tokio::time::Instant;
748    use toml_edit::DocumentMut;
749
750    #[tokio::test]
751    async fn test_tokens_blocked_during_refresh() {
752        let mock_server = MockServer::start_async().await;
753
754        let issuer_mock = mock_server
755            .mock_async(|when, then| {
756                when.method(POST).path("/v1/token");
757
758                then.status(200)
759                    .delay(Duration::from_secs(3))
760                    .json_body_obj(&TokenResponse {
761                        access_token: "new_access".to_string(),
762                        refresh_token: "new_refresh".to_string(),
763                    });
764            })
765            .await;
766
767        let original_tokens = OAuthSession::from_refresh_token(
768            RefreshToken::new("refresh".to_string()),
769            AuthServer::new("client_id".to_string(), mock_server.base_url()),
770            None,
771        );
772        let dispatcher: TokenDispatcher = original_tokens.clone().into();
773        let dispatcher_clone1 = dispatcher.clone();
774        let dispatcher_clone2 = dispatcher.clone();
775
776        let refresh_duration = Duration::from_secs(3);
777
778        let start_write = Instant::now();
779        let write_future = tokio::spawn(async move {
780            dispatcher_clone1
781                .refresh(&ConfigSource::Default, "")
782                .await
783                .unwrap()
784        });
785
786        let start_read = Instant::now();
787        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
788
789        let _ = write_future.await.unwrap();
790        let read_result = read_future.await.unwrap();
791
792        let write_duration = start_write.elapsed();
793        let read_duration = start_read.elapsed();
794
795        issuer_mock.assert_async().await;
796
797        assert!(
798            write_duration >= refresh_duration,
799            "Write operation did not take enough time"
800        );
801        assert!(
802            read_duration >= refresh_duration,
803            "Read operation was not blocked by the write operation"
804        );
805        assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
806        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
807            assert_eq!(&payload.refresh_token, "new_refresh");
808        } else {
809            panic!(
810                "Expected RefreshToken payload, got {:?}",
811                read_result.payload
812            );
813        }
814    }
815
816    #[rstest]
817    fn test_qcs_secrets_readonly(
818        #[values(
819            (Some("TRUE"), true),
820            (Some("tRue"), true),
821            (Some("true"), true),
822            (Some("YES"), true),
823            (Some("yEs"), true),
824            (Some("yes"), true),
825            (Some("1"), true),
826            (Some("2"), false),
827            (Some("other"), false),
828            (Some(""), false),
829            (None, false),
830        )]
831        read_only_values: (Option<&str>, bool),
832        #[values(true, false)] read_only_perm: bool,
833    ) {
834        let (maybe_read_only_env, env_is_read_only) = read_only_values;
835        let expected_update = !env_is_read_only && !read_only_perm;
836        figment::Jail::expect_with(|jail| {
837            let profile_name = "test";
838            let initial_access_token = "initial_access_token";
839            let initial_refresh_token = "initial_refresh_token";
840
841            let initial_secrets_file_contents = format!(
842                r#"
843[credentials]
844[credentials.{profile_name}]
845[credentials.{profile_name}.token_payload]
846access_token = "{initial_access_token}"
847expires_in = 3600
848id_token = "id_token"
849refresh_token = "{initial_refresh_token}"
850scope = "offline_access openid profile email"
851token_type = "Bearer"
852updated_at = "2024-01-01T00:00:00Z"
853"#
854            );
855
856            // Create a temporary secrets file
857            let secrets_path = "secrets.toml";
858            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
859                .expect("should create test secrets.toml");
860
861            if read_only_perm {
862                let mut permissions = std::fs::metadata(secrets_path)
863                    .expect("Should be able to get file metadata")
864                    .permissions();
865                permissions.set_readonly(true);
866                std::fs::set_permissions(secrets_path, permissions)
867                    .expect("Should be able to set file permissions");
868            }
869
870            let rt = tokio::runtime::Runtime::new().unwrap();
871            rt.block_on(async {
872                let mock_server = MockServer::start_async().await;
873
874                // Set up the mock token endpoint
875                let new_access_token = "new_access_token";
876                let issuer_mock = mock_server
877                    .mock_async(|when, then| {
878                        when.method(POST).path("/v1/token");
879                        then.status(200).json_body_obj(&TokenResponse {
880                            access_token: new_access_token.to_string(),
881                            refresh_token: initial_refresh_token.to_string(),
882                        });
883                    })
884                    .await;
885
886                // Create tokens and dispatcher
887                let original_tokens = OAuthSession::from_refresh_token(
888                    RefreshToken::new(initial_refresh_token.to_string()),
889                    AuthServer::new("client_id".to_string(), mock_server.base_url()),
890                    Some(initial_refresh_token.to_string()),
891                );
892                let dispatcher: TokenDispatcher = original_tokens.into();
893
894                // Test with QCS_SECRETS_READ_ONLY set first
895                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
896                jail.set_env("QCS_PROFILE_NAME", "test");
897                if let Some(read_only_env) = maybe_read_only_env {
898                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
899                }
900
901                let before_refresh = OffsetDateTime::now_utc();
902
903                dispatcher
904                    .refresh(
905                        &ConfigSource::File {
906                            settings_path: "".into(),
907                            secrets_path: "secrets.toml".into(),
908                        },
909                        profile_name,
910                    )
911                    .await
912                    .unwrap();
913
914                issuer_mock.assert_async().await;
915
916                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
917                let content = std::fs::read_to_string("secrets.toml").unwrap();
918                if !expected_update {
919                    assert!(
920                        content.eq(initial_secrets_file_contents.as_str()),
921                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
922                    );
923                    return;
924                }
925
926                // Verify the file was updated
927                let mut toml = std::fs::read_to_string(secrets_path)
928                    .unwrap()
929                    .parse::<DocumentMut>()
930                    .unwrap();
931
932                let token_payload = toml
933                    .get_mut("credentials")
934                    .and_then(|credentials| {
935                        credentials.get_mut(profile_name)?.get_mut("token_payload")
936                    })
937                    .expect("Should be able to get token_payload table");
938
939                assert_eq!(
940                    token_payload.get("access_token").unwrap().as_str().unwrap(),
941                    new_access_token
942                );
943
944                assert!(
945                    OffsetDateTime::parse(
946                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
947                        &Rfc3339
948                    )
949                    .unwrap()
950                        > before_refresh
951                );
952
953                let content = std::fs::read_to_string("secrets.toml").unwrap();
954                assert!(
955                content.contains("new_access_token"),
956                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
957                );
958            });
959            Ok(())
960        });
961    }
962}