1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13 oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
14};
15use crate::configuration::{
16 error::DiscoveryError,
17 pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
18 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
19};
20#[cfg(feature = "tracing-config")]
21use crate::tracing_configuration::TracingConfiguration;
22#[cfg(feature = "tracing")]
23use urlpattern::UrlPatternMatchInput;
24
25pub use super::secret_string::ClientSecret;
26
27#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[cfg_attr(feature = "python", pyo3::pyclass)]
30pub struct RefreshToken {
31 pub refresh_token: SecretRefreshToken,
33}
34
35impl RefreshToken {
36 #[must_use]
38 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
39 Self { refresh_token }
40 }
41
42 pub async fn request_access_token(
48 &mut self,
49 auth_server: &AuthServer,
50 ) -> Result<SecretAccessToken, TokenError> {
51 if self.refresh_token.is_empty() {
52 return Err(TokenError::NoRefreshToken);
53 }
54
55 let client = default_http_client()?;
56 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
57 .await?
58 .token_endpoint;
59 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
60 let resp = client.post(token_url).form(&data).send().await?;
61
62 let RefreshTokenResponse {
63 access_token,
64 refresh_token,
65 } = resp.error_for_status()?.json().await?;
66
67 self.refresh_token = refresh_token;
68 Ok(access_token)
69 }
70}
71
72#[derive(Deserialize, Debug, Serialize)]
73pub(super) struct ClientCredentialsResponse {
74 pub(super) access_token: SecretAccessToken,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
78#[cfg_attr(feature = "python", pyo3::pyclass)]
79pub struct ClientCredentials {
81 pub client_id: String,
83 pub client_secret: ClientSecret,
85}
86
87impl ClientCredentials {
88 #[must_use]
89 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
91 Self {
92 client_id: client_id.into(),
93 client_secret: client_secret.into(),
94 }
95 }
96
97 #[must_use]
99 pub fn client_id(&self) -> &str {
100 &self.client_id
101 }
102
103 #[must_use]
105 pub const fn client_secret(&self) -> &ClientSecret {
106 &self.client_secret
107 }
108
109 pub async fn request_access_token(
115 &self,
116 auth_server: &AuthServer,
117 ) -> Result<SecretAccessToken, TokenError> {
118 let request = ClientCredentialsRequest::new(None);
119 let client = default_http_client()?;
120
121 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
122 .await?
123 .token_endpoint;
124 let ready_to_send = client
125 .post(url)
126 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
127 .form(&request);
128 let response = ready_to_send.send().await?;
129
130 response.error_for_status_ref()?;
131
132 let ClientCredentialsResponse { access_token } = response.json().await?;
133 Ok(access_token)
134 }
135}
136
137#[derive(Clone, PartialEq, Eq, Deserialize)]
138#[expect(missing_debug_implementations, reason = "contains secret data")]
139#[cfg_attr(feature = "python", pyo3::pyclass)]
140pub struct PkceFlow {
142 pub access_token: SecretAccessToken,
144 pub refresh_token: Option<RefreshToken>,
146}
147
148#[derive(Debug, thiserror::Error)]
150pub enum PkceFlowError {
151 #[error(transparent)]
153 PkceLogin(#[from] PkceLoginError),
154 #[error(transparent)]
156 Discovery(#[from] DiscoveryError),
157 #[error(transparent)]
159 Request(#[from] reqwest::Error),
160}
161
162impl PkceFlow {
163 pub async fn new_login_flow(
169 cancel_token: CancellationToken,
170 auth_server: &AuthServer,
171 ) -> Result<Self, PkceFlowError> {
172 let issuer = auth_server.issuer.clone();
173
174 let client = default_http_client()?;
175 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
176
177 let response = pkce_login(
178 cancel_token,
179 PkceLoginRequest {
180 client_id: auth_server.client_id.clone(),
181 redirect_port: None,
182 discovery,
183 },
184 )
185 .await?;
186
187 Ok(Self {
188 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
189 refresh_token: response
190 .refresh_token()
191 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
192 })
193 }
194
195 pub async fn request_access_token(
201 &mut self,
202 auth_server: &AuthServer,
203 ) -> Result<SecretAccessToken, TokenError> {
204 if insecure_validate_token_exp(&self.access_token).is_ok() {
205 return Ok(self.access_token.clone());
206 }
207
208 if let Some(refresh_token) = &mut self.refresh_token {
209 let access_token = refresh_token.request_access_token(auth_server).await?;
210 self.access_token.clone_from(&access_token);
211 return Ok(access_token);
212 }
213
214 Err(TokenError::NoRefreshToken)
215 }
216}
217
218impl From<PkceFlow> for Credential {
219 fn from(value: PkceFlow) -> Self {
220 let mut token_payload = TokenPayload::default();
221 token_payload.access_token = Some(value.access_token);
222 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
223
224 Self {
225 token_payload: Some(token_payload),
226 }
227 }
228}
229
230#[derive(Clone)]
231#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
232pub enum OAuthGrant {
235 RefreshToken(RefreshToken),
237 ClientCredentials(ClientCredentials),
239 ExternallyManaged(ExternallyManaged),
241 PkceFlow(PkceFlow),
243}
244
245impl From<ExternallyManaged> for OAuthGrant {
246 fn from(v: ExternallyManaged) -> Self {
247 Self::ExternallyManaged(v)
248 }
249}
250
251impl From<ClientCredentials> for OAuthGrant {
252 fn from(v: ClientCredentials) -> Self {
253 Self::ClientCredentials(v)
254 }
255}
256
257impl From<RefreshToken> for OAuthGrant {
258 fn from(v: RefreshToken) -> Self {
259 Self::RefreshToken(v)
260 }
261}
262
263impl From<PkceFlow> for OAuthGrant {
264 fn from(v: PkceFlow) -> Self {
265 Self::PkceFlow(v)
266 }
267}
268
269impl OAuthGrant {
270 async fn request_access_token(
272 &mut self,
273 auth_server: &AuthServer,
274 ) -> Result<SecretAccessToken, TokenError> {
275 match self {
276 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
277 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
278 Self::ExternallyManaged(tokens) => tokens
279 .request_access_token(auth_server)
280 .await
281 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
282 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
283 }
284 }
285}
286
287impl std::fmt::Debug for OAuthGrant {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self {
290 Self::RefreshToken(_) => f.write_str("RefreshToken"),
291 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
292 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
293 Self::PkceFlow(_) => f.write_str("PkceTokens"),
294 }
295 }
296}
297
298#[derive(Clone)]
310#[cfg_attr(feature = "python", pyo3::pyclass)]
311pub struct OAuthSession {
312 payload: OAuthGrant,
314 access_token: Option<SecretAccessToken>,
316 auth_server: AuthServer,
318}
319
320impl OAuthSession {
321 #[must_use]
326 pub const fn new(
327 payload: OAuthGrant,
328 auth_server: AuthServer,
329 access_token: Option<SecretAccessToken>,
330 ) -> Self {
331 Self {
332 payload,
333 access_token,
334 auth_server,
335 }
336 }
337
338 #[must_use]
343 pub const fn from_externally_managed(
344 tokens: ExternallyManaged,
345 auth_server: AuthServer,
346 access_token: Option<SecretAccessToken>,
347 ) -> Self {
348 Self::new(
349 OAuthGrant::ExternallyManaged(tokens),
350 auth_server,
351 access_token,
352 )
353 }
354
355 #[must_use]
360 pub const fn from_refresh_token(
361 tokens: RefreshToken,
362 auth_server: AuthServer,
363 access_token: Option<SecretAccessToken>,
364 ) -> Self {
365 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
366 }
367
368 #[must_use]
373 pub const fn from_client_credentials(
374 tokens: ClientCredentials,
375 auth_server: AuthServer,
376 access_token: Option<SecretAccessToken>,
377 ) -> Self {
378 Self::new(
379 OAuthGrant::ClientCredentials(tokens),
380 auth_server,
381 access_token,
382 )
383 }
384
385 #[must_use]
390 pub const fn from_pkce_flow(
391 flow: PkceFlow,
392 auth_server: AuthServer,
393 access_token: Option<SecretAccessToken>,
394 ) -> Self {
395 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
396 }
397
398 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
407 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
408 }
409
410 #[must_use]
412 pub const fn payload(&self) -> &OAuthGrant {
413 &self.payload
414 }
415
416 #[allow(clippy::missing_panics_doc)]
422 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
423 let access_token = self.payload.request_access_token(&self.auth_server).await?;
424 Ok(self.access_token.insert(access_token))
425 }
426
427 #[must_use]
429 pub const fn auth_server(&self) -> &AuthServer {
430 &self.auth_server
431 }
432
433 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
441 let access_token = self.access_token()?;
442 insecure_validate_token_exp(access_token)?;
443 Ok(access_token.clone())
444 }
445}
446
447pub(crate) fn insecure_validate_token_exp(
451 access_token: &SecretAccessToken,
452) -> Result<(), TokenError> {
453 let placeholder_key = DecodingKey::from_secret(&[]);
454 let mut validation = Validation::new(Algorithm::RS256);
455 validation.validate_exp = true;
456 validation.leeway = 60;
457 validation.validate_aud = false;
458 validation.insecure_disable_signature_validation();
459
460 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
461 .map(|_| ())
462 .map_err(TokenError::InvalidAccessToken)
463}
464
465impl std::fmt::Debug for OAuthSession {
466 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 let token_populated = if self.access_token.is_some() {
468 Some(())
469 } else {
470 None
471 };
472 f.debug_struct("OAuthSession")
473 .field("payload", &self.payload)
474 .field("access_token", &token_populated)
475 .field("auth_server", &self.auth_server)
476 .finish()
477 }
478}
479
480#[derive(Clone, Debug)]
482#[cfg_attr(feature = "python", pyo3::pyclass)]
483pub struct TokenDispatcher {
484 lock: Arc<RwLock<OAuthSession>>,
485 refreshing: Arc<Mutex<bool>>,
486 notify_refreshed: Arc<Notify>,
487}
488
489impl From<OAuthSession> for TokenDispatcher {
490 fn from(value: OAuthSession) -> Self {
491 Self {
492 lock: Arc::new(RwLock::new(value)),
493 refreshing: Arc::new(Mutex::new(false)),
494 notify_refreshed: Arc::new(Notify::new()),
495 }
496 }
497}
498
499impl TokenDispatcher {
500 pub async fn use_tokens<F, O>(&self, f: F) -> O
510 where
511 F: FnOnce(&OAuthSession) -> O + Send,
512 {
513 let tokens = self.lock.read().await;
514 f(&tokens)
515 }
516
517 #[must_use]
519 pub async fn tokens(&self) -> OAuthSession {
520 self.use_tokens(Clone::clone).await
521 }
522
523 pub async fn refresh(
529 &self,
530 source: &ConfigSource,
531 profile: &str,
532 ) -> Result<OAuthSession, TokenError> {
533 self.managed_refresh(Self::perform_refresh, source, profile)
534 .await
535 }
536
537 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
545 self.use_tokens(OAuthSession::validate).await
546 }
547
548 async fn managed_refresh<F, Fut>(
551 &self,
552 refresh_fn: F,
553 source: &ConfigSource,
554 profile: &str,
555 ) -> Result<OAuthSession, TokenError>
556 where
557 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
558 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
559 {
560 let mut is_refreshing = self.refreshing.lock().await;
561
562 if *is_refreshing {
563 drop(is_refreshing);
564 self.notify_refreshed.notified().await;
565 return Ok(self.tokens().await);
566 }
567
568 *is_refreshing = true;
569 drop(is_refreshing);
570
571 let oauth_session = refresh_fn(self.lock.clone()).await?;
572
573 if let ConfigSource::File {
575 settings_path: _,
576 secrets_path,
577 } = source
578 {
579 if !Secrets::is_read_only(secrets_path).await? {
580 let refresh_token = match &oauth_session.payload {
582 OAuthGrant::PkceFlow(payload) => {
583 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
584 }
585 _ => None,
586 };
587
588 let now = OffsetDateTime::now_utc();
589 Secrets::write_tokens(
590 secrets_path,
591 profile,
592 refresh_token,
593 oauth_session.access_token()?,
594 now,
595 )
596 .await?;
597 }
598 }
599
600 *self.refreshing.lock().await = false;
601 self.notify_refreshed.notify_waiters();
602 Ok(oauth_session)
603 }
604
605 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
612 let mut credentials = lock.write().await;
613 credentials.request_access_token().await?;
614 Ok(credentials.clone())
615 }
616}
617
618pub(crate) type RefreshResult =
619 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
620
621pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
623
624#[derive(Clone)]
629#[cfg_attr(feature = "python", pyo3::pyclass)]
630pub struct ExternallyManaged {
631 refresh_function: Arc<RefreshFunction>,
632}
633
634impl ExternallyManaged {
635 pub fn new(
660 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
661 ) -> Self {
662 Self {
663 refresh_function: Arc::new(Box::new(refresh_function)),
664 }
665 }
666
667 pub fn from_async<F, Fut>(refresh_function: F) -> Self
700 where
701 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
702 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
703 + Send
704 + 'static,
705 {
706 Self {
707 refresh_function: Arc::new(Box::new(move |auth_server| {
708 Box::pin(refresh_function(auth_server))
709 })),
710 }
711 }
712
713 pub fn from_sync(
744 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
745 + Send
746 + Sync
747 + 'static,
748 ) -> Self {
749 Self {
750 refresh_function: Arc::new(Box::new(move |auth_server| {
751 let result = refresh_function(auth_server);
752 Box::pin(async move { result })
753 })),
754 }
755 }
756
757 pub async fn request_access_token(
763 &self,
764 auth_server: &AuthServer,
765 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
766 (self.refresh_function)(auth_server.clone())
767 .await
768 .map(SecretAccessToken::from)
769 }
770}
771
772impl std::fmt::Debug for ExternallyManaged {
773 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
774 f.debug_struct("ExternallyManaged")
775 .field(
776 "refresh_function",
777 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
778 )
779 .finish()
780 }
781}
782
783#[derive(Debug, Serialize, Deserialize)]
784pub(super) struct TokenRefreshRequest<'a> {
785 grant_type: &'static str,
786 client_id: &'a str,
787 refresh_token: &'a str,
788}
789
790impl<'a> TokenRefreshRequest<'a> {
791 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
792 Self {
793 grant_type: "refresh_token",
794 client_id,
795 refresh_token,
796 }
797 }
798}
799
800#[derive(Debug, Serialize, Deserialize)]
801pub(super) struct ClientCredentialsRequest {
802 grant_type: &'static str,
803 scope: Option<&'static str>,
804}
805
806impl ClientCredentialsRequest {
807 pub(super) const fn new(scope: Option<&'static str>) -> Self {
808 Self {
809 grant_type: "client_credentials",
810 scope,
811 }
812 }
813}
814
815#[derive(Deserialize, Debug, Serialize)]
816pub(super) struct RefreshTokenResponse {
817 pub(super) refresh_token: SecretRefreshToken,
818 pub(super) access_token: SecretAccessToken,
819}
820
821#[async_trait::async_trait]
823pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
824 type Error;
827
828 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
830
831 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
833
834 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
836
837 #[cfg(feature = "tracing")]
839 fn base_url(&self) -> &str;
840
841 #[cfg(feature = "tracing-config")]
843 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
844
845 #[cfg(feature = "tracing")]
848 #[allow(clippy::needless_return)]
849 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
850 #[cfg(not(feature = "tracing-config"))]
851 {
852 let _ = url;
853 return true;
854 }
855
856 #[cfg(feature = "tracing-config")]
857 self.tracing_configuration()
858 .is_none_or(|config| config.is_enabled(url))
859 }
860}
861
862#[async_trait::async_trait]
863impl TokenRefresher for ClientConfiguration {
864 type Error = TokenError;
865
866 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
867 self.get_bearer_access_token().await
868 }
869
870 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
871 Ok(self.refresh().await?.access_token()?.clone())
872 }
873
874 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
875 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
876 }
877
878 #[cfg(feature = "tracing")]
879 fn base_url(&self) -> &str {
880 &self.grpc_api_url
881 }
882
883 #[cfg(feature = "tracing-config")]
884 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
885 self.tracing_configuration.as_ref()
886 }
887}
888
889fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
891 reqwest::Client::builder()
892 .timeout(std::time::Duration::from_secs(10))
893 .build()
894}
895
896#[cfg(test)]
897mod test {
898 use std::time::Duration;
899
900 use super::*;
901 use httpmock::prelude::*;
902 use rstest::rstest;
903 use time::format_description::well_known::Rfc3339;
904 use tokio::time::Instant;
905 use toml_edit::DocumentMut;
906
907 #[tokio::test]
908 async fn test_tokens_blocked_during_refresh() {
909 let mock_server = MockServer::start_async().await;
910
911 let oidc_mock = mock_server
912 .mock_async(|when, then| {
913 when.method(GET).path("/.well-known/openid-configuration");
914 then.status(200)
915 .json_body_obj(&oidc::Discovery::new_for_test(
916 mock_server.base_url().parse().unwrap(),
917 ));
918 })
919 .await;
920
921 let issuer_mock = mock_server
922 .mock_async(|when, then| {
923 when.method(POST).path("/v1/token");
924
925 then.status(200)
926 .delay(Duration::from_secs(3))
927 .json_body_obj(&RefreshTokenResponse {
928 access_token: SecretAccessToken::from("new_access"),
929 refresh_token: SecretRefreshToken::from("new_refresh"),
930 });
931 })
932 .await;
933
934 let original_tokens = OAuthSession::from_refresh_token(
935 RefreshToken::new(SecretRefreshToken::from("refresh")),
936 AuthServer {
937 client_id: "client_id".to_string(),
938 issuer: mock_server.base_url(),
939 },
940 None,
941 );
942 let dispatcher: TokenDispatcher = original_tokens.clone().into();
943 let dispatcher_clone1 = dispatcher.clone();
944 let dispatcher_clone2 = dispatcher.clone();
945
946 let refresh_duration = Duration::from_secs(3);
947
948 let start_write = Instant::now();
949 let write_future = tokio::spawn(async move {
950 dispatcher_clone1
951 .refresh(&ConfigSource::Default, "")
952 .await
953 .unwrap()
954 });
955
956 let start_read = Instant::now();
957 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
958
959 let _ = write_future.await.unwrap();
960 let read_result = read_future.await.unwrap();
961
962 let write_duration = start_write.elapsed();
963 let read_duration = start_read.elapsed();
964
965 oidc_mock.assert_async().await;
966 issuer_mock.assert_async().await;
967
968 assert!(
969 write_duration >= refresh_duration,
970 "Write operation did not take enough time"
971 );
972 assert!(
973 read_duration >= refresh_duration,
974 "Read operation was not blocked by the write operation"
975 );
976 assert_eq!(
977 read_result.access_token.unwrap(),
978 SecretAccessToken::from("new_access")
979 );
980 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
981 assert_eq!(
982 payload.refresh_token,
983 SecretRefreshToken::from("new_refresh")
984 );
985 } else {
986 panic!(
987 "Expected RefreshToken payload, got {:?}",
988 read_result.payload
989 );
990 }
991 }
992
993 #[rstest]
994 fn test_qcs_secrets_readonly(
995 #[values(
996 (Some("TRUE"), true),
997 (Some("tRue"), true),
998 (Some("true"), true),
999 (Some("YES"), true),
1000 (Some("yEs"), true),
1001 (Some("yes"), true),
1002 (Some("1"), true),
1003 (Some("2"), false),
1004 (Some("other"), false),
1005 (Some(""), false),
1006 (None, false),
1007 )]
1008 read_only_values: (Option<&str>, bool),
1009 #[values(true, false)] read_only_perm: bool,
1010 ) {
1011 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1012 let expected_update = !env_is_read_only && !read_only_perm;
1013 figment::Jail::expect_with(|jail| {
1014 let profile_name = "test";
1015 let initial_access_token = "initial_access_token";
1016 let initial_refresh_token = "initial_refresh_token";
1017
1018 let initial_secrets_file_contents = format!(
1019 r#"
1020[credentials]
1021[credentials.{profile_name}]
1022[credentials.{profile_name}.token_payload]
1023access_token = "{initial_access_token}"
1024expires_in = 3600
1025id_token = "id_token"
1026refresh_token = "{initial_refresh_token}"
1027scope = "offline_access openid profile email"
1028token_type = "Bearer"
1029updated_at = "2024-01-01T00:00:00Z"
1030"#
1031 );
1032
1033 jail.clear_env();
1035
1036 let secrets_path = "secrets.toml";
1038 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1039 .expect("should create test secrets.toml");
1040
1041 if read_only_perm {
1042 let mut permissions = std::fs::metadata(secrets_path)
1043 .expect("Should be able to get file metadata")
1044 .permissions();
1045 permissions.set_readonly(true);
1046 std::fs::set_permissions(secrets_path, permissions)
1047 .expect("Should be able to set file permissions");
1048 }
1049
1050 let rt = tokio::runtime::Runtime::new().unwrap();
1051 rt.block_on(async {
1052 let mock_server = MockServer::start_async().await;
1053
1054 let oidc_mock = mock_server
1055 .mock_async(|when, then| {
1056 when.method(GET).path("/.well-known/openid-configuration");
1057 then.status(200)
1058 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1059 })
1060 .await;
1061
1062 let new_access_token = SecretAccessToken::from("new_access_token");
1064 let issuer_mock = mock_server
1065 .mock_async(|when, then| {
1066 when.method(POST).path("/v1/token");
1067 then.status(200).json_body_obj(&RefreshTokenResponse {
1068 access_token: new_access_token.clone(),
1069 refresh_token: SecretRefreshToken::from(initial_refresh_token),
1070 });
1071 })
1072 .await;
1073
1074 let original_tokens = OAuthSession::from_refresh_token(
1076 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1077 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url() },
1078 Some(SecretAccessToken::from(initial_refresh_token)),
1079 );
1080 let dispatcher: TokenDispatcher = original_tokens.into();
1081
1082 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1084 jail.set_env("QCS_PROFILE_NAME", "test");
1085 if let Some(read_only_env) = maybe_read_only_env {
1086 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1087 }
1088
1089 let before_refresh = OffsetDateTime::now_utc();
1090
1091 dispatcher
1092 .refresh(
1093 &ConfigSource::File {
1094 settings_path: "".into(),
1095 secrets_path: "secrets.toml".into(),
1096 },
1097 profile_name,
1098 )
1099 .await
1100 .unwrap();
1101
1102 oidc_mock.assert_async().await;
1103 issuer_mock.assert_async().await;
1104
1105 let content = std::fs::read_to_string("secrets.toml").unwrap();
1107 if !expected_update {
1108 assert!(
1109 content.eq(initial_secrets_file_contents.as_str()),
1110 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1111 );
1112 return;
1113 }
1114
1115 let mut toml = std::fs::read_to_string(secrets_path)
1117 .unwrap()
1118 .parse::<DocumentMut>()
1119 .unwrap();
1120
1121 let token_payload = toml
1122 .get_mut("credentials")
1123 .and_then(|credentials| {
1124 credentials.get_mut(profile_name)?.get_mut("token_payload")
1125 })
1126 .expect("Should be able to get token_payload table");
1127
1128 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1129
1130 assert_eq!(
1131 access_token,
1132 Some(new_access_token)
1133 );
1134
1135 assert!(
1136 OffsetDateTime::parse(
1137 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1138 &Rfc3339
1139 )
1140 .unwrap()
1141 > before_refresh
1142 );
1143
1144 let content = std::fs::read_to_string("secrets.toml").unwrap();
1145 assert!(
1146 content.contains("new_access_token"),
1147 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1148 );
1149 });
1150 Ok(())
1151 });
1152 }
1153
1154 #[test]
1155 fn test_auth_session_debug_fmt() {
1156 let session = OAuthSession {
1157 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1158 "hidden_id",
1159 "hidden_secret",
1160 )),
1161 access_token: Some(SecretAccessToken::from("token")),
1162 auth_server: AuthServer {
1163 client_id: "some_id".into(),
1164 issuer: "some_url".into(),
1165 },
1166 };
1167
1168 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
1169 }
1170}