qcs_api_client_common/configuration/
tokens.rs

1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use http::{header::CONTENT_TYPE, HeaderMap, HeaderValue};
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8use tokio::sync::{Mutex, Notify, RwLock};
9
10use super::{
11    secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
12    QCS_AUDIENCE,
13};
14#[cfg(feature = "tracing-config")]
15use crate::tracing_configuration::TracingConfiguration;
16#[cfg(feature = "tracing")]
17use urlpattern::UrlPatternMatchInput;
18
19/// A single type containing an access token and an associated refresh token.
20#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[cfg_attr(feature = "python", pyo3::pyclass)]
22pub struct RefreshToken {
23    /// The token used to refresh the access token.
24    pub refresh_token: String,
25}
26
27impl RefreshToken {
28    /// Create a new [`RefreshToken`] with the given refresh token.
29    #[must_use]
30    pub const fn new(refresh_token: String) -> Self {
31        Self { refresh_token }
32    }
33
34    /// Request and return a new access token from the given authorization server using this refresh token.
35    ///
36    /// # Errors
37    ///
38    /// See [`TokenError`]
39    pub async fn request_access_token(
40        &mut self,
41        auth_server: &AuthServer,
42    ) -> Result<String, TokenError> {
43        if self.refresh_token.is_empty() {
44            return Err(TokenError::NoRefreshToken);
45        }
46        let token_url = format!("{}/v1/token", auth_server.issuer());
47        let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
48        let resp = reqwest::Client::builder()
49            .timeout(std::time::Duration::from_secs(10))
50            .build()?
51            .post(token_url)
52            .form(&data)
53            .send()
54            .await?;
55
56        let response_data: TokenResponse = resp.error_for_status()?.json().await?;
57        self.refresh_token = response_data.refresh_token;
58        Ok(response_data.access_token)
59    }
60}
61
62#[derive(Clone, PartialEq, Eq, Deserialize)]
63#[expect(missing_debug_implementations, reason = "contains secret data")]
64#[cfg_attr(feature = "python", pyo3::pyclass)]
65/// A pair of Client ID and Client Secret, used to request an OAuth Client Credentials Grant
66pub struct ClientCredentials {
67    /// The client ID
68    pub client_id: String,
69    /// The client secret.
70    pub client_secret: String,
71}
72
73impl ClientCredentials {
74    #[must_use]
75    /// Construct a new [`ClientCredentials`]
76    pub const fn new(client_id: String, client_secret: String) -> Self {
77        Self {
78            client_id,
79            client_secret,
80        }
81    }
82
83    /// Get the client ID.
84    #[must_use]
85    pub fn client_id(&self) -> &str {
86        &self.client_id
87    }
88
89    /// Get the client secret.
90    #[must_use]
91    pub fn client_secret(&self) -> &str {
92        &self.client_secret
93    }
94
95    /// Request and return an access token from the given auth server using this set of client credentials.
96    ///
97    /// # Errors
98    ///
99    /// See [`TokenError`]
100    pub async fn request_access_token(
101        &self,
102        auth_server: &AuthServer,
103    ) -> Result<String, TokenError> {
104        let request = ClientCredentialsRequest::new(&self.client_id, &self.client_secret);
105        let url = format!("{}/v1/token", auth_server.issuer());
106
107        // let credentials = format!("{}:{}", self.auth_server.client_id(), self.client_secret);
108        // let encoded_credentials = base64::encode(credentials);
109        // let authorization_value = format!("Basic {}", encoded_credentials);
110        let mut headers = HeaderMap::new();
111        // headers.insert(AUTHORIZATION, HeaderValue::from_str(&authorization_value)?);
112        headers.insert(
113            CONTENT_TYPE,
114            HeaderValue::from_static("application/x-www-form-urlencoded"),
115        );
116
117        let client = reqwest::Client::builder()
118            .timeout(std::time::Duration::from_secs(10))
119            .build()?;
120
121        let response = client
122            .post(url)
123            .headers(headers)
124            .form(&request)
125            .send()
126            .await?;
127
128        response.error_for_status_ref()?;
129
130        let response_body: TokenResponse = response.json().await?;
131
132        Ok(response_body.access_token)
133    }
134}
135
136#[derive(Clone)]
137#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
138/// Specifies the [OAuth2 grant type](https://oauth.net/2/grant-types/) to use, along with the data
139/// needed to request said grant type.
140pub enum OAuthGrant {
141    /// Credentials that can be used to use with the [Refresh Token grant type](https://oauth.net/2/grant-types/refresh-token/).
142    RefreshToken(RefreshToken),
143    /// Payload that can be used to use the [Client Credentials grant type](https://oauth.net/2/grant-types/client-credentials/).
144    ClientCredentials(ClientCredentials),
145    /// Defers to a user provided function for access token requests.
146    ExternallyManaged(ExternallyManaged),
147}
148
149impl From<ExternallyManaged> for OAuthGrant {
150    fn from(v: ExternallyManaged) -> Self {
151        Self::ExternallyManaged(v)
152    }
153}
154
155impl From<ClientCredentials> for OAuthGrant {
156    fn from(v: ClientCredentials) -> Self {
157        Self::ClientCredentials(v)
158    }
159}
160
161impl From<RefreshToken> for OAuthGrant {
162    fn from(v: RefreshToken) -> Self {
163        Self::RefreshToken(v)
164    }
165}
166
167impl OAuthGrant {
168    /// Request a new access token from the given issuer using this grant type and payload.
169    async fn request_access_token(
170        &mut self,
171        auth_server: &AuthServer,
172    ) -> Result<String, TokenError> {
173        match self {
174            Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
175            Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
176            Self::ExternallyManaged(tokens) => tokens
177                .request_access_token(auth_server)
178                .await
179                .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
180        }
181    }
182}
183
184impl std::fmt::Debug for OAuthGrant {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        match self {
187            Self::RefreshToken(_) => f.write_str("RefreshToken"),
188            Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
189            Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
190        }
191    }
192}
193
194/// Manages the `OAuth2` authorization process and token lifecycle for accessing the QCS API.
195///
196/// This struct encapsulates the necessary information to request an access token
197/// from an authorization server, including the `OAuth2` grant type and any associated
198/// credentials or payload data.
199///
200/// # Fields
201///
202/// * `payload` - The `OAuth2` grant type and associated data that will be used to request an access token.
203/// * `access_token` - The access token currently in use, if any. If no token has been provided or requested yet, this will be `None`.
204/// * `auth_server` - The authorization server responsible for issuing tokens.
205#[derive(Clone)]
206#[cfg_attr(feature = "python", pyo3::pyclass)]
207pub struct OAuthSession {
208    /// The grant type to use to request an access token.
209    payload: OAuthGrant,
210    /// The access token that is currently in use. None if no token has been requested yet.
211    access_token: Option<String>,
212    /// The [`AuthServer`] that issues the tokens.
213    auth_server: AuthServer,
214}
215
216impl OAuthSession {
217    /// Initialize a new set of [`Credentials`] using a [`GrantPayload`].
218    ///
219    /// Optionally include an `access_token`, if not included, then one can be requested
220    /// with [`Self::request_access_token`].
221    #[must_use]
222    pub const fn new(
223        payload: OAuthGrant,
224        auth_server: AuthServer,
225        access_token: Option<String>,
226    ) -> Self {
227        Self {
228            payload,
229            access_token,
230            auth_server,
231        }
232    }
233
234    /// Initialize a new set of [`Credentials`] using an [`ExternallyManaged`].
235    ///
236    /// Optionally include an `access_token`, if not included, then one can be requested
237    /// with [`Self::request_access_token`].
238    #[must_use]
239    pub const fn from_externally_managed(
240        tokens: ExternallyManaged,
241        auth_server: AuthServer,
242        access_token: Option<String>,
243    ) -> Self {
244        Self::new(
245            OAuthGrant::ExternallyManaged(tokens),
246            auth_server,
247            access_token,
248        )
249    }
250
251    /// Initialize a new set of [`Credentials`] using a [`RefreshToken`].
252    ///
253    /// Optionally include an `access_token`, if not included, then one can be requested
254    /// with [`Self::request_access_token`].
255    #[must_use]
256    pub const fn from_refresh_token(
257        tokens: RefreshToken,
258        auth_server: AuthServer,
259        access_token: Option<String>,
260    ) -> Self {
261        Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
262    }
263
264    /// Initialize a new set of [`Credentials`] using [`ClientCredentials`].
265    ///
266    /// Optionally include an `access_token`, if not included, then one can be requested
267    /// with [`Self::request_access_token`].
268    #[must_use]
269    pub const fn from_client_credentials(
270        tokens: ClientCredentials,
271        auth_server: AuthServer,
272        access_token: Option<String>,
273    ) -> Self {
274        Self::new(
275            OAuthGrant::ClientCredentials(tokens),
276            auth_server,
277            access_token,
278        )
279    }
280
281    /// Get the current access token.
282    ///
283    /// This is an unvalidated copy of the access token. Meaning it can become stale, or may
284    /// even be already be stale. See [`Self::validate`] and [`Self::request_access_token`].
285    ///
286    /// # Errors
287    ///
288    /// - [`TokenError::NoAccessToken`] if there is no access token
289    pub fn access_token(&self) -> Result<&str, TokenError> {
290        self.access_token.as_ref().map_or_else(
291            || Err(TokenError::NoAccessToken),
292            |token| Ok(token.as_str()),
293        )
294    }
295
296    /// Get the payload used to request an access token.
297    #[must_use]
298    pub const fn payload(&self) -> &OAuthGrant {
299        &self.payload
300    }
301
302    /// Request and return an updated access token using these credentials.
303    ///
304    /// # Errors
305    ///
306    /// See [`TokenError`]
307    #[allow(clippy::missing_panics_doc)]
308    pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
309        let access_token = self.payload.request_access_token(&self.auth_server).await?;
310        self.access_token = Some(access_token);
311        Ok(self
312            .access_token
313            .as_ref()
314            .expect("This value is set in the previous line, so it cannot be None"))
315    }
316
317    /// The [`AuthServer`] that issues the tokens.
318    #[must_use]
319    pub const fn auth_server(&self) -> &AuthServer {
320        &self.auth_server
321    }
322
323    /// Validate the access token, returning it if it is valid, or an error describing why it is
324    /// invalid.
325    ///
326    /// # Errors
327    ///
328    /// - [`TokenError::NoAccessToken`] if an access token has not been requested.
329    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid.
330    pub fn validate(&self) -> Result<String, TokenError> {
331        self.access_token().map_or_else(
332            |_| Err(TokenError::NoAccessToken),
333            |access_token| {
334                let placeholder_key = DecodingKey::from_secret(&[]);
335                let mut validation = Validation::new(Algorithm::RS256);
336                validation.validate_exp = true;
337                validation.leeway = 60;
338                validation.set_audience(&[QCS_AUDIENCE]);
339                validation.insecure_disable_signature_validation();
340                jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
341                    .map(|_| access_token.to_string())
342                    .map_err(TokenError::InvalidAccessToken)
343            },
344        )
345    }
346}
347
348impl std::fmt::Debug for OAuthSession {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        let token_populated = if self.access_token.is_some() {
351            Some(())
352        } else {
353            None
354        };
355        f.debug_struct("OAuthSession")
356            .field("payload", &self.payload)
357            .field("access_token", &token_populated)
358            .field("auth_server", &self.auth_server)
359            .finish()
360    }
361}
362
363/// A wrapper for [`OAuthSession`] that provides thread-safe access to the inner tokens.
364#[derive(Clone, Debug)]
365#[cfg_attr(feature = "python", pyo3::pyclass)]
366pub struct TokenDispatcher {
367    lock: Arc<RwLock<OAuthSession>>,
368    refreshing: Arc<Mutex<bool>>,
369    notify_refreshed: Arc<Notify>,
370}
371
372impl From<OAuthSession> for TokenDispatcher {
373    fn from(value: OAuthSession) -> Self {
374        Self {
375            lock: Arc::new(RwLock::new(value)),
376            refreshing: Arc::new(Mutex::new(false)),
377            notify_refreshed: Arc::new(Notify::new()),
378        }
379    }
380}
381
382impl TokenDispatcher {
383    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
384    /// dispatcher.
385    ///
386    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
387    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
388    ///
389    /// # Parameters
390    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
391    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
392    pub async fn use_tokens<F, O>(&self, f: F) -> O
393    where
394        F: FnOnce(&OAuthSession) -> O + Send,
395    {
396        let tokens = self.lock.read().await;
397        f(&tokens)
398    }
399
400    /// Get a copy of the current access token.
401    #[must_use]
402    pub async fn tokens(&self) -> OAuthSession {
403        self.use_tokens(Clone::clone).await
404    }
405
406    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
407    ///
408    /// # Errors
409    ///
410    /// See [`TokenError`]
411    pub async fn refresh(
412        &self,
413        source: &ConfigSource,
414        profile: &str,
415    ) -> Result<OAuthSession, TokenError> {
416        self.managed_refresh(Self::perform_refresh, source, profile)
417            .await
418    }
419
420    /// Validate the access token, returning it if it is valid, or an error describing why it is
421    /// invalid.
422    ///
423    /// # Errors
424    ///
425    /// - [`TokenError::NoAccessToken`] if there is no access token
426    /// - [`TokenError::InvalidAccessToken`] if the access token is invalid
427    pub async fn validate(&self) -> Result<String, TokenError> {
428        self.use_tokens(OAuthSession::validate).await
429    }
430
431    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
432    /// ``refresh_fn``.
433    async fn managed_refresh<F, Fut>(
434        &self,
435        refresh_fn: F,
436        source: &ConfigSource,
437        profile: &str,
438    ) -> Result<OAuthSession, TokenError>
439    where
440        F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
441        Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
442    {
443        let mut is_refreshing = self.refreshing.lock().await;
444
445        if *is_refreshing {
446            drop(is_refreshing);
447            self.notify_refreshed.notified().await;
448            return Ok(self.tokens().await);
449        }
450
451        *is_refreshing = true;
452        drop(is_refreshing);
453
454        let oauth_session = refresh_fn(self.lock.clone()).await?;
455
456        // If the config source is a file, write the new access token to the file
457        if let ConfigSource::File {
458            settings_path: _,
459            secrets_path,
460        } = source
461        {
462            if !Secrets::is_read_only(secrets_path).await? {
463                let now = OffsetDateTime::now_utc();
464                Secrets::write_access_token(
465                    secrets_path,
466                    profile,
467                    oauth_session.access_token()?,
468                    now,
469                )
470                .await?;
471            }
472        }
473
474        *self.refreshing.lock().await = false;
475        self.notify_refreshed.notify_waiters();
476        Ok(oauth_session)
477    }
478
479    /// Refreshes the tokens. Readers will be blocked until the refresh is complete. Returns a copy
480    /// of the updated [`Credentials`]
481    ///
482    /// # Errors
483    ///
484    /// See [`TokenError`]
485    async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
486        let mut credentials = lock.write().await;
487        credentials.request_access_token().await?;
488        Ok(credentials.clone())
489    }
490}
491
492pub(crate) type RefreshResult =
493    Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
494
495/// A function that asynchronously refreshes a token.
496pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
497
498/// A struct that manages access tokens by utilizing a user-provided refresh function.
499///
500/// The [`ExternallyManaged`] struct allows users to define custom logic for
501/// fetching or refreshing access tokens.
502#[derive(Clone)]
503#[cfg_attr(feature = "python", pyo3::pyclass)]
504pub struct ExternallyManaged {
505    refresh_function: Arc<RefreshFunction>,
506}
507
508impl ExternallyManaged {
509    /// Creates a new [`ExternallyManaged`] instance from a [`RefreshFunction`].
510    ///
511    /// Consider using [`ExternallyManaged::from_async`], and [`ExternallyManaged::from_sync`], if
512    /// they better fit your use case.
513    ///
514    /// # Arguments
515    ///
516    /// * `refresh_function` - A function or closure that asynchronously refreshes a token.
517    ///
518    /// # Example
519    ///
520    /// ```
521    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
522    /// use std::future::Future;
523    /// use std::pin::Pin;
524    /// use std::boxed::Box;
525    /// use std::error::Error;
526    ///
527    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
528    /// + Send + Sync>> {
529    ///     Ok("new_token_value".to_string())
530    /// }
531    /// let token_manager = ExternallyManaged::new(|auth_server| Box::pin(example_refresh_function(auth_server)));
532    /// ```
533    pub fn new(
534        refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
535    ) -> Self {
536        Self {
537            refresh_function: Arc::new(Box::new(refresh_function)),
538        }
539    }
540
541    /// Constructs a new [`ExternallyManaged`] instance using an async function or closure.
542    ///
543    /// This method simplifies the creation of the [`ExternallyManaged`] instance by handling
544    /// the boxing and pinning of the future internally.
545    ///
546    /// # Arguments
547    ///
548    /// * `refresh_function` - An async function or closure that returns a [`Future`] which, when awaited,
549    ///   produces a [`Result<String, TokenError>`].
550    ///
551    /// # Example
552    ///
553    /// ```
554    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
555    /// use tokio::runtime::Runtime;
556    /// use std::error::Error;
557    ///
558    /// async fn example_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
559    /// + Send + Sync>> {
560    ///     Ok("new_token_value".to_string())
561    /// }
562    ///
563    /// let token_manager = ExternallyManaged::from_async(example_refresh_function);
564    ///
565    /// let rt = Runtime::new().unwrap();
566    /// rt.block_on(async {
567    ///     match token_manager.request_access_token(&AuthServer::default()).await {
568    ///         Ok(token) => println!("Token: {}", token),
569    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
570    ///     }
571    /// });
572    /// ```
573    pub fn from_async<F, Fut>(refresh_function: F) -> Self
574    where
575        F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
576        Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
577            + Send
578            + 'static,
579    {
580        Self {
581            refresh_function: Arc::new(Box::new(move |auth_server| {
582                Box::pin(refresh_function(auth_server))
583            })),
584        }
585    }
586
587    /// Constructs a new [`ExternallyManaged`] instance using a synchronous function.
588    ///
589    /// The synchronous function is wrapped in an async block to fit the expected signature.
590    ///
591    /// # Arguments
592    ///
593    /// * `refresh_function` - A synchronous function that returns a [`Result<String, TokenError>`].
594    ///
595    /// # Example
596    ///
597    /// ```
598    /// use qcs_api_client_common::configuration::{AuthServer, ExternallyManaged, TokenError};
599    /// use tokio::runtime::Runtime;
600    /// use std::error::Error;
601    ///
602    /// fn example_sync_refresh_function(_auth_server: AuthServer) -> Result<String, Box<dyn Error
603    /// + Send + Sync>> {
604    ///     Ok("sync_token_value".to_string())
605    /// }
606    ///
607    /// let token_manager = ExternallyManaged::from_sync(example_sync_refresh_function);
608    ///
609    /// let rt = Runtime::new().unwrap();
610    /// rt.block_on(async {
611    ///     match token_manager.request_access_token(&AuthServer::default()).await {
612    ///         Ok(token) => println!("Token: {}", token),
613    ///         Err(e) => println!("Failed to refresh token: {:?}", e),
614    ///     }
615    /// });
616    /// ```
617    pub fn from_sync(
618        refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
619            + Send
620            + Sync
621            + 'static,
622    ) -> Self {
623        Self {
624            refresh_function: Arc::new(Box::new(move |auth_server| {
625                let result = refresh_function(auth_server);
626                Box::pin(async move { result })
627            })),
628        }
629    }
630
631    /// Request an updated access token using the provided refresh function.
632    ///
633    /// # Errors
634    ///
635    /// Errors are propagated from the refresh function.
636    pub async fn request_access_token(
637        &self,
638        auth_server: &AuthServer,
639    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
640        (self.refresh_function)(auth_server.clone()).await
641    }
642}
643
644impl std::fmt::Debug for ExternallyManaged {
645    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
646        f.debug_struct("ExternallyManaged")
647            .field(
648                "refresh_function",
649                &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
650            )
651            .finish()
652    }
653}
654
655#[derive(Debug, Serialize, Deserialize)]
656pub(super) struct TokenRefreshRequest<'a> {
657    grant_type: &'static str,
658    client_id: &'a str,
659    refresh_token: &'a str,
660}
661
662impl<'a> TokenRefreshRequest<'a> {
663    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
664        Self {
665            grant_type: "refresh_token",
666            client_id,
667            refresh_token,
668        }
669    }
670}
671
672#[derive(Debug, Serialize, Deserialize)]
673pub(super) struct ClientCredentialsRequest<'a> {
674    grant_type: &'static str,
675    client_id: &'a str,
676    client_secret: &'a str,
677}
678
679impl<'a> ClientCredentialsRequest<'a> {
680    pub(super) const fn new(client_id: &'a str, client_secret: &'a str) -> Self {
681        Self {
682            grant_type: "client_credentials",
683            client_id,
684            client_secret,
685        }
686    }
687}
688
689#[derive(Deserialize, Debug, Serialize)]
690pub(super) struct TokenResponse {
691    pub(super) refresh_token: String,
692    pub(super) access_token: String,
693}
694
695/// Get and refresh access tokens
696#[async_trait::async_trait]
697pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
698    /// The type to be returned in the event of a error during getting or
699    /// refreshing an access token
700    type Error;
701
702    /// Get and validate the current access token, refreshing it if it doesn't exist or is invalid.
703    async fn validated_access_token(&self) -> Result<String, Self::Error>;
704
705    /// Get the current access token, if any
706    async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
707
708    /// Get a fresh access token
709    async fn refresh_access_token(&self) -> Result<String, Self::Error>;
710
711    /// Get the base URL for requests
712    #[cfg(feature = "tracing")]
713    fn base_url(&self) -> &str;
714
715    /// Get the tracing configuration
716    #[cfg(feature = "tracing-config")]
717    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
718
719    /// Returns whether the given URL should be traced. Following
720    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
721    #[cfg(feature = "tracing")]
722    #[allow(clippy::needless_return)]
723    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
724        #[cfg(not(feature = "tracing-config"))]
725        {
726            let _ = url;
727            return true;
728        }
729
730        #[cfg(feature = "tracing-config")]
731        self.tracing_configuration()
732            .is_none_or(|config| config.is_enabled(url))
733    }
734}
735
736#[async_trait::async_trait]
737impl TokenRefresher for ClientConfiguration {
738    type Error = TokenError;
739
740    async fn validated_access_token(&self) -> Result<String, Self::Error> {
741        self.get_bearer_access_token().await
742    }
743
744    async fn refresh_access_token(&self) -> Result<String, Self::Error> {
745        Ok(self.refresh().await?.access_token()?.to_string())
746    }
747
748    async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
749        Ok(Some(
750            self.oauth_session().await?.access_token()?.to_string(),
751        ))
752    }
753
754    #[cfg(feature = "tracing")]
755    fn base_url(&self) -> &str {
756        &self.grpc_api_url
757    }
758
759    #[cfg(feature = "tracing-config")]
760    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
761        self.tracing_configuration.as_ref()
762    }
763}
764
765#[cfg(test)]
766mod test {
767    use std::time::Duration;
768
769    use super::*;
770    use httpmock::prelude::*;
771    use rstest::rstest;
772    use time::format_description::well_known::Rfc3339;
773    use tokio::time::Instant;
774    use toml_edit::DocumentMut;
775
776    #[tokio::test]
777    async fn test_tokens_blocked_during_refresh() {
778        let mock_server = MockServer::start_async().await;
779
780        let issuer_mock = mock_server
781            .mock_async(|when, then| {
782                when.method(POST).path("/v1/token");
783
784                then.status(200)
785                    .delay(Duration::from_secs(3))
786                    .json_body_obj(&TokenResponse {
787                        access_token: "new_access".to_string(),
788                        refresh_token: "new_refresh".to_string(),
789                    });
790            })
791            .await;
792
793        let original_tokens = OAuthSession::from_refresh_token(
794            RefreshToken::new("refresh".to_string()),
795            AuthServer::new("client_id".to_string(), mock_server.base_url()),
796            None,
797        );
798        let dispatcher: TokenDispatcher = original_tokens.clone().into();
799        let dispatcher_clone1 = dispatcher.clone();
800        let dispatcher_clone2 = dispatcher.clone();
801
802        let refresh_duration = Duration::from_secs(3);
803
804        let start_write = Instant::now();
805        let write_future = tokio::spawn(async move {
806            dispatcher_clone1
807                .refresh(&ConfigSource::Default, "")
808                .await
809                .unwrap()
810        });
811
812        let start_read = Instant::now();
813        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
814
815        let _ = write_future.await.unwrap();
816        let read_result = read_future.await.unwrap();
817
818        let write_duration = start_write.elapsed();
819        let read_duration = start_read.elapsed();
820
821        issuer_mock.assert_async().await;
822
823        assert!(
824            write_duration >= refresh_duration,
825            "Write operation did not take enough time"
826        );
827        assert!(
828            read_duration >= refresh_duration,
829            "Read operation was not blocked by the write operation"
830        );
831        assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
832        if let OAuthGrant::RefreshToken(payload) = read_result.payload {
833            assert_eq!(&payload.refresh_token, "new_refresh");
834        } else {
835            panic!(
836                "Expected RefreshToken payload, got {:?}",
837                read_result.payload
838            );
839        }
840    }
841
842    #[rstest]
843    fn test_qcs_secrets_readonly(
844        #[values(
845            (Some("TRUE"), true),
846            (Some("tRue"), true),
847            (Some("true"), true),
848            (Some("YES"), true),
849            (Some("yEs"), true),
850            (Some("yes"), true),
851            (Some("1"), true),
852            (Some("2"), false),
853            (Some("other"), false),
854            (Some(""), false),
855            (None, false),
856        )]
857        read_only_values: (Option<&str>, bool),
858        #[values(true, false)] read_only_perm: bool,
859    ) {
860        let (maybe_read_only_env, env_is_read_only) = read_only_values;
861        let expected_update = !env_is_read_only && !read_only_perm;
862        figment::Jail::expect_with(|jail| {
863            let profile_name = "test";
864            let initial_access_token = "initial_access_token";
865            let initial_refresh_token = "initial_refresh_token";
866
867            let initial_secrets_file_contents = format!(
868                r#"
869[credentials]
870[credentials.{profile_name}]
871[credentials.{profile_name}.token_payload]
872access_token = "{initial_access_token}"
873expires_in = 3600
874id_token = "id_token"
875refresh_token = "{initial_refresh_token}"
876scope = "offline_access openid profile email"
877token_type = "Bearer"
878updated_at = "2024-01-01T00:00:00Z"
879"#
880            );
881
882            // Create a temporary secrets file
883            let secrets_path = "secrets.toml";
884            jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
885                .expect("should create test secrets.toml");
886
887            if read_only_perm {
888                let mut permissions = std::fs::metadata(secrets_path)
889                    .expect("Should be able to get file metadata")
890                    .permissions();
891                permissions.set_readonly(true);
892                std::fs::set_permissions(secrets_path, permissions)
893                    .expect("Should be able to set file permissions");
894            }
895
896            let rt = tokio::runtime::Runtime::new().unwrap();
897            rt.block_on(async {
898                let mock_server = MockServer::start_async().await;
899
900                // Set up the mock token endpoint
901                let new_access_token = "new_access_token";
902                let issuer_mock = mock_server
903                    .mock_async(|when, then| {
904                        when.method(POST).path("/v1/token");
905                        then.status(200).json_body_obj(&TokenResponse {
906                            access_token: new_access_token.to_string(),
907                            refresh_token: initial_refresh_token.to_string(),
908                        });
909                    })
910                    .await;
911
912                // Create tokens and dispatcher
913                let original_tokens = OAuthSession::from_refresh_token(
914                    RefreshToken::new(initial_refresh_token.to_string()),
915                    AuthServer::new("client_id".to_string(), mock_server.base_url()),
916                    Some(initial_refresh_token.to_string()),
917                );
918                let dispatcher: TokenDispatcher = original_tokens.into();
919
920                // Test with QCS_SECRETS_READ_ONLY set first
921                jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
922                jail.set_env("QCS_PROFILE_NAME", "test");
923                if let Some(read_only_env) = maybe_read_only_env {
924                    jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
925                }
926
927                let before_refresh = OffsetDateTime::now_utc();
928
929                dispatcher
930                    .refresh(
931                        &ConfigSource::File {
932                            settings_path: "".into(),
933                            secrets_path: "secrets.toml".into(),
934                        },
935                        profile_name,
936                    )
937                    .await
938                    .unwrap();
939
940                issuer_mock.assert_async().await;
941
942                // Verify the file was not updated if QCS_SECRETS_READ_ONLY is set truthy
943                let content = std::fs::read_to_string("secrets.toml").unwrap();
944                if !expected_update {
945                    assert!(
946                        content.eq(initial_secrets_file_contents.as_str()),
947                        "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
948                    );
949                    return;
950                }
951
952                // Verify the file was updated
953                let mut toml = std::fs::read_to_string(secrets_path)
954                    .unwrap()
955                    .parse::<DocumentMut>()
956                    .unwrap();
957
958                let token_payload = toml
959                    .get_mut("credentials")
960                    .and_then(|credentials| {
961                        credentials.get_mut(profile_name)?.get_mut("token_payload")
962                    })
963                    .expect("Should be able to get token_payload table");
964
965                assert_eq!(
966                    token_payload.get("access_token").unwrap().as_str().unwrap(),
967                    new_access_token
968                );
969
970                assert!(
971                    OffsetDateTime::parse(
972                        token_payload.get("updated_at").unwrap().as_str().unwrap(),
973                        &Rfc3339
974                    )
975                    .unwrap()
976                        > before_refresh
977                );
978
979                let content = std::fs::read_to_string("secrets.toml").unwrap();
980                assert!(
981                content.contains("new_access_token"),
982                "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
983                );
984            });
985            Ok(())
986        });
987    }
988
989    #[test]
990    fn test_auth_session_debug_fmt() {
991        let session = OAuthSession {
992            payload: OAuthGrant::ClientCredentials(ClientCredentials {
993                client_id: "hidden_id".into(),
994                client_secret: "hidden_secret".into(),
995            }),
996            access_token: Some("token".into()),
997            auth_server: AuthServer::new("some_id".into(), "some_url".into()),
998        };
999
1000        assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
1001    }
1002}