1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16 ClientConfiguration, ConfigSource, TokenError, oidc, secrets::Secrets, settings::AuthServer,
17};
18use crate::configuration::{
19 error::DiscoveryError,
20 pkce::{PkceLoginError, PkceLoginRequest, pkce_login},
21 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38 pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43 #[must_use]
45 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46 Self { refresh_token }
47 }
48
49 pub async fn request_access_token(
55 &mut self,
56 auth_server: &AuthServer,
57 ) -> Result<SecretAccessToken, TokenError> {
58 if self.refresh_token.is_empty() {
59 return Err(TokenError::NoRefreshToken);
60 }
61
62 let client = default_http_client()?;
63 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64 .await?
65 .token_endpoint;
66 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67 let resp = client.post(token_url).form(&data).send().await?;
68
69 let RefreshTokenResponse {
70 access_token,
71 refresh_token,
72 } = resp.error_for_status()?.json().await?;
73
74 if let Some(refresh_token) = refresh_token {
75 self.refresh_token = refresh_token;
76 }
77 Ok(access_token)
78 }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83 pub(super) access_token: SecretAccessToken,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94 pub client_id: String,
96 pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101 #[must_use]
102 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104 Self {
105 client_id: client_id.into(),
106 client_secret: client_secret.into(),
107 }
108 }
109
110 #[must_use]
112 pub fn client_id(&self) -> &str {
113 &self.client_id
114 }
115
116 #[must_use]
118 pub const fn client_secret(&self) -> &ClientSecret {
119 &self.client_secret
120 }
121
122 pub async fn request_access_token(
128 &self,
129 auth_server: &AuthServer,
130 ) -> Result<SecretAccessToken, TokenError> {
131 let request = ClientCredentialsRequest::new(None);
132 let client = default_http_client()?;
133
134 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135 .await?
136 .token_endpoint;
137 let ready_to_send = client
138 .post(url)
139 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140 .form(&request);
141 let response = ready_to_send.send().await?;
142
143 response.error_for_status_ref()?;
144
145 let ClientCredentialsResponse { access_token } = response.json().await?;
146 Ok(access_token)
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153 feature = "python",
154 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156pub struct PkceFlow {
158 pub access_token: SecretAccessToken,
160 pub refresh_token: Option<RefreshToken>,
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167 #[error(transparent)]
169 PkceLogin(#[from] PkceLoginError),
170 #[error(transparent)]
172 Discovery(#[from] DiscoveryError),
173 #[error(transparent)]
175 Request(#[from] qcs_dependencies_client::reqwest::Error),
176}
177
178impl PkceFlow {
179 pub async fn new_login_flow(
185 cancel_token: CancellationToken,
186 auth_server: &AuthServer,
187 ) -> Result<Self, PkceFlowError> {
188 let issuer = auth_server.issuer.clone();
189
190 let client = default_http_client()?;
191 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193 let response = pkce_login(
194 cancel_token,
195 PkceLoginRequest {
196 client_id: auth_server.client_id.clone(),
197 redirect_port: None,
198 discovery,
199 scopes: auth_server.scopes.clone(),
200 },
201 )
202 .await?;
203
204 Ok(Self {
205 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206 refresh_token: response
207 .refresh_token()
208 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209 })
210 }
211
212 pub async fn request_access_token(
218 &mut self,
219 auth_server: &AuthServer,
220 ) -> Result<SecretAccessToken, TokenError> {
221 if insecure_validate_token_exp(&self.access_token).is_ok() {
222 return Ok(self.access_token.clone());
223 }
224
225 if let Some(refresh_token) = &mut self.refresh_token {
226 let access_token = refresh_token.request_access_token(auth_server).await?;
227 self.access_token.clone_from(&access_token);
228 return Ok(access_token);
229 }
230
231 Err(TokenError::NoRefreshToken)
232 }
233}
234
235impl From<PkceFlow> for Credential {
236 fn from(value: PkceFlow) -> Self {
237 let mut token_payload = TokenPayload::default();
238 token_payload.access_token = Some(value.access_token);
239 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241 Self {
242 token_payload: Some(token_payload),
243 }
244 }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249pub enum OAuthGrant {
252 RefreshToken(RefreshToken),
254 ClientCredentials(ClientCredentials),
256 ExternallyManaged(ExternallyManaged),
258 PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263 fn from(v: ExternallyManaged) -> Self {
264 Self::ExternallyManaged(v)
265 }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269 fn from(v: ClientCredentials) -> Self {
270 Self::ClientCredentials(v)
271 }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275 fn from(v: RefreshToken) -> Self {
276 Self::RefreshToken(v)
277 }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281 fn from(v: PkceFlow) -> Self {
282 Self::PkceFlow(v)
283 }
284}
285
286impl OAuthGrant {
287 async fn request_access_token(
289 &mut self,
290 auth_server: &AuthServer,
291 ) -> Result<SecretAccessToken, TokenError> {
292 match self {
293 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295 Self::ExternallyManaged(tokens) => tokens
296 .request_access_token(auth_server)
297 .await
298 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300 }
301 }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 match self {
307 Self::RefreshToken(_) => f.write_str("RefreshToken"),
308 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310 Self::PkceFlow(_) => f.write_str("PkceTokens"),
311 }
312 }
313}
314
315#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329 feature = "python",
330 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333 payload: OAuthGrant,
335 access_token: Option<SecretAccessToken>,
337 auth_server: AuthServer,
339}
340
341impl OAuthSession {
342 #[must_use]
347 pub const fn new(
348 payload: OAuthGrant,
349 auth_server: AuthServer,
350 access_token: Option<SecretAccessToken>,
351 ) -> Self {
352 Self {
353 payload,
354 access_token,
355 auth_server,
356 }
357 }
358
359 #[must_use]
364 pub const fn from_externally_managed(
365 tokens: ExternallyManaged,
366 auth_server: AuthServer,
367 access_token: Option<SecretAccessToken>,
368 ) -> Self {
369 Self::new(
370 OAuthGrant::ExternallyManaged(tokens),
371 auth_server,
372 access_token,
373 )
374 }
375
376 #[must_use]
381 pub const fn from_refresh_token(
382 tokens: RefreshToken,
383 auth_server: AuthServer,
384 access_token: Option<SecretAccessToken>,
385 ) -> Self {
386 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387 }
388
389 #[must_use]
394 pub const fn from_client_credentials(
395 tokens: ClientCredentials,
396 auth_server: AuthServer,
397 access_token: Option<SecretAccessToken>,
398 ) -> Self {
399 Self::new(
400 OAuthGrant::ClientCredentials(tokens),
401 auth_server,
402 access_token,
403 )
404 }
405
406 #[must_use]
411 pub const fn from_pkce_flow(
412 flow: PkceFlow,
413 auth_server: AuthServer,
414 access_token: Option<SecretAccessToken>,
415 ) -> Self {
416 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417 }
418
419 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429 }
430
431 #[must_use]
433 pub const fn payload(&self) -> &OAuthGrant {
434 &self.payload
435 }
436
437 #[allow(clippy::missing_panics_doc)]
443 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444 let access_token = self.payload.request_access_token(&self.auth_server).await?;
445 Ok(self.access_token.insert(access_token))
446 }
447
448 #[must_use]
450 pub const fn auth_server(&self) -> &AuthServer {
451 &self.auth_server
452 }
453
454 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462 let access_token = self.access_token()?;
463 insecure_validate_token_exp(access_token)?;
464 Ok(access_token.clone())
465 }
466}
467
468pub(crate) fn insecure_validate_token_exp(
472 access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474 let placeholder_key = DecodingKey::from_secret(&[]);
475 let mut validation = Validation::new(Algorithm::RS256);
476 validation.validate_exp = true;
477 validation.leeway = 60;
478 validation.validate_aud = false;
479 validation.insecure_disable_signature_validation();
480
481 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482 .map(|_| ())
483 .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 let token_populated = if self.access_token.is_some() {
489 Some(())
490 } else {
491 None
492 };
493 f.debug_struct("OAuthSession")
494 .field("payload", &self.payload)
495 .field("access_token", &token_populated)
496 .field("auth_server", &self.auth_server)
497 .finish()
498 }
499}
500
501#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505 feature = "python",
506 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
507)]
508pub struct TokenDispatcher {
509 lock: Arc<RwLock<OAuthSession>>,
510 refreshing: Arc<Mutex<bool>>,
511 notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515 fn from(value: OAuthSession) -> Self {
516 Self {
517 lock: Arc::new(RwLock::new(value)),
518 refreshing: Arc::new(Mutex::new(false)),
519 notify_refreshed: Arc::new(Notify::new()),
520 }
521 }
522}
523
524impl TokenDispatcher {
525 pub async fn use_tokens<F, O>(&self, f: F) -> O
535 where
536 F: FnOnce(&OAuthSession) -> O + Send,
537 {
538 let tokens = self.lock.read().await;
539 f(&tokens)
540 }
541
542 #[must_use]
544 pub async fn tokens(&self) -> OAuthSession {
545 self.use_tokens(Clone::clone).await
546 }
547
548 pub async fn refresh(
554 &self,
555 source: &ConfigSource,
556 profile: &str,
557 ) -> Result<OAuthSession, TokenError> {
558 self.managed_refresh(Self::perform_refresh, source, profile)
559 .await
560 }
561
562 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570 self.use_tokens(OAuthSession::validate).await
571 }
572
573 async fn managed_refresh<F, Fut>(
576 &self,
577 refresh_fn: F,
578 source: &ConfigSource,
579 profile: &str,
580 ) -> Result<OAuthSession, TokenError>
581 where
582 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584 {
585 let mut is_refreshing = self.refreshing.lock().await;
586
587 if *is_refreshing {
588 drop(is_refreshing);
589 self.notify_refreshed.notified().await;
590 return Ok(self.tokens().await);
591 }
592
593 *is_refreshing = true;
594 drop(is_refreshing);
595
596 let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598 let write_result = if let ConfigSource::File {
600 settings_path: _,
601 secrets_path,
602 } = source
603 {
604 match Secrets::is_read_only(secrets_path).await {
605 Ok(true) => Ok(()),
606 Ok(false) => {
607 let refresh_token = match &oauth_session.payload {
609 OAuthGrant::PkceFlow(payload) => {
610 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
611 }
612 _ => None,
613 };
614
615 let now = OffsetDateTime::now_utc();
616 Secrets::write_tokens(
617 secrets_path,
618 profile,
619 refresh_token,
620 oauth_session.access_token()?,
621 now,
622 )
623 .await
624 }
625 Err(e) => Err(e),
626 }
627 } else {
628 Ok(())
629 };
630
631 *self.refreshing.lock().await = false;
633 self.notify_refreshed.notify_waiters();
634
635 if let Err(error) = write_result {
637 return Err(TokenError::Write {
638 error,
639 oauth_session: Box::new(oauth_session),
640 });
641 }
642
643 Ok(oauth_session)
644 }
645
646 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
653 let mut credentials = lock.write().await;
654 credentials.request_access_token().await?;
655 Ok(credentials.clone())
656 }
657}
658
659pub(crate) type RefreshResult =
660 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
661
662pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
664
665#[derive(Clone)]
670#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
671#[cfg_attr(
672 feature = "python",
673 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
674)]
675pub struct ExternallyManaged {
676 refresh_function: Arc<RefreshFunction>,
677}
678
679impl ExternallyManaged {
680 pub fn new(
705 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
706 ) -> Self {
707 Self {
708 refresh_function: Arc::new(Box::new(refresh_function)),
709 }
710 }
711
712 pub fn from_async<F, Fut>(refresh_function: F) -> Self
745 where
746 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
747 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
748 + Send
749 + 'static,
750 {
751 Self {
752 refresh_function: Arc::new(Box::new(move |auth_server| {
753 Box::pin(refresh_function(auth_server))
754 })),
755 }
756 }
757
758 pub fn from_sync(
789 refresh_function: impl Fn(
790 AuthServer,
791 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
792 + Send
793 + Sync
794 + 'static,
795 ) -> Self {
796 Self {
797 refresh_function: Arc::new(Box::new(move |auth_server| {
798 let result = refresh_function(auth_server);
799 Box::pin(async move { result })
800 })),
801 }
802 }
803
804 pub async fn request_access_token(
810 &self,
811 auth_server: &AuthServer,
812 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
813 (self.refresh_function)(auth_server.clone())
814 .await
815 .map(SecretAccessToken::from)
816 }
817}
818
819impl std::fmt::Debug for ExternallyManaged {
820 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
821 f.debug_struct("ExternallyManaged")
822 .field(
823 "refresh_function",
824 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
825 )
826 .finish()
827 }
828}
829
830#[derive(Debug, Serialize, Deserialize)]
831pub(super) struct TokenRefreshRequest<'a> {
832 grant_type: &'static str,
833 client_id: &'a str,
834 refresh_token: &'a str,
835}
836
837impl<'a> TokenRefreshRequest<'a> {
838 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
839 Self {
840 grant_type: "refresh_token",
841 client_id,
842 refresh_token,
843 }
844 }
845}
846
847#[derive(Debug, Serialize, Deserialize)]
848pub(super) struct ClientCredentialsRequest {
849 grant_type: &'static str,
850 scope: Option<&'static str>,
851}
852
853impl ClientCredentialsRequest {
854 pub(super) const fn new(scope: Option<&'static str>) -> Self {
855 Self {
856 grant_type: "client_credentials",
857 scope,
858 }
859 }
860}
861
862#[derive(Deserialize, Debug, Serialize)]
863pub(super) struct RefreshTokenResponse {
864 pub(super) refresh_token: Option<SecretRefreshToken>,
865 pub(super) access_token: SecretAccessToken,
866}
867
868#[async_trait::async_trait]
870pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
871 type Error;
874
875 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
877
878 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
880
881 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
883
884 #[cfg(feature = "tracing")]
886 fn base_url(&self) -> &str;
887
888 #[cfg(feature = "tracing-config")]
890 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
891
892 #[cfg(feature = "tracing")]
895 #[allow(clippy::needless_return)]
896 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
897 #[cfg(not(feature = "tracing-config"))]
898 {
899 let _ = url;
900 return true;
901 }
902
903 #[cfg(feature = "tracing-config")]
904 self.tracing_configuration()
905 .is_none_or(|config| config.is_enabled(url))
906 }
907}
908
909#[async_trait::async_trait]
910impl TokenRefresher for ClientConfiguration {
911 type Error = TokenError;
912
913 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
914 self.get_bearer_access_token().await
915 }
916
917 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
918 match self.refresh().await {
919 Ok(session) => Ok(session.access_token()?.clone()),
920 Err(TokenError::Write {
921 error,
922 oauth_session,
923 }) => {
924 #[cfg(feature = "tracing")]
926 tracing::warn!(
927 "Token refresh succeeded but failed to persist: {}. Returning access token from error.",
928 error
929 );
930 Ok(oauth_session.access_token()?.clone())
931 }
932 Err(e) => Err(e),
933 }
934 }
935
936 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
937 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
938 }
939
940 #[cfg(feature = "tracing")]
941 fn base_url(&self) -> &str {
942 &self.grpc_api_url
943 }
944
945 #[cfg(feature = "tracing-config")]
946 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
947 self.tracing_configuration.as_ref()
948 }
949}
950
951pub(super) fn default_http_client()
953-> Result<qcs_dependencies_client::reqwest::Client, qcs_dependencies_client::reqwest::Error> {
954 qcs_dependencies_client::reqwest::Client::builder()
955 .timeout(std::time::Duration::from_secs(10))
956 .build()
957}
958
959#[cfg(test)]
960mod test {
961 #![allow(clippy::result_large_err, reason = "happens in figment tests")]
962
963 use std::time::Duration;
964
965 use super::*;
966 use httpmock::prelude::*;
967 use rstest::rstest;
968 use time::format_description::well_known::Rfc3339;
969 use tokio::time::Instant;
970 use toml_edit::DocumentMut;
971
972 #[tokio::test]
973 async fn test_tokens_blocked_during_refresh() {
974 let mock_server = MockServer::start_async().await;
975
976 let oidc_mock = mock_server
977 .mock_async(|when, then| {
978 when.method(GET).path("/.well-known/openid-configuration");
979 then.status(200)
980 .json_body_obj(&oidc::Discovery::new_for_test(
981 mock_server.base_url().parse().unwrap(),
982 ));
983 })
984 .await;
985
986 let issuer_mock = mock_server
987 .mock_async(|when, then| {
988 when.method(POST).path("/v1/token");
989
990 then.status(200)
991 .delay(Duration::from_secs(3))
992 .json_body_obj(&RefreshTokenResponse {
993 access_token: SecretAccessToken::from("new_access"),
994 refresh_token: Some(SecretRefreshToken::from("new_refresh")),
995 });
996 })
997 .await;
998
999 let original_tokens = OAuthSession::from_refresh_token(
1000 RefreshToken::new(SecretRefreshToken::from("refresh")),
1001 AuthServer {
1002 client_id: "client_id".to_string(),
1003 issuer: mock_server.base_url(),
1004 scopes: None,
1005 },
1006 None,
1007 );
1008 let dispatcher: TokenDispatcher = original_tokens.clone().into();
1009 let dispatcher_clone1 = dispatcher.clone();
1010 let dispatcher_clone2 = dispatcher.clone();
1011
1012 let refresh_duration = Duration::from_secs(3);
1013
1014 let start_write = Instant::now();
1015 let write_future = tokio::spawn(async move {
1016 dispatcher_clone1
1017 .refresh(&ConfigSource::Default, "")
1018 .await
1019 .unwrap()
1020 });
1021
1022 let start_read = Instant::now();
1023 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
1024
1025 let _ = write_future.await.unwrap();
1026 let read_result = read_future.await.unwrap();
1027
1028 let write_duration = start_write.elapsed();
1029 let read_duration = start_read.elapsed();
1030
1031 oidc_mock.assert_async().await;
1032 issuer_mock.assert_async().await;
1033
1034 assert!(
1035 write_duration >= refresh_duration,
1036 "Write operation did not take enough time"
1037 );
1038 assert!(
1039 read_duration >= refresh_duration,
1040 "Read operation was not blocked by the write operation"
1041 );
1042 assert_eq!(
1043 read_result.access_token.unwrap(),
1044 SecretAccessToken::from("new_access")
1045 );
1046 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1047 assert_eq!(
1048 payload.refresh_token,
1049 SecretRefreshToken::from("new_refresh")
1050 );
1051 } else {
1052 panic!(
1053 "Expected RefreshToken payload, got {:?}",
1054 read_result.payload
1055 );
1056 }
1057 }
1058
1059 #[rstest]
1060 fn test_qcs_secrets_readonly(
1061 #[values(
1062 (Some("TRUE"), true),
1063 (Some("tRue"), true),
1064 (Some("true"), true),
1065 (Some("YES"), true),
1066 (Some("yEs"), true),
1067 (Some("yes"), true),
1068 (Some("1"), true),
1069 (Some("2"), false),
1070 (Some("other"), false),
1071 (Some(""), false),
1072 (None, false),
1073 )]
1074 read_only_values: (Option<&str>, bool),
1075 #[values(true, false)] read_only_perm: bool,
1076 ) {
1077 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1078 let expected_update = !env_is_read_only && !read_only_perm;
1079 figment::Jail::expect_with(|jail| {
1080 let profile_name = "test";
1081 let initial_access_token = "initial_access_token";
1082 let initial_refresh_token = "initial_refresh_token";
1083
1084 let initial_secrets_file_contents = format!(
1085 r#"
1086[credentials]
1087[credentials.{profile_name}]
1088[credentials.{profile_name}.token_payload]
1089access_token = "{initial_access_token}"
1090expires_in = 3600
1091id_token = "id_token"
1092refresh_token = "{initial_refresh_token}"
1093scope = "offline_access openid profile email"
1094token_type = "Bearer"
1095updated_at = "2024-01-01T00:00:00Z"
1096"#
1097 );
1098
1099 jail.clear_env();
1101
1102 let secrets_path = "secrets.toml";
1104 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1105 .expect("should create test secrets.toml");
1106
1107 if read_only_perm {
1108 let mut permissions = std::fs::metadata(secrets_path)
1109 .expect("Should be able to get file metadata")
1110 .permissions();
1111 permissions.set_readonly(true);
1112 std::fs::set_permissions(secrets_path, permissions)
1113 .expect("Should be able to set file permissions");
1114 }
1115
1116 let rt = tokio::runtime::Runtime::new().unwrap();
1117 rt.block_on(async {
1118 let mock_server = MockServer::start_async().await;
1119
1120 let oidc_mock = mock_server
1121 .mock_async(|when, then| {
1122 when.method(GET).path("/.well-known/openid-configuration");
1123 then.status(200)
1124 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1125 })
1126 .await;
1127
1128 let new_access_token = SecretAccessToken::from("new_access_token");
1130 let issuer_mock = mock_server
1131 .mock_async(|when, then| {
1132 when.method(POST).path("/v1/token");
1133 then.status(200).json_body_obj(&RefreshTokenResponse {
1134 access_token: new_access_token.clone(),
1135 refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1136 });
1137 })
1138 .await;
1139
1140 let original_tokens = OAuthSession::from_refresh_token(
1142 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1143 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1144 Some(SecretAccessToken::from(initial_refresh_token)),
1145 );
1146 let dispatcher: TokenDispatcher = original_tokens.into();
1147
1148 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1150 jail.set_env("QCS_PROFILE_NAME", "test");
1151 if let Some(read_only_env) = maybe_read_only_env {
1152 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1153 }
1154
1155 let before_refresh = OffsetDateTime::now_utc();
1156
1157 dispatcher
1158 .refresh(
1159 &ConfigSource::File {
1160 settings_path: "".into(),
1161 secrets_path: "secrets.toml".into(),
1162 },
1163 profile_name,
1164 )
1165 .await
1166 .unwrap();
1167
1168 oidc_mock.assert_async().await;
1169 issuer_mock.assert_async().await;
1170
1171 let content = std::fs::read_to_string("secrets.toml").unwrap();
1173 if !expected_update {
1174 assert!(
1175 content.eq(initial_secrets_file_contents.as_str()),
1176 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1177 );
1178 return;
1179 }
1180
1181 let mut toml = std::fs::read_to_string(secrets_path)
1183 .unwrap()
1184 .parse::<DocumentMut>()
1185 .unwrap();
1186
1187 let token_payload = toml
1188 .get_mut("credentials")
1189 .and_then(|credentials| {
1190 credentials.get_mut(profile_name)?.get_mut("token_payload")
1191 })
1192 .expect("Should be able to get token_payload table");
1193
1194 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1195
1196 assert_eq!(
1197 access_token,
1198 Some(new_access_token)
1199 );
1200
1201 assert!(
1202 OffsetDateTime::parse(
1203 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1204 &Rfc3339
1205 )
1206 .unwrap()
1207 > before_refresh
1208 );
1209
1210 let content = std::fs::read_to_string("secrets.toml").unwrap();
1211 assert!(
1212 content.contains("new_access_token"),
1213 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1214 );
1215 });
1216 Ok(())
1217 });
1218 }
1219
1220 #[test]
1221 fn test_auth_session_debug_fmt() {
1222 let session = OAuthSession {
1223 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1224 "hidden_id",
1225 "hidden_secret",
1226 )),
1227 access_token: Some(SecretAccessToken::from("token")),
1228 auth_server: AuthServer {
1229 client_id: "some_id".into(),
1230 issuer: "some_url".into(),
1231 scopes: None,
1232 },
1233 };
1234
1235 assert_eq!(
1236 "OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }",
1237 &format!("{session:?}")
1238 );
1239 }
1240}