1use std::{pin::Pin, sync::Arc};
2
3use futures::Future;
4use http::{header::CONTENT_TYPE, HeaderMap, HeaderValue};
5use jsonwebtoken::{Algorithm, DecodingKey, Validation};
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8use tokio::sync::{Mutex, Notify, RwLock};
9
10use super::{
11 secrets::Secrets, settings::AuthServer, ClientConfiguration, ConfigSource, TokenError,
12 QCS_AUDIENCE,
13};
14#[cfg(feature = "tracing-config")]
15use crate::tracing_configuration::TracingConfiguration;
16#[cfg(feature = "tracing")]
17use urlpattern::UrlPatternMatchInput;
18
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[cfg_attr(feature = "python", pyo3::pyclass)]
22pub struct RefreshToken {
23 pub refresh_token: String,
25}
26
27impl RefreshToken {
28 #[must_use]
30 pub const fn new(refresh_token: String) -> Self {
31 Self { refresh_token }
32 }
33
34 pub async fn request_access_token(
40 &mut self,
41 auth_server: &AuthServer,
42 ) -> Result<String, TokenError> {
43 if self.refresh_token.is_empty() {
44 return Err(TokenError::NoRefreshToken);
45 }
46 let token_url = format!("{}/v1/token", auth_server.issuer());
47 let data = TokenRefreshRequest::new(auth_server.client_id(), &self.refresh_token);
48 let resp = reqwest::Client::builder()
49 .timeout(std::time::Duration::from_secs(10))
50 .build()?
51 .post(token_url)
52 .form(&data)
53 .send()
54 .await?;
55
56 let response_data: TokenResponse = resp.error_for_status()?.json().await?;
57 self.refresh_token = response_data.refresh_token;
58 Ok(response_data.access_token)
59 }
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
63#[cfg_attr(feature = "python", pyo3::pyclass)]
64pub struct ClientCredentials {
66 pub client_id: String,
68 pub client_secret: String,
70}
71
72impl ClientCredentials {
73 #[must_use]
74 pub const fn new(client_id: String, client_secret: String) -> Self {
76 Self {
77 client_id,
78 client_secret,
79 }
80 }
81
82 #[must_use]
84 pub fn client_id(&self) -> &str {
85 &self.client_id
86 }
87
88 #[must_use]
90 pub fn client_secret(&self) -> &str {
91 &self.client_secret
92 }
93
94 pub async fn request_access_token(
100 &self,
101 auth_server: &AuthServer,
102 ) -> Result<String, TokenError> {
103 let request = ClientCredentialsRequest::new(&self.client_id, &self.client_secret);
104 let url = format!("{}/v1/token", auth_server.issuer());
105
106 let mut headers = HeaderMap::new();
110 headers.insert(
112 CONTENT_TYPE,
113 HeaderValue::from_static("application/x-www-form-urlencoded"),
114 );
115
116 let client = reqwest::Client::builder()
117 .timeout(std::time::Duration::from_secs(10))
118 .build()?;
119
120 let response = client
121 .post(url)
122 .headers(headers)
123 .form(&request)
124 .send()
125 .await?;
126
127 response.error_for_status_ref()?;
128
129 let response_body: TokenResponse = response.json().await?;
130
131 Ok(response_body.access_token)
132 }
133}
134
135#[derive(Clone, Debug)]
136#[cfg_attr(feature = "python", derive(pyo3::FromPyObject))]
137pub enum OAuthGrant {
140 RefreshToken(RefreshToken),
142 ClientCredentials(ClientCredentials),
144 ExternallyManaged(ExternallyManaged),
146}
147
148impl From<ExternallyManaged> for OAuthGrant {
149 fn from(v: ExternallyManaged) -> Self {
150 Self::ExternallyManaged(v)
151 }
152}
153
154impl From<ClientCredentials> for OAuthGrant {
155 fn from(v: ClientCredentials) -> Self {
156 Self::ClientCredentials(v)
157 }
158}
159
160impl From<RefreshToken> for OAuthGrant {
161 fn from(v: RefreshToken) -> Self {
162 Self::RefreshToken(v)
163 }
164}
165
166impl OAuthGrant {
167 async fn request_access_token(
169 &mut self,
170 auth_server: &AuthServer,
171 ) -> Result<String, TokenError> {
172 match self {
173 Self::RefreshToken(tokens) => tokens.request_access_token(auth_server).await,
174 Self::ClientCredentials(tokens) => tokens.request_access_token(auth_server).await,
175 Self::ExternallyManaged(tokens) => tokens
176 .request_access_token(auth_server)
177 .await
178 .map_err(|e| TokenError::ExternallyManaged(e.to_string())),
179 }
180 }
181}
182
183#[derive(Clone, Debug)]
195#[cfg_attr(feature = "python", pyo3::pyclass)]
196pub struct OAuthSession {
197 payload: OAuthGrant,
199 access_token: Option<String>,
201 auth_server: AuthServer,
203}
204
205impl OAuthSession {
206 #[must_use]
211 pub const fn new(
212 payload: OAuthGrant,
213 auth_server: AuthServer,
214 access_token: Option<String>,
215 ) -> Self {
216 Self {
217 payload,
218 access_token,
219 auth_server,
220 }
221 }
222
223 #[must_use]
228 pub const fn from_externally_managed(
229 tokens: ExternallyManaged,
230 auth_server: AuthServer,
231 access_token: Option<String>,
232 ) -> Self {
233 Self::new(
234 OAuthGrant::ExternallyManaged(tokens),
235 auth_server,
236 access_token,
237 )
238 }
239
240 #[must_use]
245 pub const fn from_refresh_token(
246 tokens: RefreshToken,
247 auth_server: AuthServer,
248 access_token: Option<String>,
249 ) -> Self {
250 Self::new(OAuthGrant::RefreshToken(tokens), auth_server, access_token)
251 }
252
253 #[must_use]
258 pub const fn from_client_credentials(
259 tokens: ClientCredentials,
260 auth_server: AuthServer,
261 access_token: Option<String>,
262 ) -> Self {
263 Self::new(
264 OAuthGrant::ClientCredentials(tokens),
265 auth_server,
266 access_token,
267 )
268 }
269
270 pub fn access_token(&self) -> Result<&str, TokenError> {
279 self.access_token.as_ref().map_or_else(
280 || Err(TokenError::NoAccessToken),
281 |token| Ok(token.as_str()),
282 )
283 }
284
285 #[must_use]
287 pub const fn payload(&self) -> &OAuthGrant {
288 &self.payload
289 }
290
291 #[allow(clippy::missing_panics_doc)]
297 pub async fn request_access_token(&mut self) -> Result<&str, TokenError> {
298 let access_token = self.payload.request_access_token(&self.auth_server).await?;
299 self.access_token = Some(access_token);
300 Ok(self
301 .access_token
302 .as_ref()
303 .expect("This value is set in the previous line, so it cannot be None"))
304 }
305
306 #[must_use]
308 pub const fn auth_server(&self) -> &AuthServer {
309 &self.auth_server
310 }
311
312 pub fn validate(&self) -> Result<String, TokenError> {
320 self.access_token().map_or_else(
321 |_| Err(TokenError::NoAccessToken),
322 |access_token| {
323 let placeholder_key = DecodingKey::from_secret(&[]);
324 let mut validation = Validation::new(Algorithm::RS256);
325 validation.validate_exp = true;
326 validation.leeway = 60;
327 validation.set_audience(&[QCS_AUDIENCE]);
328 validation.insecure_disable_signature_validation();
329 jsonwebtoken::decode::<toml::Value>(access_token, &placeholder_key, &validation)
330 .map(|_| access_token.to_string())
331 .map_err(TokenError::InvalidAccessToken)
332 },
333 )
334 }
335}
336
337#[derive(Clone, Debug)]
339#[cfg_attr(feature = "python", pyo3::pyclass)]
340pub struct TokenDispatcher {
341 lock: Arc<RwLock<OAuthSession>>,
342 refreshing: Arc<Mutex<bool>>,
343 notify_refreshed: Arc<Notify>,
344}
345
346impl From<OAuthSession> for TokenDispatcher {
347 fn from(value: OAuthSession) -> Self {
348 Self {
349 lock: Arc::new(RwLock::new(value)),
350 refreshing: Arc::new(Mutex::new(false)),
351 notify_refreshed: Arc::new(Notify::new()),
352 }
353 }
354}
355
356impl TokenDispatcher {
357 pub async fn use_tokens<F, O>(&self, f: F) -> O
367 where
368 F: FnOnce(&OAuthSession) -> O + Send,
369 {
370 let tokens = self.lock.read().await;
371 f(&tokens)
372 }
373
374 #[must_use]
376 pub async fn tokens(&self) -> OAuthSession {
377 self.use_tokens(Clone::clone).await
378 }
379
380 pub async fn refresh(
386 &self,
387 source: &ConfigSource,
388 profile: &str,
389 ) -> Result<OAuthSession, TokenError> {
390 self.managed_refresh(Self::perform_refresh, source, profile)
391 .await
392 }
393
394 pub async fn validate(&self) -> Result<String, TokenError> {
402 self.use_tokens(OAuthSession::validate).await
403 }
404
405 async fn managed_refresh<F, Fut>(
408 &self,
409 refresh_fn: F,
410 source: &ConfigSource,
411 profile: &str,
412 ) -> Result<OAuthSession, TokenError>
413 where
414 F: FnOnce(Arc<RwLock<OAuthSession>>) -> Fut + Send,
415 Fut: Future<Output = Result<OAuthSession, TokenError>> + Send,
416 {
417 let mut is_refreshing = self.refreshing.lock().await;
418
419 if *is_refreshing {
420 drop(is_refreshing);
421 self.notify_refreshed.notified().await;
422 return Ok(self.tokens().await);
423 }
424
425 *is_refreshing = true;
426 drop(is_refreshing);
427
428 let oauth_session = refresh_fn(self.lock.clone()).await?;
429
430 if let ConfigSource::File {
432 settings_path: _,
433 secrets_path,
434 } = source
435 {
436 if !Secrets::is_read_only(secrets_path).await? {
437 let now = OffsetDateTime::now_utc();
438 Secrets::write_access_token(
439 secrets_path,
440 profile,
441 oauth_session.access_token()?,
442 now,
443 )
444 .await?;
445 }
446 }
447
448 *self.refreshing.lock().await = false;
449 self.notify_refreshed.notify_waiters();
450 Ok(oauth_session)
451 }
452
453 async fn perform_refresh(lock: Arc<RwLock<OAuthSession>>) -> Result<OAuthSession, TokenError> {
460 let mut credentials = lock.write().await;
461 credentials.request_access_token().await?;
462 Ok(credentials.clone())
463 }
464}
465
466pub(crate) type RefreshResult =
467 Pin<Box<dyn Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>> + Send>>;
468
469pub type RefreshFunction = Box<dyn (Fn(AuthServer) -> RefreshResult) + Send + Sync>;
471
472#[derive(Clone)]
477#[cfg_attr(feature = "python", pyo3::pyclass)]
478pub struct ExternallyManaged {
479 refresh_function: Arc<RefreshFunction>,
480}
481
482impl ExternallyManaged {
483 pub fn new(
508 refresh_function: impl Fn(AuthServer) -> RefreshResult + Send + Sync + 'static,
509 ) -> Self {
510 Self {
511 refresh_function: Arc::new(Box::new(refresh_function)),
512 }
513 }
514
515 pub fn from_async<F, Fut>(refresh_function: F) -> Self
548 where
549 F: Fn(AuthServer) -> Fut + Send + Sync + 'static,
550 Fut: Future<Output = Result<String, Box<dyn std::error::Error + Send + Sync>>>
551 + Send
552 + 'static,
553 {
554 Self {
555 refresh_function: Arc::new(Box::new(move |auth_server| {
556 Box::pin(refresh_function(auth_server))
557 })),
558 }
559 }
560
561 pub fn from_sync(
592 refresh_function: impl Fn(AuthServer) -> Result<String, Box<dyn std::error::Error + Send + Sync>>
593 + Send
594 + Sync
595 + 'static,
596 ) -> Self {
597 Self {
598 refresh_function: Arc::new(Box::new(move |auth_server| {
599 let result = refresh_function(auth_server);
600 Box::pin(async move { result })
601 })),
602 }
603 }
604
605 pub async fn request_access_token(
611 &self,
612 auth_server: &AuthServer,
613 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
614 (self.refresh_function)(auth_server.clone()).await
615 }
616}
617
618impl std::fmt::Debug for ExternallyManaged {
619 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620 f.debug_struct("ExternallyManaged")
621 .field(
622 "refresh_function",
623 &"Fn() -> Pin<Box<dyn Future<Output = Result<String, TokenError>> + Send>>",
624 )
625 .finish()
626 }
627}
628
629#[derive(Debug, Serialize, Deserialize)]
630pub(super) struct TokenRefreshRequest<'a> {
631 grant_type: &'static str,
632 client_id: &'a str,
633 refresh_token: &'a str,
634}
635
636impl<'a> TokenRefreshRequest<'a> {
637 pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> Self {
638 Self {
639 grant_type: "refresh_token",
640 client_id,
641 refresh_token,
642 }
643 }
644}
645
646#[derive(Debug, Serialize, Deserialize)]
647pub(super) struct ClientCredentialsRequest<'a> {
648 grant_type: &'static str,
649 client_id: &'a str,
650 client_secret: &'a str,
651}
652
653impl<'a> ClientCredentialsRequest<'a> {
654 pub(super) const fn new(client_id: &'a str, client_secret: &'a str) -> Self {
655 Self {
656 grant_type: "client_credentials",
657 client_id,
658 client_secret,
659 }
660 }
661}
662
663#[derive(Deserialize, Debug, Serialize)]
664pub(super) struct TokenResponse {
665 pub(super) refresh_token: String,
666 pub(super) access_token: String,
667}
668
669#[async_trait::async_trait]
671pub trait TokenRefresher: Clone + std::fmt::Debug + Send {
672 type Error;
675
676 async fn validated_access_token(&self) -> Result<String, Self::Error>;
678
679 async fn get_access_token(&self) -> Result<Option<String>, Self::Error>;
681
682 async fn refresh_access_token(&self) -> Result<String, Self::Error>;
684
685 #[cfg(feature = "tracing")]
687 fn base_url(&self) -> &str;
688
689 #[cfg(feature = "tracing-config")]
691 fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
692
693 #[cfg(feature = "tracing")]
696 #[allow(clippy::needless_return)]
697 fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
698 #[cfg(not(feature = "tracing-config"))]
699 {
700 let _ = url;
701 return true;
702 }
703
704 #[cfg(feature = "tracing-config")]
705 self.tracing_configuration()
706 .is_none_or(|config| config.is_enabled(url))
707 }
708}
709
710#[async_trait::async_trait]
711impl TokenRefresher for ClientConfiguration {
712 type Error = TokenError;
713
714 async fn validated_access_token(&self) -> Result<String, Self::Error> {
715 self.get_bearer_access_token().await
716 }
717
718 async fn refresh_access_token(&self) -> Result<String, Self::Error> {
719 Ok(self.refresh().await?.access_token()?.to_string())
720 }
721
722 async fn get_access_token(&self) -> Result<Option<String>, Self::Error> {
723 Ok(Some(
724 self.oauth_session().await?.access_token()?.to_string(),
725 ))
726 }
727
728 #[cfg(feature = "tracing")]
729 fn base_url(&self) -> &str {
730 &self.grpc_api_url
731 }
732
733 #[cfg(feature = "tracing-config")]
734 fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
735 self.tracing_configuration.as_ref()
736 }
737}
738
739#[cfg(test)]
740mod test {
741 use std::time::Duration;
742
743 use super::*;
744 use httpmock::prelude::*;
745 use rstest::rstest;
746 use time::format_description::well_known::Rfc3339;
747 use tokio::time::Instant;
748 use toml_edit::DocumentMut;
749
750 #[tokio::test]
751 async fn test_tokens_blocked_during_refresh() {
752 let mock_server = MockServer::start_async().await;
753
754 let issuer_mock = mock_server
755 .mock_async(|when, then| {
756 when.method(POST).path("/v1/token");
757
758 then.status(200)
759 .delay(Duration::from_secs(3))
760 .json_body_obj(&TokenResponse {
761 access_token: "new_access".to_string(),
762 refresh_token: "new_refresh".to_string(),
763 });
764 })
765 .await;
766
767 let original_tokens = OAuthSession::from_refresh_token(
768 RefreshToken::new("refresh".to_string()),
769 AuthServer::new("client_id".to_string(), mock_server.base_url()),
770 None,
771 );
772 let dispatcher: TokenDispatcher = original_tokens.clone().into();
773 let dispatcher_clone1 = dispatcher.clone();
774 let dispatcher_clone2 = dispatcher.clone();
775
776 let refresh_duration = Duration::from_secs(3);
777
778 let start_write = Instant::now();
779 let write_future = tokio::spawn(async move {
780 dispatcher_clone1
781 .refresh(&ConfigSource::Default, "")
782 .await
783 .unwrap()
784 });
785
786 let start_read = Instant::now();
787 let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });
788
789 let _ = write_future.await.unwrap();
790 let read_result = read_future.await.unwrap();
791
792 let write_duration = start_write.elapsed();
793 let read_duration = start_read.elapsed();
794
795 issuer_mock.assert_async().await;
796
797 assert!(
798 write_duration >= refresh_duration,
799 "Write operation did not take enough time"
800 );
801 assert!(
802 read_duration >= refresh_duration,
803 "Read operation was not blocked by the write operation"
804 );
805 assert_eq!(read_result.access_token.as_ref().unwrap(), "new_access");
806 if let OAuthGrant::RefreshToken(payload) = read_result.payload {
807 assert_eq!(&payload.refresh_token, "new_refresh");
808 } else {
809 panic!(
810 "Expected RefreshToken payload, got {:?}",
811 read_result.payload
812 );
813 }
814 }
815
816 #[rstest]
817 fn test_qcs_secrets_readonly(
818 #[values(
819 (Some("TRUE"), true),
820 (Some("tRue"), true),
821 (Some("true"), true),
822 (Some("YES"), true),
823 (Some("yEs"), true),
824 (Some("yes"), true),
825 (Some("1"), true),
826 (Some("2"), false),
827 (Some("other"), false),
828 (Some(""), false),
829 (None, false),
830 )]
831 read_only_values: (Option<&str>, bool),
832 #[values(true, false)] read_only_perm: bool,
833 ) {
834 let (maybe_read_only_env, env_is_read_only) = read_only_values;
835 let expected_update = !env_is_read_only && !read_only_perm;
836 figment::Jail::expect_with(|jail| {
837 let profile_name = "test";
838 let initial_access_token = "initial_access_token";
839 let initial_refresh_token = "initial_refresh_token";
840
841 let initial_secrets_file_contents = format!(
842 r#"
843[credentials]
844[credentials.{profile_name}]
845[credentials.{profile_name}.token_payload]
846access_token = "{initial_access_token}"
847expires_in = 3600
848id_token = "id_token"
849refresh_token = "{initial_refresh_token}"
850scope = "offline_access openid profile email"
851token_type = "Bearer"
852updated_at = "2024-01-01T00:00:00Z"
853"#
854 );
855
856 let secrets_path = "secrets.toml";
858 jail.create_file(secrets_path, initial_secrets_file_contents.as_str())
859 .expect("should create test secrets.toml");
860
861 if read_only_perm {
862 let mut permissions = std::fs::metadata(secrets_path)
863 .expect("Should be able to get file metadata")
864 .permissions();
865 permissions.set_readonly(true);
866 std::fs::set_permissions(secrets_path, permissions)
867 .expect("Should be able to set file permissions");
868 }
869
870 let rt = tokio::runtime::Runtime::new().unwrap();
871 rt.block_on(async {
872 let mock_server = MockServer::start_async().await;
873
874 let new_access_token = "new_access_token";
876 let issuer_mock = mock_server
877 .mock_async(|when, then| {
878 when.method(POST).path("/v1/token");
879 then.status(200).json_body_obj(&TokenResponse {
880 access_token: new_access_token.to_string(),
881 refresh_token: initial_refresh_token.to_string(),
882 });
883 })
884 .await;
885
886 let original_tokens = OAuthSession::from_refresh_token(
888 RefreshToken::new(initial_refresh_token.to_string()),
889 AuthServer::new("client_id".to_string(), mock_server.base_url()),
890 Some(initial_refresh_token.to_string()),
891 );
892 let dispatcher: TokenDispatcher = original_tokens.into();
893
894 jail.set_env("QCS_SECRETS_FILE_PATH", "secrets.toml");
896 jail.set_env("QCS_PROFILE_NAME", "test");
897 if let Some(read_only_env) = maybe_read_only_env {
898 jail.set_env("QCS_SECRETS_READ_ONLY", read_only_env);
899 }
900
901 let before_refresh = OffsetDateTime::now_utc();
902
903 dispatcher
904 .refresh(
905 &ConfigSource::File {
906 settings_path: "".into(),
907 secrets_path: "secrets.toml".into(),
908 },
909 profile_name,
910 )
911 .await
912 .unwrap();
913
914 issuer_mock.assert_async().await;
915
916 let content = std::fs::read_to_string("secrets.toml").unwrap();
918 if !expected_update {
919 assert!(
920 content.eq(initial_secrets_file_contents.as_str()),
921 "File should not be updated when QCS_SECRETS_READ_ONLY is set or file permissions are read-only"
922 );
923 return;
924 }
925
926 let mut toml = std::fs::read_to_string(secrets_path)
928 .unwrap()
929 .parse::<DocumentMut>()
930 .unwrap();
931
932 let token_payload = toml
933 .get_mut("credentials")
934 .and_then(|credentials| {
935 credentials.get_mut(profile_name)?.get_mut("token_payload")
936 })
937 .expect("Should be able to get token_payload table");
938
939 assert_eq!(
940 token_payload.get("access_token").unwrap().as_str().unwrap(),
941 new_access_token
942 );
943
944 assert!(
945 OffsetDateTime::parse(
946 token_payload.get("updated_at").unwrap().as_str().unwrap(),
947 &Rfc3339
948 )
949 .unwrap()
950 > before_refresh
951 );
952
953 let content = std::fs::read_to_string("secrets.toml").unwrap();
954 assert!(
955 content.contains("new_access_token"),
956 "File should be updated with new access token when QCS_SECRETS_READ_ONLY is not set or is set but disabled, and file permissions allow writing"
957 );
958 });
959 Ok(())
960 });
961 }
962}