1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16 oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
17};
18use crate::configuration::{
19 error::DiscoveryError,
20 pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
21 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38 pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43 #[must_use]
45 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46 Self { refresh_token }
47 }
48
49 pub async fn request_access_token(
55 &mut self,
56 auth_server: &AuthServer,
57 ) -> Result<SecretAccessToken, TokenError> {
58 if self.refresh_token.is_empty() {
59 return Err(TokenError::NoRefreshToken);
60 }
61
62 let client = default_http_client()?;
63 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64 .await?
65 .token_endpoint;
66 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67 let resp = client.post(token_url).form(&data).send().await?;
68
69 let RefreshTokenResponse {
70 access_token,
71 refresh_token,
72 } = resp.error_for_status()?.json().await?;
73
74 if let Some(refresh_token) = refresh_token {
75 self.refresh_token = refresh_token;
76 }
77 Ok(access_token)
78 }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83 pub(super) access_token: SecretAccessToken,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94 pub client_id: String,
96 pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101 #[must_use]
102 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104 Self {
105 client_id: client_id.into(),
106 client_secret: client_secret.into(),
107 }
108 }
109
110 #[must_use]
112 pub fn client_id(&self) -> &str {
113 &self.client_id
114 }
115
116 #[must_use]
118 pub const fn client_secret(&self) -> &ClientSecret {
119 &self.client_secret
120 }
121
122 pub async fn request_access_token(
128 &self,
129 auth_server: &AuthServer,
130 ) -> Result<SecretAccessToken, TokenError> {
131 let request = ClientCredentialsRequest::new(None);
132 let client = default_http_client()?;
133
134 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135 .await?
136 .token_endpoint;
137 let ready_to_send = client
138 .post(url)
139 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140 .form(&request);
141 let response = ready_to_send.send().await?;
142
143 response.error_for_status_ref()?;
144
145 let ClientCredentialsResponse { access_token } = response.json().await?;
146 Ok(access_token)
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153 feature = "python",
154 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156pub struct PkceFlow {
158 pub access_token: SecretAccessToken,
160 pub refresh_token: Option<RefreshToken>,
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167 #[error(transparent)]
169 PkceLogin(#[from] PkceLoginError),
170 #[error(transparent)]
172 Discovery(#[from] DiscoveryError),
173 #[error(transparent)]
175 Request(#[from] reqwest::Error),
176}
177
178impl PkceFlow {
179 pub async fn new_login_flow(
185 cancel_token: CancellationToken,
186 auth_server: &AuthServer,
187 ) -> Result<Self, PkceFlowError> {
188 let issuer = auth_server.issuer.clone();
189
190 let client = default_http_client()?;
191 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193 let response = pkce_login(
194 cancel_token,
195 PkceLoginRequest {
196 client_id: auth_server.client_id.clone(),
197 redirect_port: None,
198 discovery,
199 scopes: auth_server.scopes.clone(),
200 },
201 )
202 .await?;
203
204 Ok(Self {
205 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206 refresh_token: response
207 .refresh_token()
208 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209 })
210 }
211
212 pub async fn request_access_token(
218 &mut self,
219 auth_server: &AuthServer,
220 ) -> Result<SecretAccessToken, TokenError> {
221 if insecure_validate_token_exp(&self.access_token).is_ok() {
222 return Ok(self.access_token.clone());
223 }
224
225 if let Some(refresh_token) = &mut self.refresh_token {
226 let access_token = refresh_token.request_access_token(auth_server).await?;
227 self.access_token.clone_from(&access_token);
228 return Ok(access_token);
229 }
230
231 Err(TokenError::NoRefreshToken)
232 }
233}
234
235impl From<PkceFlow> for Credential {
236 fn from(value: PkceFlow) -> Self {
237 let mut token_payload = TokenPayload::default();
238 token_payload.access_token = Some(value.access_token);
239 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241 Self {
242 token_payload: Some(token_payload),
243 }
244 }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249pub enum OAuthGrant {
252 RefreshToken(RefreshToken),
254 ClientCredentials(ClientCredentials),
256 ExternallyManaged(ExternallyManaged),
258 PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263 fn from(v: ExternallyManaged) -> Self {
264 Self::ExternallyManaged(v)
265 }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269 fn from(v: ClientCredentials) -> Self {
270 Self::ClientCredentials(v)
271 }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275 fn from(v: RefreshToken) -> Self {
276 Self::RefreshToken(v)
277 }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281 fn from(v: PkceFlow) -> Self {
282 Self::PkceFlow(v)
283 }
284}
285
286impl OAuthGrant {
287 async fn request_access_token(
289 &mut self,
290 auth_server: &AuthServer,
291 ) -> Result<SecretAccessToken, TokenError> {
292 match self {
293 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295 Self::ExternallyManaged(tokens) => tokens
296 .request_access_token(auth_server)
297 .await
298 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300 }
301 }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 match self {
307 Self::RefreshToken(_) => f.write_str("RefreshToken"),
308 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310 Self::PkceFlow(_) => f.write_str("PkceTokens"),
311 }
312 }
313}
314
315#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329 feature = "python",
330 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333 payload: OAuthGrant,
335 access_token: Option<SecretAccessToken>,
337 auth_server: AuthServer,
339}
340
341impl OAuthSession {
342 #[must_use]
347 pub const fn new(
348 payload: OAuthGrant,
349 auth_server: AuthServer,
350 access_token: Option<SecretAccessToken>,
351 ) -> Self {
352 Self {
353 payload,
354 access_token,
355 auth_server,
356 }
357 }
358
359 #[must_use]
364 pub const fn from_externally_managed(
365 tokens: ExternallyManaged,
366 auth_server: AuthServer,
367 access_token: Option<SecretAccessToken>,
368 ) -> Self {
369 Self::new(
370 OAuthGrant::ExternallyManaged(tokens),
371 auth_server,
372 access_token,
373 )
374 }
375
376 #[must_use]
381 pub const fn from_refresh_token(
382 tokens: RefreshToken,
383 auth_server: AuthServer,
384 access_token: Option<SecretAccessToken>,
385 ) -> Self {
386 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387 }
388
389 #[must_use]
394 pub const fn from_client_credentials(
395 tokens: ClientCredentials,
396 auth_server: AuthServer,
397 access_token: Option<SecretAccessToken>,
398 ) -> Self {
399 Self::new(
400 OAuthGrant::ClientCredentials(tokens),
401 auth_server,
402 access_token,
403 )
404 }
405
406 #[must_use]
411 pub const fn from_pkce_flow(
412 flow: PkceFlow,
413 auth_server: AuthServer,
414 access_token: Option<SecretAccessToken>,
415 ) -> Self {
416 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417 }
418
419 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429 }
430
431 #[must_use]
433 pub const fn payload(&self) -> &OAuthGrant {
434 &self.payload
435 }
436
437 #[allow(clippy::missing_panics_doc)]
443 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444 let access_token = self.payload.request_access_token(&self.auth_server).await?;
445 Ok(self.access_token.insert(access_token))
446 }
447
448 #[must_use]
450 pub const fn auth_server(&self) -> &AuthServer {
451 &self.auth_server
452 }
453
454 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462 let access_token = self.access_token()?;
463 insecure_validate_token_exp(access_token)?;
464 Ok(access_token.clone())
465 }
466}
467
468pub(crate) fn insecure_validate_token_exp(
472 access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474 let placeholder_key = DecodingKey::from_secret(&[]);
475 let mut validation = Validation::new(Algorithm::RS256);
476 validation.validate_exp = true;
477 validation.leeway = 60;
478 validation.validate_aud = false;
479 validation.insecure_disable_signature_validation();
480
481 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482 .map(|_| ())
483 .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 let token_populated = if self.access_token.is_some() {
489 Some(())
490 } else {
491 None
492 };
493 f.debug_struct("OAuthSession")
494 .field("payload", &self.payload)
495 .field("access_token", &token_populated)
496 .field("auth_server", &self.auth_server)
497 .finish()
498 }
499}
500
501#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505 feature = "python",
506 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
507)]
508pub struct TokenDispatcher {
509 lock: Arc<RwLock<OAuthSession>>,
510 refreshing: Arc<Mutex<bool>>,
511 notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515 fn from(value: OAuthSession) -> Self {
516 Self {
517 lock: Arc::new(RwLock::new(value)),
518 refreshing: Arc::new(Mutex::new(false)),
519 notify_refreshed: Arc::new(Notify::new()),
520 }
521 }
522}
523
524impl TokenDispatcher {
525 pub async fn use_tokens<F, O>(&self, f: F) -> O
535 where
536 F: FnOnce(&OAuthSession) -> O + Send,
537 {
538 let tokens = self.lock.read().await;
539 f(&tokens)
540 }
541
542 #[must_use]
544 pub async fn tokens(&self) -> OAuthSession {
545 self.use_tokens(Clone::clone).await
546 }
547
548 pub async fn refresh(
554 &self,
555 source: &ConfigSource,
556 profile: &str,
557 ) -> Result<OAuthSession, TokenError> {
558 self.managed_refresh(Self::perform_refresh, source, profile)
559 .await
560 }
561
562 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570 self.use_tokens(OAuthSession::validate).await
571 }
572
573 async fn managed_refresh<F, Fut>(
576 &self,
577 refresh_fn: F,
578 source: &ConfigSource,
579 profile: &str,
580 ) -> Result<OAuthSession, TokenError>
581 where
582 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584 {
585 let mut is_refreshing = self.refreshing.lock().await;
586
587 if *is_refreshing {
588 drop(is_refreshing);
589 self.notify_refreshed.notified().await;
590 return Ok(self.tokens().await);
591 }
592
593 *is_refreshing = true;
594 drop(is_refreshing);
595
596 let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598 let write_result = if let ConfigSource::File {
600 settings_path: _,
601 secrets_path,
602 } = source
603 {
604 match Secrets::is_read_only(secrets_path).await {
605 Ok(true) => Ok(()),
606 Ok(false) => {
607 let refresh_token = match &oauth_session.payload {
609 OAuthGrant::PkceFlow(payload) => {
610 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
611 }
612 _ => None,
613 };
614
615 let now = OffsetDateTime::now_utc();
616 Secrets::write_tokens(
617 secrets_path,
618 profile,
619 refresh_token,
620 oauth_session.access_token()?,
621 now,
622 )
623 .await
624 }
625 Err(e) => Err(e),
626 }
627 } else {
628 Ok(())
629 };
630
631 *self.refreshing.lock().await = false;
633 self.notify_refreshed.notify_waiters();
634
635 if let Err(error) = write_result {
637 return Err(TokenError::Write {
638 error,
639 oauth_session: Box::new(oauth_session),
640 });
641 }
642
643 Ok(oauth_session)
644 }
645
646 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
653 let mut credentials = lock.write().await;
654 credentials.request_access_token().await?;
655 Ok(credentials.clone())
656 }
657}
658
659pub(crate) type RefreshResult =
660 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
661
662pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
664
665#[derive(Clone)]
670#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
671#[cfg_attr(
672 feature = "python",
673 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
674)]
675pub struct ExternallyManaged {
676 refresh_function: Arc<RefreshFunction>,
677}
678
679impl ExternallyManaged {
680 pub fn new(
705 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
706 ) -> Self {
707 Self {
708 refresh_function: Arc::new(Box::new(refresh_function)),
709 }
710 }
711
712 pub fn from_async<F, Fut>(refresh_function: F) -> Self
745 where
746 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
747 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
748 + Send
749 + 'static,
750 {
751 Self {
752 refresh_function: Arc::new(Box::new(move |auth_server| {
753 Box::pin(refresh_function(auth_server))
754 })),
755 }
756 }
757
758 pub fn from_sync(
789 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
790 + Send
791 + Sync
792 + 'static,
793 ) -> Self {
794 Self {
795 refresh_function: Arc::new(Box::new(move |auth_server| {
796 let result = refresh_function(auth_server);
797 Box::pin(async move { result })
798 })),
799 }
800 }
801
802 pub async fn request_access_token(
808 &self,
809 auth_server: &AuthServer,
810 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
811 (self.refresh_function)(auth_server.clone())
812 .await
813 .map(SecretAccessToken::from)
814 }
815}
816
817impl std::fmt::Debug for ExternallyManaged {
818 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
819 f.debug_struct("ExternallyManaged")
820 .field(
821 "refresh_function",
822 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
823 )
824 .finish()
825 }
826}
827
828#[derive(Debug, Serialize, Deserialize)]
829pub(super) struct TokenRefreshRequest<'a> {
830 grant_type: &'static str,
831 client_id: &'a str,
832 refresh_token: &'a str,
833}
834
835impl<'a> TokenRefreshRequest<'a> {
836 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
837 Self {
838 grant_type: "refresh_token",
839 client_id,
840 refresh_token,
841 }
842 }
843}
844
845#[derive(Debug, Serialize, Deserialize)]
846pub(super) struct ClientCredentialsRequest {
847 grant_type: &'static str,
848 scope: Option<&'static str>,
849}
850
851impl ClientCredentialsRequest {
852 pub(super) const fn new(scope: Option<&'static str>) -> Self {
853 Self {
854 grant_type: "client_credentials",
855 scope,
856 }
857 }
858}
859
860#[derive(Deserialize, Debug, Serialize)]
861pub(super) struct RefreshTokenResponse {
862 pub(super) refresh_token: Option<SecretRefreshToken>,
863 pub(super) access_token: SecretAccessToken,
864}
865
866#[async_trait::async_trait]
868pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
869 type Error;
872
873 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
875
876 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
878
879 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
881
882 #[cfg(feature = "tracing")]
884 fn base_url(&self) -> &str;
885
886 #[cfg(feature = "tracing-config")]
888 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
889
890 #[cfg(feature = "tracing")]
893 #[allow(clippy::needless_return)]
894 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
895 #[cfg(not(feature = "tracing-config"))]
896 {
897 let _ = url;
898 return true;
899 }
900
901 #[cfg(feature = "tracing-config")]
902 self.tracing_configuration()
903 .is_none_or(|config| config.is_enabled(url))
904 }
905}
906
907#[async_trait::async_trait]
908impl TokenRefresher for ClientConfiguration {
909 type Error = TokenError;
910
911 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
912 self.get_bearer_access_token().await
913 }
914
915 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
916 match self.refresh().await {
917 Ok(session) => Ok(session.access_token()?.clone()),
918 Err(TokenError::Write {
919 error,
920 oauth_session,
921 }) => {
922 #[cfg(feature = "tracing")]
924 tracing::warn!(
925 "Token refresh succeeded but failed to persist: {}. Returning access token from error.",
926 error
927 );
928 Ok(oauth_session.access_token()?.clone())
929 }
930 Err(e) => Err(e),
931 }
932 }
933
934 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
935 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
936 }
937
938 #[cfg(feature = "tracing")]
939 fn base_url(&self) -> &str {
940 &self.grpc_api_url
941 }
942
943 #[cfg(feature = "tracing-config")]
944 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
945 self.tracing_configuration.as_ref()
946 }
947}
948
949pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
951 reqwest::Client::builder()
952 .timeout(std::time::Duration::from_secs(10))
953 .build()
954}
955
956#[cfg(test)]
957mod test {
958 #![allow(clippy::result_large_err, reason = "happens in figment tests")]
959
960 use std::time::Duration;
961
962 use super::*;
963 use httpmock::prelude::*;
964 use rstest::rstest;
965 use time::format_description::well_known::Rfc3339;
966 use tokio::time::Instant;
967 use toml_edit::DocumentMut;
968
969 #[tokio::test]
970 async fn test_tokens_blocked_during_refresh() {
971 let mock_server = MockServer::start_async().await;
972
973 let oidc_mock = mock_server
974 .mock_async(|when, then| {
975 when.method(GET).path("/.well-known/openid-configuration");
976 then.status(200)
977 .json_body_obj(&oidc::Discovery::new_for_test(
978 mock_server.base_url().parse().unwrap(),
979 ));
980 })
981 .await;
982
983 let issuer_mock = mock_server
984 .mock_async(|when, then| {
985 when.method(POST).path("/v1/token");
986
987 then.status(200)
988 .delay(Duration::from_secs(3))
989 .json_body_obj(&RefreshTokenResponse {
990 access_token: SecretAccessToken::from("new_access"),
991 refresh_token: Some(SecretRefreshToken::from("new_refresh")),
992 });
993 })
994 .await;
995
996 let original_tokens = OAuthSession::from_refresh_token(
997 RefreshToken::new(SecretRefreshToken::from("refresh")),
998 AuthServer {
999 client_id: "client_id".to_string(),
1000 issuer: mock_server.base_url(),
1001 scopes: None,
1002 },
1003 None,
1004 );
1005 let dispatcher: TokenDispatcher = original_tokens.clone().into();
1006 let dispatcher_clone1 = dispatcher.clone();
1007 let dispatcher_clone2 = dispatcher.clone();
1008
1009 let refresh_duration = Duration::from_secs(3);
1010
1011 let start_write = Instant::now();
1012 let write_future = tokio::spawn(async move {
1013 dispatcher_clone1
1014 .refresh(&ConfigSource::Default, "")
1015 .await
1016 .unwrap()
1017 });
1018
1019 let start_read = Instant::now();
1020 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
1021
1022 let _ = write_future.await.unwrap();
1023 let read_result = read_future.await.unwrap();
1024
1025 let write_duration = start_write.elapsed();
1026 let read_duration = start_read.elapsed();
1027
1028 oidc_mock.assert_async().await;
1029 issuer_mock.assert_async().await;
1030
1031 assert!(
1032 write_duration >= refresh_duration,
1033 "Write operation did not take enough time"
1034 );
1035 assert!(
1036 read_duration >= refresh_duration,
1037 "Read operation was not blocked by the write operation"
1038 );
1039 assert_eq!(
1040 read_result.access_token.unwrap(),
1041 SecretAccessToken::from("new_access")
1042 );
1043 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1044 assert_eq!(
1045 payload.refresh_token,
1046 SecretRefreshToken::from("new_refresh")
1047 );
1048 } else {
1049 panic!(
1050 "Expected RefreshToken payload, got {:?}",
1051 read_result.payload
1052 );
1053 }
1054 }
1055
1056 #[rstest]
1057 fn test_qcs_secrets_readonly(
1058 #[values(
1059 (Some("TRUE"), true),
1060 (Some("tRue"), true),
1061 (Some("true"), true),
1062 (Some("YES"), true),
1063 (Some("yEs"), true),
1064 (Some("yes"), true),
1065 (Some("1"), true),
1066 (Some("2"), false),
1067 (Some("other"), false),
1068 (Some(""), false),
1069 (None, false),
1070 )]
1071 read_only_values: (Option<&str>, bool),
1072 #[values(true, false)] read_only_perm: bool,
1073 ) {
1074 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1075 let expected_update = !env_is_read_only && !read_only_perm;
1076 figment::Jail::expect_with(|jail| {
1077 let profile_name = "test";
1078 let initial_access_token = "initial_access_token";
1079 let initial_refresh_token = "initial_refresh_token";
1080
1081 let initial_secrets_file_contents = format!(
1082 r#"
1083[credentials]
1084[credentials.{profile_name}]
1085[credentials.{profile_name}.token_payload]
1086access_token = "{initial_access_token}"
1087expires_in = 3600
1088id_token = "id_token"
1089refresh_token = "{initial_refresh_token}"
1090scope = "offline_access openid profile email"
1091token_type = "Bearer"
1092updated_at = "2024-01-01T00:00:00Z"
1093"#
1094 );
1095
1096 jail.clear_env();
1098
1099 let secrets_path = "secrets.toml";
1101 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1102 .expect("should create test secrets.toml");
1103
1104 if read_only_perm {
1105 let mut permissions = std::fs::metadata(secrets_path)
1106 .expect("Should be able to get file metadata")
1107 .permissions();
1108 permissions.set_readonly(true);
1109 std::fs::set_permissions(secrets_path, permissions)
1110 .expect("Should be able to set file permissions");
1111 }
1112
1113 let rt = tokio::runtime::Runtime::new().unwrap();
1114 rt.block_on(async {
1115 let mock_server = MockServer::start_async().await;
1116
1117 let oidc_mock = mock_server
1118 .mock_async(|when, then| {
1119 when.method(GET).path("/.well-known/openid-configuration");
1120 then.status(200)
1121 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1122 })
1123 .await;
1124
1125 let new_access_token = SecretAccessToken::from("new_access_token");
1127 let issuer_mock = mock_server
1128 .mock_async(|when, then| {
1129 when.method(POST).path("/v1/token");
1130 then.status(200).json_body_obj(&RefreshTokenResponse {
1131 access_token: new_access_token.clone(),
1132 refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1133 });
1134 })
1135 .await;
1136
1137 let original_tokens = OAuthSession::from_refresh_token(
1139 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1140 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1141 Some(SecretAccessToken::from(initial_refresh_token)),
1142 );
1143 let dispatcher: TokenDispatcher = original_tokens.into();
1144
1145 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1147 jail.set_env("QCS_PROFILE_NAME", "test");
1148 if let Some(read_only_env) = maybe_read_only_env {
1149 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1150 }
1151
1152 let before_refresh = OffsetDateTime::now_utc();
1153
1154 dispatcher
1155 .refresh(
1156 &ConfigSource::File {
1157 settings_path: "".into(),
1158 secrets_path: "secrets.toml".into(),
1159 },
1160 profile_name,
1161 )
1162 .await
1163 .unwrap();
1164
1165 oidc_mock.assert_async().await;
1166 issuer_mock.assert_async().await;
1167
1168 let content = std::fs::read_to_string("secrets.toml").unwrap();
1170 if !expected_update {
1171 assert!(
1172 content.eq(initial_secrets_file_contents.as_str()),
1173 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1174 );
1175 return;
1176 }
1177
1178 let mut toml = std::fs::read_to_string(secrets_path)
1180 .unwrap()
1181 .parse::<DocumentMut>()
1182 .unwrap();
1183
1184 let token_payload = toml
1185 .get_mut("credentials")
1186 .and_then(|credentials| {
1187 credentials.get_mut(profile_name)?.get_mut("token_payload")
1188 })
1189 .expect("Should be able to get token_payload table");
1190
1191 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1192
1193 assert_eq!(
1194 access_token,
1195 Some(new_access_token)
1196 );
1197
1198 assert!(
1199 OffsetDateTime::parse(
1200 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1201 &Rfc3339
1202 )
1203 .unwrap()
1204 > before_refresh
1205 );
1206
1207 let content = std::fs::read_to_string("secrets.toml").unwrap();
1208 assert!(
1209 content.contains("new_access_token"),
1210 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1211 );
1212 });
1213 Ok(())
1214 });
1215 }
1216
1217 #[test]
1218 fn test_auth_session_debug_fmt() {
1219 let session = OAuthSession {
1220 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1221 "hidden_id",
1222 "hidden_secret",
1223 )),
1224 access_token: Some(SecretAccessToken::from("token")),
1225 auth_server: AuthServer {
1226 client_id: "some_id".into(),
1227 issuer: "some_url".into(),
1228 scopes: None,
1229 },
1230 };
1231
1232 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1233 }
1234}