1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16 oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
17};
18use crate::configuration::{
19 error::DiscoveryError,
20 pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
21 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38 pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43 #[must_use]
45 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46 Self { refresh_token }
47 }
48
49 pub async fn request_access_token(
55 &mut self,
56 auth_server: &AuthServer,
57 ) -> Result<SecretAccessToken, TokenError> {
58 if self.refresh_token.is_empty() {
59 return Err(TokenError::NoRefreshToken);
60 }
61
62 let client = default_http_client()?;
63 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64 .await?
65 .token_endpoint;
66 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67 let resp = client.post(token_url).form(&data).send().await?;
68
69 let RefreshTokenResponse {
70 access_token,
71 refresh_token,
72 } = resp.error_for_status()?.json().await?;
73
74 if let Some(refresh_token) = refresh_token {
75 self.refresh_token = refresh_token;
76 }
77 Ok(access_token)
78 }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83 pub(super) access_token: SecretAccessToken,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94 pub client_id: String,
96 pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101 #[must_use]
102 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104 Self {
105 client_id: client_id.into(),
106 client_secret: client_secret.into(),
107 }
108 }
109
110 #[must_use]
112 pub fn client_id(&self) -> &str {
113 &self.client_id
114 }
115
116 #[must_use]
118 pub const fn client_secret(&self) -> &ClientSecret {
119 &self.client_secret
120 }
121
122 pub async fn request_access_token(
128 &self,
129 auth_server: &AuthServer,
130 ) -> Result<SecretAccessToken, TokenError> {
131 let request = ClientCredentialsRequest::new(None);
132 let client = default_http_client()?;
133
134 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135 .await?
136 .token_endpoint;
137 let ready_to_send = client
138 .post(url)
139 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140 .form(&request);
141 let response = ready_to_send.send().await?;
142
143 response.error_for_status_ref()?;
144
145 let ClientCredentialsResponse { access_token } = response.json().await?;
146 Ok(access_token)
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153 feature = "python",
154 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156pub struct PkceFlow {
158 pub access_token: SecretAccessToken,
160 pub refresh_token: Option<RefreshToken>,
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167 #[error(transparent)]
169 PkceLogin(#[from] PkceLoginError),
170 #[error(transparent)]
172 Discovery(#[from] DiscoveryError),
173 #[error(transparent)]
175 Request(#[from] reqwest::Error),
176}
177
178impl PkceFlow {
179 pub async fn new_login_flow(
185 cancel_token: CancellationToken,
186 auth_server: &AuthServer,
187 ) -> Result<Self, PkceFlowError> {
188 let issuer = auth_server.issuer.clone();
189
190 let client = default_http_client()?;
191 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193 let response = pkce_login(
194 cancel_token,
195 PkceLoginRequest {
196 client_id: auth_server.client_id.clone(),
197 redirect_port: None,
198 discovery,
199 scopes: auth_server.scopes.clone(),
200 },
201 )
202 .await?;
203
204 Ok(Self {
205 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206 refresh_token: response
207 .refresh_token()
208 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209 })
210 }
211
212 pub async fn request_access_token(
218 &mut self,
219 auth_server: &AuthServer,
220 ) -> Result<SecretAccessToken, TokenError> {
221 if insecure_validate_token_exp(&self.access_token).is_ok() {
222 return Ok(self.access_token.clone());
223 }
224
225 if let Some(refresh_token) = &mut self.refresh_token {
226 let access_token = refresh_token.request_access_token(auth_server).await?;
227 self.access_token.clone_from(&access_token);
228 return Ok(access_token);
229 }
230
231 Err(TokenError::NoRefreshToken)
232 }
233}
234
235impl From<PkceFlow> for Credential {
236 fn from(value: PkceFlow) -> Self {
237 let mut token_payload = TokenPayload::default();
238 token_payload.access_token = Some(value.access_token);
239 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241 Self {
242 token_payload: Some(token_payload),
243 }
244 }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249pub enum OAuthGrant {
252 RefreshToken(RefreshToken),
254 ClientCredentials(ClientCredentials),
256 ExternallyManaged(ExternallyManaged),
258 PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263 fn from(v: ExternallyManaged) -> Self {
264 Self::ExternallyManaged(v)
265 }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269 fn from(v: ClientCredentials) -> Self {
270 Self::ClientCredentials(v)
271 }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275 fn from(v: RefreshToken) -> Self {
276 Self::RefreshToken(v)
277 }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281 fn from(v: PkceFlow) -> Self {
282 Self::PkceFlow(v)
283 }
284}
285
286impl OAuthGrant {
287 async fn request_access_token(
289 &mut self,
290 auth_server: &AuthServer,
291 ) -> Result<SecretAccessToken, TokenError> {
292 match self {
293 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295 Self::ExternallyManaged(tokens) => tokens
296 .request_access_token(auth_server)
297 .await
298 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300 }
301 }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 match self {
307 Self::RefreshToken(_) => f.write_str("RefreshToken"),
308 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310 Self::PkceFlow(_) => f.write_str("PkceTokens"),
311 }
312 }
313}
314
315#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329 feature = "python",
330 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333 payload: OAuthGrant,
335 access_token: Option<SecretAccessToken>,
337 auth_server: AuthServer,
339}
340
341impl OAuthSession {
342 #[must_use]
347 pub const fn new(
348 payload: OAuthGrant,
349 auth_server: AuthServer,
350 access_token: Option<SecretAccessToken>,
351 ) -> Self {
352 Self {
353 payload,
354 access_token,
355 auth_server,
356 }
357 }
358
359 #[must_use]
364 pub const fn from_externally_managed(
365 tokens: ExternallyManaged,
366 auth_server: AuthServer,
367 access_token: Option<SecretAccessToken>,
368 ) -> Self {
369 Self::new(
370 OAuthGrant::ExternallyManaged(tokens),
371 auth_server,
372 access_token,
373 )
374 }
375
376 #[must_use]
381 pub const fn from_refresh_token(
382 tokens: RefreshToken,
383 auth_server: AuthServer,
384 access_token: Option<SecretAccessToken>,
385 ) -> Self {
386 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387 }
388
389 #[must_use]
394 pub const fn from_client_credentials(
395 tokens: ClientCredentials,
396 auth_server: AuthServer,
397 access_token: Option<SecretAccessToken>,
398 ) -> Self {
399 Self::new(
400 OAuthGrant::ClientCredentials(tokens),
401 auth_server,
402 access_token,
403 )
404 }
405
406 #[must_use]
411 pub const fn from_pkce_flow(
412 flow: PkceFlow,
413 auth_server: AuthServer,
414 access_token: Option<SecretAccessToken>,
415 ) -> Self {
416 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417 }
418
419 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429 }
430
431 #[must_use]
433 pub const fn payload(&self) -> &OAuthGrant {
434 &self.payload
435 }
436
437 #[allow(clippy::missing_panics_doc)]
443 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444 let access_token = self.payload.request_access_token(&self.auth_server).await?;
445 Ok(self.access_token.insert(access_token))
446 }
447
448 #[must_use]
450 pub const fn auth_server(&self) -> &AuthServer {
451 &self.auth_server
452 }
453
454 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462 let access_token = self.access_token()?;
463 insecure_validate_token_exp(access_token)?;
464 Ok(access_token.clone())
465 }
466}
467
468pub(crate) fn insecure_validate_token_exp(
472 access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474 let placeholder_key = DecodingKey::from_secret(&[]);
475 let mut validation = Validation::new(Algorithm::RS256);
476 validation.validate_exp = true;
477 validation.leeway = 60;
478 validation.validate_aud = false;
479 validation.insecure_disable_signature_validation();
480
481 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482 .map(|_| ())
483 .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 let token_populated = if self.access_token.is_some() {
489 Some(())
490 } else {
491 None
492 };
493 f.debug_struct("OAuthSession")
494 .field("payload", &self.payload)
495 .field("access_token", &token_populated)
496 .field("auth_server", &self.auth_server)
497 .finish()
498 }
499}
500
501#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505 feature = "python",
506 pyo3::pyclass(module = "qcs_api_client_common.configuration")
507)]
508pub struct TokenDispatcher {
509 lock: Arc<RwLock<OAuthSession>>,
510 refreshing: Arc<Mutex<bool>>,
511 notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515 fn from(value: OAuthSession) -> Self {
516 Self {
517 lock: Arc::new(RwLock::new(value)),
518 refreshing: Arc::new(Mutex::new(false)),
519 notify_refreshed: Arc::new(Notify::new()),
520 }
521 }
522}
523
524impl TokenDispatcher {
525 pub async fn use_tokens<F, O>(&self, f: F) -> O
535 where
536 F: FnOnce(&OAuthSession) -> O + Send,
537 {
538 let tokens = self.lock.read().await;
539 f(&tokens)
540 }
541
542 #[must_use]
544 pub async fn tokens(&self) -> OAuthSession {
545 self.use_tokens(Clone::clone).await
546 }
547
548 pub async fn refresh(
554 &self,
555 source: &ConfigSource,
556 profile: &str,
557 ) -> Result<OAuthSession, TokenError> {
558 self.managed_refresh(Self::perform_refresh, source, profile)
559 .await
560 }
561
562 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570 self.use_tokens(OAuthSession::validate).await
571 }
572
573 async fn managed_refresh<F, Fut>(
576 &self,
577 refresh_fn: F,
578 source: &ConfigSource,
579 profile: &str,
580 ) -> Result<OAuthSession, TokenError>
581 where
582 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584 {
585 let mut is_refreshing = self.refreshing.lock().await;
586
587 if *is_refreshing {
588 drop(is_refreshing);
589 self.notify_refreshed.notified().await;
590 return Ok(self.tokens().await);
591 }
592
593 *is_refreshing = true;
594 drop(is_refreshing);
595
596 let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598 if let ConfigSource::File {
600 settings_path: _,
601 secrets_path,
602 } = source
603 {
604 if !Secrets::is_read_only(secrets_path).await? {
605 let refresh_token = match &oauth_session.payload {
607 OAuthGrant::PkceFlow(payload) => {
608 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
609 }
610 _ => None,
611 };
612
613 let now = OffsetDateTime::now_utc();
614 Secrets::write_tokens(
615 secrets_path,
616 profile,
617 refresh_token,
618 oauth_session.access_token()?,
619 now,
620 )
621 .await?;
622 }
623 }
624
625 *self.refreshing.lock().await = false;
626 self.notify_refreshed.notify_waiters();
627 Ok(oauth_session)
628 }
629
630 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
637 let mut credentials = lock.write().await;
638 credentials.request_access_token().await?;
639 Ok(credentials.clone())
640 }
641}
642
643pub(crate) type RefreshResult =
644 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
645
646pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
648
649#[derive(Clone)]
654#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
655#[cfg_attr(
656 feature = "python",
657 pyo3::pyclass(module = "qcs_api_client_common.configuration")
658)]
659pub struct ExternallyManaged {
660 refresh_function: Arc<RefreshFunction>,
661}
662
663impl ExternallyManaged {
664 pub fn new(
689 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
690 ) -> Self {
691 Self {
692 refresh_function: Arc::new(Box::new(refresh_function)),
693 }
694 }
695
696 pub fn from_async<F, Fut>(refresh_function: F) -> Self
729 where
730 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
731 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
732 + Send
733 + 'static,
734 {
735 Self {
736 refresh_function: Arc::new(Box::new(move |auth_server| {
737 Box::pin(refresh_function(auth_server))
738 })),
739 }
740 }
741
742 pub fn from_sync(
773 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
774 + Send
775 + Sync
776 + 'static,
777 ) -> Self {
778 Self {
779 refresh_function: Arc::new(Box::new(move |auth_server| {
780 let result = refresh_function(auth_server);
781 Box::pin(async move { result })
782 })),
783 }
784 }
785
786 pub async fn request_access_token(
792 &self,
793 auth_server: &AuthServer,
794 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
795 (self.refresh_function)(auth_server.clone())
796 .await
797 .map(SecretAccessToken::from)
798 }
799}
800
801impl std::fmt::Debug for ExternallyManaged {
802 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
803 f.debug_struct("ExternallyManaged")
804 .field(
805 "refresh_function",
806 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
807 )
808 .finish()
809 }
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813pub(super) struct TokenRefreshRequest<'a> {
814 grant_type: &'static str,
815 client_id: &'a str,
816 refresh_token: &'a str,
817}
818
819impl<'a> TokenRefreshRequest<'a> {
820 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
821 Self {
822 grant_type: "refresh_token",
823 client_id,
824 refresh_token,
825 }
826 }
827}
828
829#[derive(Debug, Serialize, Deserialize)]
830pub(super) struct ClientCredentialsRequest {
831 grant_type: &'static str,
832 scope: Option<&'static str>,
833}
834
835impl ClientCredentialsRequest {
836 pub(super) const fn new(scope: Option<&'static str>) -> Self {
837 Self {
838 grant_type: "client_credentials",
839 scope,
840 }
841 }
842}
843
844#[derive(Deserialize, Debug, Serialize)]
845pub(super) struct RefreshTokenResponse {
846 pub(super) refresh_token: Option<SecretRefreshToken>,
847 pub(super) access_token: SecretAccessToken,
848}
849
850#[async_trait::async_trait]
852pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
853 type Error;
856
857 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
859
860 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
862
863 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
865
866 #[cfg(feature = "tracing")]
868 fn base_url(&self) -> &str;
869
870 #[cfg(feature = "tracing-config")]
872 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
873
874 #[cfg(feature = "tracing")]
877 #[allow(clippy::needless_return)]
878 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
879 #[cfg(not(feature = "tracing-config"))]
880 {
881 let _ = url;
882 return true;
883 }
884
885 #[cfg(feature = "tracing-config")]
886 self.tracing_configuration()
887 .is_none_or(|config| config.is_enabled(url))
888 }
889}
890
891#[async_trait::async_trait]
892impl TokenRefresher for ClientConfiguration {
893 type Error = TokenError;
894
895 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
896 self.get_bearer_access_token().await
897 }
898
899 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
900 Ok(self.refresh().await?.access_token()?.clone())
901 }
902
903 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
904 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
905 }
906
907 #[cfg(feature = "tracing")]
908 fn base_url(&self) -> &str {
909 &self.grpc_api_url
910 }
911
912 #[cfg(feature = "tracing-config")]
913 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
914 self.tracing_configuration.as_ref()
915 }
916}
917
918pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
920 reqwest::Client::builder()
921 .timeout(std::time::Duration::from_secs(10))
922 .build()
923}
924
925#[cfg(test)]
926mod test {
927 use std::time::Duration;
928
929 use super::*;
930 use httpmock::prelude::*;
931 use rstest::rstest;
932 use time::format_description::well_known::Rfc3339;
933 use tokio::time::Instant;
934 use toml_edit::DocumentMut;
935
936 #[tokio::test]
937 async fn test_tokens_blocked_during_refresh() {
938 let mock_server = MockServer::start_async().await;
939
940 let oidc_mock = mock_server
941 .mock_async(|when, then| {
942 when.method(GET).path("/.well-known/openid-configuration");
943 then.status(200)
944 .json_body_obj(&oidc::Discovery::new_for_test(
945 mock_server.base_url().parse().unwrap(),
946 ));
947 })
948 .await;
949
950 let issuer_mock = mock_server
951 .mock_async(|when, then| {
952 when.method(POST).path("/v1/token");
953
954 then.status(200)
955 .delay(Duration::from_secs(3))
956 .json_body_obj(&RefreshTokenResponse {
957 access_token: SecretAccessToken::from("new_access"),
958 refresh_token: Some(SecretRefreshToken::from("new_refresh")),
959 });
960 })
961 .await;
962
963 let original_tokens = OAuthSession::from_refresh_token(
964 RefreshToken::new(SecretRefreshToken::from("refresh")),
965 AuthServer {
966 client_id: "client_id".to_string(),
967 issuer: mock_server.base_url(),
968 scopes: None,
969 },
970 None,
971 );
972 let dispatcher: TokenDispatcher = original_tokens.clone().into();
973 let dispatcher_clone1 = dispatcher.clone();
974 let dispatcher_clone2 = dispatcher.clone();
975
976 let refresh_duration = Duration::from_secs(3);
977
978 let start_write = Instant::now();
979 let write_future = tokio::spawn(async move {
980 dispatcher_clone1
981 .refresh(&ConfigSource::Default, "")
982 .await
983 .unwrap()
984 });
985
986 let start_read = Instant::now();
987 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
988
989 let _ = write_future.await.unwrap();
990 let read_result = read_future.await.unwrap();
991
992 let write_duration = start_write.elapsed();
993 let read_duration = start_read.elapsed();
994
995 oidc_mock.assert_async().await;
996 issuer_mock.assert_async().await;
997
998 assert!(
999 write_duration >= refresh_duration,
1000 "Write operation did not take enough time"
1001 );
1002 assert!(
1003 read_duration >= refresh_duration,
1004 "Read operation was not blocked by the write operation"
1005 );
1006 assert_eq!(
1007 read_result.access_token.unwrap(),
1008 SecretAccessToken::from("new_access")
1009 );
1010 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1011 assert_eq!(
1012 payload.refresh_token,
1013 SecretRefreshToken::from("new_refresh")
1014 );
1015 } else {
1016 panic!(
1017 "Expected RefreshToken payload, got {:?}",
1018 read_result.payload
1019 );
1020 }
1021 }
1022
1023 #[rstest]
1024 fn test_qcs_secrets_readonly(
1025 #[values(
1026 (Some("TRUE"), true),
1027 (Some("tRue"), true),
1028 (Some("true"), true),
1029 (Some("YES"), true),
1030 (Some("yEs"), true),
1031 (Some("yes"), true),
1032 (Some("1"), true),
1033 (Some("2"), false),
1034 (Some("other"), false),
1035 (Some(""), false),
1036 (None, false),
1037 )]
1038 read_only_values: (Option<&str>, bool),
1039 #[values(true, false)] read_only_perm: bool,
1040 ) {
1041 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1042 let expected_update = !env_is_read_only && !read_only_perm;
1043 figment::Jail::expect_with(|jail| {
1044 let profile_name = "test";
1045 let initial_access_token = "initial_access_token";
1046 let initial_refresh_token = "initial_refresh_token";
1047
1048 let initial_secrets_file_contents = format!(
1049 r#"
1050[credentials]
1051[credentials.{profile_name}]
1052[credentials.{profile_name}.token_payload]
1053access_token = "{initial_access_token}"
1054expires_in = 3600
1055id_token = "id_token"
1056refresh_token = "{initial_refresh_token}"
1057scope = "offline_access openid profile email"
1058token_type = "Bearer"
1059updated_at = "2024-01-01T00:00:00Z"
1060"#
1061 );
1062
1063 jail.clear_env();
1065
1066 let secrets_path = "secrets.toml";
1068 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1069 .expect("should create test secrets.toml");
1070
1071 if read_only_perm {
1072 let mut permissions = std::fs::metadata(secrets_path)
1073 .expect("Should be able to get file metadata")
1074 .permissions();
1075 permissions.set_readonly(true);
1076 std::fs::set_permissions(secrets_path, permissions)
1077 .expect("Should be able to set file permissions");
1078 }
1079
1080 let rt = tokio::runtime::Runtime::new().unwrap();
1081 rt.block_on(async {
1082 let mock_server = MockServer::start_async().await;
1083
1084 let oidc_mock = mock_server
1085 .mock_async(|when, then| {
1086 when.method(GET).path("/.well-known/openid-configuration");
1087 then.status(200)
1088 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1089 })
1090 .await;
1091
1092 let new_access_token = SecretAccessToken::from("new_access_token");
1094 let issuer_mock = mock_server
1095 .mock_async(|when, then| {
1096 when.method(POST).path("/v1/token");
1097 then.status(200).json_body_obj(&RefreshTokenResponse {
1098 access_token: new_access_token.clone(),
1099 refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1100 });
1101 })
1102 .await;
1103
1104 let original_tokens = OAuthSession::from_refresh_token(
1106 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1107 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1108 Some(SecretAccessToken::from(initial_refresh_token)),
1109 );
1110 let dispatcher: TokenDispatcher = original_tokens.into();
1111
1112 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1114 jail.set_env("QCS_PROFILE_NAME", "test");
1115 if let Some(read_only_env) = maybe_read_only_env {
1116 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1117 }
1118
1119 let before_refresh = OffsetDateTime::now_utc();
1120
1121 dispatcher
1122 .refresh(
1123 &ConfigSource::File {
1124 settings_path: "".into(),
1125 secrets_path: "secrets.toml".into(),
1126 },
1127 profile_name,
1128 )
1129 .await
1130 .unwrap();
1131
1132 oidc_mock.assert_async().await;
1133 issuer_mock.assert_async().await;
1134
1135 let content = std::fs::read_to_string("secrets.toml").unwrap();
1137 if !expected_update {
1138 assert!(
1139 content.eq(initial_secrets_file_contents.as_str()),
1140 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1141 );
1142 return;
1143 }
1144
1145 let mut toml = std::fs::read_to_string(secrets_path)
1147 .unwrap()
1148 .parse::<DocumentMut>()
1149 .unwrap();
1150
1151 let token_payload = toml
1152 .get_mut("credentials")
1153 .and_then(|credentials| {
1154 credentials.get_mut(profile_name)?.get_mut("token_payload")
1155 })
1156 .expect("Should be able to get token_payload table");
1157
1158 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1159
1160 assert_eq!(
1161 access_token,
1162 Some(new_access_token)
1163 );
1164
1165 assert!(
1166 OffsetDateTime::parse(
1167 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1168 &Rfc3339
1169 )
1170 .unwrap()
1171 > before_refresh
1172 );
1173
1174 let content = std::fs::read_to_string("secrets.toml").unwrap();
1175 assert!(
1176 content.contains("new_access_token"),
1177 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1178 );
1179 });
1180 Ok(())
1181 });
1182 }
1183
1184 #[test]
1185 fn test_auth_session_debug_fmt() {
1186 let session = OAuthSession {
1187 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1188 "hidden_id",
1189 "hidden_secret",
1190 )),
1191 access_token: Some(SecretAccessToken::from("token")),
1192 auth_server: AuthServer {
1193 client_id: "some_id".into(),
1194 issuer: "some_url".into(),
1195 scopes: None,
1196 },
1197 };
1198
1199 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1200 }
1201}