1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13 oidc, secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
14};
15use crate::configuration::{
16 error::DiscoveryError,
17 pkce::{pkce_login, PkceLoginError, PkceLoginRequest},
18 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
19};
20#[cfg(feature = "tracing-config")]
21use crate::tracing_configuration::TracingConfiguration;
22#[cfg(feature = "tracing")]
23use urlpattern::UrlPatternMatchInput;
24
25pub use super::secret_string::ClientSecret;
26
27#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[cfg_attr(feature = "python", pyo3::pyclass)]
30pub struct RefreshToken {
31 pub refresh_token: SecretRefreshToken,
33}
34
35impl RefreshToken {
36 #[must_use]
38 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
39 Self { refresh_token }
40 }
41
42 pub async fn request_access_token(
48 &mut self,
49 auth_server: &AuthServer,
50 ) -> Result<SecretAccessToken, TokenError> {
51 if self.refresh_token.is_empty() {
52 return Err(TokenError::NoRefreshToken);
53 }
54
55 let client = default_http_client()?;
56 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
57 .await?
58 .token_endpoint;
59 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
60 let resp = client.post(token_url).form(&data).send().await?;
61
62 let RefreshTokenResponse {
63 access_token,
64 refresh_token,
65 } = resp.error_for_status()?.json().await?;
66
67 if let Some(refresh_token) = refresh_token {
68 self.refresh_token = refresh_token;
69 }
70 Ok(access_token)
71 }
72}
73
74#[derive(Deserialize, Debug, Serialize)]
75pub(super) struct ClientCredentialsResponse {
76 pub(super) access_token: SecretAccessToken,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
80#[cfg_attr(feature = "python", pyo3::pyclass)]
81pub struct ClientCredentials {
83 pub client_id: String,
85 pub client_secret: ClientSecret,
87}
88
89impl ClientCredentials {
90 #[must_use]
91 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
93 Self {
94 client_id: client_id.into(),
95 client_secret: client_secret.into(),
96 }
97 }
98
99 #[must_use]
101 pub fn client_id(&self) -> &str {
102 &self.client_id
103 }
104
105 #[must_use]
107 pub const fn client_secret(&self) -> &ClientSecret {
108 &self.client_secret
109 }
110
111 pub async fn request_access_token(
117 &self,
118 auth_server: &AuthServer,
119 ) -> Result<SecretAccessToken, TokenError> {
120 let request = ClientCredentialsRequest::new(None);
121 let client = default_http_client()?;
122
123 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
124 .await?
125 .token_endpoint;
126 let ready_to_send = client
127 .post(url)
128 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
129 .form(&request);
130 let response = ready_to_send.send().await?;
131
132 response.error_for_status_ref()?;
133
134 let ClientCredentialsResponse { access_token } = response.json().await?;
135 Ok(access_token)
136 }
137}
138
139#[derive(Clone, PartialEq, Eq, Deserialize)]
140#[expect(missing_debug_implementations, reason = "contains secret data")]
141#[cfg_attr(feature = "python", pyo3::pyclass)]
142pub struct PkceFlow {
144 pub access_token: SecretAccessToken,
146 pub refresh_token: Option<RefreshToken>,
148}
149
150#[derive(Debug, thiserror::Error)]
152pub enum PkceFlowError {
153 #[error(transparent)]
155 PkceLogin(#[from] PkceLoginError),
156 #[error(transparent)]
158 Discovery(#[from] DiscoveryError),
159 #[error(transparent)]
161 Request(#[from] reqwest::Error),
162}
163
164impl PkceFlow {
165 pub async fn new_login_flow(
171 cancel_token: CancellationToken,
172 auth_server: &AuthServer,
173 ) -> Result<Self, PkceFlowError> {
174 let issuer = auth_server.issuer.clone();
175
176 let client = default_http_client()?;
177 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
178
179 let response = pkce_login(
180 cancel_token,
181 PkceLoginRequest {
182 client_id: auth_server.client_id.clone(),
183 redirect_port: None,
184 discovery,
185 scopes: auth_server.scopes.clone(),
186 },
187 )
188 .await?;
189
190 Ok(Self {
191 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
192 refresh_token: response
193 .refresh_token()
194 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
195 })
196 }
197
198 pub async fn request_access_token(
204 &mut self,
205 auth_server: &AuthServer,
206 ) -> Result<SecretAccessToken, TokenError> {
207 if insecure_validate_token_exp(&self.access_token).is_ok() {
208 return Ok(self.access_token.clone());
209 }
210
211 if let Some(refresh_token) = &mut self.refresh_token {
212 let access_token = refresh_token.request_access_token(auth_server).await?;
213 self.access_token.clone_from(&access_token);
214 return Ok(access_token);
215 }
216
217 Err(TokenError::NoRefreshToken)
218 }
219}
220
221impl From<PkceFlow> for Credential {
222 fn from(value: PkceFlow) -> Self {
223 let mut token_payload = TokenPayload::default();
224 token_payload.access_token = Some(value.access_token);
225 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
226
227 Self {
228 token_payload: Some(token_payload),
229 }
230 }
231}
232
233#[derive(Clone)]
234#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
235pub enum OAuthGrant {
238 RefreshToken(RefreshToken),
240 ClientCredentials(ClientCredentials),
242 ExternallyManaged(ExternallyManaged),
244 PkceFlow(PkceFlow),
246}
247
248impl From<ExternallyManaged> for OAuthGrant {
249 fn from(v: ExternallyManaged) -> Self {
250 Self::ExternallyManaged(v)
251 }
252}
253
254impl From<ClientCredentials> for OAuthGrant {
255 fn from(v: ClientCredentials) -> Self {
256 Self::ClientCredentials(v)
257 }
258}
259
260impl From<RefreshToken> for OAuthGrant {
261 fn from(v: RefreshToken) -> Self {
262 Self::RefreshToken(v)
263 }
264}
265
266impl From<PkceFlow> for OAuthGrant {
267 fn from(v: PkceFlow) -> Self {
268 Self::PkceFlow(v)
269 }
270}
271
272impl OAuthGrant {
273 async fn request_access_token(
275 &mut self,
276 auth_server: &AuthServer,
277 ) -> Result<SecretAccessToken, TokenError> {
278 match self {
279 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
280 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
281 Self::ExternallyManaged(tokens) => tokens
282 .request_access_token(auth_server)
283 .await
284 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
285 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
286 }
287 }
288}
289
290impl std::fmt::Debug for OAuthGrant {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match self {
293 Self::RefreshToken(_) => f.write_str("RefreshToken"),
294 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
295 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
296 Self::PkceFlow(_) => f.write_str("PkceTokens"),
297 }
298 }
299}
300
301#[derive(Clone)]
313#[cfg_attr(feature = "python", pyo3::pyclass)]
314pub struct OAuthSession {
315 payload: OAuthGrant,
317 access_token: Option<SecretAccessToken>,
319 auth_server: AuthServer,
321}
322
323impl OAuthSession {
324 #[must_use]
329 pub const fn new(
330 payload: OAuthGrant,
331 auth_server: AuthServer,
332 access_token: Option<SecretAccessToken>,
333 ) -> Self {
334 Self {
335 payload,
336 access_token,
337 auth_server,
338 }
339 }
340
341 #[must_use]
346 pub const fn from_externally_managed(
347 tokens: ExternallyManaged,
348 auth_server: AuthServer,
349 access_token: Option<SecretAccessToken>,
350 ) -> Self {
351 Self::new(
352 OAuthGrant::ExternallyManaged(tokens),
353 auth_server,
354 access_token,
355 )
356 }
357
358 #[must_use]
363 pub const fn from_refresh_token(
364 tokens: RefreshToken,
365 auth_server: AuthServer,
366 access_token: Option<SecretAccessToken>,
367 ) -> Self {
368 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
369 }
370
371 #[must_use]
376 pub const fn from_client_credentials(
377 tokens: ClientCredentials,
378 auth_server: AuthServer,
379 access_token: Option<SecretAccessToken>,
380 ) -> Self {
381 Self::new(
382 OAuthGrant::ClientCredentials(tokens),
383 auth_server,
384 access_token,
385 )
386 }
387
388 #[must_use]
393 pub const fn from_pkce_flow(
394 flow: PkceFlow,
395 auth_server: AuthServer,
396 access_token: Option<SecretAccessToken>,
397 ) -> Self {
398 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
399 }
400
401 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
410 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
411 }
412
413 #[must_use]
415 pub const fn payload(&self) -> &OAuthGrant {
416 &self.payload
417 }
418
419 #[allow(clippy::missing_panics_doc)]
425 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
426 let access_token = self.payload.request_access_token(&self.auth_server).await?;
427 Ok(self.access_token.insert(access_token))
428 }
429
430 #[must_use]
432 pub const fn auth_server(&self) -> &AuthServer {
433 &self.auth_server
434 }
435
436 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
444 let access_token = self.access_token()?;
445 insecure_validate_token_exp(access_token)?;
446 Ok(access_token.clone())
447 }
448}
449
450pub(crate) fn insecure_validate_token_exp(
454 access_token: &SecretAccessToken,
455) -> Result<(), TokenError> {
456 let placeholder_key = DecodingKey::from_secret(&[]);
457 let mut validation = Validation::new(Algorithm::RS256);
458 validation.validate_exp = true;
459 validation.leeway = 60;
460 validation.validate_aud = false;
461 validation.insecure_disable_signature_validation();
462
463 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
464 .map(|_| ())
465 .map_err(TokenError::InvalidAccessToken)
466}
467
468impl std::fmt::Debug for OAuthSession {
469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470 let token_populated = if self.access_token.is_some() {
471 Some(())
472 } else {
473 None
474 };
475 f.debug_struct("OAuthSession")
476 .field("payload", &self.payload)
477 .field("access_token", &token_populated)
478 .field("auth_server", &self.auth_server)
479 .finish()
480 }
481}
482
483#[derive(Clone, Debug)]
485#[cfg_attr(feature = "python", pyo3::pyclass)]
486pub struct TokenDispatcher {
487 lock: Arc<RwLock<OAuthSession>>,
488 refreshing: Arc<Mutex<bool>>,
489 notify_refreshed: Arc<Notify>,
490}
491
492impl From<OAuthSession> for TokenDispatcher {
493 fn from(value: OAuthSession) -> Self {
494 Self {
495 lock: Arc::new(RwLock::new(value)),
496 refreshing: Arc::new(Mutex::new(false)),
497 notify_refreshed: Arc::new(Notify::new()),
498 }
499 }
500}
501
502impl TokenDispatcher {
503 pub async fn use_tokens<F, O>(&self, f: F) -> O
513 where
514 F: FnOnce(&OAuthSession) -> O + Send,
515 {
516 let tokens = self.lock.read().await;
517 f(&tokens)
518 }
519
520 #[must_use]
522 pub async fn tokens(&self) -> OAuthSession {
523 self.use_tokens(Clone::clone).await
524 }
525
526 pub async fn refresh(
532 &self,
533 source: &ConfigSource,
534 profile: &str,
535 ) -> Result<OAuthSession, TokenError> {
536 self.managed_refresh(Self::perform_refresh, source, profile)
537 .await
538 }
539
540 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
548 self.use_tokens(OAuthSession::validate).await
549 }
550
551 async fn managed_refresh<F, Fut>(
554 &self,
555 refresh_fn: F,
556 source: &ConfigSource,
557 profile: &str,
558 ) -> Result<OAuthSession, TokenError>
559 where
560 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
561 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
562 {
563 let mut is_refreshing = self.refreshing.lock().await;
564
565 if *is_refreshing {
566 drop(is_refreshing);
567 self.notify_refreshed.notified().await;
568 return Ok(self.tokens().await);
569 }
570
571 *is_refreshing = true;
572 drop(is_refreshing);
573
574 let oauth_session = refresh_fn(self.lock.clone()).await?;
575
576 if let ConfigSource::File {
578 settings_path: _,
579 secrets_path,
580 } = source
581 {
582 if !Secrets::is_read_only(secrets_path).await? {
583 let refresh_token = match &oauth_session.payload {
585 OAuthGrant::PkceFlow(payload) => {
586 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
587 }
588 _ => None,
589 };
590
591 let now = OffsetDateTime::now_utc();
592 Secrets::write_tokens(
593 secrets_path,
594 profile,
595 refresh_token,
596 oauth_session.access_token()?,
597 now,
598 )
599 .await?;
600 }
601 }
602
603 *self.refreshing.lock().await = false;
604 self.notify_refreshed.notify_waiters();
605 Ok(oauth_session)
606 }
607
608 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
615 let mut credentials = lock.write().await;
616 credentials.request_access_token().await?;
617 Ok(credentials.clone())
618 }
619}
620
621pub(crate) type RefreshResult =
622 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
623
624pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
626
627#[derive(Clone)]
632#[cfg_attr(feature = "python", pyo3::pyclass)]
633pub struct ExternallyManaged {
634 refresh_function: Arc<RefreshFunction>,
635}
636
637impl ExternallyManaged {
638 pub fn new(
663 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
664 ) -> Self {
665 Self {
666 refresh_function: Arc::new(Box::new(refresh_function)),
667 }
668 }
669
670 pub fn from_async<F, Fut>(refresh_function: F) -> Self
703 where
704 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
705 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
706 + Send
707 + 'static,
708 {
709 Self {
710 refresh_function: Arc::new(Box::new(move |auth_server| {
711 Box::pin(refresh_function(auth_server))
712 })),
713 }
714 }
715
716 pub fn from_sync(
747 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
748 + Send
749 + Sync
750 + 'static,
751 ) -> Self {
752 Self {
753 refresh_function: Arc::new(Box::new(move |auth_server| {
754 let result = refresh_function(auth_server);
755 Box::pin(async move { result })
756 })),
757 }
758 }
759
760 pub async fn request_access_token(
766 &self,
767 auth_server: &AuthServer,
768 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
769 (self.refresh_function)(auth_server.clone())
770 .await
771 .map(SecretAccessToken::from)
772 }
773}
774
775impl std::fmt::Debug for ExternallyManaged {
776 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
777 f.debug_struct("ExternallyManaged")
778 .field(
779 "refresh_function",
780 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
781 )
782 .finish()
783 }
784}
785
786#[derive(Debug, Serialize, Deserialize)]
787pub(super) struct TokenRefreshRequest<'a> {
788 grant_type: &'static str,
789 client_id: &'a str,
790 refresh_token: &'a str,
791}
792
793impl<'a> TokenRefreshRequest<'a> {
794 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
795 Self {
796 grant_type: "refresh_token",
797 client_id,
798 refresh_token,
799 }
800 }
801}
802
803#[derive(Debug, Serialize, Deserialize)]
804pub(super) struct ClientCredentialsRequest {
805 grant_type: &'static str,
806 scope: Option<&'static str>,
807}
808
809impl ClientCredentialsRequest {
810 pub(super) const fn new(scope: Option<&'static str>) -> Self {
811 Self {
812 grant_type: "client_credentials",
813 scope,
814 }
815 }
816}
817
818#[derive(Deserialize, Debug, Serialize)]
819pub(super) struct RefreshTokenResponse {
820 pub(super) refresh_token: Option<SecretRefreshToken>,
821 pub(super) access_token: SecretAccessToken,
822}
823
824#[async_trait::async_trait]
826pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
827 type Error;
830
831 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
833
834 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
836
837 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
839
840 #[cfg(feature = "tracing")]
842 fn base_url(&self) -> &str;
843
844 #[cfg(feature = "tracing-config")]
846 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
847
848 #[cfg(feature = "tracing")]
851 #[allow(clippy::needless_return)]
852 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
853 #[cfg(not(feature = "tracing-config"))]
854 {
855 let _ = url;
856 return true;
857 }
858
859 #[cfg(feature = "tracing-config")]
860 self.tracing_configuration()
861 .is_none_or(|config| config.is_enabled(url))
862 }
863}
864
865#[async_trait::async_trait]
866impl TokenRefresher for ClientConfiguration {
867 type Error = TokenError;
868
869 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
870 self.get_bearer_access_token().await
871 }
872
873 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
874 Ok(self.refresh().await?.access_token()?.clone())
875 }
876
877 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
878 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
879 }
880
881 #[cfg(feature = "tracing")]
882 fn base_url(&self) -> &str {
883 &self.grpc_api_url
884 }
885
886 #[cfg(feature = "tracing-config")]
887 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
888 self.tracing_configuration.as_ref()
889 }
890}
891
892pub(super) fn default_http_client() -> Result<reqwest::Client, reqwest::Error> {
894 reqwest::Client::builder()
895 .timeout(std::time::Duration::from_secs(10))
896 .build()
897}
898
899#[cfg(test)]
900mod test {
901 use std::time::Duration;
902
903 use super::*;
904 use httpmock::prelude::*;
905 use rstest::rstest;
906 use time::format_description::well_known::Rfc3339;
907 use tokio::time::Instant;
908 use toml_edit::DocumentMut;
909
910 #[tokio::test]
911 async fn test_tokens_blocked_during_refresh() {
912 let mock_server = MockServer::start_async().await;
913
914 let oidc_mock = mock_server
915 .mock_async(|when, then| {
916 when.method(GET).path("/.well-known/openid-configuration");
917 then.status(200)
918 .json_body_obj(&oidc::Discovery::new_for_test(
919 mock_server.base_url().parse().unwrap(),
920 ));
921 })
922 .await;
923
924 let issuer_mock = mock_server
925 .mock_async(|when, then| {
926 when.method(POST).path("/v1/token");
927
928 then.status(200)
929 .delay(Duration::from_secs(3))
930 .json_body_obj(&RefreshTokenResponse {
931 access_token: SecretAccessToken::from("new_access"),
932 refresh_token: Some(SecretRefreshToken::from("new_refresh")),
933 });
934 })
935 .await;
936
937 let original_tokens = OAuthSession::from_refresh_token(
938 RefreshToken::new(SecretRefreshToken::from("refresh")),
939 AuthServer {
940 client_id: "client_id".to_string(),
941 issuer: mock_server.base_url(),
942 scopes: None,
943 },
944 None,
945 );
946 let dispatcher: TokenDispatcher = original_tokens.clone().into();
947 let dispatcher_clone1 = dispatcher.clone();
948 let dispatcher_clone2 = dispatcher.clone();
949
950 let refresh_duration = Duration::from_secs(3);
951
952 let start_write = Instant::now();
953 let write_future = tokio::spawn(async move {
954 dispatcher_clone1
955 .refresh(&ConfigSource::Default, "")
956 .await
957 .unwrap()
958 });
959
960 let start_read = Instant::now();
961 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
962
963 let _ = write_future.await.unwrap();
964 let read_result = read_future.await.unwrap();
965
966 let write_duration = start_write.elapsed();
967 let read_duration = start_read.elapsed();
968
969 oidc_mock.assert_async().await;
970 issuer_mock.assert_async().await;
971
972 assert!(
973 write_duration >= refresh_duration,
974 "Write operation did not take enough time"
975 );
976 assert!(
977 read_duration >= refresh_duration,
978 "Read operation was not blocked by the write operation"
979 );
980 assert_eq!(
981 read_result.access_token.unwrap(),
982 SecretAccessToken::from("new_access")
983 );
984 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
985 assert_eq!(
986 payload.refresh_token,
987 SecretRefreshToken::from("new_refresh")
988 );
989 } else {
990 panic!(
991 "Expected RefreshToken payload, got {:?}",
992 read_result.payload
993 );
994 }
995 }
996
997 #[rstest]
998 fn test_qcs_secrets_readonly(
999 #[values(
1000 (Some("TRUE"), true),
1001 (Some("tRue"), true),
1002 (Some("true"), true),
1003 (Some("YES"), true),
1004 (Some("yEs"), true),
1005 (Some("yes"), true),
1006 (Some("1"), true),
1007 (Some("2"), false),
1008 (Some("other"), false),
1009 (Some(""), false),
1010 (None, false),
1011 )]
1012 read_only_values: (Option<&str>, bool),
1013 #[values(true, false)] read_only_perm: bool,
1014 ) {
1015 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1016 let expected_update = !env_is_read_only && !read_only_perm;
1017 figment::Jail::expect_with(|jail| {
1018 let profile_name = "test";
1019 let initial_access_token = "initial_access_token";
1020 let initial_refresh_token = "initial_refresh_token";
1021
1022 let initial_secrets_file_contents = format!(
1023 r#"
1024[credentials]
1025[credentials.{profile_name}]
1026[credentials.{profile_name}.token_payload]
1027access_token = "{initial_access_token}"
1028expires_in = 3600
1029id_token = "id_token"
1030refresh_token = "{initial_refresh_token}"
1031scope = "offline_access openid profile email"
1032token_type = "Bearer"
1033updated_at = "2024-01-01T00:00:00Z"
1034"#
1035 );
1036
1037 jail.clear_env();
1039
1040 let secrets_path = "secrets.toml";
1042 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1043 .expect("should create test secrets.toml");
1044
1045 if read_only_perm {
1046 let mut permissions = std::fs::metadata(secrets_path)
1047 .expect("Should be able to get file metadata")
1048 .permissions();
1049 permissions.set_readonly(true);
1050 std::fs::set_permissions(secrets_path, permissions)
1051 .expect("Should be able to set file permissions");
1052 }
1053
1054 let rt = tokio::runtime::Runtime::new().unwrap();
1055 rt.block_on(async {
1056 let mock_server = MockServer::start_async().await;
1057
1058 let oidc_mock = mock_server
1059 .mock_async(|when, then| {
1060 when.method(GET).path("/.well-known/openid-configuration");
1061 then.status(200)
1062 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1063 })
1064 .await;
1065
1066 let new_access_token = SecretAccessToken::from("new_access_token");
1068 let issuer_mock = mock_server
1069 .mock_async(|when, then| {
1070 when.method(POST).path("/v1/token");
1071 then.status(200).json_body_obj(&RefreshTokenResponse {
1072 access_token: new_access_token.clone(),
1073 refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1074 });
1075 })
1076 .await;
1077
1078 let original_tokens = OAuthSession::from_refresh_token(
1080 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1081 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1082 Some(SecretAccessToken::from(initial_refresh_token)),
1083 );
1084 let dispatcher: TokenDispatcher = original_tokens.into();
1085
1086 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1088 jail.set_env("QCS_PROFILE_NAME", "test");
1089 if let Some(read_only_env) = maybe_read_only_env {
1090 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1091 }
1092
1093 let before_refresh = OffsetDateTime::now_utc();
1094
1095 dispatcher
1096 .refresh(
1097 &ConfigSource::File {
1098 settings_path: "".into(),
1099 secrets_path: "secrets.toml".into(),
1100 },
1101 profile_name,
1102 )
1103 .await
1104 .unwrap();
1105
1106 oidc_mock.assert_async().await;
1107 issuer_mock.assert_async().await;
1108
1109 let content = std::fs::read_to_string("secrets.toml").unwrap();
1111 if !expected_update {
1112 assert!(
1113 content.eq(initial_secrets_file_contents.as_str()),
1114 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1115 );
1116 return;
1117 }
1118
1119 let mut toml = std::fs::read_to_string(secrets_path)
1121 .unwrap()
1122 .parse::<DocumentMut>()
1123 .unwrap();
1124
1125 let token_payload = toml
1126 .get_mut("credentials")
1127 .and_then(|credentials| {
1128 credentials.get_mut(profile_name)?.get_mut("token_payload")
1129 })
1130 .expect("Should be able to get token_payload table");
1131
1132 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1133
1134 assert_eq!(
1135 access_token,
1136 Some(new_access_token)
1137 );
1138
1139 assert!(
1140 OffsetDateTime::parse(
1141 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1142 &Rfc3339
1143 )
1144 .unwrap()
1145 > before_refresh
1146 );
1147
1148 let content = std::fs::read_to_string("secrets.toml").unwrap();
1149 assert!(
1150 content.contains("new_access_token"),
1151 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1152 );
1153 });
1154 Ok(())
1155 });
1156 }
1157
1158 #[test]
1159 fn test_auth_session_debug_fmt() {
1160 let session = OAuthSession {
1161 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1162 "hidden_id",
1163 "hidden_secret",
1164 )),
1165 access_token: Some(SecretAccessToken::from("token")),
1166 auth_server: AuthServer {
1167 client_id: "some_id".into(),
1168 issuer: "some_url".into(),
1169 scopes: None,
1170 },
1171 };
1172
1173 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }", &format!("{session:?}"));
1174 }
1175}