1use std::{pin::Pin, sync::Arc};
3
4use futures::Future;
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use oauth2::TokenResponse;
7use serde::{Deserialize, Serialize};
8use time::OffsetDateTime;
9use tokio::sync::{Mutex, Notify, RwLock};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::gen_stub_pyclass;
14
15use super::{
16 ClientConfiguration, ConfigSource, TokenError, oidc, secrets::Secrets, settings::AuthServer,
17};
18use crate::configuration::{
19 error::DiscoveryError,
20 pkce::{PkceLoginError, PkceLoginRequest, pkce_login},
21 secrets::{Credential, SecretAccessToken, SecretRefreshToken, TokenPayload},
22};
23#[cfg(feature = "tracing-config")]
24use crate::tracing_configuration::TracingConfiguration;
25#[cfg(feature = "tracing")]
26use urlpattern::UrlPatternMatchInput;
27
28pub use super::secret_string::ClientSecret;
29
30#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(eq, get_all, set_all, module = "qcs_api_client_common.configuration")
36)]
37pub struct RefreshToken {
38 pub refresh_token: SecretRefreshToken,
40}
41
42impl RefreshToken {
43 #[must_use]
45 pub const fn new(refresh_token: SecretRefreshToken) -> Self {
46 Self { refresh_token }
47 }
48
49 pub async fn request_access_token(
55 &mut self,
56 auth_server: &AuthServer,
57 ) -> Result<SecretAccessToken, TokenError> {
58 if self.refresh_token.is_empty() {
59 return Err(TokenError::NoRefreshToken);
60 }
61
62 let client = default_http_client()?;
63 let token_url = oidc::fetch_discovery(&client, &auth_server.issuer)
64 .await?
65 .token_endpoint;
66 let data = TokenRefreshRequest::new(&auth_server.client_id, self.refresh_token.secret());
67 let resp = client.post(token_url).form(&data).send().await?;
68
69 let RefreshTokenResponse {
70 access_token,
71 refresh_token,
72 } = resp.error_for_status()?.json().await?;
73
74 if let Some(refresh_token) = refresh_token {
75 self.refresh_token = refresh_token;
76 }
77 Ok(access_token)
78 }
79}
80
81#[derive(Deserialize, Debug, Serialize)]
82pub(super) struct ClientCredentialsResponse {
83 pub(super) access_token: SecretAccessToken,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
92)]
93pub struct ClientCredentials {
94 pub client_id: String,
96 pub client_secret: ClientSecret,
98}
99
100impl ClientCredentials {
101 #[must_use]
102 pub fn new(client_id: impl Into<String>, client_secret: impl Into<ClientSecret>) -> Self {
104 Self {
105 client_id: client_id.into(),
106 client_secret: client_secret.into(),
107 }
108 }
109
110 #[must_use]
112 pub fn client_id(&self) -> &str {
113 &self.client_id
114 }
115
116 #[must_use]
118 pub const fn client_secret(&self) -> &ClientSecret {
119 &self.client_secret
120 }
121
122 pub async fn request_access_token(
128 &self,
129 auth_server: &AuthServer,
130 ) -> Result<SecretAccessToken, TokenError> {
131 let request = ClientCredentialsRequest::new(None);
132 let client = default_http_client()?;
133
134 let url = oidc::fetch_discovery(&client, &auth_server.issuer)
135 .await?
136 .token_endpoint;
137 let ready_to_send = client
138 .post(url)
139 .basic_auth(&auth_server.client_id, Some(&self.client_secret.secret()))
140 .form(&request);
141 let response = ready_to_send.send().await?;
142
143 response.error_for_status_ref()?;
144
145 let ClientCredentialsResponse { access_token } = response.json().await?;
146 Ok(access_token)
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
151#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
152#[cfg_attr(
153 feature = "python",
154 pyo3::pyclass(eq, get_all, frozen, module = "qcs_api_client_common.configuration")
155)]
156pub struct PkceFlow {
158 pub access_token: SecretAccessToken,
160 pub refresh_token: Option<RefreshToken>,
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum PkceFlowError {
167 #[error(transparent)]
169 PkceLogin(#[from] PkceLoginError),
170 #[error(transparent)]
172 Discovery(#[from] DiscoveryError),
173 #[error(transparent)]
175 Request(#[from] qcs_dependencies_client::reqwest::Error),
176}
177
178impl PkceFlow {
179 pub async fn new_login_flow(
185 cancel_token: CancellationToken,
186 auth_server: &AuthServer,
187 ) -> Result<Self, PkceFlowError> {
188 let issuer = auth_server.issuer.clone();
189
190 let client = default_http_client()?;
191 let discovery = oidc::fetch_discovery(&client, &issuer).await?;
192
193 let response = pkce_login(
194 cancel_token,
195 PkceLoginRequest {
196 client_id: auth_server.client_id.clone(),
197 redirect_port: None,
198 discovery,
199 scopes: auth_server.scopes.clone(),
200 },
201 )
202 .await?;
203
204 Ok(Self {
205 access_token: SecretAccessToken::from(response.access_token().secret().clone()),
206 refresh_token: response
207 .refresh_token()
208 .map(|rt| RefreshToken::new(SecretRefreshToken::from(rt.secret().clone()))),
209 })
210 }
211
212 pub async fn request_access_token(
218 &mut self,
219 auth_server: &AuthServer,
220 ) -> Result<SecretAccessToken, TokenError> {
221 if insecure_validate_token_exp(&self.access_token).is_ok() {
222 return Ok(self.access_token.clone());
223 }
224
225 if let Some(refresh_token) = &mut self.refresh_token {
226 let access_token = refresh_token.request_access_token(auth_server).await?;
227 self.access_token.clone_from(&access_token);
228 return Ok(access_token);
229 }
230
231 Err(TokenError::NoRefreshToken)
232 }
233}
234
235impl From<PkceFlow> for Credential {
236 fn from(value: PkceFlow) -> Self {
237 let mut token_payload = TokenPayload::default();
238 token_payload.access_token = Some(value.access_token);
239 token_payload.refresh_token = value.refresh_token.map(|rt| rt.refresh_token);
240
241 Self {
242 token_payload: Some(token_payload),
243 }
244 }
245}
246
247#[derive(Clone)]
248#[cfg_attr(feature = "python", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
249pub enum OAuthGrant {
252 RefreshToken(RefreshToken),
254 ClientCredentials(ClientCredentials),
256 ExternallyManaged(ExternallyManaged),
258 PkceFlow(PkceFlow),
260}
261
262impl From<ExternallyManaged> for OAuthGrant {
263 fn from(v: ExternallyManaged) -> Self {
264 Self::ExternallyManaged(v)
265 }
266}
267
268impl From<ClientCredentials> for OAuthGrant {
269 fn from(v: ClientCredentials) -> Self {
270 Self::ClientCredentials(v)
271 }
272}
273
274impl From<RefreshToken> for OAuthGrant {
275 fn from(v: RefreshToken) -> Self {
276 Self::RefreshToken(v)
277 }
278}
279
280impl From<PkceFlow> for OAuthGrant {
281 fn from(v: PkceFlow) -> Self {
282 Self::PkceFlow(v)
283 }
284}
285
286impl OAuthGrant {
287 async fn request_access_token(
289 &mut self,
290 auth_server: &AuthServer,
291 ) -> Result<SecretAccessToken, TokenError> {
292 match self {
293 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
294 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
295 Self::ExternallyManaged(tokens) => tokens
296 .request_access_token(auth_server)
297 .await
298 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
299 Self::PkceFlow(tokens) => tokens.request_access_token(auth_server).await,
300 }
301 }
302}
303
304impl std::fmt::Debug for OAuthGrant {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 match self {
307 Self::RefreshToken(_) => f.write_str("RefreshToken"),
308 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
309 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
310 Self::PkceFlow(_) => f.write_str("PkceTokens"),
311 }
312 }
313}
314
315#[derive(Clone)]
327#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
328#[cfg_attr(
329 feature = "python",
330 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen, get_all)
331)]
332pub struct OAuthSession {
333 payload: OAuthGrant,
335 access_token: Option<SecretAccessToken>,
337 auth_server: AuthServer,
339}
340
341impl OAuthSession {
342 #[must_use]
347 pub const fn new(
348 payload: OAuthGrant,
349 auth_server: AuthServer,
350 access_token: Option<SecretAccessToken>,
351 ) -> Self {
352 Self {
353 payload,
354 access_token,
355 auth_server,
356 }
357 }
358
359 #[must_use]
364 pub const fn from_externally_managed(
365 tokens: ExternallyManaged,
366 auth_server: AuthServer,
367 access_token: Option<SecretAccessToken>,
368 ) -> Self {
369 Self::new(
370 OAuthGrant::ExternallyManaged(tokens),
371 auth_server,
372 access_token,
373 )
374 }
375
376 #[must_use]
381 pub const fn from_refresh_token(
382 tokens: RefreshToken,
383 auth_server: AuthServer,
384 access_token: Option<SecretAccessToken>,
385 ) -> Self {
386 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
387 }
388
389 #[must_use]
394 pub const fn from_client_credentials(
395 tokens: ClientCredentials,
396 auth_server: AuthServer,
397 access_token: Option<SecretAccessToken>,
398 ) -> Self {
399 Self::new(
400 OAuthGrant::ClientCredentials(tokens),
401 auth_server,
402 access_token,
403 )
404 }
405
406 #[must_use]
411 pub const fn from_pkce_flow(
412 flow: PkceFlow,
413 auth_server: AuthServer,
414 access_token: Option<SecretAccessToken>,
415 ) -> Self {
416 Self::new(OAuthGrant::PkceFlow(flow), auth_server, access_token)
417 }
418
419 pub fn access_token(&self) -> Result<&SecretAccessToken, TokenError> {
428 self.access_token.as_ref().ok_or(TokenError::NoAccessToken)
429 }
430
431 #[must_use]
433 pub const fn payload(&self) -> &OAuthGrant {
434 &self.payload
435 }
436
437 #[allow(clippy::missing_panics_doc)]
443 pub async fn request_access_token(&mut self) -> Result<&SecretAccessToken, TokenError> {
444 let access_token = self.payload.request_access_token(&self.auth_server).await?;
445 Ok(self.access_token.insert(access_token))
446 }
447
448 #[must_use]
450 pub const fn auth_server(&self) -> &AuthServer {
451 &self.auth_server
452 }
453
454 pub fn validate(&self) -> Result<SecretAccessToken, TokenError> {
462 let access_token = self.access_token()?;
463 insecure_validate_token_exp(access_token)?;
464 Ok(access_token.clone())
465 }
466}
467
468pub(crate) fn insecure_validate_token_exp(
472 access_token: &SecretAccessToken,
473) -> Result<(), TokenError> {
474 let placeholder_key = DecodingKey::from_secret(&[]);
475 let mut validation = Validation::new(Algorithm::RS256);
476 validation.validate_exp = true;
477 validation.leeway = 60;
478 validation.validate_aud = false;
479 validation.insecure_disable_signature_validation();
480
481 jsonwebtoken::decode::<toml::Value>(access_token.secret(), &placeholder_key, &validation)
482 .map(|_| ())
483 .map_err(TokenError::InvalidAccessToken)
484}
485
486impl std::fmt::Debug for OAuthSession {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 let token_populated = if self.access_token.is_some() {
489 Some(())
490 } else {
491 None
492 };
493 f.debug_struct("OAuthSession")
494 .field("payload", &self.payload)
495 .field("access_token", &token_populated)
496 .field("auth_server", &self.auth_server)
497 .finish()
498 }
499}
500
501#[derive(Clone, Debug)]
503#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
504#[cfg_attr(
505 feature = "python",
506 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
507)]
508pub struct TokenDispatcher {
509 lock: Arc<RwLock<OAuthSession>>,
510 refreshing: Arc<Mutex<bool>>,
511 notify_refreshed: Arc<Notify>,
512}
513
514impl From<OAuthSession> for TokenDispatcher {
515 fn from(value: OAuthSession) -> Self {
516 Self {
517 lock: Arc::new(RwLock::new(value)),
518 refreshing: Arc::new(Mutex::new(false)),
519 notify_refreshed: Arc::new(Notify::new()),
520 }
521 }
522}
523
524impl TokenDispatcher {
525 pub async fn use_tokens<F, O>(&self, f: F) -> O
535 where
536 F: FnOnce(&OAuthSession) -> O + Send,
537 {
538 let tokens = self.lock.read().await;
539 f(&tokens)
540 }
541
542 #[must_use]
544 pub async fn tokens(&self) -> OAuthSession {
545 self.use_tokens(Clone::clone).await
546 }
547
548 pub async fn refresh(
554 &self,
555 source: &ConfigSource,
556 profile: &str,
557 ) -> Result<OAuthSession, TokenError> {
558 self.managed_refresh(Self::perform_refresh, source, profile)
559 .await
560 }
561
562 pub async fn validate(&self) -> Result<SecretAccessToken, TokenError> {
570 self.use_tokens(OAuthSession::validate).await
571 }
572
573 async fn managed_refresh<F, Fut>(
576 &self,
577 refresh_fn: F,
578 source: &ConfigSource,
579 profile: &str,
580 ) -> Result<OAuthSession, TokenError>
581 where
582 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
583 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
584 {
585 let mut is_refreshing = self.refreshing.lock().await;
586
587 if *is_refreshing {
588 drop(is_refreshing);
589 self.notify_refreshed.notified().await;
590 return Ok(self.tokens().await);
591 }
592
593 *is_refreshing = true;
594 drop(is_refreshing);
595
596 let oauth_session = refresh_fn(self.lock.clone()).await?;
597
598 let write_result = if let ConfigSource::File {
600 settings_path: _,
601 secrets_path,
602 } = source
603 {
604 match Secrets::is_read_only(secrets_path).await {
605 Ok(true) => Ok(()),
606 Ok(false) => {
607 let refresh_token = match &oauth_session.payload {
611 OAuthGrant::PkceFlow(payload) => {
612 payload.refresh_token.as_ref().map(|rt| &rt.refresh_token)
613 }
614 OAuthGrant::RefreshToken(payload) => Some(&payload.refresh_token),
615 OAuthGrant::ExternallyManaged(_) | OAuthGrant::ClientCredentials(_) => None,
616 };
617
618 let now = OffsetDateTime::now_utc();
619 Secrets::write_tokens(
620 secrets_path,
621 profile,
622 refresh_token,
623 oauth_session.access_token()?,
624 now,
625 )
626 .await
627 }
628 Err(e) => Err(e),
629 }
630 } else {
631 Ok(())
632 };
633
634 *self.refreshing.lock().await = false;
636 self.notify_refreshed.notify_waiters();
637
638 if let Err(error) = write_result {
640 return Err(TokenError::Write {
641 error,
642 oauth_session: Box::new(oauth_session),
643 });
644 }
645
646 Ok(oauth_session)
647 }
648
649 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
656 let mut credentials = lock.write().await;
657 credentials.request_access_token().await?;
658 Ok(credentials.clone())
659 }
660}
661
662pub(crate) type RefreshResult =
663 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
664
665pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
667
668#[derive(Clone)]
673#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
674#[cfg_attr(
675 feature = "python",
676 pyo3::pyclass(module = "qcs_api_client_common.configuration", frozen)
677)]
678pub struct ExternallyManaged {
679 refresh_function: Arc<RefreshFunction>,
680}
681
682impl ExternallyManaged {
683 pub fn new(
708 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
709 ) -> Self {
710 Self {
711 refresh_function: Arc::new(Box::new(refresh_function)),
712 }
713 }
714
715 pub fn from_async<F, Fut>(refresh_function: F) -> Self
748 where
749 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
750 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
751 + Send
752 + 'static,
753 {
754 Self {
755 refresh_function: Arc::new(Box::new(move |auth_server| {
756 Box::pin(refresh_function(auth_server))
757 })),
758 }
759 }
760
761 pub fn from_sync(
792 refresh_function: impl Fn(
793 AuthServer,
794 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
795 + Send
796 + Sync
797 + 'static,
798 ) -> Self {
799 Self {
800 refresh_function: Arc::new(Box::new(move |auth_server| {
801 let result = refresh_function(auth_server);
802 Box::pin(async move { result })
803 })),
804 }
805 }
806
807 pub async fn request_access_token(
813 &self,
814 auth_server: &AuthServer,
815 ) -> Result<SecretAccessToken, Box<dyn std::error::Error + Send + Sync>> {
816 (self.refresh_function)(auth_server.clone())
817 .await
818 .map(SecretAccessToken::from)
819 }
820}
821
822impl std::fmt::Debug for ExternallyManaged {
823 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
824 f.debug_struct("ExternallyManaged")
825 .field(
826 "refresh_function",
827 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
828 )
829 .finish()
830 }
831}
832
833#[derive(Debug, Serialize, Deserialize)]
834pub(super) struct TokenRefreshRequest<'a> {
835 grant_type: &'static str,
836 client_id: &'a str,
837 refresh_token: &'a str,
838}
839
840impl<'a> TokenRefreshRequest<'a> {
841 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
842 Self {
843 grant_type: "refresh_token",
844 client_id,
845 refresh_token,
846 }
847 }
848}
849
850#[derive(Debug, Serialize, Deserialize)]
851pub(super) struct ClientCredentialsRequest {
852 grant_type: &'static str,
853 scope: Option<&'static str>,
854}
855
856impl ClientCredentialsRequest {
857 pub(super) const fn new(scope: Option<&'static str>) -> Self {
858 Self {
859 grant_type: "client_credentials",
860 scope,
861 }
862 }
863}
864
865#[derive(Deserialize, Debug, Serialize)]
866pub(super) struct RefreshTokenResponse {
867 pub(super) refresh_token: Option<SecretRefreshToken>,
868 pub(super) access_token: SecretAccessToken,
869}
870
871#[async_trait::async_trait]
873pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
874 type Error;
877
878 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
880
881 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error>;
883
884 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error>;
886
887 #[cfg(feature = "tracing")]
889 fn base_url(&self) -> &str;
890
891 #[cfg(feature = "tracing-config")]
893 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
894
895 #[cfg(feature = "tracing")]
898 #[allow(clippy::needless_return)]
899 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
900 #[cfg(not(feature = "tracing-config"))]
901 {
902 let _ = url;
903 return true;
904 }
905
906 #[cfg(feature = "tracing-config")]
907 self.tracing_configuration()
908 .is_none_or(|config| config.is_enabled(url))
909 }
910}
911
912#[async_trait::async_trait]
913impl TokenRefresher for ClientConfiguration {
914 type Error = TokenError;
915
916 async fn validated_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
917 self.get_bearer_access_token().await
918 }
919
920 async fn refresh_access_token(&self) -> Result<SecretAccessToken, Self::Error> {
921 match self.refresh().await {
922 Ok(session) => Ok(session.access_token()?.clone()),
923 Err(TokenError::Write {
924 error,
925 oauth_session,
926 }) => {
927 #[cfg(feature = "tracing")]
929 tracing::warn!(
930 "Token refresh succeeded but failed to persist: {}. Returning access token from error.",
931 error
932 );
933 Ok(oauth_session.access_token()?.clone())
934 }
935 Err(e) => Err(e),
936 }
937 }
938
939 async fn get_access_token(&self) -> Result<Option<SecretAccessToken>, Self::Error> {
940 Ok(Some(self.oauth_session().await?.access_token()?.clone()))
941 }
942
943 #[cfg(feature = "tracing")]
944 fn base_url(&self) -> &str {
945 &self.grpc_api_url
946 }
947
948 #[cfg(feature = "tracing-config")]
949 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
950 self.tracing_configuration.as_ref()
951 }
952}
953
954pub(super) fn default_http_client()
956-> Result<qcs_dependencies_client::reqwest::Client, qcs_dependencies_client::reqwest::Error> {
957 qcs_dependencies_client::reqwest::Client::builder()
958 .timeout(std::time::Duration::from_secs(10))
959 .build()
960}
961
962#[cfg(test)]
963mod test {
964 #![allow(clippy::result_large_err, reason = "happens in figment tests")]
965
966 use std::time::Duration;
967
968 use super::*;
969 use httpmock::prelude::*;
970 use rstest::rstest;
971 use time::format_description::well_known::Rfc3339;
972 use tokio::time::Instant;
973 use toml_edit::DocumentMut;
974
975 #[tokio::test]
976 async fn test_tokens_blocked_during_refresh() {
977 let mock_server = MockServer::start_async().await;
978
979 let oidc_mock = mock_server
980 .mock_async(|when, then| {
981 when.method(GET).path("/.well-known/openid-configuration");
982 then.status(200)
983 .json_body_obj(&oidc::Discovery::new_for_test(
984 mock_server.base_url().parse().unwrap(),
985 ));
986 })
987 .await;
988
989 let issuer_mock = mock_server
990 .mock_async(|when, then| {
991 when.method(POST).path("/v1/token");
992
993 then.status(200)
994 .delay(Duration::from_secs(3))
995 .json_body_obj(&RefreshTokenResponse {
996 access_token: SecretAccessToken::from("new_access"),
997 refresh_token: Some(SecretRefreshToken::from("new_refresh")),
998 });
999 })
1000 .await;
1001
1002 let original_tokens = OAuthSession::from_refresh_token(
1003 RefreshToken::new(SecretRefreshToken::from("refresh")),
1004 AuthServer {
1005 client_id: "client_id".to_string(),
1006 issuer: mock_server.base_url(),
1007 scopes: None,
1008 },
1009 None,
1010 );
1011 let dispatcher: TokenDispatcher = original_tokens.clone().into();
1012 let dispatcher_clone1 = dispatcher.clone();
1013 let dispatcher_clone2 = dispatcher.clone();
1014
1015 let refresh_duration = Duration::from_secs(3);
1016
1017 let start_write = Instant::now();
1018 let write_future = tokio::spawn(async move {
1019 dispatcher_clone1
1020 .refresh(&ConfigSource::Default, "")
1021 .await
1022 .unwrap()
1023 });
1024
1025 let start_read = Instant::now();
1026 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
1027
1028 let _ = write_future.await.unwrap();
1029 let read_result = read_future.await.unwrap();
1030
1031 let write_duration = start_write.elapsed();
1032 let read_duration = start_read.elapsed();
1033
1034 oidc_mock.assert_async().await;
1035 issuer_mock.assert_async().await;
1036
1037 assert!(
1038 write_duration >= refresh_duration,
1039 "Write operation did not take enough time"
1040 );
1041 assert!(
1042 read_duration >= refresh_duration,
1043 "Read operation was not blocked by the write operation"
1044 );
1045 assert_eq!(
1046 read_result.access_token.unwrap(),
1047 SecretAccessToken::from("new_access")
1048 );
1049 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
1050 assert_eq!(
1051 payload.refresh_token,
1052 SecretRefreshToken::from("new_refresh")
1053 );
1054 } else {
1055 panic!(
1056 "Expected RefreshToken payload, got {:?}",
1057 read_result.payload
1058 );
1059 }
1060 }
1061
1062 #[rstest]
1063 fn test_qcs_secrets_readonly(
1064 #[values(
1065 (Some("TRUE"), true),
1066 (Some("tRue"), true),
1067 (Some("true"), true),
1068 (Some("YES"), true),
1069 (Some("yEs"), true),
1070 (Some("yes"), true),
1071 (Some("1"), true),
1072 (Some("2"), false),
1073 (Some("other"), false),
1074 (Some(""), false),
1075 (None, false),
1076 )]
1077 read_only_values: (Option<&str>, bool),
1078 #[values(true, false)] read_only_perm: bool,
1079 ) {
1080 let (maybe_read_only_env, env_is_read_only) = read_only_values;
1081 let expected_update = !env_is_read_only && !read_only_perm;
1082 figment::Jail::expect_with(|jail| {
1083 let profile_name = "test";
1084 let initial_access_token = "initial_access_token";
1085 let initial_refresh_token = "initial_refresh_token";
1086
1087 let initial_secrets_file_contents = format!(
1088 r#"
1089[credentials]
1090[credentials.{profile_name}]
1091[credentials.{profile_name}.token_payload]
1092access_token = "{initial_access_token}"
1093expires_in = 3600
1094id_token = "id_token"
1095refresh_token = "{initial_refresh_token}"
1096scope = "offline_access openid profile email"
1097token_type = "Bearer"
1098updated_at = "2024-01-01T00:00:00Z"
1099"#
1100 );
1101
1102 jail.clear_env();
1104
1105 let secrets_path = "secrets.toml";
1107 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
1108 .expect("should create test secrets.toml");
1109
1110 if read_only_perm {
1111 let mut permissions = std::fs::metadata(secrets_path)
1112 .expect("Should be able to get file metadata")
1113 .permissions();
1114 permissions.set_readonly(true);
1115 std::fs::set_permissions(secrets_path, permissions)
1116 .expect("Should be able to set file permissions");
1117 }
1118
1119 let rt = tokio::runtime::Runtime::new().unwrap();
1120 rt.block_on(async {
1121 let mock_server = MockServer::start_async().await;
1122
1123 let oidc_mock = mock_server
1124 .mock_async(|when, then| {
1125 when.method(GET).path("/.well-known/openid-configuration");
1126 then.status(200)
1127 .json_body_obj(&oidc::Discovery::new_for_test(mock_server.base_url().parse().unwrap()));
1128 })
1129 .await;
1130
1131 let new_access_token = SecretAccessToken::from("new_access_token");
1133 let issuer_mock = mock_server
1134 .mock_async(|when, then| {
1135 when.method(POST).path("/v1/token");
1136 then.status(200).json_body_obj(&RefreshTokenResponse {
1137 access_token: new_access_token.clone(),
1138 refresh_token: Some(SecretRefreshToken::from(initial_refresh_token)),
1139 });
1140 })
1141 .await;
1142
1143 let original_tokens = OAuthSession::from_refresh_token(
1145 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1146 AuthServer { client_id: "client_id".to_string(), issuer: mock_server.base_url(), scopes: None },
1147 Some(SecretAccessToken::from(initial_refresh_token)),
1148 );
1149 let dispatcher: TokenDispatcher = original_tokens.into();
1150
1151 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
1153 jail.set_env("QCS_PROFILE_NAME", "test");
1154 if let Some(read_only_env) = maybe_read_only_env {
1155 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
1156 }
1157
1158 let before_refresh = OffsetDateTime::now_utc();
1159
1160 dispatcher
1161 .refresh(
1162 &ConfigSource::File {
1163 settings_path: "".into(),
1164 secrets_path: "secrets.toml".into(),
1165 },
1166 profile_name,
1167 )
1168 .await
1169 .unwrap();
1170
1171 oidc_mock.assert_async().await;
1172 issuer_mock.assert_async().await;
1173
1174 let content = std::fs::read_to_string("secrets.toml").unwrap();
1176 if !expected_update {
1177 assert!(
1178 content.eq(initial_secrets_file_contents.as_str()),
1179 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
1180 );
1181 return;
1182 }
1183
1184 let mut toml = std::fs::read_to_string(secrets_path)
1186 .unwrap()
1187 .parse::<DocumentMut>()
1188 .unwrap();
1189
1190 let token_payload = toml
1191 .get_mut("credentials")
1192 .and_then(|credentials| {
1193 credentials.get_mut(profile_name)?.get_mut("token_payload")
1194 })
1195 .expect("Should be able to get token_payload table");
1196
1197 let access_token = token_payload.get("access_token").unwrap().as_str().map(str::to_string).map(SecretAccessToken::from);
1198
1199 assert_eq!(
1200 access_token,
1201 Some(new_access_token)
1202 );
1203
1204 assert!(
1205 OffsetDateTime::parse(
1206 token_payload.get("updated_at").unwrap().as_str().unwrap(),
1207 &Rfc3339
1208 )
1209 .unwrap()
1210 > before_refresh
1211 );
1212
1213 let content = std::fs::read_to_string("secrets.toml").unwrap();
1214 assert!(
1215 content.contains("new_access_token"),
1216 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
1217 );
1218 });
1219 Ok(())
1220 });
1221 }
1222
1223 #[test]
1226 fn test_refresh_token_grant_persists_rotated_refresh_token() {
1227 let initial_refresh_token = "initial_refresh_token";
1228 let rotated_refresh_token = "rotated_refresh_token";
1229 let new_access_token = "new_access_token";
1230
1231 figment::Jail::expect_with(|jail| {
1232 jail.clear_env();
1233
1234 let secrets_path = "secrets.toml";
1235 let initial_secrets_file_contents = format!(
1236 r#"
1237[credentials]
1238[credentials.test]
1239[credentials.test.token_payload]
1240access_token = "initial_access_token"
1241refresh_token = "{initial_refresh_token}"
1242updated_at = "2024-01-01T00:00:00Z"
1243"#
1244 );
1245 jail.create_file(secrets_path, &initial_secrets_file_contents)
1246 .expect("should create test secrets.toml");
1247
1248 let rt = tokio::runtime::Runtime::new().unwrap();
1249 rt.block_on(async {
1250 let mock_server = MockServer::start_async().await;
1251 let oidc_mock = mock_server
1252 .mock_async(|when, then| {
1253 when.method(GET).path("/.well-known/openid-configuration");
1254 then.status(200)
1255 .json_body_obj(&oidc::Discovery::new_for_test(
1256 mock_server.base_url().parse().unwrap(),
1257 ));
1258 })
1259 .await;
1260 let issuer_mock = mock_server
1261 .mock_async(|when, then| {
1262 when.method(POST).path("/v1/token");
1263 then.status(200).json_body_obj(&RefreshTokenResponse {
1264 access_token: SecretAccessToken::from(new_access_token),
1265 refresh_token: Some(SecretRefreshToken::from(rotated_refresh_token)),
1266 });
1267 })
1268 .await;
1269
1270 let dispatcher: TokenDispatcher = OAuthSession::from_refresh_token(
1271 RefreshToken::new(SecretRefreshToken::from(initial_refresh_token)),
1272 AuthServer {
1273 client_id: "client_id".to_string(),
1274 issuer: mock_server.base_url(),
1275 scopes: None,
1276 },
1277 Some(SecretAccessToken::from("initial_access_token")),
1278 )
1279 .into();
1280
1281 dispatcher
1282 .refresh(
1283 &ConfigSource::File {
1284 settings_path: "".into(),
1285 secrets_path: secrets_path.into(),
1286 },
1287 "test",
1288 )
1289 .await
1290 .expect("refresh should succeed");
1291
1292 oidc_mock.assert_async().await;
1293 issuer_mock.assert_async().await;
1294 });
1295
1296 let payload = Secrets::load_from_path(&secrets_path.into())
1298 .expect("should load secrets")
1299 .credentials
1300 .remove("test")
1301 .expect("should have test credentials")
1302 .token_payload
1303 .expect("should have token payload");
1304 assert_eq!(
1305 payload.refresh_token.unwrap(),
1306 SecretRefreshToken::from(rotated_refresh_token),
1307 "rotated refresh token should be persisted to the secrets file"
1308 );
1309 assert_eq!(
1310 payload.access_token.unwrap(),
1311 SecretAccessToken::from(new_access_token),
1312 "new access token should be persisted to the secrets file"
1313 );
1314
1315 Ok(())
1316 });
1317 }
1318
1319 #[test]
1320 fn test_auth_session_debug_fmt() {
1321 let session = OAuthSession {
1322 payload: OAuthGrant::ClientCredentials(ClientCredentials::new(
1323 "hidden_id",
1324 "hidden_secret",
1325 )),
1326 access_token: Some(SecretAccessToken::from("token")),
1327 auth_server: AuthServer {
1328 client_id: "some_id".into(),
1329 issuer: "some_url".into(),
1330 scopes: None,
1331 },
1332 };
1333
1334 assert_eq!(
1335 "OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\", scopes: None } }",
1336 &format!("{session:?}")
1337 );
1338 }
1339}