1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use http::{header::CONTENT_TYPE, HeaderMap, HeaderValue};
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8use tokio::sync::{Mutex, Notify, RwLock};
9
10use super::{
11 secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
12 QCS_AUDIENCE,
13};
14#[cfg(feature = "tracing-config")]
15use crate::tracing_configuration::TracingConfiguration;
16#[cfg(feature = "tracing")]
17use urlpattern::UrlPatternMatchInput;
18
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[cfg_attr(feature = "python", pyo3::pyclass)]
22pub struct RefreshToken {
23 pub refresh_token: String,
25}
26
27impl RefreshToken {
28 #[must_use]
30 pub const fn new(refresh_token: String) -> Self {
31 Self { refresh_token }
32 }
33
34 pub async fn request_access_token(
40 &mut self,
41 auth_server: &AuthServer,
42 ) -> Result<String, TokenError> {
43 if self.refresh_token.is_empty() {
44 return Err(TokenError::NoRefreshToken);
45 }
46 let token_url = format!("{}/v1/token", auth_server.issuer());
47 let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
48 let resp = reqwest::Client::builder()
49 .timeout(std::time::Duration::from_secs(10))
50 .build()?
51 .post(token_url)
52 .form(&data)
53 .send()
54 .await?;
55
56 let response_data: TokenResponse = resp.error_for_status()?.json().await?;
57 self.refresh_token = response_data.refresh_token;
58 Ok(response_data.access_token)
59 }
60}
61
62#[derive(Clone, PartialEq, Eq, Deserialize)]
63#[expect(missing_debug_implementations, reason = "contains secret data")]
64#[cfg_attr(feature = "python", pyo3::pyclass)]
65pub struct ClientCredentials {
67 pub client_id: String,
69 pub client_secret: String,
71}
72
73impl ClientCredentials {
74 #[must_use]
75 pub const fn new(client_id: String, client_secret: String) -> Self {
77 Self {
78 client_id,
79 client_secret,
80 }
81 }
82
83 #[must_use]
85 pub fn client_id(&self) -> &str {
86 &self.client_id
87 }
88
89 #[must_use]
91 pub fn client_secret(&self) -> &str {
92 &self.client_secret
93 }
94
95 pub async fn request_access_token(
101 &self,
102 auth_server: &AuthServer,
103 ) -> Result<String, TokenError> {
104 let request = ClientCredentialsRequest::new(&self.client_id, &self.client_secret);
105 let url = format!("{}/v1/token", auth_server.issuer());
106
107 let mut headers = HeaderMap::new();
111 headers.insert(
113 CONTENT_TYPE,
114 HeaderValue::from_static("application/x-www-form-urlencoded"),
115 );
116
117 let client = reqwest::Client::builder()
118 .timeout(std::time::Duration::from_secs(10))
119 .build()?;
120
121 let response = client
122 .post(url)
123 .headers(headers)
124 .form(&request)
125 .send()
126 .await?;
127
128 response.error_for_status_ref()?;
129
130 let response_body: TokenResponse = response.json().await?;
131
132 Ok(response_body.access_token)
133 }
134}
135
136#[derive(Clone)]
137#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
138pub enum OAuthGrant {
141 RefreshToken(RefreshToken),
143 ClientCredentials(ClientCredentials),
145 ExternallyManaged(ExternallyManaged),
147}
148
149impl From<ExternallyManaged> for OAuthGrant {
150 fn from(v: ExternallyManaged) -> Self {
151 Self::ExternallyManaged(v)
152 }
153}
154
155impl From<ClientCredentials> for OAuthGrant {
156 fn from(v: ClientCredentials) -> Self {
157 Self::ClientCredentials(v)
158 }
159}
160
161impl From<RefreshToken> for OAuthGrant {
162 fn from(v: RefreshToken) -> Self {
163 Self::RefreshToken(v)
164 }
165}
166
167impl OAuthGrant {
168 async fn request_access_token(
170 &mut self,
171 auth_server: &AuthServer,
172 ) -> Result<String, TokenError> {
173 match self {
174 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
175 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
176 Self::ExternallyManaged(tokens) => tokens
177 .request_access_token(auth_server)
178 .await
179 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
180 }
181 }
182}
183
184impl std::fmt::Debug for OAuthGrant {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 match self {
187 Self::RefreshToken(_) => f.write_str("RefreshToken"),
188 Self::ClientCredentials(_) => f.write_str("ClientCredentials"),
189 Self::ExternallyManaged(_) => f.write_str("ExternallyManaged"),
190 }
191 }
192}
193
194#[derive(Clone)]
206#[cfg_attr(feature = "python", pyo3::pyclass)]
207pub struct OAuthSession {
208 payload: OAuthGrant,
210 access_token: Option<String>,
212 auth_server: AuthServer,
214}
215
216impl OAuthSession {
217 #[must_use]
222 pub const fn new(
223 payload: OAuthGrant,
224 auth_server: AuthServer,
225 access_token: Option<String>,
226 ) -> Self {
227 Self {
228 payload,
229 access_token,
230 auth_server,
231 }
232 }
233
234 #[must_use]
239 pub const fn from_externally_managed(
240 tokens: ExternallyManaged,
241 auth_server: AuthServer,
242 access_token: Option<String>,
243 ) -> Self {
244 Self::new(
245 OAuthGrant::ExternallyManaged(tokens),
246 auth_server,
247 access_token,
248 )
249 }
250
251 #[must_use]
256 pub const fn from_refresh_token(
257 tokens: RefreshToken,
258 auth_server: AuthServer,
259 access_token: Option<String>,
260 ) -> Self {
261 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
262 }
263
264 #[must_use]
269 pub const fn from_client_credentials(
270 tokens: ClientCredentials,
271 auth_server: AuthServer,
272 access_token: Option<String>,
273 ) -> Self {
274 Self::new(
275 OAuthGrant::ClientCredentials(tokens),
276 auth_server,
277 access_token,
278 )
279 }
280
281 pub fn access_token(&self) -> Result<&str, TokenError> {
290 self.access_token.as_ref().map_or_else(
291 || Err(TokenError::NoAccessToken),
292 |token| Ok(token.as_str()),
293 )
294 }
295
296 #[must_use]
298 pub const fn payload(&self) -> &OAuthGrant {
299 &self.payload
300 }
301
302 #[allow(clippy::missing_panics_doc)]
308 pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
309 let access_token = self.payload.request_access_token(&self.auth_server).await?;
310 self.access_token = Some(access_token);
311 Ok(self
312 .access_token
313 .as_ref()
314 .expect("This value is set in the previous line, so it cannot be None"))
315 }
316
317 #[must_use]
319 pub const fn auth_server(&self) -> &AuthServer {
320 &self.auth_server
321 }
322
323 pub fn validate(&self) -> Result<String, TokenError> {
331 self.access_token().map_or_else(
332 |_| Err(TokenError::NoAccessToken),
333 |access_token| {
334 let placeholder_key = DecodingKey::from_secret(&[]);
335 let mut validation = Validation::new(Algorithm::RS256);
336 validation.validate_exp = true;
337 validation.leeway = 60;
338 validation.set_audience(&[QCS_AUDIENCE]);
339 validation.insecure_disable_signature_validation();
340 jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
341 .map(|_| access_token.to_string())
342 .map_err(TokenError::InvalidAccessToken)
343 },
344 )
345 }
346}
347
348impl std::fmt::Debug for OAuthSession {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 let token_populated = if self.access_token.is_some() {
351 Some(())
352 } else {
353 None
354 };
355 f.debug_struct("OAuthSession")
356 .field("payload", &self.payload)
357 .field("access_token", &token_populated)
358 .field("auth_server", &self.auth_server)
359 .finish()
360 }
361}
362
363#[derive(Clone, Debug)]
365#[cfg_attr(feature = "python", pyo3::pyclass)]
366pub struct TokenDispatcher {
367 lock: Arc<RwLock<OAuthSession>>,
368 refreshing: Arc<Mutex<bool>>,
369 notify_refreshed: Arc<Notify>,
370}
371
372impl From<OAuthSession> for TokenDispatcher {
373 fn from(value: OAuthSession) -> Self {
374 Self {
375 lock: Arc::new(RwLock::new(value)),
376 refreshing: Arc::new(Mutex::new(false)),
377 notify_refreshed: Arc::new(Notify::new()),
378 }
379 }
380}
381
382impl TokenDispatcher {
383 pub async fn use_tokens<F, O>(&self, f: F) -> O
393 where
394 F: FnOnce(&OAuthSession) -> O + Send,
395 {
396 let tokens = self.lock.read().await;
397 f(&tokens)
398 }
399
400 #[must_use]
402 pub async fn tokens(&self) -> OAuthSession {
403 self.use_tokens(Clone::clone).await
404 }
405
406 pub async fn refresh(
412 &self,
413 source: &ConfigSource,
414 profile: &str,
415 ) -> Result<OAuthSession, TokenError> {
416 self.managed_refresh(Self::perform_refresh, source, profile)
417 .await
418 }
419
420 pub async fn validate(&self) -> Result<String, TokenError> {
428 self.use_tokens(OAuthSession::validate).await
429 }
430
431 async fn managed_refresh<F, Fut>(
434 &self,
435 refresh_fn: F,
436 source: &ConfigSource,
437 profile: &str,
438 ) -> Result<OAuthSession, TokenError>
439 where
440 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
441 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
442 {
443 let mut is_refreshing = self.refreshing.lock().await;
444
445 if *is_refreshing {
446 drop(is_refreshing);
447 self.notify_refreshed.notified().await;
448 return Ok(self.tokens().await);
449 }
450
451 *is_refreshing = true;
452 drop(is_refreshing);
453
454 let oauth_session = refresh_fn(self.lock.clone()).await?;
455
456 if let ConfigSource::File {
458 settings_path: _,
459 secrets_path,
460 } = source
461 {
462 if !Secrets::is_read_only(secrets_path).await? {
463 let now = OffsetDateTime::now_utc();
464 Secrets::write_access_token(
465 secrets_path,
466 profile,
467 oauth_session.access_token()?,
468 now,
469 )
470 .await?;
471 }
472 }
473
474 *self.refreshing.lock().await = false;
475 self.notify_refreshed.notify_waiters();
476 Ok(oauth_session)
477 }
478
479 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
486 let mut credentials = lock.write().await;
487 credentials.request_access_token().await?;
488 Ok(credentials.clone())
489 }
490}
491
492pub(crate) type RefreshResult =
493 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
494
495pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
497
498#[derive(Clone)]
503#[cfg_attr(feature = "python", pyo3::pyclass)]
504pub struct ExternallyManaged {
505 refresh_function: Arc<RefreshFunction>,
506}
507
508impl ExternallyManaged {
509 pub fn new(
534 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
535 ) -> Self {
536 Self {
537 refresh_function: Arc::new(Box::new(refresh_function)),
538 }
539 }
540
541 pub fn from_async<F, Fut>(refresh_function: F) -> Self
574 where
575 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
576 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
577 + Send
578 + 'static,
579 {
580 Self {
581 refresh_function: Arc::new(Box::new(move |auth_server| {
582 Box::pin(refresh_function(auth_server))
583 })),
584 }
585 }
586
587 pub fn from_sync(
618 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
619 + Send
620 + Sync
621 + 'static,
622 ) -> Self {
623 Self {
624 refresh_function: Arc::new(Box::new(move |auth_server| {
625 let result = refresh_function(auth_server);
626 Box::pin(async move { result })
627 })),
628 }
629 }
630
631 pub async fn request_access_token(
637 &self,
638 auth_server: &AuthServer,
639 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
640 (self.refresh_function)(auth_server.clone()).await
641 }
642}
643
644impl std::fmt::Debug for ExternallyManaged {
645 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
646 f.debug_struct("ExternallyManaged")
647 .field(
648 "refresh_function",
649 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
650 )
651 .finish()
652 }
653}
654
655#[derive(Debug, Serialize, Deserialize)]
656pub(super) struct TokenRefreshRequest<'a> {
657 grant_type: &'static str,
658 client_id: &'a str,
659 refresh_token: &'a str,
660}
661
662impl<'a> TokenRefreshRequest<'a> {
663 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
664 Self {
665 grant_type: "refresh_token",
666 client_id,
667 refresh_token,
668 }
669 }
670}
671
672#[derive(Debug, Serialize, Deserialize)]
673pub(super) struct ClientCredentialsRequest<'a> {
674 grant_type: &'static str,
675 client_id: &'a str,
676 client_secret: &'a str,
677}
678
679impl<'a> ClientCredentialsRequest<'a> {
680 pub(super) const fn new(client_id: &'a str, client_secret: &'a str) -> Self {
681 Self {
682 grant_type: "client_credentials",
683 client_id,
684 client_secret,
685 }
686 }
687}
688
689#[derive(Deserialize, Debug, Serialize)]
690pub(super) struct TokenResponse {
691 pub(super) refresh_token: String,
692 pub(super) access_token: String,
693}
694
695#[async_trait::async_trait]
697pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
698 type Error;
701
702 async fn validated_access_token(&self) -> Result<String, Self::Error>;
704
705 async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
707
708 async fn refresh_access_token(&self) -> Result<String, Self::Error>;
710
711 #[cfg(feature = "tracing")]
713 fn base_url(&self) -> &str;
714
715 #[cfg(feature = "tracing-config")]
717 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
718
719 #[cfg(feature = "tracing")]
722 #[allow(clippy::needless_return)]
723 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
724 #[cfg(not(feature = "tracing-config"))]
725 {
726 let _ = url;
727 return true;
728 }
729
730 #[cfg(feature = "tracing-config")]
731 self.tracing_configuration()
732 .is_none_or(|config| config.is_enabled(url))
733 }
734}
735
736#[async_trait::async_trait]
737impl TokenRefresher for ClientConfiguration {
738 type Error = TokenError;
739
740 async fn validated_access_token(&self) -> Result<String, Self::Error> {
741 self.get_bearer_access_token().await
742 }
743
744 async fn refresh_access_token(&self) -> Result<String, Self::Error> {
745 Ok(self.refresh().await?.access_token()?.to_string())
746 }
747
748 async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
749 Ok(Some(
750 self.oauth_session().await?.access_token()?.to_string(),
751 ))
752 }
753
754 #[cfg(feature = "tracing")]
755 fn base_url(&self) -> &str {
756 &self.grpc_api_url
757 }
758
759 #[cfg(feature = "tracing-config")]
760 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
761 self.tracing_configuration.as_ref()
762 }
763}
764
765#[cfg(test)]
766mod test {
767 use std::time::Duration;
768
769 use super::*;
770 use httpmock::prelude::*;
771 use rstest::rstest;
772 use time::format_description::well_known::Rfc3339;
773 use tokio::time::Instant;
774 use toml_edit::DocumentMut;
775
776 #[tokio::test]
777 async fn test_tokens_blocked_during_refresh() {
778 let mock_server = MockServer::start_async().await;
779
780 let issuer_mock = mock_server
781 .mock_async(|when, then| {
782 when.method(POST).path("/v1/token");
783
784 then.status(200)
785 .delay(Duration::from_secs(3))
786 .json_body_obj(&TokenResponse {
787 access_token: "new_access".to_string(),
788 refresh_token: "new_refresh".to_string(),
789 });
790 })
791 .await;
792
793 let original_tokens = OAuthSession::from_refresh_token(
794 RefreshToken::new("refresh".to_string()),
795 AuthServer::new("client_id".to_string(), mock_server.base_url()),
796 None,
797 );
798 let dispatcher: TokenDispatcher = original_tokens.clone().into();
799 let dispatcher_clone1 = dispatcher.clone();
800 let dispatcher_clone2 = dispatcher.clone();
801
802 let refresh_duration = Duration::from_secs(3);
803
804 let start_write = Instant::now();
805 let write_future = tokio::spawn(async move {
806 dispatcher_clone1
807 .refresh(&ConfigSource::Default, "")
808 .await
809 .unwrap()
810 });
811
812 let start_read = Instant::now();
813 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
814
815 let _ = write_future.await.unwrap();
816 let read_result = read_future.await.unwrap();
817
818 let write_duration = start_write.elapsed();
819 let read_duration = start_read.elapsed();
820
821 issuer_mock.assert_async().await;
822
823 assert!(
824 write_duration >= refresh_duration,
825 "Write operation did not take enough time"
826 );
827 assert!(
828 read_duration >= refresh_duration,
829 "Read operation was not blocked by the write operation"
830 );
831 assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
832 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
833 assert_eq!(&payload.refresh_token, "new_refresh");
834 } else {
835 panic!(
836 "Expected RefreshToken payload, got {:?}",
837 read_result.payload
838 );
839 }
840 }
841
842 #[rstest]
843 fn test_qcs_secrets_readonly(
844 #[values(
845 (Some("TRUE"), true),
846 (Some("tRue"), true),
847 (Some("true"), true),
848 (Some("YES"), true),
849 (Some("yEs"), true),
850 (Some("yes"), true),
851 (Some("1"), true),
852 (Some("2"), false),
853 (Some("other"), false),
854 (Some(""), false),
855 (None, false),
856 )]
857 read_only_values: (Option<&str>, bool),
858 #[values(true, false)] read_only_perm: bool,
859 ) {
860 let (maybe_read_only_env, env_is_read_only) = read_only_values;
861 let expected_update = !env_is_read_only && !read_only_perm;
862 figment::Jail::expect_with(|jail| {
863 let profile_name = "test";
864 let initial_access_token = "initial_access_token";
865 let initial_refresh_token = "initial_refresh_token";
866
867 let initial_secrets_file_contents = format!(
868 r#"
869[credentials]
870[credentials.{profile_name}]
871[credentials.{profile_name}.token_payload]
872access_token = "{initial_access_token}"
873expires_in = 3600
874id_token = "id_token"
875refresh_token = "{initial_refresh_token}"
876scope = "offline_access openid profile email"
877token_type = "Bearer"
878updated_at = "2024-01-01T00:00:00Z"
879"#
880 );
881
882 let secrets_path = "secrets.toml";
884 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
885 .expect("should create test secrets.toml");
886
887 if read_only_perm {
888 let mut permissions = std::fs::metadata(secrets_path)
889 .expect("Should be able to get file metadata")
890 .permissions();
891 permissions.set_readonly(true);
892 std::fs::set_permissions(secrets_path, permissions)
893 .expect("Should be able to set file permissions");
894 }
895
896 let rt = tokio::runtime::Runtime::new().unwrap();
897 rt.block_on(async {
898 let mock_server = MockServer::start_async().await;
899
900 let new_access_token = "new_access_token";
902 let issuer_mock = mock_server
903 .mock_async(|when, then| {
904 when.method(POST).path("/v1/token");
905 then.status(200).json_body_obj(&TokenResponse {
906 access_token: new_access_token.to_string(),
907 refresh_token: initial_refresh_token.to_string(),
908 });
909 })
910 .await;
911
912 let original_tokens = OAuthSession::from_refresh_token(
914 RefreshToken::new(initial_refresh_token.to_string()),
915 AuthServer::new("client_id".to_string(), mock_server.base_url()),
916 Some(initial_refresh_token.to_string()),
917 );
918 let dispatcher: TokenDispatcher = original_tokens.into();
919
920 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
922 jail.set_env("QCS_PROFILE_NAME", "test");
923 if let Some(read_only_env) = maybe_read_only_env {
924 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
925 }
926
927 let before_refresh = OffsetDateTime::now_utc();
928
929 dispatcher
930 .refresh(
931 &ConfigSource::File {
932 settings_path: "".into(),
933 secrets_path: "secrets.toml".into(),
934 },
935 profile_name,
936 )
937 .await
938 .unwrap();
939
940 issuer_mock.assert_async().await;
941
942 let content = std::fs::read_to_string("secrets.toml").unwrap();
944 if !expected_update {
945 assert!(
946 content.eq(initial_secrets_file_contents.as_str()),
947 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
948 );
949 return;
950 }
951
952 let mut toml = std::fs::read_to_string(secrets_path)
954 .unwrap()
955 .parse::<DocumentMut>()
956 .unwrap();
957
958 let token_payload = toml
959 .get_mut("credentials")
960 .and_then(|credentials| {
961 credentials.get_mut(profile_name)?.get_mut("token_payload")
962 })
963 .expect("Should be able to get token_payload table");
964
965 assert_eq!(
966 token_payload.get("access_token").unwrap().as_str().unwrap(),
967 new_access_token
968 );
969
970 assert!(
971 OffsetDateTime::parse(
972 token_payload.get("updated_at").unwrap().as_str().unwrap(),
973 &Rfc3339
974 )
975 .unwrap()
976 > before_refresh
977 );
978
979 let content = std::fs::read_to_string("secrets.toml").unwrap();
980 assert!(
981 content.contains("new_access_token"),
982 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
983 );
984 });
985 Ok(())
986 });
987 }
988
989 #[test]
990 fn test_auth_session_debug_fmt() {
991 let session = OAuthSession {
992 payload: OAuthGrant::ClientCredentials(ClientCredentials {
993 client_id: "hidden_id".into(),
994 client_secret: "hidden_secret".into(),
995 }),
996 access_token: Some("token".into()),
997 auth_server: AuthServer::new("some_id".into(), "some_url".into()),
998 };
999
1000 assert_eq!("OAuthSession { payload: ClientCredentials, access_token: Some(()), auth_server: AuthServer { client_id: \"some_id\", issuer: \"some_url\" } }", &format!("{session:?}"));
1001 }
1002}