1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation};
5use serde::{Deserialize, Serialize};
6use time::OffsetDateTime;
7use tokio::sync::{Mutex, Notify, RwLock};
8
9use super::{
10 secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
11 QCS_AUDIENCE,
12};
13#[cfg(feature = "tracing-config")]
14use crate::tracing_configuration::TracingConfiguration;
15#[cfg(feature = "tracing")]
16use urlpattern::UrlPatternMatchInput;
17
18#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
20#[cfg_attr(feature = "python", pyo3::pyclass)]
21pub struct RefreshToken {
22 pub refresh_token: String,
24}
25
26impl RefreshToken {
27 #[must_use]
29 pub const fn new(refresh_token: String) -> Self {
30 Self { refresh_token }
31 }
32
33 pub async fn request_access_token(
39 &mut self,
40 auth_server: &AuthServer,
41 ) -> Result<String, TokenError> {
42 if self.refresh_token.is_empty() {
43 return Err(TokenError::NoRefreshToken);
44 }
45 let token_url = format!("{}/v1/token", auth_server.issuer());
46 let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
47 let resp = reqwest::Client::builder()
48 .timeout(std::time::Duration::from_secs(10))
49 .build()?
50 .post(token_url)
51 .form(&data)
52 .send()
53 .await?;
54
55 let response_data: RefreshTokenResponse = resp.error_for_status()?.json().await?;
56 self.refresh_token = response_data.refresh_token;
57 Ok(response_data.access_token)
58 }
59}
60
61#[derive(Deserialize, Debug, Serialize)]
62pub(super) struct ClientCredentialsResponse {
63 pub(super) access_token: String,
64}
65
66#[derive(Clone, PartialEq, Eq, Deserialize)]
67#[expect(missing_debug_implementations, reason = "contains secret data")]
68#[cfg_attr(feature = "python", pyo3::pyclass)]
69pub struct ClientCredentials {
71 pub client_id: String,
73 pub client_secret: String,
75}
76
77impl ClientCredentials {
78 #[must_use]
79 pub const fn new(client_id: String, client_secret: String) -> Self {
81 Self {
82 client_id,
83 client_secret,
84 }
85 }
86
87 #[must_use]
89 pub fn client_id(&self) -> &str {
90 &self.client_id
91 }
92
93 #[must_use]
95 pub fn client_secret(&self) -> &str {
96 &self.client_secret
97 }
98
99 pub async fn request_access_token(
105 &self,
106 auth_server: &AuthServer,
107 ) -> Result<String, TokenError> {
108 let request = ClientCredentialsRequest::new(None);
109 let url = format!("{}/v1/token", auth_server.issuer());
110
111 let client = reqwest::Client::builder()
112 .timeout(std::time::Duration::from_secs(10))
113 .build()?;
114
115 let ready_to_send = client
116 .post(url)
117 .basic_auth(auth_server.client_id(), Some(&self.client_secret))
118 .form(&request);
119 let response = ready_to_send.send().await?;
120
121 response.error_for_status_ref()?;
122
123 let response_body: ClientCredentialsResponse = response.json().await?;
124
125 Ok(response_body.access_token)
126 }
127}
128
129#[derive(Clone)]
130#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
131pub enum OAuthGrant {
134 RefreshToken(RefreshToken),
136 ClientCredentials(ClientCredentials),
138 ExternallyManaged(ExternallyManaged),
140}
141
142impl From<ExternallyManaged> for OAuthGrant {
143 fn from(v: ExternallyManaged) -> Self {
144 Self::ExternallyManaged(v)
145 }
146}
147
148impl From<ClientCredentials> for OAuthGrant {
149 fn from(v: ClientCredentials) -> Self {
150 Self::ClientCredentials(v)
151 }
152}
153
154impl From<RefreshToken> for OAuthGrant {
155 fn from(v: RefreshToken) -> Self {
156 Self::RefreshToken(v)
157 }
158}
159
160impl OAuthGrant {
161 async fn request_access_token(
163 &mut self,
164 auth_server: &AuthServer,
165 ) -> Result<String, TokenError> {
166 match self {
167 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
168 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
169 Self::ExternallyManaged(tokens) => tokens
170 .request_access_token(auth_server)
171 .await
172 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
173 }
174 }
175}
176
177impl std::fmt::Debug for OAuthGrant {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 Self::RefreshToken(_) => f.write_str("RefreshToken"),
181 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
182 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
183 }
184 }
185}
186
187#[derive(Clone)]
199#[cfg_attr(feature = "python", pyo3::pyclass)]
200pub struct OAuthSession {
201 payload: OAuthGrant,
203 access_token: Option<String>,
205 auth_server: AuthServer,
207}
208
209impl OAuthSession {
210 #[must_use]
215 pub const fn new(
216 payload: OAuthGrant,
217 auth_server: AuthServer,
218 access_token: Option<String>,
219 ) -> Self {
220 Self {
221 payload,
222 access_token,
223 auth_server,
224 }
225 }
226
227 #[must_use]
232 pub const fn from_externally_managed(
233 tokens: ExternallyManaged,
234 auth_server: AuthServer,
235 access_token: Option<String>,
236 ) -> Self {
237 Self::new(
238 OAuthGrant::ExternallyManaged(tokens),
239 auth_server,
240 access_token,
241 )
242 }
243
244 #[must_use]
249 pub const fn from_refresh_token(
250 tokens: RefreshToken,
251 auth_server: AuthServer,
252 access_token: Option<String>,
253 ) -> Self {
254 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
255 }
256
257 #[must_use]
262 pub const fn from_client_credentials(
263 tokens: ClientCredentials,
264 auth_server: AuthServer,
265 access_token: Option<String>,
266 ) -> Self {
267 Self::new(
268 OAuthGrant::ClientCredentials(tokens),
269 auth_server,
270 access_token,
271 )
272 }
273
274 pub fn access_token(&self) -> Result<&str, TokenError> {
283 self.access_token.as_ref().map_or_else(
284 || Err(TokenError::NoAccessToken),
285 |token| Ok(token.as_str()),
286 )
287 }
288
289 #[must_use]
291 pub const fn payload(&self) -> &OAuthGrant {
292 &self.payload
293 }
294
295 #[allow(clippy::missing_panics_doc)]
301 pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
302 let access_token = self.payload.request_access_token(&self.auth_server).await?;
303 self.access_token = Some(access_token);
304 Ok(self
305 .access_token
306 .as_ref()
307 .expect("This value is set in the previous line, so it cannot be None"))
308 }
309
310 #[must_use]
312 pub const fn auth_server(&self) -> &AuthServer {
313 &self.auth_server
314 }
315
316 pub fn validate(&self) -> Result<String, TokenError> {
324 self.access_token().map_or_else(
325 |_| Err(TokenError::NoAccessToken),
326 |access_token| {
327 let placeholder_key = DecodingKey::from_secret(&[]);
328 let mut validation = Validation::new(Algorithm::RS256);
329 validation.validate_exp = true;
330 validation.leeway = 60;
331 validation.set_audience(&[QCS_AUDIENCE]);
332 validation.insecure_disable_signature_validation();
333 jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
334 .map(|_| access_token.to_string())
335 .map_err(TokenError::InvalidAccessToken)
336 },
337 )
338 }
339}
340
341impl std::fmt::Debug for OAuthSession {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 let token_populated = if self.access_token.is_some() {
344 Some(())
345 } else {
346 None
347 };
348 f.debug_struct("OAuthSession")
349 .field("payload", &self.payload)
350 .field("access_token", &token_populated)
351 .field("auth_server", &self.auth_server)
352 .finish()
353 }
354}
355
356#[derive(Clone, Debug)]
358#[cfg_attr(feature = "python", pyo3::pyclass)]
359pub struct TokenDispatcher {
360 lock: Arc<RwLock<OAuthSession>>,
361 refreshing: Arc<Mutex<bool>>,
362 notify_refreshed: Arc<Notify>,
363}
364
365impl From<OAuthSession> for TokenDispatcher {
366 fn from(value: OAuthSession) -> Self {
367 Self {
368 lock: Arc::new(RwLock::new(value)),
369 refreshing: Arc::new(Mutex::new(false)),
370 notify_refreshed: Arc::new(Notify::new()),
371 }
372 }
373}
374
375impl TokenDispatcher {
376 pub async fn use_tokens<F, O>(&self, f: F) -> O
386 where
387 F: FnOnce(&OAuthSession) -> O + Send,
388 {
389 let tokens = self.lock.read().await;
390 f(&tokens)
391 }
392
393 #[must_use]
395 pub async fn tokens(&self) -> OAuthSession {
396 self.use_tokens(Clone::clone).await
397 }
398
399 pub async fn refresh(
405 &self,
406 source: &ConfigSource,
407 profile: &str,
408 ) -> Result<OAuthSession, TokenError> {
409 self.managed_refresh(Self::perform_refresh, source, profile)
410 .await
411 }
412
413 pub async fn validate(&self) -> Result<String, TokenError> {
421 self.use_tokens(OAuthSession::validate).await
422 }
423
424 async fn managed_refresh<F, Fut>(
427 &self,
428 refresh_fn: F,
429 source: &ConfigSource,
430 profile: &str,
431 ) -> Result<OAuthSession, TokenError>
432 where
433 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
434 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
435 {
436 let mut is_refreshing = self.refreshing.lock().await;
437
438 if *is_refreshing {
439 drop(is_refreshing);
440 self.notify_refreshed.notified().await;
441 return Ok(self.tokens().await);
442 }
443
444 *is_refreshing = true;
445 drop(is_refreshing);
446
447 let oauth_session = refresh_fn(self.lock.clone()).await?;
448
449 if let ConfigSource::File {
451 settings_path: _,
452 secrets_path,
453 } = source
454 {
455 if !Secrets::is_read_only(secrets_path).await? {
456 let now = OffsetDateTime::now_utc();
457 Secrets::write_access_token(
458 secrets_path,
459 profile,
460 oauth_session.access_token()?,
461 now,
462 )
463 .await?;
464 }
465 }
466
467 *self.refreshing.lock().await = false;
468 self.notify_refreshed.notify_waiters();
469 Ok(oauth_session)
470 }
471
472 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
479 let mut credentials = lock.write().await;
480 credentials.request_access_token().await?;
481 Ok(credentials.clone())
482 }
483}
484
485pub(crate) type RefreshResult =
486 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
487
488pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
490
491#[derive(Clone)]
496#[cfg_attr(feature = "python", pyo3::pyclass)]
497pub struct ExternallyManaged {
498 refresh_function: Arc<RefreshFunction>,
499}
500
501impl ExternallyManaged {
502 pub fn new(
527 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
528 ) -> Self {
529 Self {
530 refresh_function: Arc::new(Box::new(refresh_function)),
531 }
532 }
533
534 pub fn from_async<F, Fut>(refresh_function: F) -> Self
567 where
568 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
569 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
570 + Send
571 + 'static,
572 {
573 Self {
574 refresh_function: Arc::new(Box::new(move |auth_server| {
575 Box::pin(refresh_function(auth_server))
576 })),
577 }
578 }
579
580 pub fn from_sync(
611 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
612 + Send
613 + Sync
614 + 'static,
615 ) -> Self {
616 Self {
617 refresh_function: Arc::new(Box::new(move |auth_server| {
618 let result = refresh_function(auth_server);
619 Box::pin(async move { result })
620 })),
621 }
622 }
623
624 pub async fn request_access_token(
630 &self,
631 auth_server: &AuthServer,
632 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
633 (self.refresh_function)(auth_server.clone()).await
634 }
635}
636
637impl std::fmt::Debug for ExternallyManaged {
638 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 f.debug_struct("ExternallyManaged")
640 .field(
641 "refresh_function",
642 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
643 )
644 .finish()
645 }
646}
647
648#[derive(Debug, Serialize, Deserialize)]
649pub(super) struct TokenRefreshRequest<'a> {
650 grant_type: &'static str,
651 client_id: &'a str,
652 refresh_token: &'a str,
653}
654
655impl<'a> TokenRefreshRequest<'a> {
656 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
657 Self {
658 grant_type: "refresh_token",
659 client_id,
660 refresh_token,
661 }
662 }
663}
664
665#[derive(Debug, Serialize, Deserialize)]
666pub(super) struct ClientCredentialsRequest {
667 grant_type: &'static str,
668 scope: Option<&'static str>,
669}
670
671impl ClientCredentialsRequest {
672 pub(super) const fn new(scope: Option<&'static str>) -> Self {
673 Self {
674 grant_type: "client_credentials",
675 scope,
676 }
677 }
678}
679
680#[derive(Deserialize, Debug, Serialize)]
681pub(super) struct RefreshTokenResponse {
682 pub(super) refresh_token: String,
683 pub(super) access_token: String,
684}
685
686#[async_trait::async_trait]
688pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
689 type Error;
692
693 async fn validated_access_token(&self) -> Result<String, Self::Error>;
695
696 async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
698
699 async fn refresh_access_token(&self) -> Result<String, Self::Error>;
701
702 #[cfg(feature = "tracing")]
704 fn base_url(&self) -> &str;
705
706 #[cfg(feature = "tracing-config")]
708 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
709
710 #[cfg(feature = "tracing")]
713 #[allow(clippy::needless_return)]
714 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
715 #[cfg(not(feature = "tracing-config"))]
716 {
717 let _ = url;
718 return true;
719 }
720
721 #[cfg(feature = "tracing-config")]
722 self.tracing_configuration()
723 .is_none_or(|config| config.is_enabled(url))
724 }
725}
726
727#[async_trait::async_trait]
728impl TokenRefresher for ClientConfiguration {
729 type Error = TokenError;
730
731 async fn validated_access_token(&self) -> Result<String, Self::Error> {
732 self.get_bearer_access_token().await
733 }
734
735 async fn refresh_access_token(&self) -> Result<String, Self::Error> {
736 Ok(self.refresh().await?.access_token()?.to_string())
737 }
738
739 async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
740 Ok(Some(
741 self.oauth_session().await?.access_token()?.to_string(),
742 ))
743 }
744
745 #[cfg(feature = "tracing")]
746 fn base_url(&self) -> &str {
747 &self.grpc_api_url
748 }
749
750 #[cfg(feature = "tracing-config")]
751 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
752 self.tracing_configuration.as_ref()
753 }
754}
755
756#[cfg(test)]
757mod test {
758 use std::time::Duration;
759
760 use super::*;
761 use httpmock::prelude::*;
762 use rstest::rstest;
763 use time::format_description::well_known::Rfc3339;
764 use tokio::time::Instant;
765 use toml_edit::DocumentMut;
766
767 #[tokio::test]
768 async fn test_tokens_blocked_during_refresh() {
769 let mock_server = MockServer::start_async().await;
770
771 let issuer_mock = mock_server
772 .mock_async(|when, then| {
773 when.method(POST).path("/v1/token");
774
775 then.status(200)
776 .delay(Duration::from_secs(3))
777 .json_body_obj(&RefreshTokenResponse {
778 access_token: "new_access".to_string(),
779 refresh_token: "new_refresh".to_string(),
780 });
781 })
782 .await;
783
784 let original_tokens = OAuthSession::from_refresh_token(
785 RefreshToken::new("refresh".to_string()),
786 AuthServer::new("client_id".to_string(), mock_server.base_url()),
787 None,
788 );
789 let dispatcher: TokenDispatcher = original_tokens.clone().into();
790 let dispatcher_clone1 = dispatcher.clone();
791 let dispatcher_clone2 = dispatcher.clone();
792
793 let refresh_duration = Duration::from_secs(3);
794
795 let start_write = Instant::now();
796 let write_future = tokio::spawn(async move {
797 dispatcher_clone1
798 .refresh(&ConfigSource::Default, "")
799 .await
800 .unwrap()
801 });
802
803 let start_read = Instant::now();
804 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
805
806 let _ = write_future.await.unwrap();
807 let read_result = read_future.await.unwrap();
808
809 let write_duration = start_write.elapsed();
810 let read_duration = start_read.elapsed();
811
812 issuer_mock.assert_async().await;
813
814 assert!(
815 write_duration >= refresh_duration,
816 "Write operation did not take enough time"
817 );
818 assert!(
819 read_duration >= refresh_duration,
820 "Read operation was not blocked by the write operation"
821 );
822 assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
823 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
824 assert_eq!(&payload.refresh_token, "new_refresh");
825 } else {
826 panic!(
827 "Expected RefreshToken payload, got {:?}",
828 read_result.payload
829 );
830 }
831 }
832
833 #[rstest]
834 fn test_qcs_secrets_readonly(
835 #[values(
836 (Some("TRUE"), true),
837 (Some("tRue"), true),
838 (Some("true"), true),
839 (Some("YES"), true),
840 (Some("yEs"), true),
841 (Some("yes"), true),
842 (Some("1"), true),
843 (Some("2"), false),
844 (Some("other"), false),
845 (Some(""), false),
846 (None, false),
847 )]
848 read_only_values: (Option<&str>, bool),
849 #[values(true, false)] read_only_perm: bool,
850 ) {
851 let (maybe_read_only_env, env_is_read_only) = read_only_values;
852 let expected_update = !env_is_read_only && !read_only_perm;
853 figment::Jail::expect_with(|jail| {
854 let profile_name = "test";
855 let initial_access_token = "initial_access_token";
856 let initial_refresh_token = "initial_refresh_token";
857
858 let initial_secrets_file_contents = format!(
859 r#"
860[credentials]
861[credentials.{profile_name}]
862[credentials.{profile_name}.token_payload]
863access_token = "{initial_access_token}"
864expires_in = 3600
865id_token = "id_token"
866refresh_token = "{initial_refresh_token}"
867scope = "offline_access openid profile email"
868token_type = "Bearer"
869updated_at = "2024-01-01T00:00:00Z"
870"#
871 );
872
873 jail.clear_env();
875
876 let secrets_path = "secrets.toml";
878 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
879 .expect("should create test secrets.toml");
880
881 if read_only_perm {
882 let mut permissions = std::fs::metadata(secrets_path)
883 .expect("Should be able to get file metadata")
884 .permissions();
885 permissions.set_readonly(true);
886 std::fs::set_permissions(secrets_path, permissions)
887 .expect("Should be able to set file permissions");
888 }
889
890 let rt = tokio::runtime::Runtime::new().unwrap();
891 rt.block_on(async {
892 let mock_server = MockServer::start_async().await;
893
894 let new_access_token = "new_access_token";
896 let issuer_mock = mock_server
897 .mock_async(|when, then| {
898 when.method(POST).path("/v1/token");
899 then.status(200).json_body_obj(&RefreshTokenResponse {
900 access_token: new_access_token.to_string(),
901 refresh_token: initial_refresh_token.to_string(),
902 });
903 })
904 .await;
905
906 let original_tokens = OAuthSession::from_refresh_token(
908 RefreshToken::new(initial_refresh_token.to_string()),
909 AuthServer::new("client_id".to_string(), mock_server.base_url()),
910 Some(initial_refresh_token.to_string()),
911 );
912 let dispatcher: TokenDispatcher = original_tokens.into();
913
914 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
916 jail.set_env("QCS_PROFILE_NAME", "test");
917 if let Some(read_only_env) = maybe_read_only_env {
918 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
919 }
920
921 let before_refresh = OffsetDateTime::now_utc();
922
923 dispatcher
924 .refresh(
925 &ConfigSource::File {
926 settings_path: "".into(),
927 secrets_path: "secrets.toml".into(),
928 },
929 profile_name,
930 )
931 .await
932 .unwrap();
933
934 issuer_mock.assert_async().await;
935
936 let content = std::fs::read_to_string("secrets.toml").unwrap();
938 if !expected_update {
939 assert!(
940 content.eq(initial_secrets_file_contents.as_str()),
941 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
942 );
943 return;
944 }
945
946 let mut toml = std::fs::read_to_string(secrets_path)
948 .unwrap()
949 .parse::<DocumentMut>()
950 .unwrap();
951
952 let token_payload = toml
953 .get_mut("credentials")
954 .and_then(|credentials| {
955 credentials.get_mut(profile_name)?.get_mut("token_payload")
956 })
957 .expect("Should be able to get token_payload table");
958
959 assert_eq!(
960 token_payload.get("access_token").unwrap().as_str().unwrap(),
961 new_access_token
962 );
963
964 assert!(
965 OffsetDateTime::parse(
966 token_payload.get("updated_at").unwrap().as_str().unwrap(),
967 &Rfc3339
968 )
969 .unwrap()
970 > before_refresh
971 );
972
973 let content = std::fs::read_to_string("secrets.toml").unwrap();
974 assert!(
975 content.contains("new_access_token"),
976 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
977 );
978 });
979 Ok(())
980 });
981 }
982
983 #[test]
984 fn test_auth_session_debug_fmt() {
985 let session = OAuthSession {
986 payload: OAuthGrant::ClientCredentials(ClientCredentials {
987 client_id: "hidden_id".into(),
988 client_secret: "hidden_secret".into(),
989 }),
990 access_token: Some("token".into()),
991 auth_server: AuthServer::new("some_id".into(), "some_url".into()),
992 };
993
994 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
995 }
996}