1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16 oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
17};
18use crate::configuration::{
19 error::DiscoveryError,
20 pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
21 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38 pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43 #[must_use]
45 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46 Self { refresh_token }
47 }
48
49 pub async fn request_access_token(
55 &mut self,
56 auth_server: &AuthServer,
57 ) -> Result<SecretAccessToken, TokenError> {
58 if self.refresh_token.is_empty() {
59 return Err(TokenError::NoRefreshToken);
60 }
61
62 let client = default_http_client()?;
63 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64 .await?
65 .token_endpoint;
66 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67 let resp = client.post(token_url).form(&data).send().await?;
68
69 let RefreshTokenResponse {
70 access_token,
71 refresh_token,
72 } = resp.error_for_status()?.json().await?;
73
74 if let Some(refresh_token) = refresh_token {
75 self.refresh_token = refresh_token;
76 }
77 Ok(access_token)
78 }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83 pub(super) access_token: SecretAccessToken,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94 pub client_id: String,
96 pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101 #[must_use]
102 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104 Self {
105 client_id: client_id.into(),
106 client_secret: client_secret.into(),
107 }
108 }
109
110 #[must_use]
112 pub fn client_id(&self) -> &str {
113 &self.client_id
114 }
115
116 #[must_use]
118 pub const fn client_secret(&self) -> &ClientSecret {
119 &self.client_secret
120 }
121
122 pub async fn request_access_token(
128 &self,
129 auth_server: &AuthServer,
130 ) -> Result<SecretAccessToken, TokenError> {
131 let request = ClientCredentialsRequest::new(None);
132 let client = default_http_client()?;
133
134 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135 .await?
136 .token_endpoint;
137 let ready_to_send = client
138 .post(url)
139 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140 .form(&request);
141 let response = ready_to_send.send().await?;
142
143 response.error_for_status_ref()?;
144
145 let ClientCredentialsResponse { access_token } = response.json().await?;
146 Ok(access_token)
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153 feature = "python",
154 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156pub struct PkceFlow {
158 pub access_token: SecretAccessToken,
160 pub refresh_token: Option<RefreshToken>,
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167 #[error(transparent)]
169 PkceLogin(#[from] PkceLoginError),
170 #[error(transparent)]
172 Discovery(#[from] DiscoveryError),
173 #[error(transparent)]
175 Request(#[from] reqwest::Error),
176}
177
178impl PkceFlow {
179 pub async fn new_login_flow(
185 cancel_token: CancellationToken,
186 auth_server: &AuthServer,
187 ) -> Result<Self, PkceFlowError> {
188 let issuer = auth_server.issuer.clone();
189
190 let client = default_http_client()?;
191 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193 let response = pkce_login(
194 cancel_token,
195 PkceLoginRequest {
196 client_id: auth_server.client_id.clone(),
197 redirect_port: None,
198 discovery,
199 scopes: auth_server.scopes.clone(),
200 },
201 )
202 .await?;
203
204 Ok(Self {
205 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206 refresh_token: response
207 .refresh_token()
208 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209 })
210 }
211
212 pub async fn request_access_token(
218 &mut self,
219 auth_server: &AuthServer,
220 ) -> Result<SecretAccessToken, TokenError> {
221 if insecure_validate_token_exp(&self.access_token).is_ok() {
222 return Ok(self.access_token.clone());
223 }
224
225 if let Some(refresh_token) = &mut self.refresh_token {
226 let access_token = refresh_token.request_access_token(auth_server).await?;
227 self.access_token.clone_from(&access_token);
228 return Ok(access_token);
229 }
230
231 Err(TokenError::NoRefreshToken)
232 }
233}
234
235impl From<PkceFlow> for Credential {
236 fn from(value: PkceFlow) -> Self {
237 let mut token_payload = TokenPayload::default();
238 token_payload.access_token = Some(value.access_token);
239 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241 Self {
242 token_payload: Some(token_payload),
243 }
244 }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249pub enum OAuthGrant {
252 RefreshToken(RefreshToken),
254 ClientCredentials(ClientCredentials),
256 ExternallyManaged(ExternallyManaged),
258 PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263 fn from(v: ExternallyManaged) -> Self {
264 Self::ExternallyManaged(v)
265 }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269 fn from(v: ClientCredentials) -> Self {
270 Self::ClientCredentials(v)
271 }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275 fn from(v: RefreshToken) -> Self {
276 Self::RefreshToken(v)
277 }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281 fn from(v: PkceFlow) -> Self {
282 Self::PkceFlow(v)
283 }
284}
285
286impl OAuthGrant {
287 async fn request_access_token(
289 &mut self,
290 auth_server: &AuthServer,
291 ) -> Result<SecretAccessToken, TokenError> {
292 match self {
293 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295 Self::ExternallyManaged(tokens) => tokens
296 .request_access_token(auth_server)
297 .await
298 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300 }
301 }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 match self {
307 Self::RefreshToken(_) => f.write_str("RefreshToken"),
308 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310 Self::PkceFlow(_) => f.write_str("PkceTokens"),
311 }
312 }
313}
314
315#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329 feature = "python",
330 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333 payload: OAuthGrant,
335 access_token: Option<SecretAccessToken>,
337 auth_server: AuthServer,
339}
340
341impl OAuthSession {
342 #[must_use]
347 pub const fn new(
348 payload: OAuthGrant,
349 auth_server: AuthServer,
350 access_token: Option<SecretAccessToken>,
351 ) -> Self {
352 Self {
353 payload,
354 access_token,
355 auth_server,
356 }
357 }
358
359 #[must_use]
364 pub const fn from_externally_managed(
365 tokens: ExternallyManaged,
366 auth_server: AuthServer,
367 access_token: Option<SecretAccessToken>,
368 ) -> Self {
369 Self::new(
370 OAuthGrant::ExternallyManaged(tokens),
371 auth_server,
372 access_token,
373 )
374 }
375
376 #[must_use]
381 pub const fn from_refresh_token(
382 tokens: RefreshToken,
383 auth_server: AuthServer,
384 access_token: Option<SecretAccessToken>,
385 ) -> Self {
386 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387 }
388
389 #[must_use]
394 pub const fn from_client_credentials(
395 tokens: ClientCredentials,
396 auth_server: AuthServer,
397 access_token: Option<SecretAccessToken>,
398 ) -> Self {
399 Self::new(
400 OAuthGrant::ClientCredentials(tokens),
401 auth_server,
402 access_token,
403 )
404 }
405
406 #[must_use]
411 pub const fn from_pkce_flow(
412 flow: PkceFlow,
413 auth_server: AuthServer,
414 access_token: Option<SecretAccessToken>,
415 ) -> Self {
416 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417 }
418
419 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429 }
430
431 #[must_use]
433 pub const fn payload(&self) -> &OAuthGrant {
434 &self.payload
435 }
436
437 #[allow(clippy::missing_panics_doc)]
443 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444 let access_token = self.payload.request_access_token(&self.auth_server).await?;
445 Ok(self.access_token.insert(access_token))
446 }
447
448 #[must_use]
450 pub const fn auth_server(&self) -> &AuthServer {
451 &self.auth_server
452 }
453
454 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462 let access_token = self.access_token()?;
463 insecure_validate_token_exp(access_token)?;
464 Ok(access_token.clone())
465 }
466}
467
468pub(crate) fn insecure_validate_token_exp(
472 access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474 let placeholder_key = DecodingKey::from_secret(&[]);
475 let mut validation = Validation::new(Algorithm::RS256);
476 validation.validate_exp = true;
477 validation.leeway = 60;
478 validation.validate_aud = false;
479 validation.insecure_disable_signature_validation();
480
481 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482 .map(|_| ())
483 .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 let token_populated = if self.access_token.is_some() {
489 Some(())
490 } else {
491 None
492 };
493 f.debug_struct("OAuthSession")
494 .field("payload", &self.payload)
495 .field("access_token", &token_populated)
496 .field("auth_server", &self.auth_server)
497 .finish()
498 }
499}
500
501#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505 feature = "python",
506 pyo3::pyclass(module = "qcs_api_client_common.configuration")
507)]
508pub struct TokenDispatcher {
509 lock: Arc<RwLock<OAuthSession>>,
510 refreshing: Arc<Mutex<bool>>,
511 notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515 fn from(value: OAuthSession) -> Self {
516 Self {
517 lock: Arc::new(RwLock::new(value)),
518 refreshing: Arc::new(Mutex::new(false)),
519 notify_refreshed: Arc::new(Notify::new()),
520 }
521 }
522}
523
524impl TokenDispatcher {
525 pub async fn use_tokens<F, O>(&self, f: F) -> O
535 where
536 F: FnOnce(&OAuthSession) -> O + Send,
537 {
538 let tokens = self.lock.read().await;
539 f(&tokens)
540 }
541
542 #[must_use]
544 pub async fn tokens(&self) -> OAuthSession {
545 self.use_tokens(Clone::clone).await
546 }
547
548 pub async fn refresh(
554 &self,
555 source: &ConfigSource,
556 profile: &str,
557 ) -> Result<OAuthSession, TokenError> {
558 self.managed_refresh(Self::perform_refresh, source, profile)
559 .await
560 }
561
562 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570 self.use_tokens(OAuthSession::validate).await
571 }
572
573 async fn managed_refresh<F, Fut>(
576 &self,
577 refresh_fn: F,
578 source: &ConfigSource,
579 profile: &str,
580 ) -> Result<OAuthSession, TokenError>
581 where
582 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584 {
585 let mut is_refreshing = self.refreshing.lock().await;
586
587 if *is_refreshing {
588 drop(is_refreshing);
589 self.notify_refreshed.notified().await;
590 return Ok(self.tokens().await);
591 }
592
593 *is_refreshing = true;
594 drop(is_refreshing);
595
596 let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598 if let ConfigSource::File {
600 settings_path: _,
601 secrets_path,
602 } = source
603 {
604 if !Secrets::is_read_only(secrets_path).await? {
605 let refresh_token = match &oauth_session.payload {
607 OAuthGrant::PkceFlow(payload) => {
608 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
609 }
610 _ => None,
611 };
612
613 let now = OffsetDateTime::now_utc();
614 Secrets::write_tokens(
615 secrets_path,
616 profile,
617 refresh_token,
618 oauth_session.access_token()?,
619 now,
620 )
621 .await?;
622 }
623 }
624
625 *self.refreshing.lock().await = false;
626 self.notify_refreshed.notify_waiters();
627 Ok(oauth_session)
628 }
629
630 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
637 let mut credentials = lock.write().await;
638 credentials.request_access_token().await?;
639 Ok(credentials.clone())
640 }
641}
642
643pub(crate) type RefreshResult =
644 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
645
646pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
648
649#[derive(Clone)]
654#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
655#[cfg_attr(
656 feature = "python",
657 pyo3::pyclass(module = "qcs_api_client_common.configuration")
658)]
659pub struct ExternallyManaged {
660 refresh_function: Arc<RefreshFunction>,
661}
662
663impl ExternallyManaged {
664 pub fn new(
689 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
690 ) -> Self {
691 Self {
692 refresh_function: Arc::new(Box::new(refresh_function)),
693 }
694 }
695
696 pub fn from_async<F, Fut>(refresh_function: F) -> Self
729 where
730 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
731 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
732 + Send
733 + 'static,
734 {
735 Self {
736 refresh_function: Arc::new(Box::new(move |auth_server| {
737 Box::pin(refresh_function(auth_server))
738 })),
739 }
740 }
741
742 pub fn from_sync(
773 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
774 + Send
775 + Sync
776 + 'static,
777 ) -> Self {
778 Self {
779 refresh_function: Arc::new(Box::new(move |auth_server| {
780 let result = refresh_function(auth_server);
781 Box::pin(async move { result })
782 })),
783 }
784 }
785
786 pub async fn request_access_token(
792 &self,
793 auth_server: &AuthServer,
794 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
795 (self.refresh_function)(auth_server.clone())
796 .await
797 .map(SecretAccessToken::from)
798 }
799}
800
801impl std::fmt::Debug for ExternallyManaged {
802 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
803 f.debug_struct("ExternallyManaged")
804 .field(
805 "refresh_function",
806 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
807 )
808 .finish()
809 }
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813pub(super) struct TokenRefreshRequest<'a> {
814 grant_type: &'static str,
815 client_id: &'a str,
816 refresh_token: &'a str,
817}
818
819impl<'a> TokenRefreshRequest<'a> {
820 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
821 Self {
822 grant_type: "refresh_token",
823 client_id,
824 refresh_token,
825 }
826 }
827}
828
829#[derive(Debug, Serialize, Deserialize)]
830pub(super) struct ClientCredentialsRequest {
831 grant_type: &'static str,
832 scope: Option<&'static str>,
833}
834
835impl ClientCredentialsRequest {
836 pub(super) const fn new(scope: Option<&'static str>) -> Self {
837 Self {
838 grant_type: "client_credentials",
839 scope,
840 }
841 }
842}
843
844#[derive(Deserialize, Debug, Serialize)]
845pub(super) struct RefreshTokenResponse {
846 pub(super) refresh_token: Option<SecretRefreshToken>,
847 pub(super) access_token: SecretAccessToken,
848}
849
850#[async_trait::async_trait]
852pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
853 type Error;
856
857 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
859
860 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
862
863 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
865
866 #[cfg(feature = "tracing")]
868 fn base_url(&self) -> &str;
869
870 #[cfg(feature = "tracing-config")]
872 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
873
874 #[cfg(feature = "tracing")]
877 #[allow(clippy::needless_return)]
878 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
879 #[cfg(not(feature = "tracing-config"))]
880 {
881 let _ = url;
882 return true;
883 }
884
885 #[cfg(feature = "tracing-config")]
886 self.tracing_configuration()
887 .is_none_or(|config| config.is_enabled(url))
888 }
889}
890
891#[async_trait::async_trait]
892impl TokenRefresher for ClientConfiguration {
893 type Error = TokenError;
894
895 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
896 self.get_bearer_access_token().await
897 }
898
899 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
900 Ok(self.refresh().await?.access_token()?.clone())
901 }
902
903 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
904 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
905 }
906
907 #[cfg(feature = "tracing")]
908 fn base_url(&self) -> &str {
909 &self.grpc_api_url
910 }
911
912 #[cfg(feature = "tracing-config")]
913 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
914 self.tracing_configuration.as_ref()
915 }
916}
917
918pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
920 reqwest::Client::builder()
921 .timeout(std::time::Duration::from_secs(10))
922 .build()
923}
924
925#[cfg(test)]
926mod test {
927 #![allow(clippy::result_large_err, reason = "happens in figment tests")]
928
929 use std::time::Duration;
930
931 use super::*;
932 use httpmock::prelude::*;
933 use rstest::rstest;
934 use time::format_description::well_known::Rfc3339;
935 use tokio::time::Instant;
936 use toml_edit::DocumentMut;
937
938 #[tokio::test]
939 async fn test_tokens_blocked_during_refresh() {
940 let mock_server = MockServer::start_async().await;
941
942 let oidc_mock = mock_server
943 .mock_async(|when, then| {
944 when.method(GET).path("/.well-known/openid-configuration");
945 then.status(200)
946 .json_body_obj(&oidc::Discovery::new_for_test(
947 mock_server.base_url().parse().unwrap(),
948 ));
949 })
950 .await;
951
952 let issuer_mock = mock_server
953 .mock_async(|when, then| {
954 when.method(POST).path("/v1/token");
955
956 then.status(200)
957 .delay(Duration::from_secs(3))
958 .json_body_obj(&RefreshTokenResponse {
959 access_token: SecretAccessToken::from("new_access"),
960 refresh_token: Some(SecretRefreshToken::from("new_refresh")),
961 });
962 })
963 .await;
964
965 let original_tokens = OAuthSession::from_refresh_token(
966 RefreshToken::new(SecretRefreshToken::from("refresh")),
967 AuthServer {
968 client_id: "client_id".to_string(),
969 issuer: mock_server.base_url(),
970 scopes: None,
971 },
972 None,
973 );
974 let dispatcher: TokenDispatcher = original_tokens.clone().into();
975 let dispatcher_clone1 = dispatcher.clone();
976 let dispatcher_clone2 = dispatcher.clone();
977
978 let refresh_duration = Duration::from_secs(3);
979
980 let start_write = Instant::now();
981 let write_future = tokio::spawn(async move {
982 dispatcher_clone1
983 .refresh(&ConfigSource::Default, "")
984 .await
985 .unwrap()
986 });
987
988 let start_read = Instant::now();
989 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
990
991 let _ = write_future.await.unwrap();
992 let read_result = read_future.await.unwrap();
993
994 let write_duration = start_write.elapsed();
995 let read_duration = start_read.elapsed();
996
997 oidc_mock.assert_async().await;
998 issuer_mock.assert_async().await;
999
1000 assert!(
1001 write_duration >= refresh_duration,
1002 "Write operation did not take enough time"
1003 );
1004 assert!(
1005 read_duration >= refresh_duration,
1006 "Read operation was not blocked by the write operation"
1007 );
1008 assert_eq!(
1009 read_result.access_token.unwrap(),
1010 SecretAccessToken::from("new_access")
1011 );
1012 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1013 assert_eq!(
1014 payload.refresh_token,
1015 SecretRefreshToken::from("new_refresh")
1016 );
1017 } else {
1018 panic!(
1019 "Expected RefreshToken payload, got {:?}",
1020 read_result.payload
1021 );
1022 }
1023 }
1024
1025 #[rstest]
1026 fn test_qcs_secrets_readonly(
1027 #[values(
1028 (Some("TRUE"), true),
1029 (Some("tRue"), true),
1030 (Some("true"), true),
1031 (Some("YES"), true),
1032 (Some("yEs"), true),
1033 (Some("yes"), true),
1034 (Some("1"), true),
1035 (Some("2"), false),
1036 (Some("other"), false),
1037 (Some(""), false),
1038 (None, false),
1039 )]
1040 read_only_values: (Option<&str>, bool),
1041 #[values(true, false)] read_only_perm: bool,
1042 ) {
1043 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1044 let expected_update = !env_is_read_only && !read_only_perm;
1045 figment::Jail::expect_with(|jail| {
1046 let profile_name = "test";
1047 let initial_access_token = "initial_access_token";
1048 let initial_refresh_token = "initial_refresh_token";
1049
1050 let initial_secrets_file_contents = format!(
1051 r#"
1052[credentials]
1053[credentials.{profile_name}]
1054[credentials.{profile_name}.token_payload]
1055access_token = "{initial_access_token}"
1056expires_in = 3600
1057id_token = "id_token"
1058refresh_token = "{initial_refresh_token}"
1059scope = "offline_access openid profile email"
1060token_type = "Bearer"
1061updated_at = "2024-01-01T00:00:00Z"
1062"#
1063 );
1064
1065 jail.clear_env();
1067
1068 let secrets_path = "secrets.toml";
1070 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1071 .expect("should create test secrets.toml");
1072
1073 if read_only_perm {
1074 let mut permissions = std::fs::metadata(secrets_path)
1075 .expect("Should be able to get file metadata")
1076 .permissions();
1077 permissions.set_readonly(true);
1078 std::fs::set_permissions(secrets_path, permissions)
1079 .expect("Should be able to set file permissions");
1080 }
1081
1082 let rt = tokio::runtime::Runtime::new().unwrap();
1083 rt.block_on(async {
1084 let mock_server = MockServer::start_async().await;
1085
1086 let oidc_mock = mock_server
1087 .mock_async(|when, then| {
1088 when.method(GET).path("/.well-known/openid-configuration");
1089 then.status(200)
1090 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1091 })
1092 .await;
1093
1094 let new_access_token = SecretAccessToken::from("new_access_token");
1096 let issuer_mock = mock_server
1097 .mock_async(|when, then| {
1098 when.method(POST).path("/v1/token");
1099 then.status(200).json_body_obj(&RefreshTokenResponse {
1100 access_token: new_access_token.clone(),
1101 refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1102 });
1103 })
1104 .await;
1105
1106 let original_tokens = OAuthSession::from_refresh_token(
1108 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1109 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1110 Some(SecretAccessToken::from(initial_refresh_token)),
1111 );
1112 let dispatcher: TokenDispatcher = original_tokens.into();
1113
1114 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1116 jail.set_env("QCS_PROFILE_NAME", "test");
1117 if let Some(read_only_env) = maybe_read_only_env {
1118 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1119 }
1120
1121 let before_refresh = OffsetDateTime::now_utc();
1122
1123 dispatcher
1124 .refresh(
1125 &ConfigSource::File {
1126 settings_path: "".into(),
1127 secrets_path: "secrets.toml".into(),
1128 },
1129 profile_name,
1130 )
1131 .await
1132 .unwrap();
1133
1134 oidc_mock.assert_async().await;
1135 issuer_mock.assert_async().await;
1136
1137 let content = std::fs::read_to_string("secrets.toml").unwrap();
1139 if !expected_update {
1140 assert!(
1141 content.eq(initial_secrets_file_contents.as_str()),
1142 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1143 );
1144 return;
1145 }
1146
1147 let mut toml = std::fs::read_to_string(secrets_path)
1149 .unwrap()
1150 .parse::<DocumentMut>()
1151 .unwrap();
1152
1153 let token_payload = toml
1154 .get_mut("credentials")
1155 .and_then(|credentials| {
1156 credentials.get_mut(profile_name)?.get_mut("token_payload")
1157 })
1158 .expect("Should be able to get token_payload table");
1159
1160 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1161
1162 assert_eq!(
1163 access_token,
1164 Some(new_access_token)
1165 );
1166
1167 assert!(
1168 OffsetDateTime::parse(
1169 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1170 &Rfc3339
1171 )
1172 .unwrap()
1173 > before_refresh
1174 );
1175
1176 let content = std::fs::read_to_string("secrets.toml").unwrap();
1177 assert!(
1178 content.contains("new_access_token"),
1179 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1180 );
1181 });
1182 Ok(())
1183 });
1184 }
1185
1186 #[test]
1187 fn test_auth_session_debug_fmt() {
1188 let session = OAuthSession {
1189 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1190 "hidden_id",
1191 "hidden_secret",
1192 )),
1193 access_token: Some(SecretAccessToken::from("token")),
1194 auth_server: AuthServer {
1195 client_id: "some_id".into(),
1196 issuer: "some_url".into(),
1197 scopes: None,
1198 },
1199 };
1200
1201 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1202 }
1203}