1use std::path::PathBuf;
8
9use chrono::{DateTime, Utc};
10use oauth2::basic::{BasicClient, BasicTokenResponse};
11use oauth2::{
12 AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
13 PkceCodeVerifier, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
14};
15use serde::{Deserialize, Serialize};
16
17const ONSHAPE_AUTH_URL_STR: &str = "https://oauth.onshape.com/oauth/authorize";
23
24const ONSHAPE_TOKEN_URL_STR: &str = "https://oauth.onshape.com/oauth/token";
26
27#[must_use]
34pub fn onshape_auth_url() -> AuthUrl {
35 #[allow(clippy::expect_used)]
36 AuthUrl::new(ONSHAPE_AUTH_URL_STR.to_string()).expect("hard-coded Onshape auth URL is valid")
37}
38
39#[must_use]
46pub fn onshape_token_url() -> TokenUrl {
47 #[allow(clippy::expect_used)]
48 TokenUrl::new(ONSHAPE_TOKEN_URL_STR.to_string()).expect("hard-coded Onshape token URL is valid")
49}
50
51#[derive(Clone, Debug, Deserialize, Serialize)]
61pub struct OAuthTokenData {
62 #[serde(
64 serialize_with = "serialize_access_token",
65 deserialize_with = "deserialize_access_token"
66 )]
67 pub access_token: AccessToken,
68 #[serde(
70 serialize_with = "serialize_refresh_token",
71 deserialize_with = "deserialize_refresh_token"
72 )]
73 pub refresh_token: RefreshToken,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub expires_at: Option<DateTime<Utc>>,
79 #[serde(
85 default = "default_token_type",
86 deserialize_with = "deserialize_token_type"
87 )]
88 pub token_type: String,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub scopes: Option<Vec<String>>,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub client_id: Option<String>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
112 pub client_secret: Option<String>,
113 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub proxy_url: Option<String>,
123}
124
125impl OAuthTokenData {
126 #[must_use]
131 pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
132 self.expires_at.is_some_and(|expires| expires <= now)
133 }
134
135 #[must_use]
139 pub fn is_expiring_soon(&self, now: DateTime<Utc>, margin: chrono::Duration) -> bool {
140 self.expires_at
141 .is_some_and(|expires| expires <= now + margin)
142 }
143}
144
145impl OAuthTokenData {
146 #[must_use]
153 pub fn from_response(response: &BasicTokenResponse, now: DateTime<Utc>) -> Self {
154 let expires_at = response
155 .expires_in()
156 .and_then(|d| chrono::Duration::from_std(d).ok())
157 .map(|d| now + d);
158
159 let scopes = response
160 .scopes()
161 .map(|scopes| scopes.iter().map(|s| s.as_ref().to_owned()).collect());
162
163 Self {
164 access_token: response.access_token().clone(),
165 refresh_token: response
166 .refresh_token()
167 .cloned()
168 .unwrap_or_else(|| RefreshToken::new(String::new())),
169 expires_at,
170 token_type: response.token_type().as_ref().to_string(),
171 scopes,
172 client_id: None,
175 client_secret: None,
176 proxy_url: None,
177 }
178 }
179}
180
181impl OAuthTokenData {
182 #[must_use]
187 pub fn from_raw(
188 access_token: String,
189 refresh_token: String,
190 expires_at: Option<DateTime<Utc>>,
191 token_type: String,
192 scopes: Option<Vec<String>>,
193 ) -> Self {
194 Self {
195 access_token: AccessToken::new(access_token),
196 refresh_token: RefreshToken::new(refresh_token),
197 expires_at,
198 token_type,
199 scopes,
200 client_id: None,
201 client_secret: None,
202 proxy_url: None,
203 }
204 }
205}
206
207fn default_token_type() -> String {
208 "bearer".into()
209}
210
211fn deserialize_token_type<'de, D>(deserializer: D) -> Result<String, D::Error>
220where
221 D: serde::Deserializer<'de>,
222{
223 let s = String::deserialize(deserializer)?;
224 if s.eq_ignore_ascii_case("bearer") {
225 Ok("bearer".to_string())
226 } else {
227 Err(serde::de::Error::custom(format!(
228 "invalid token_type \"{s}\", expected \"bearer\""
229 )))
230 }
231}
232
233fn serialize_access_token<S>(token: &AccessToken, serializer: S) -> Result<S::Ok, S::Error>
237where
238 S: serde::Serializer,
239{
240 serializer.serialize_str(token.secret())
241}
242
243fn deserialize_access_token<'de, D>(deserializer: D) -> Result<AccessToken, D::Error>
245where
246 D: serde::Deserializer<'de>,
247{
248 let s = String::deserialize(deserializer)?;
249 Ok(AccessToken::new(s))
250}
251
252fn serialize_refresh_token<S>(token: &RefreshToken, serializer: S) -> Result<S::Ok, S::Error>
254where
255 S: serde::Serializer,
256{
257 serializer.serialize_str(token.secret())
258}
259
260fn deserialize_refresh_token<'de, D>(deserializer: D) -> Result<RefreshToken, D::Error>
262where
263 D: serde::Deserializer<'de>,
264{
265 let s = String::deserialize(deserializer)?;
266 Ok(RefreshToken::new(s))
267}
268
269pub type OnshapeOAuthClient = BasicClient<
278 oauth2::EndpointSet,
279 oauth2::EndpointNotSet,
280 oauth2::EndpointNotSet,
281 oauth2::EndpointNotSet,
282 oauth2::EndpointSet,
283>;
284
285#[must_use]
296pub fn onshape_oauth_client(client_id: &str, client_secret: &str) -> OnshapeOAuthClient {
297 BasicClient::new(ClientId::new(client_id.to_string()))
298 .set_client_secret(ClientSecret::new(client_secret.to_string()))
299 .set_auth_uri(onshape_auth_url())
300 .set_token_uri(onshape_token_url())
301}
302
303#[must_use]
315pub fn default_data_dir() -> Option<PathBuf> {
316 dirs::data_dir().map(|dir| dir.join("onshape-mcp"))
317}
318
319#[must_use]
327pub fn default_token_file_path() -> Option<PathBuf> {
328 default_data_dir().map(|dir| dir.join("tokens.json"))
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq)]
337pub enum PreExecuteAction {
338 Proceed,
340 RefreshNeeded,
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq)]
346pub enum PostExecuteAction {
347 Done,
349 RefreshAndRetry,
351}
352
353pub struct OAuthSession {
358 pub tokens: OAuthTokenData,
360 refresh_margin: chrono::Duration,
361}
362
363impl OAuthSession {
364 #[must_use]
369 pub const fn new(tokens: OAuthTokenData, refresh_margin: chrono::Duration) -> Self {
370 Self {
371 tokens,
372 refresh_margin,
373 }
374 }
375
376 #[must_use]
380 pub fn pre_execute_action(&self, now: DateTime<Utc>) -> PreExecuteAction {
381 if self.tokens.is_expiring_soon(now, self.refresh_margin) {
382 PreExecuteAction::RefreshNeeded
383 } else {
384 PreExecuteAction::Proceed
385 }
386 }
387
388 #[must_use]
393 pub const fn post_execute_action(
394 &self,
395 status: u16,
396 already_refreshed: bool,
397 ) -> PostExecuteAction {
398 if status == 401 && !already_refreshed {
399 PostExecuteAction::RefreshAndRetry
400 } else {
401 PostExecuteAction::Done
402 }
403 }
404
405 pub fn apply_refresh(&mut self, response: &BasicTokenResponse, now: DateTime<Utc>) {
411 let mut new_tokens = OAuthTokenData::from_response(response, now);
412 if response.refresh_token().is_none() {
415 new_tokens.refresh_token = self.tokens.refresh_token.clone();
416 }
417 new_tokens.client_id.clone_from(&self.tokens.client_id);
420 new_tokens
421 .client_secret
422 .clone_from(&self.tokens.client_secret);
423 self.tokens = new_tokens;
424 }
425
426 pub fn apply_external_tokens(
435 &mut self,
436 file_tokens: OAuthTokenData,
437 now: DateTime<Utc>,
438 ) -> bool {
439 let (Some(file_expires), Some(current_expires)) =
441 (file_tokens.expires_at, self.tokens.expires_at)
442 else {
443 return false;
444 };
445
446 if file_expires > current_expires && file_expires > now {
448 self.tokens = file_tokens;
449 true
450 } else {
451 false
452 }
453 }
454
455 #[must_use]
457 pub const fn access_token(&self) -> &AccessToken {
458 &self.tokens.access_token
459 }
460
461 #[must_use]
463 pub const fn refresh_token(&self) -> &RefreshToken {
464 &self.tokens.refresh_token
465 }
466}
467
468#[derive(Clone, Debug)]
477pub struct OAuthLoginConfig {
478 pub client_id: String,
480 pub redirect_uri: String,
483 pub scopes: Vec<String>,
485}
486
487pub struct OAuthLoginSession {
493 pub pkce_verifier: PkceCodeVerifier,
495 pub csrf_state: CsrfToken,
497 pub config: OAuthLoginConfig,
499}
500
501#[derive(Debug, thiserror::Error)]
503pub enum CallbackValidationError {
504 #[error("invalid callback URL: {0}")]
506 InvalidUrl(String),
507 #[error("OAuth error from provider: {error} (description: {description:?})")]
509 OAuthError {
510 error: String,
512 description: Option<String>,
514 },
515 #[error("CSRF state mismatch: expected {expected}, got {actual}")]
517 StateMismatch {
518 expected: String,
520 actual: String,
522 },
523 #[error("callback is missing the 'state' parameter")]
525 MissingState,
526 #[error("callback is missing the 'code' parameter")]
528 MissingCode,
529}
530
531#[must_use]
549pub fn build_authorize_url(
550 config: &OAuthLoginConfig,
551 csrf_state: &CsrfToken,
552 pkce_challenge: PkceCodeChallenge,
553) -> String {
554 let client = BasicClient::new(ClientId::new(config.client_id.clone()))
555 .set_auth_uri(onshape_auth_url())
556 .set_token_uri(onshape_token_url())
557 .set_redirect_uri(
558 #[allow(clippy::expect_used)]
559 RedirectUrl::new(config.redirect_uri.clone())
560 .expect("redirect_uri should be a valid URL"),
561 );
562
563 let mut auth_request = client
564 .authorize_url(|| csrf_state.clone())
565 .set_pkce_challenge(pkce_challenge);
566
567 for scope in &config.scopes {
568 auth_request = auth_request.add_scope(oauth2::Scope::new(scope.clone()));
569 }
570
571 let (url, _csrf_token) = auth_request.url();
572 url.to_string()
573}
574
575pub fn validate_callback(
593 callback_url: &str,
594 expected_state: &CsrfToken,
595) -> Result<AuthorizationCode, CallbackValidationError> {
596 let url = url::Url::parse(callback_url)
597 .map_err(|e| CallbackValidationError::InvalidUrl(e.to_string()))?;
598
599 let params: std::collections::HashMap<String, String> = url
600 .query_pairs()
601 .map(|(k, v)| (k.into_owned(), v.into_owned()))
602 .collect();
603
604 if let Some(error) = params.get("error") {
606 return Err(CallbackValidationError::OAuthError {
607 error: error.clone(),
608 description: params.get("error_description").cloned(),
609 });
610 }
611
612 let state = params
614 .get("state")
615 .ok_or(CallbackValidationError::MissingState)?;
616
617 if state != expected_state.secret() {
618 return Err(CallbackValidationError::StateMismatch {
619 expected: expected_state.secret().clone(),
620 actual: state.clone(),
621 });
622 }
623
624 let code = params
626 .get("code")
627 .ok_or(CallbackValidationError::MissingCode)?;
628
629 Ok(AuthorizationCode::new(code.clone()))
630}
631
632#[cfg(test)]
637#[allow(clippy::expect_used, clippy::panic)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn token_data_serializes_to_json() {
643 let tokens = OAuthTokenData {
644 access_token: AccessToken::new("access-123".to_string()),
645 refresh_token: RefreshToken::new("refresh-456".to_string()),
646 expires_at: None,
647 token_type: "bearer".into(),
648 scopes: None,
649 client_id: None,
650 client_secret: None,
651 proxy_url: None,
652 };
653 let json = serde_json::to_string(&tokens).expect("should serialize");
654 let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
655 assert_eq!(value["access_token"], "access-123");
656 assert_eq!(value["refresh_token"], "refresh-456");
657 assert_eq!(value["token_type"], "bearer");
658 assert!(value.get("expires_at").is_none());
659 assert!(value.get("scopes").is_none());
660 }
661
662 #[test]
663 fn token_data_deserializes_from_json() {
664 let json = r#"{
665 "access_token": "access-789",
666 "refresh_token": "refresh-012",
667 "token_type": "bearer"
668 }"#;
669 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
670 assert_eq!(tokens.access_token.secret(), "access-789");
671 assert_eq!(tokens.refresh_token.secret(), "refresh-012");
672 assert_eq!(tokens.token_type, "bearer");
673 assert!(tokens.expires_at.is_none());
674 assert!(tokens.scopes.is_none());
675 }
676
677 #[test]
678 fn token_data_roundtrips_with_expiry() {
679 let expires = DateTime::parse_from_rfc3339("2025-06-15T12:00:00Z")
680 .expect("should parse")
681 .to_utc();
682 let tokens = OAuthTokenData {
683 access_token: AccessToken::new("at".to_string()),
684 refresh_token: RefreshToken::new("rt".to_string()),
685 expires_at: Some(expires),
686 token_type: "bearer".into(),
687 scopes: None,
688 client_id: None,
689 client_secret: None,
690 proxy_url: None,
691 };
692 let json = serde_json::to_string(&tokens).expect("should serialize");
693 let roundtripped: OAuthTokenData = serde_json::from_str(&json).expect("should deserialize");
694 assert_eq!(roundtripped.expires_at, Some(expires));
695 }
696
697 #[test]
698 fn is_expired_returns_true_when_past() {
699 let expires = DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z")
700 .expect("should parse")
701 .to_utc();
702 let tokens = OAuthTokenData {
703 access_token: AccessToken::new("at".to_string()),
704 refresh_token: RefreshToken::new("rt".to_string()),
705 expires_at: Some(expires),
706 token_type: "bearer".into(),
707 scopes: None,
708 client_id: None,
709 client_secret: None,
710 proxy_url: None,
711 };
712 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
713 .expect("should parse")
714 .to_utc();
715 assert!(tokens.is_expired(now));
716 }
717
718 #[test]
719 fn is_expired_returns_true_when_exactly_at_expiry() {
720 let expires = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
721 .expect("should parse")
722 .to_utc();
723 let tokens = OAuthTokenData {
724 access_token: AccessToken::new("at".to_string()),
725 refresh_token: RefreshToken::new("rt".to_string()),
726 expires_at: Some(expires),
727 token_type: "bearer".into(),
728 scopes: None,
729 client_id: None,
730 client_secret: None,
731 proxy_url: None,
732 };
733 assert!(tokens.is_expired(expires));
734 }
735
736 #[test]
737 fn is_expired_returns_false_when_future() {
738 let expires = DateTime::parse_from_rfc3339("2030-01-01T00:00:00Z")
739 .expect("should parse")
740 .to_utc();
741 let tokens = OAuthTokenData {
742 access_token: AccessToken::new("at".to_string()),
743 refresh_token: RefreshToken::new("rt".to_string()),
744 expires_at: Some(expires),
745 token_type: "bearer".into(),
746 scopes: None,
747 client_id: None,
748 client_secret: None,
749 proxy_url: None,
750 };
751 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
752 .expect("should parse")
753 .to_utc();
754 assert!(!tokens.is_expired(now));
755 }
756
757 #[test]
758 fn is_expired_returns_false_when_no_expiry() {
759 let tokens = OAuthTokenData {
760 access_token: AccessToken::new("at".to_string()),
761 refresh_token: RefreshToken::new("rt".to_string()),
762 expires_at: None,
763 token_type: "bearer".into(),
764 scopes: None,
765 client_id: None,
766 client_secret: None,
767 proxy_url: None,
768 };
769 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
770 .expect("should parse")
771 .to_utc();
772 assert!(!tokens.is_expired(now));
773 }
774
775 #[test]
776 fn default_token_type_is_bearer() {
777 let json = r#"{
778 "access_token": "at",
779 "refresh_token": "rt"
780 }"#;
781 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
782 assert_eq!(tokens.token_type, "bearer");
783 }
784
785 #[test]
786 fn token_type_bearer_case_insensitive() {
787 let json = r#"{
789 "access_token": "at",
790 "refresh_token": "rt",
791 "token_type": "Bearer"
792 }"#;
793 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
794 assert_eq!(tokens.token_type, "bearer");
795 }
796
797 #[test]
798 fn token_type_bearer_all_caps() {
799 let json = r#"{
800 "access_token": "at",
801 "refresh_token": "rt",
802 "token_type": "BEARER"
803 }"#;
804 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
805 assert_eq!(tokens.token_type, "bearer");
806 }
807
808 #[test]
809 fn token_type_invalid_rejects() {
810 let json = r#"{
811 "access_token": "at",
812 "refresh_token": "rt",
813 "token_type": "mac"
814 }"#;
815 let result: Result<OAuthTokenData, _> = serde_json::from_str(json);
816 let err = result.expect_err("should reject non-bearer token type");
817 let msg = err.to_string();
818 assert!(
819 msg.contains("invalid token_type"),
820 "error should mention invalid token_type: {msg}"
821 );
822 }
823
824 #[test]
825 fn scopes_deserialize_when_present() {
826 let json = r#"{
827 "access_token": "at",
828 "refresh_token": "rt",
829 "token_type": "bearer",
830 "scopes": ["OAuth2Read", "OAuth2Write"]
831 }"#;
832 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
833 let scopes = tokens.scopes.expect("should have scopes");
834 assert_eq!(scopes, vec!["OAuth2Read", "OAuth2Write"]);
835 }
836
837 #[test]
838 fn scopes_default_to_none_when_absent() {
839 let json = r#"{
840 "access_token": "at",
841 "refresh_token": "rt",
842 "token_type": "bearer"
843 }"#;
844 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
845 assert!(tokens.scopes.is_none());
846 }
847
848 #[test]
849 fn scopes_serialize_when_present() {
850 let tokens = OAuthTokenData {
851 access_token: AccessToken::new("at".to_string()),
852 refresh_token: RefreshToken::new("rt".to_string()),
853 expires_at: None,
854 token_type: "bearer".into(),
855 scopes: Some(vec!["OAuth2Read".into(), "OAuth2Write".into()]),
856 client_id: None,
857 client_secret: None,
858 proxy_url: None,
859 };
860 let json = serde_json::to_string(&tokens).expect("should serialize");
861 let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
862 let scopes = value["scopes"].as_array().expect("scopes should be array");
863 assert_eq!(scopes.len(), 2);
864 assert_eq!(scopes[0], "OAuth2Read");
865 assert_eq!(scopes[1], "OAuth2Write");
866 }
867
868 #[test]
869 fn scopes_omitted_from_json_when_none() {
870 let tokens = OAuthTokenData {
871 access_token: AccessToken::new("at".to_string()),
872 refresh_token: RefreshToken::new("rt".to_string()),
873 expires_at: None,
874 token_type: "bearer".into(),
875 scopes: None,
876 client_id: None,
877 client_secret: None,
878 proxy_url: None,
879 };
880 let json = serde_json::to_string(&tokens).expect("should serialize");
881 let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
882 assert!(
883 value.get("scopes").is_none(),
884 "scopes should be omitted from JSON when None"
885 );
886 }
887
888 #[test]
889 fn default_token_file_path_returns_some() {
890 let path = default_token_file_path();
893 if let Some(ref p) = path {
894 assert!(p.ends_with("onshape-mcp/tokens.json"));
895 }
896 }
898
899 #[test]
900 fn onshape_auth_url_is_valid() {
901 let url = onshape_auth_url();
902 let url_str = url.url().as_str();
903 assert!(url_str.starts_with("https://"));
904 assert!(url_str.contains("oauth.onshape.com"));
905 }
906
907 #[test]
908 fn onshape_token_url_is_valid() {
909 let url = onshape_token_url();
910 let url_str = url.url().as_str();
911 assert!(url_str.starts_with("https://"));
912 assert!(url_str.contains("oauth.onshape.com"));
913 }
914
915 #[test]
916 fn onshape_oauth_client_builds_successfully() {
917 let _client = onshape_oauth_client("test-client-id", "test-client-secret");
918 }
919
920 #[test]
921 fn from_response_with_expiry() {
922 let json = r#"{
923 "access_token": "test-access-token",
924 "token_type": "Bearer",
925 "expires_in": 3600,
926 "refresh_token": "test-refresh-token"
927 }"#;
928 let response: BasicTokenResponse =
929 serde_json::from_str(json).expect("should deserialize token response");
930 let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
931 .expect("parse")
932 .to_utc();
933
934 let token_data = OAuthTokenData::from_response(&response, now);
935
936 assert_eq!(token_data.access_token.secret(), "test-access-token");
937 assert_eq!(token_data.refresh_token.secret(), "test-refresh-token");
938
939 let expires_at = token_data.expires_at.expect("should have expiry");
940 assert_eq!(
941 expires_at,
942 now + chrono::Duration::seconds(3600),
943 "expires_at should be exactly now + 3600s"
944 );
945 }
946
947 #[test]
948 fn from_response_without_expiry() {
949 let json = r#"{
950 "access_token": "test-access-token",
951 "token_type": "Bearer"
952 }"#;
953 let response: BasicTokenResponse =
954 serde_json::from_str(json).expect("should deserialize token response");
955 let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
956 .expect("parse")
957 .to_utc();
958
959 let token_data = OAuthTokenData::from_response(&response, now);
960
961 assert_eq!(token_data.access_token.secret(), "test-access-token");
962 assert!(token_data.expires_at.is_none());
963 assert!(token_data.refresh_token.secret().is_empty());
965 assert!(token_data.scopes.is_none());
967 }
968
969 #[test]
970 fn from_response_preserves_scopes() {
971 let json = r#"{
972 "access_token": "test-at",
973 "token_type": "Bearer",
974 "refresh_token": "test-rt",
975 "scope": "OAuth2Read OAuth2Write"
976 }"#;
977 let response: BasicTokenResponse =
978 serde_json::from_str(json).expect("should deserialize token response");
979 let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
980 .expect("parse")
981 .to_utc();
982
983 let token_data = OAuthTokenData::from_response(&response, now);
984
985 let scopes = token_data.scopes.expect("should have scopes");
986 assert_eq!(scopes, vec!["OAuth2Read", "OAuth2Write"]);
987 }
988
989 #[test]
990 fn token_data_json_shape_backward_compatible() {
991 let tokens = OAuthTokenData {
994 access_token: AccessToken::new("my-access".to_string()),
995 refresh_token: RefreshToken::new("my-refresh".to_string()),
996 expires_at: None,
997 token_type: "bearer".into(),
998 scopes: None,
999 client_id: None,
1000 client_secret: None,
1001 proxy_url: None,
1002 };
1003 let json = serde_json::to_string_pretty(&tokens).expect("should serialize");
1004 let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
1005
1006 assert!(value["access_token"].is_string());
1008 assert!(value["refresh_token"].is_string());
1009 assert!(value["token_type"].is_string());
1010 assert_eq!(value["access_token"], "my-access");
1011 assert_eq!(value["refresh_token"], "my-refresh");
1012 }
1013
1014 #[test]
1019 fn is_expiring_soon_false_when_well_before_margin() {
1020 let tokens = OAuthTokenData {
1021 access_token: AccessToken::new("at".into()),
1022 refresh_token: RefreshToken::new("rt".into()),
1023 expires_at: Some(
1024 DateTime::parse_from_rfc3339("2025-01-01T00:02:00Z")
1025 .expect("parse")
1026 .to_utc(),
1027 ),
1028 token_type: "bearer".into(),
1029 scopes: None,
1030 client_id: None,
1031 client_secret: None,
1032 proxy_url: None,
1033 };
1034 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1035 .expect("parse")
1036 .to_utc();
1037 assert!(!tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
1039 }
1040
1041 #[test]
1042 fn is_expiring_soon_true_when_within_margin() {
1043 let tokens = OAuthTokenData {
1044 access_token: AccessToken::new("at".into()),
1045 refresh_token: RefreshToken::new("rt".into()),
1046 expires_at: Some(
1047 DateTime::parse_from_rfc3339("2025-01-01T00:00:55Z")
1048 .expect("parse")
1049 .to_utc(),
1050 ),
1051 token_type: "bearer".into(),
1052 scopes: None,
1053 client_id: None,
1054 client_secret: None,
1055 proxy_url: None,
1056 };
1057 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1058 .expect("parse")
1059 .to_utc();
1060 assert!(tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
1062 }
1063
1064 #[test]
1065 fn is_expiring_soon_true_when_already_expired() {
1066 let tokens = OAuthTokenData {
1067 access_token: AccessToken::new("at".into()),
1068 refresh_token: RefreshToken::new("rt".into()),
1069 expires_at: Some(
1070 DateTime::parse_from_rfc3339("2024-12-31T23:59:00Z")
1071 .expect("parse")
1072 .to_utc(),
1073 ),
1074 token_type: "bearer".into(),
1075 scopes: None,
1076 client_id: None,
1077 client_secret: None,
1078 proxy_url: None,
1079 };
1080 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1081 .expect("parse")
1082 .to_utc();
1083 assert!(tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
1084 }
1085
1086 #[test]
1087 fn is_expiring_soon_false_when_no_expiry() {
1088 let tokens = OAuthTokenData {
1089 access_token: AccessToken::new("at".into()),
1090 refresh_token: RefreshToken::new("rt".into()),
1091 expires_at: None,
1092 token_type: "bearer".into(),
1093 scopes: None,
1094 client_id: None,
1095 client_secret: None,
1096 proxy_url: None,
1097 };
1098 let now = Utc::now();
1099 assert!(!tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
1100 }
1101
1102 #[test]
1103 fn is_expiring_soon_true_at_exact_margin_boundary() {
1104 let tokens = OAuthTokenData {
1105 access_token: AccessToken::new("at".into()),
1106 refresh_token: RefreshToken::new("rt".into()),
1107 expires_at: Some(
1108 DateTime::parse_from_rfc3339("2025-01-01T00:01:00Z")
1109 .expect("parse")
1110 .to_utc(),
1111 ),
1112 token_type: "bearer".into(),
1113 scopes: None,
1114 client_id: None,
1115 client_secret: None,
1116 proxy_url: None,
1117 };
1118 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1119 .expect("parse")
1120 .to_utc();
1121 assert!(tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
1123 }
1124
1125 fn make_session(expires_at: Option<DateTime<Utc>>) -> OAuthSession {
1130 OAuthSession::new(
1131 OAuthTokenData {
1132 access_token: AccessToken::new("at".into()),
1133 refresh_token: RefreshToken::new("rt".into()),
1134 expires_at,
1135 token_type: "bearer".into(),
1136 scopes: None,
1137 client_id: None,
1138 client_secret: None,
1139 proxy_url: None,
1140 },
1141 chrono::Duration::seconds(60),
1142 )
1143 }
1144
1145 #[test]
1146 fn pre_execute_proceed_when_well_before_expiry() {
1147 let session = make_session(Some(
1148 DateTime::parse_from_rfc3339("2025-01-01T00:02:00Z")
1149 .expect("parse")
1150 .to_utc(),
1151 ));
1152 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1153 .expect("parse")
1154 .to_utc();
1155 assert_eq!(session.pre_execute_action(now), PreExecuteAction::Proceed);
1156 }
1157
1158 #[test]
1159 fn pre_execute_refresh_when_within_margin() {
1160 let session = make_session(Some(
1161 DateTime::parse_from_rfc3339("2025-01-01T00:00:55Z")
1162 .expect("parse")
1163 .to_utc(),
1164 ));
1165 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1166 .expect("parse")
1167 .to_utc();
1168 assert_eq!(
1169 session.pre_execute_action(now),
1170 PreExecuteAction::RefreshNeeded
1171 );
1172 }
1173
1174 #[test]
1175 fn pre_execute_refresh_when_already_expired() {
1176 let session = make_session(Some(
1177 DateTime::parse_from_rfc3339("2024-12-31T23:00:00Z")
1178 .expect("parse")
1179 .to_utc(),
1180 ));
1181 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1182 .expect("parse")
1183 .to_utc();
1184 assert_eq!(
1185 session.pre_execute_action(now),
1186 PreExecuteAction::RefreshNeeded
1187 );
1188 }
1189
1190 #[test]
1191 fn pre_execute_proceed_when_no_expiry() {
1192 let session = make_session(None);
1193 let now = Utc::now();
1194 assert_eq!(session.pre_execute_action(now), PreExecuteAction::Proceed);
1195 }
1196
1197 #[test]
1202 fn post_execute_done_on_200() {
1203 let session = make_session(None);
1204 assert_eq!(
1205 session.post_execute_action(200, false),
1206 PostExecuteAction::Done
1207 );
1208 }
1209
1210 #[test]
1211 fn post_execute_refresh_and_retry_on_401_not_refreshed() {
1212 let session = make_session(None);
1213 assert_eq!(
1214 session.post_execute_action(401, false),
1215 PostExecuteAction::RefreshAndRetry
1216 );
1217 }
1218
1219 #[test]
1220 fn post_execute_done_on_401_already_refreshed() {
1221 let session = make_session(None);
1222 assert_eq!(
1223 session.post_execute_action(401, true),
1224 PostExecuteAction::Done
1225 );
1226 }
1227
1228 #[test]
1229 fn post_execute_done_on_403() {
1230 let session = make_session(None);
1231 assert_eq!(
1232 session.post_execute_action(403, false),
1233 PostExecuteAction::Done
1234 );
1235 }
1236
1237 #[test]
1242 fn apply_external_tokens_adopts_fresher_tokens() {
1243 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1244 .expect("parse")
1245 .to_utc();
1246 let mut session = OAuthSession::new(
1247 OAuthTokenData {
1248 access_token: AccessToken::new("old-at".into()),
1249 refresh_token: RefreshToken::new("old-rt".into()),
1250 expires_at: Some(now + chrono::Duration::seconds(100)),
1251 token_type: "bearer".into(),
1252 scopes: None,
1253 client_id: None,
1254 client_secret: None,
1255 proxy_url: None,
1256 },
1257 chrono::Duration::seconds(60),
1258 );
1259 let file_tokens = OAuthTokenData {
1260 access_token: AccessToken::new("new-at".into()),
1261 refresh_token: RefreshToken::new("new-rt".into()),
1262 expires_at: Some(now + chrono::Duration::seconds(3600)),
1263 token_type: "bearer".into(),
1264 scopes: None,
1265 client_id: None,
1266 client_secret: None,
1267 proxy_url: None,
1268 };
1269 assert!(session.apply_external_tokens(file_tokens, now));
1270 assert_eq!(session.access_token().secret(), "new-at");
1271 assert_eq!(session.refresh_token().secret(), "new-rt");
1272 }
1273
1274 #[test]
1275 fn apply_external_tokens_rejects_same_or_earlier_expiry() {
1276 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1277 .expect("parse")
1278 .to_utc();
1279 let mut session = OAuthSession::new(
1280 OAuthTokenData {
1281 access_token: AccessToken::new("current-at".into()),
1282 refresh_token: RefreshToken::new("current-rt".into()),
1283 expires_at: Some(now + chrono::Duration::seconds(3600)),
1284 token_type: "bearer".into(),
1285 scopes: None,
1286 client_id: None,
1287 client_secret: None,
1288 proxy_url: None,
1289 },
1290 chrono::Duration::seconds(60),
1291 );
1292 let file_tokens = OAuthTokenData {
1293 access_token: AccessToken::new("file-at".into()),
1294 refresh_token: RefreshToken::new("file-rt".into()),
1295 expires_at: Some(now + chrono::Duration::seconds(3600)),
1296 token_type: "bearer".into(),
1297 scopes: None,
1298 client_id: None,
1299 client_secret: None,
1300 proxy_url: None,
1301 };
1302 assert!(!session.apply_external_tokens(file_tokens, now));
1303 assert_eq!(session.access_token().secret(), "current-at");
1304 }
1305
1306 #[test]
1307 fn apply_external_tokens_rejects_expired_file_tokens() {
1308 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1309 .expect("parse")
1310 .to_utc();
1311 let mut session = OAuthSession::new(
1312 OAuthTokenData {
1313 access_token: AccessToken::new("current-at".into()),
1314 refresh_token: RefreshToken::new("current-rt".into()),
1315 expires_at: Some(now - chrono::Duration::seconds(100)),
1316 token_type: "bearer".into(),
1317 scopes: None,
1318 client_id: None,
1319 client_secret: None,
1320 proxy_url: None,
1321 },
1322 chrono::Duration::seconds(60),
1323 );
1324 let file_tokens = OAuthTokenData {
1325 access_token: AccessToken::new("file-at".into()),
1326 refresh_token: RefreshToken::new("file-rt".into()),
1327 expires_at: Some(now - chrono::Duration::seconds(50)),
1328 token_type: "bearer".into(),
1329 scopes: None,
1330 client_id: None,
1331 client_secret: None,
1332 proxy_url: None,
1333 };
1334 assert!(!session.apply_external_tokens(file_tokens, now));
1335 assert_eq!(session.access_token().secret(), "current-at");
1336 }
1337
1338 #[test]
1339 fn apply_external_tokens_rejects_when_both_none_expiry() {
1340 let now = Utc::now();
1341 let mut session = OAuthSession::new(
1342 OAuthTokenData {
1343 access_token: AccessToken::new("current-at".into()),
1344 refresh_token: RefreshToken::new("current-rt".into()),
1345 expires_at: None,
1346 token_type: "bearer".into(),
1347 scopes: None,
1348 client_id: None,
1349 client_secret: None,
1350 proxy_url: None,
1351 },
1352 chrono::Duration::seconds(60),
1353 );
1354 let file_tokens = OAuthTokenData {
1355 access_token: AccessToken::new("file-at".into()),
1356 refresh_token: RefreshToken::new("file-rt".into()),
1357 expires_at: None,
1358 token_type: "bearer".into(),
1359 scopes: None,
1360 client_id: None,
1361 client_secret: None,
1362 proxy_url: None,
1363 };
1364 assert!(!session.apply_external_tokens(file_tokens, now));
1365 assert_eq!(session.access_token().secret(), "current-at");
1366 }
1367
1368 #[test]
1369 fn apply_external_tokens_rejects_when_file_has_none_expiry() {
1370 let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
1371 .expect("parse")
1372 .to_utc();
1373 let mut session = OAuthSession::new(
1374 OAuthTokenData {
1375 access_token: AccessToken::new("current-at".into()),
1376 refresh_token: RefreshToken::new("current-rt".into()),
1377 expires_at: Some(now + chrono::Duration::seconds(100)),
1378 token_type: "bearer".into(),
1379 scopes: None,
1380 client_id: None,
1381 client_secret: None,
1382 proxy_url: None,
1383 },
1384 chrono::Duration::seconds(60),
1385 );
1386 let file_tokens = OAuthTokenData {
1387 access_token: AccessToken::new("file-at".into()),
1388 refresh_token: RefreshToken::new("file-rt".into()),
1389 expires_at: None,
1390 token_type: "bearer".into(),
1391 scopes: None,
1392 client_id: None,
1393 client_secret: None,
1394 proxy_url: None,
1395 };
1396 assert!(!session.apply_external_tokens(file_tokens, now));
1397 assert_eq!(session.access_token().secret(), "current-at");
1398 }
1399
1400 #[test]
1405 fn apply_refresh_updates_tokens_with_expiry() {
1406 let mut session = make_session(None);
1407 let json = r#"{
1408 "access_token": "new-access-token",
1409 "token_type": "bearer",
1410 "expires_in": 3600,
1411 "refresh_token": "new-refresh-token"
1412 }"#;
1413 let response: BasicTokenResponse = serde_json::from_str(json).expect("should deserialize");
1414 let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
1415 .expect("parse")
1416 .to_utc();
1417 session.apply_refresh(&response, now);
1418
1419 assert_eq!(session.access_token().secret(), "new-access-token");
1420 assert_eq!(session.refresh_token().secret(), "new-refresh-token");
1421
1422 let expires_at = session.tokens.expires_at.expect("should have expiry");
1423 assert_eq!(
1424 expires_at,
1425 now + chrono::Duration::seconds(3600),
1426 "expires_at should be exactly now + 3600s"
1427 );
1428 }
1429
1430 #[test]
1431 fn apply_refresh_updates_tokens_without_expiry() {
1432 let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
1433 .expect("parse")
1434 .to_utc();
1435 let mut session = make_session(Some(now));
1436 let json = r#"{
1437 "access_token": "new-at",
1438 "token_type": "bearer"
1439 }"#;
1440 let response: BasicTokenResponse = serde_json::from_str(json).expect("should deserialize");
1441 session.apply_refresh(&response, now);
1442
1443 assert_eq!(session.access_token().secret(), "new-at");
1444 assert!(session.tokens.expires_at.is_none());
1445 }
1446
1447 #[test]
1452 fn token_data_roundtrips_with_proxy_url() {
1453 let tokens = OAuthTokenData {
1454 access_token: AccessToken::new("at".to_string()),
1455 refresh_token: RefreshToken::new("rt".to_string()),
1456 expires_at: None,
1457 token_type: "bearer".into(),
1458 scopes: None,
1459 client_id: Some("cid".into()),
1460 client_secret: None,
1461 proxy_url: Some("https://proxy.example.com".into()),
1462 };
1463 let json = serde_json::to_string(&tokens).expect("should serialize");
1464 let roundtripped: OAuthTokenData = serde_json::from_str(&json).expect("should deserialize");
1465 assert_eq!(
1466 roundtripped.proxy_url.as_deref(),
1467 Some("https://proxy.example.com")
1468 );
1469 assert_eq!(roundtripped.client_id.as_deref(), Some("cid"));
1470 assert!(roundtripped.client_secret.is_none());
1471 }
1472
1473 #[test]
1474 fn token_data_backward_compat_without_proxy_url() {
1475 let json = r#"{
1477 "access_token": "at",
1478 "refresh_token": "rt",
1479 "token_type": "bearer",
1480 "client_id": "cid",
1481 "client_secret": "cs"
1482 }"#;
1483 let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
1484 assert!(tokens.proxy_url.is_none());
1485 assert_eq!(tokens.client_id.as_deref(), Some("cid"));
1486 assert_eq!(tokens.client_secret.as_deref(), Some("cs"));
1487 }
1488
1489 #[test]
1490 fn token_data_proxy_url_omitted_from_json_when_none() {
1491 let tokens = OAuthTokenData {
1492 access_token: AccessToken::new("at".to_string()),
1493 refresh_token: RefreshToken::new("rt".to_string()),
1494 expires_at: None,
1495 token_type: "bearer".into(),
1496 scopes: None,
1497 client_id: None,
1498 client_secret: None,
1499 proxy_url: None,
1500 };
1501 let json = serde_json::to_string(&tokens).expect("should serialize");
1502 let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
1503 assert!(value.get("proxy_url").is_none());
1504 }
1505
1506 #[test]
1507 fn from_raw_creates_token_data() {
1508 let tokens = OAuthTokenData::from_raw(
1509 "access-token".into(),
1510 "refresh-token".into(),
1511 None,
1512 "bearer".into(),
1513 Some(vec!["OAuth2Read".into(), "OAuth2Write".into()]),
1514 );
1515 assert_eq!(tokens.access_token.secret(), "access-token");
1516 assert_eq!(tokens.refresh_token.secret(), "refresh-token");
1517 assert!(tokens.expires_at.is_none());
1518 assert_eq!(tokens.token_type, "bearer");
1519 assert_eq!(
1520 tokens.scopes,
1521 Some(vec!["OAuth2Read".into(), "OAuth2Write".into()])
1522 );
1523 assert!(tokens.client_id.is_none());
1524 assert!(tokens.client_secret.is_none());
1525 assert!(tokens.proxy_url.is_none());
1526 }
1527
1528 fn test_login_config() -> OAuthLoginConfig {
1533 OAuthLoginConfig {
1534 client_id: "test-client-id".into(),
1535 redirect_uri: "http://127.0.0.1:18338/callback".into(),
1536 scopes: vec!["OAuth2Read".into(), "OAuth2Write".into()],
1537 }
1538 }
1539
1540 #[test]
1541 fn build_authorize_url_contains_required_params() {
1542 let config = test_login_config();
1543 let state = CsrfToken::new("test-state-token".into());
1544 let (challenge, _verifier) = PkceCodeChallenge::new_random_sha256();
1545
1546 let url_str = build_authorize_url(&config, &state, challenge);
1547 let url = url::Url::parse(&url_str).expect("should be a valid URL");
1548
1549 assert_eq!(url.scheme(), "https");
1550 assert_eq!(url.host_str(), Some("oauth.onshape.com"));
1551 assert_eq!(url.path(), "/oauth/authorize");
1552
1553 let params: std::collections::HashMap<String, String> = url
1554 .query_pairs()
1555 .map(|(k, v)| (k.into_owned(), v.into_owned()))
1556 .collect();
1557
1558 assert_eq!(
1559 params.get("client_id").map(String::as_str),
1560 Some("test-client-id")
1561 );
1562 assert_eq!(
1563 params.get("redirect_uri").map(String::as_str),
1564 Some("http://127.0.0.1:18338/callback")
1565 );
1566 assert_eq!(
1567 params.get("response_type").map(String::as_str),
1568 Some("code")
1569 );
1570 assert_eq!(
1571 params.get("state").map(String::as_str),
1572 Some("test-state-token")
1573 );
1574 assert!(params.contains_key("code_challenge"));
1575 assert_eq!(
1576 params.get("code_challenge_method").map(String::as_str),
1577 Some("S256")
1578 );
1579 let scope = params.get("scope").expect("should have scope parameter");
1581 assert!(scope.contains("OAuth2Read"));
1582 assert!(scope.contains("OAuth2Write"));
1583 }
1584
1585 #[test]
1586 fn build_authorize_url_with_no_scopes() {
1587 let config = OAuthLoginConfig {
1588 client_id: "cid".into(),
1589 redirect_uri: "http://127.0.0.1:18338/callback".into(),
1590 scopes: vec![],
1591 };
1592 let state = CsrfToken::new("state".into());
1593 let (challenge, _verifier) = PkceCodeChallenge::new_random_sha256();
1594
1595 let url_str = build_authorize_url(&config, &state, challenge);
1596 let url = url::Url::parse(&url_str).expect("should be a valid URL");
1597 let params: std::collections::HashMap<String, String> = url
1598 .query_pairs()
1599 .map(|(k, v)| (k.into_owned(), v.into_owned()))
1600 .collect();
1601
1602 assert!(!params.contains_key("scope"));
1604 }
1605
1606 #[test]
1607 fn validate_callback_extracts_code() {
1608 let state = CsrfToken::new("my-state".into());
1609 let callback = "http://127.0.0.1:18338/callback?code=auth-code-123&state=my-state";
1610
1611 let code = validate_callback(callback, &state).expect("should validate");
1612 assert_eq!(code.secret(), "auth-code-123");
1613 }
1614
1615 #[test]
1616 fn validate_callback_detects_state_mismatch() {
1617 let state = CsrfToken::new("expected-state".into());
1618 let callback = "http://127.0.0.1:18338/callback?code=abc&state=wrong-state";
1619
1620 let err = validate_callback(callback, &state).expect_err("should fail");
1621 assert!(
1622 matches!(err, CallbackValidationError::StateMismatch { .. }),
1623 "expected StateMismatch, got: {err:?}"
1624 );
1625 }
1626
1627 #[test]
1628 fn validate_callback_detects_oauth_error() {
1629 let state = CsrfToken::new("my-state".into());
1630 let callback = "http://127.0.0.1:18338/callback?error=access_denied&error_description=User+denied+access&state=my-state";
1631
1632 let err = validate_callback(callback, &state).expect_err("should fail");
1633 match err {
1634 CallbackValidationError::OAuthError { error, description } => {
1635 assert_eq!(error, "access_denied");
1636 assert_eq!(description.as_deref(), Some("User denied access"));
1637 }
1638 other => panic!("expected OAuthError, got: {other:?}"),
1639 }
1640 }
1641
1642 #[test]
1643 fn validate_callback_detects_missing_state() {
1644 let state = CsrfToken::new("my-state".into());
1645 let callback = "http://127.0.0.1:18338/callback?code=abc";
1646
1647 let err = validate_callback(callback, &state).expect_err("should fail");
1648 assert!(
1649 matches!(err, CallbackValidationError::MissingState),
1650 "expected MissingState, got: {err:?}"
1651 );
1652 }
1653
1654 #[test]
1655 fn validate_callback_detects_missing_code() {
1656 let state = CsrfToken::new("my-state".into());
1657 let callback = "http://127.0.0.1:18338/callback?state=my-state";
1658
1659 let err = validate_callback(callback, &state).expect_err("should fail");
1660 assert!(
1661 matches!(err, CallbackValidationError::MissingCode),
1662 "expected MissingCode, got: {err:?}"
1663 );
1664 }
1665
1666 #[test]
1667 fn validate_callback_detects_invalid_url() {
1668 let state = CsrfToken::new("my-state".into());
1669 let err = validate_callback("not a url at all ://", &state).expect_err("should fail");
1670 assert!(
1671 matches!(err, CallbackValidationError::InvalidUrl(_)),
1672 "expected InvalidUrl, got: {err:?}"
1673 );
1674 }
1675}