use std::path::PathBuf;
use chrono::{DateTime, Utc};
use oauth2::basic::{BasicClient, BasicTokenResponse};
use oauth2::{
AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
PkceCodeVerifier, RedirectUrl, RefreshToken, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
const ONSHAPE_AUTH_URL_STR: &str = "https://oauth.onshape.com/oauth/authorize";
const ONSHAPE_TOKEN_URL_STR: &str = "https://oauth.onshape.com/oauth/token";
#[must_use]
pub fn onshape_auth_url() -> AuthUrl {
#[allow(clippy::expect_used)]
AuthUrl::new(ONSHAPE_AUTH_URL_STR.to_string()).expect("hard-coded Onshape auth URL is valid")
}
#[must_use]
pub fn onshape_token_url() -> TokenUrl {
#[allow(clippy::expect_used)]
TokenUrl::new(ONSHAPE_TOKEN_URL_STR.to_string()).expect("hard-coded Onshape token URL is valid")
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OAuthTokenData {
#[serde(
serialize_with = "serialize_access_token",
deserialize_with = "deserialize_access_token"
)]
pub access_token: AccessToken,
#[serde(
serialize_with = "serialize_refresh_token",
deserialize_with = "deserialize_refresh_token"
)]
pub refresh_token: RefreshToken,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expires_at: Option<DateTime<Utc>>,
#[serde(
default = "default_token_type",
deserialize_with = "deserialize_token_type"
)]
pub token_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub scopes: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub proxy_url: Option<String>,
}
impl OAuthTokenData {
#[must_use]
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
self.expires_at.is_some_and(|expires| expires <= now)
}
#[must_use]
pub fn is_expiring_soon(&self, now: DateTime<Utc>, margin: chrono::Duration) -> bool {
self.expires_at
.is_some_and(|expires| expires <= now + margin)
}
}
impl OAuthTokenData {
#[must_use]
pub fn from_response(response: &BasicTokenResponse, now: DateTime<Utc>) -> Self {
let expires_at = response
.expires_in()
.and_then(|d| chrono::Duration::from_std(d).ok())
.map(|d| now + d);
let scopes = response
.scopes()
.map(|scopes| scopes.iter().map(|s| s.as_ref().to_owned()).collect());
Self {
access_token: response.access_token().clone(),
refresh_token: response
.refresh_token()
.cloned()
.unwrap_or_else(|| RefreshToken::new(String::new())),
expires_at,
token_type: response.token_type().as_ref().to_string(),
scopes,
client_id: None,
client_secret: None,
proxy_url: None,
}
}
}
impl OAuthTokenData {
#[must_use]
pub fn from_raw(
access_token: String,
refresh_token: String,
expires_at: Option<DateTime<Utc>>,
token_type: String,
scopes: Option<Vec<String>>,
) -> Self {
Self {
access_token: AccessToken::new(access_token),
refresh_token: RefreshToken::new(refresh_token),
expires_at,
token_type,
scopes,
client_id: None,
client_secret: None,
proxy_url: None,
}
}
}
fn default_token_type() -> String {
"bearer".into()
}
fn deserialize_token_type<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if s.eq_ignore_ascii_case("bearer") {
Ok("bearer".to_string())
} else {
Err(serde::de::Error::custom(format!(
"invalid token_type \"{s}\", expected \"bearer\""
)))
}
}
fn serialize_access_token<S>(token: &AccessToken, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(token.secret())
}
fn deserialize_access_token<'de, D>(deserializer: D) -> Result<AccessToken, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(AccessToken::new(s))
}
fn serialize_refresh_token<S>(token: &RefreshToken, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(token.secret())
}
fn deserialize_refresh_token<'de, D>(deserializer: D) -> Result<RefreshToken, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(RefreshToken::new(s))
}
pub type OnshapeOAuthClient = BasicClient<
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>;
#[must_use]
pub fn onshape_oauth_client(client_id: &str, client_secret: &str) -> OnshapeOAuthClient {
BasicClient::new(ClientId::new(client_id.to_string()))
.set_client_secret(ClientSecret::new(client_secret.to_string()))
.set_auth_uri(onshape_auth_url())
.set_token_uri(onshape_token_url())
}
#[must_use]
pub fn default_data_dir() -> Option<PathBuf> {
dirs::data_dir().map(|dir| dir.join("onshape-mcp"))
}
#[must_use]
pub fn default_token_file_path() -> Option<PathBuf> {
default_data_dir().map(|dir| dir.join("tokens.json"))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PreExecuteAction {
Proceed,
RefreshNeeded,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostExecuteAction {
Done,
RefreshAndRetry,
}
pub struct OAuthSession {
pub tokens: OAuthTokenData,
refresh_margin: chrono::Duration,
}
impl OAuthSession {
#[must_use]
pub const fn new(tokens: OAuthTokenData, refresh_margin: chrono::Duration) -> Self {
Self {
tokens,
refresh_margin,
}
}
#[must_use]
pub fn pre_execute_action(&self, now: DateTime<Utc>) -> PreExecuteAction {
if self.tokens.is_expiring_soon(now, self.refresh_margin) {
PreExecuteAction::RefreshNeeded
} else {
PreExecuteAction::Proceed
}
}
#[must_use]
pub const fn post_execute_action(
&self,
status: u16,
already_refreshed: bool,
) -> PostExecuteAction {
if status == 401 && !already_refreshed {
PostExecuteAction::RefreshAndRetry
} else {
PostExecuteAction::Done
}
}
pub fn apply_refresh(&mut self, response: &BasicTokenResponse, now: DateTime<Utc>) {
let mut new_tokens = OAuthTokenData::from_response(response, now);
if response.refresh_token().is_none() {
new_tokens.refresh_token = self.tokens.refresh_token.clone();
}
new_tokens.client_id.clone_from(&self.tokens.client_id);
new_tokens
.client_secret
.clone_from(&self.tokens.client_secret);
self.tokens = new_tokens;
}
pub fn apply_external_tokens(
&mut self,
file_tokens: OAuthTokenData,
now: DateTime<Utc>,
) -> bool {
let (Some(file_expires), Some(current_expires)) =
(file_tokens.expires_at, self.tokens.expires_at)
else {
return false;
};
if file_expires > current_expires && file_expires > now {
self.tokens = file_tokens;
true
} else {
false
}
}
#[must_use]
pub const fn access_token(&self) -> &AccessToken {
&self.tokens.access_token
}
#[must_use]
pub const fn refresh_token(&self) -> &RefreshToken {
&self.tokens.refresh_token
}
}
#[derive(Clone, Debug)]
pub struct OAuthLoginConfig {
pub client_id: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
}
pub struct OAuthLoginSession {
pub pkce_verifier: PkceCodeVerifier,
pub csrf_state: CsrfToken,
pub config: OAuthLoginConfig,
}
#[derive(Debug, thiserror::Error)]
pub enum CallbackValidationError {
#[error("invalid callback URL: {0}")]
InvalidUrl(String),
#[error("OAuth error from provider: {error} (description: {description:?})")]
OAuthError {
error: String,
description: Option<String>,
},
#[error("CSRF state mismatch: expected {expected}, got {actual}")]
StateMismatch {
expected: String,
actual: String,
},
#[error("callback is missing the 'state' parameter")]
MissingState,
#[error("callback is missing the 'code' parameter")]
MissingCode,
}
#[must_use]
pub fn build_authorize_url(
config: &OAuthLoginConfig,
csrf_state: &CsrfToken,
pkce_challenge: PkceCodeChallenge,
) -> String {
let client = BasicClient::new(ClientId::new(config.client_id.clone()))
.set_auth_uri(onshape_auth_url())
.set_token_uri(onshape_token_url())
.set_redirect_uri(
#[allow(clippy::expect_used)]
RedirectUrl::new(config.redirect_uri.clone())
.expect("redirect_uri should be a valid URL"),
);
let mut auth_request = client
.authorize_url(|| csrf_state.clone())
.set_pkce_challenge(pkce_challenge);
for scope in &config.scopes {
auth_request = auth_request.add_scope(oauth2::Scope::new(scope.clone()));
}
let (url, _csrf_token) = auth_request.url();
url.to_string()
}
pub fn validate_callback(
callback_url: &str,
expected_state: &CsrfToken,
) -> Result<AuthorizationCode, CallbackValidationError> {
let url = url::Url::parse(callback_url)
.map_err(|e| CallbackValidationError::InvalidUrl(e.to_string()))?;
let params: std::collections::HashMap<String, String> = url
.query_pairs()
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect();
if let Some(error) = params.get("error") {
return Err(CallbackValidationError::OAuthError {
error: error.clone(),
description: params.get("error_description").cloned(),
});
}
let state = params
.get("state")
.ok_or(CallbackValidationError::MissingState)?;
if state != expected_state.secret() {
return Err(CallbackValidationError::StateMismatch {
expected: expected_state.secret().clone(),
actual: state.clone(),
});
}
let code = params
.get("code")
.ok_or(CallbackValidationError::MissingCode)?;
Ok(AuthorizationCode::new(code.clone()))
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn token_data_serializes_to_json() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("access-123".to_string()),
refresh_token: RefreshToken::new("refresh-456".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let json = serde_json::to_string(&tokens).expect("should serialize");
let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
assert_eq!(value["access_token"], "access-123");
assert_eq!(value["refresh_token"], "refresh-456");
assert_eq!(value["token_type"], "bearer");
assert!(value.get("expires_at").is_none());
assert!(value.get("scopes").is_none());
}
#[test]
fn token_data_deserializes_from_json() {
let json = r#"{
"access_token": "access-789",
"refresh_token": "refresh-012",
"token_type": "bearer"
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
assert_eq!(tokens.access_token.secret(), "access-789");
assert_eq!(tokens.refresh_token.secret(), "refresh-012");
assert_eq!(tokens.token_type, "bearer");
assert!(tokens.expires_at.is_none());
assert!(tokens.scopes.is_none());
}
#[test]
fn token_data_roundtrips_with_expiry() {
let expires = DateTime::parse_from_rfc3339("2025-06-15T12:00:00Z")
.expect("should parse")
.to_utc();
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: Some(expires),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let json = serde_json::to_string(&tokens).expect("should serialize");
let roundtripped: OAuthTokenData = serde_json::from_str(&json).expect("should deserialize");
assert_eq!(roundtripped.expires_at, Some(expires));
}
#[test]
fn is_expired_returns_true_when_past() {
let expires = DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z")
.expect("should parse")
.to_utc();
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: Some(expires),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("should parse")
.to_utc();
assert!(tokens.is_expired(now));
}
#[test]
fn is_expired_returns_true_when_exactly_at_expiry() {
let expires = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("should parse")
.to_utc();
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: Some(expires),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
assert!(tokens.is_expired(expires));
}
#[test]
fn is_expired_returns_false_when_future() {
let expires = DateTime::parse_from_rfc3339("2030-01-01T00:00:00Z")
.expect("should parse")
.to_utc();
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: Some(expires),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("should parse")
.to_utc();
assert!(!tokens.is_expired(now));
}
#[test]
fn is_expired_returns_false_when_no_expiry() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("should parse")
.to_utc();
assert!(!tokens.is_expired(now));
}
#[test]
fn default_token_type_is_bearer() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt"
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
assert_eq!(tokens.token_type, "bearer");
}
#[test]
fn token_type_bearer_case_insensitive() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt",
"token_type": "Bearer"
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
assert_eq!(tokens.token_type, "bearer");
}
#[test]
fn token_type_bearer_all_caps() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt",
"token_type": "BEARER"
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
assert_eq!(tokens.token_type, "bearer");
}
#[test]
fn token_type_invalid_rejects() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt",
"token_type": "mac"
}"#;
let result: Result<OAuthTokenData, _> = serde_json::from_str(json);
let err = result.expect_err("should reject non-bearer token type");
let msg = err.to_string();
assert!(
msg.contains("invalid token_type"),
"error should mention invalid token_type: {msg}"
);
}
#[test]
fn scopes_deserialize_when_present() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt",
"token_type": "bearer",
"scopes": ["OAuth2Read", "OAuth2Write"]
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
let scopes = tokens.scopes.expect("should have scopes");
assert_eq!(scopes, vec!["OAuth2Read", "OAuth2Write"]);
}
#[test]
fn scopes_default_to_none_when_absent() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt",
"token_type": "bearer"
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
assert!(tokens.scopes.is_none());
}
#[test]
fn scopes_serialize_when_present() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: Some(vec!["OAuth2Read".into(), "OAuth2Write".into()]),
client_id: None,
client_secret: None,
proxy_url: None,
};
let json = serde_json::to_string(&tokens).expect("should serialize");
let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
let scopes = value["scopes"].as_array().expect("scopes should be array");
assert_eq!(scopes.len(), 2);
assert_eq!(scopes[0], "OAuth2Read");
assert_eq!(scopes[1], "OAuth2Write");
}
#[test]
fn scopes_omitted_from_json_when_none() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let json = serde_json::to_string(&tokens).expect("should serialize");
let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
assert!(
value.get("scopes").is_none(),
"scopes should be omitted from JSON when None"
);
}
#[test]
fn default_token_file_path_returns_some() {
let path = default_token_file_path();
if let Some(ref p) = path {
assert!(p.ends_with("onshape-mcp/tokens.json"));
}
}
#[test]
fn onshape_auth_url_is_valid() {
let url = onshape_auth_url();
let url_str = url.url().as_str();
assert!(url_str.starts_with("https://"));
assert!(url_str.contains("oauth.onshape.com"));
}
#[test]
fn onshape_token_url_is_valid() {
let url = onshape_token_url();
let url_str = url.url().as_str();
assert!(url_str.starts_with("https://"));
assert!(url_str.contains("oauth.onshape.com"));
}
#[test]
fn onshape_oauth_client_builds_successfully() {
let _client = onshape_oauth_client("test-client-id", "test-client-secret");
}
#[test]
fn from_response_with_expiry() {
let json = r#"{
"access_token": "test-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "test-refresh-token"
}"#;
let response: BasicTokenResponse =
serde_json::from_str(json).expect("should deserialize token response");
let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
.expect("parse")
.to_utc();
let token_data = OAuthTokenData::from_response(&response, now);
assert_eq!(token_data.access_token.secret(), "test-access-token");
assert_eq!(token_data.refresh_token.secret(), "test-refresh-token");
let expires_at = token_data.expires_at.expect("should have expiry");
assert_eq!(
expires_at,
now + chrono::Duration::seconds(3600),
"expires_at should be exactly now + 3600s"
);
}
#[test]
fn from_response_without_expiry() {
let json = r#"{
"access_token": "test-access-token",
"token_type": "Bearer"
}"#;
let response: BasicTokenResponse =
serde_json::from_str(json).expect("should deserialize token response");
let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
.expect("parse")
.to_utc();
let token_data = OAuthTokenData::from_response(&response, now);
assert_eq!(token_data.access_token.secret(), "test-access-token");
assert!(token_data.expires_at.is_none());
assert!(token_data.refresh_token.secret().is_empty());
assert!(token_data.scopes.is_none());
}
#[test]
fn from_response_preserves_scopes() {
let json = r#"{
"access_token": "test-at",
"token_type": "Bearer",
"refresh_token": "test-rt",
"scope": "OAuth2Read OAuth2Write"
}"#;
let response: BasicTokenResponse =
serde_json::from_str(json).expect("should deserialize token response");
let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
.expect("parse")
.to_utc();
let token_data = OAuthTokenData::from_response(&response, now);
let scopes = token_data.scopes.expect("should have scopes");
assert_eq!(scopes, vec!["OAuth2Read", "OAuth2Write"]);
}
#[test]
fn token_data_json_shape_backward_compatible() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("my-access".to_string()),
refresh_token: RefreshToken::new("my-refresh".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let json = serde_json::to_string_pretty(&tokens).expect("should serialize");
let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
assert!(value["access_token"].is_string());
assert!(value["refresh_token"].is_string());
assert!(value["token_type"].is_string());
assert_eq!(value["access_token"], "my-access");
assert_eq!(value["refresh_token"], "my-refresh");
}
#[test]
fn is_expiring_soon_false_when_well_before_margin() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".into()),
refresh_token: RefreshToken::new("rt".into()),
expires_at: Some(
DateTime::parse_from_rfc3339("2025-01-01T00:02:00Z")
.expect("parse")
.to_utc(),
),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert!(!tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
}
#[test]
fn is_expiring_soon_true_when_within_margin() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".into()),
refresh_token: RefreshToken::new("rt".into()),
expires_at: Some(
DateTime::parse_from_rfc3339("2025-01-01T00:00:55Z")
.expect("parse")
.to_utc(),
),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert!(tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
}
#[test]
fn is_expiring_soon_true_when_already_expired() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".into()),
refresh_token: RefreshToken::new("rt".into()),
expires_at: Some(
DateTime::parse_from_rfc3339("2024-12-31T23:59:00Z")
.expect("parse")
.to_utc(),
),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert!(tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
}
#[test]
fn is_expiring_soon_false_when_no_expiry() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".into()),
refresh_token: RefreshToken::new("rt".into()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = Utc::now();
assert!(!tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
}
#[test]
fn is_expiring_soon_true_at_exact_margin_boundary() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".into()),
refresh_token: RefreshToken::new("rt".into()),
expires_at: Some(
DateTime::parse_from_rfc3339("2025-01-01T00:01:00Z")
.expect("parse")
.to_utc(),
),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert!(tokens.is_expiring_soon(now, chrono::Duration::seconds(60)));
}
fn make_session(expires_at: Option<DateTime<Utc>>) -> OAuthSession {
OAuthSession::new(
OAuthTokenData {
access_token: AccessToken::new("at".into()),
refresh_token: RefreshToken::new("rt".into()),
expires_at,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
},
chrono::Duration::seconds(60),
)
}
#[test]
fn pre_execute_proceed_when_well_before_expiry() {
let session = make_session(Some(
DateTime::parse_from_rfc3339("2025-01-01T00:02:00Z")
.expect("parse")
.to_utc(),
));
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert_eq!(session.pre_execute_action(now), PreExecuteAction::Proceed);
}
#[test]
fn pre_execute_refresh_when_within_margin() {
let session = make_session(Some(
DateTime::parse_from_rfc3339("2025-01-01T00:00:55Z")
.expect("parse")
.to_utc(),
));
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert_eq!(
session.pre_execute_action(now),
PreExecuteAction::RefreshNeeded
);
}
#[test]
fn pre_execute_refresh_when_already_expired() {
let session = make_session(Some(
DateTime::parse_from_rfc3339("2024-12-31T23:00:00Z")
.expect("parse")
.to_utc(),
));
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
assert_eq!(
session.pre_execute_action(now),
PreExecuteAction::RefreshNeeded
);
}
#[test]
fn pre_execute_proceed_when_no_expiry() {
let session = make_session(None);
let now = Utc::now();
assert_eq!(session.pre_execute_action(now), PreExecuteAction::Proceed);
}
#[test]
fn post_execute_done_on_200() {
let session = make_session(None);
assert_eq!(
session.post_execute_action(200, false),
PostExecuteAction::Done
);
}
#[test]
fn post_execute_refresh_and_retry_on_401_not_refreshed() {
let session = make_session(None);
assert_eq!(
session.post_execute_action(401, false),
PostExecuteAction::RefreshAndRetry
);
}
#[test]
fn post_execute_done_on_401_already_refreshed() {
let session = make_session(None);
assert_eq!(
session.post_execute_action(401, true),
PostExecuteAction::Done
);
}
#[test]
fn post_execute_done_on_403() {
let session = make_session(None);
assert_eq!(
session.post_execute_action(403, false),
PostExecuteAction::Done
);
}
#[test]
fn apply_external_tokens_adopts_fresher_tokens() {
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
let mut session = OAuthSession::new(
OAuthTokenData {
access_token: AccessToken::new("old-at".into()),
refresh_token: RefreshToken::new("old-rt".into()),
expires_at: Some(now + chrono::Duration::seconds(100)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
},
chrono::Duration::seconds(60),
);
let file_tokens = OAuthTokenData {
access_token: AccessToken::new("new-at".into()),
refresh_token: RefreshToken::new("new-rt".into()),
expires_at: Some(now + chrono::Duration::seconds(3600)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
assert!(session.apply_external_tokens(file_tokens, now));
assert_eq!(session.access_token().secret(), "new-at");
assert_eq!(session.refresh_token().secret(), "new-rt");
}
#[test]
fn apply_external_tokens_rejects_same_or_earlier_expiry() {
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
let mut session = OAuthSession::new(
OAuthTokenData {
access_token: AccessToken::new("current-at".into()),
refresh_token: RefreshToken::new("current-rt".into()),
expires_at: Some(now + chrono::Duration::seconds(3600)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
},
chrono::Duration::seconds(60),
);
let file_tokens = OAuthTokenData {
access_token: AccessToken::new("file-at".into()),
refresh_token: RefreshToken::new("file-rt".into()),
expires_at: Some(now + chrono::Duration::seconds(3600)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
assert!(!session.apply_external_tokens(file_tokens, now));
assert_eq!(session.access_token().secret(), "current-at");
}
#[test]
fn apply_external_tokens_rejects_expired_file_tokens() {
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
let mut session = OAuthSession::new(
OAuthTokenData {
access_token: AccessToken::new("current-at".into()),
refresh_token: RefreshToken::new("current-rt".into()),
expires_at: Some(now - chrono::Duration::seconds(100)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
},
chrono::Duration::seconds(60),
);
let file_tokens = OAuthTokenData {
access_token: AccessToken::new("file-at".into()),
refresh_token: RefreshToken::new("file-rt".into()),
expires_at: Some(now - chrono::Duration::seconds(50)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
assert!(!session.apply_external_tokens(file_tokens, now));
assert_eq!(session.access_token().secret(), "current-at");
}
#[test]
fn apply_external_tokens_rejects_when_both_none_expiry() {
let now = Utc::now();
let mut session = OAuthSession::new(
OAuthTokenData {
access_token: AccessToken::new("current-at".into()),
refresh_token: RefreshToken::new("current-rt".into()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
},
chrono::Duration::seconds(60),
);
let file_tokens = OAuthTokenData {
access_token: AccessToken::new("file-at".into()),
refresh_token: RefreshToken::new("file-rt".into()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
assert!(!session.apply_external_tokens(file_tokens, now));
assert_eq!(session.access_token().secret(), "current-at");
}
#[test]
fn apply_external_tokens_rejects_when_file_has_none_expiry() {
let now = DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
.expect("parse")
.to_utc();
let mut session = OAuthSession::new(
OAuthTokenData {
access_token: AccessToken::new("current-at".into()),
refresh_token: RefreshToken::new("current-rt".into()),
expires_at: Some(now + chrono::Duration::seconds(100)),
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
},
chrono::Duration::seconds(60),
);
let file_tokens = OAuthTokenData {
access_token: AccessToken::new("file-at".into()),
refresh_token: RefreshToken::new("file-rt".into()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
assert!(!session.apply_external_tokens(file_tokens, now));
assert_eq!(session.access_token().secret(), "current-at");
}
#[test]
fn apply_refresh_updates_tokens_with_expiry() {
let mut session = make_session(None);
let json = r#"{
"access_token": "new-access-token",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token"
}"#;
let response: BasicTokenResponse = serde_json::from_str(json).expect("should deserialize");
let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
.expect("parse")
.to_utc();
session.apply_refresh(&response, now);
assert_eq!(session.access_token().secret(), "new-access-token");
assert_eq!(session.refresh_token().secret(), "new-refresh-token");
let expires_at = session.tokens.expires_at.expect("should have expiry");
assert_eq!(
expires_at,
now + chrono::Duration::seconds(3600),
"expires_at should be exactly now + 3600s"
);
}
#[test]
fn apply_refresh_updates_tokens_without_expiry() {
let now = DateTime::parse_from_rfc3339("2025-06-01T12:00:00Z")
.expect("parse")
.to_utc();
let mut session = make_session(Some(now));
let json = r#"{
"access_token": "new-at",
"token_type": "bearer"
}"#;
let response: BasicTokenResponse = serde_json::from_str(json).expect("should deserialize");
session.apply_refresh(&response, now);
assert_eq!(session.access_token().secret(), "new-at");
assert!(session.tokens.expires_at.is_none());
}
#[test]
fn token_data_roundtrips_with_proxy_url() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: Some("cid".into()),
client_secret: None,
proxy_url: Some("https://proxy.example.com".into()),
};
let json = serde_json::to_string(&tokens).expect("should serialize");
let roundtripped: OAuthTokenData = serde_json::from_str(&json).expect("should deserialize");
assert_eq!(
roundtripped.proxy_url.as_deref(),
Some("https://proxy.example.com")
);
assert_eq!(roundtripped.client_id.as_deref(), Some("cid"));
assert!(roundtripped.client_secret.is_none());
}
#[test]
fn token_data_backward_compat_without_proxy_url() {
let json = r#"{
"access_token": "at",
"refresh_token": "rt",
"token_type": "bearer",
"client_id": "cid",
"client_secret": "cs"
}"#;
let tokens: OAuthTokenData = serde_json::from_str(json).expect("should deserialize");
assert!(tokens.proxy_url.is_none());
assert_eq!(tokens.client_id.as_deref(), Some("cid"));
assert_eq!(tokens.client_secret.as_deref(), Some("cs"));
}
#[test]
fn token_data_proxy_url_omitted_from_json_when_none() {
let tokens = OAuthTokenData {
access_token: AccessToken::new("at".to_string()),
refresh_token: RefreshToken::new("rt".to_string()),
expires_at: None,
token_type: "bearer".into(),
scopes: None,
client_id: None,
client_secret: None,
proxy_url: None,
};
let json = serde_json::to_string(&tokens).expect("should serialize");
let value: serde_json::Value = serde_json::from_str(&json).expect("should be valid JSON");
assert!(value.get("proxy_url").is_none());
}
#[test]
fn from_raw_creates_token_data() {
let tokens = OAuthTokenData::from_raw(
"access-token".into(),
"refresh-token".into(),
None,
"bearer".into(),
Some(vec!["OAuth2Read".into(), "OAuth2Write".into()]),
);
assert_eq!(tokens.access_token.secret(), "access-token");
assert_eq!(tokens.refresh_token.secret(), "refresh-token");
assert!(tokens.expires_at.is_none());
assert_eq!(tokens.token_type, "bearer");
assert_eq!(
tokens.scopes,
Some(vec!["OAuth2Read".into(), "OAuth2Write".into()])
);
assert!(tokens.client_id.is_none());
assert!(tokens.client_secret.is_none());
assert!(tokens.proxy_url.is_none());
}
fn test_login_config() -> OAuthLoginConfig {
OAuthLoginConfig {
client_id: "test-client-id".into(),
redirect_uri: "http://127.0.0.1:18338/callback".into(),
scopes: vec!["OAuth2Read".into(), "OAuth2Write".into()],
}
}
#[test]
fn build_authorize_url_contains_required_params() {
let config = test_login_config();
let state = CsrfToken::new("test-state-token".into());
let (challenge, _verifier) = PkceCodeChallenge::new_random_sha256();
let url_str = build_authorize_url(&config, &state, challenge);
let url = url::Url::parse(&url_str).expect("should be a valid URL");
assert_eq!(url.scheme(), "https");
assert_eq!(url.host_str(), Some("oauth.onshape.com"));
assert_eq!(url.path(), "/oauth/authorize");
let params: std::collections::HashMap<String, String> = url
.query_pairs()
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect();
assert_eq!(
params.get("client_id").map(String::as_str),
Some("test-client-id")
);
assert_eq!(
params.get("redirect_uri").map(String::as_str),
Some("http://127.0.0.1:18338/callback")
);
assert_eq!(
params.get("response_type").map(String::as_str),
Some("code")
);
assert_eq!(
params.get("state").map(String::as_str),
Some("test-state-token")
);
assert!(params.contains_key("code_challenge"));
assert_eq!(
params.get("code_challenge_method").map(String::as_str),
Some("S256")
);
let scope = params.get("scope").expect("should have scope parameter");
assert!(scope.contains("OAuth2Read"));
assert!(scope.contains("OAuth2Write"));
}
#[test]
fn build_authorize_url_with_no_scopes() {
let config = OAuthLoginConfig {
client_id: "cid".into(),
redirect_uri: "http://127.0.0.1:18338/callback".into(),
scopes: vec![],
};
let state = CsrfToken::new("state".into());
let (challenge, _verifier) = PkceCodeChallenge::new_random_sha256();
let url_str = build_authorize_url(&config, &state, challenge);
let url = url::Url::parse(&url_str).expect("should be a valid URL");
let params: std::collections::HashMap<String, String> = url
.query_pairs()
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect();
assert!(!params.contains_key("scope"));
}
#[test]
fn validate_callback_extracts_code() {
let state = CsrfToken::new("my-state".into());
let callback = "http://127.0.0.1:18338/callback?code=auth-code-123&state=my-state";
let code = validate_callback(callback, &state).expect("should validate");
assert_eq!(code.secret(), "auth-code-123");
}
#[test]
fn validate_callback_detects_state_mismatch() {
let state = CsrfToken::new("expected-state".into());
let callback = "http://127.0.0.1:18338/callback?code=abc&state=wrong-state";
let err = validate_callback(callback, &state).expect_err("should fail");
assert!(
matches!(err, CallbackValidationError::StateMismatch { .. }),
"expected StateMismatch, got: {err:?}"
);
}
#[test]
fn validate_callback_detects_oauth_error() {
let state = CsrfToken::new("my-state".into());
let callback = "http://127.0.0.1:18338/callback?error=access_denied&error_description=User+denied+access&state=my-state";
let err = validate_callback(callback, &state).expect_err("should fail");
match err {
CallbackValidationError::OAuthError { error, description } => {
assert_eq!(error, "access_denied");
assert_eq!(description.as_deref(), Some("User denied access"));
}
other => panic!("expected OAuthError, got: {other:?}"),
}
}
#[test]
fn validate_callback_detects_missing_state() {
let state = CsrfToken::new("my-state".into());
let callback = "http://127.0.0.1:18338/callback?code=abc";
let err = validate_callback(callback, &state).expect_err("should fail");
assert!(
matches!(err, CallbackValidationError::MissingState),
"expected MissingState, got: {err:?}"
);
}
#[test]
fn validate_callback_detects_missing_code() {
let state = CsrfToken::new("my-state".into());
let callback = "http://127.0.0.1:18338/callback?state=my-state";
let err = validate_callback(callback, &state).expect_err("should fail");
assert!(
matches!(err, CallbackValidationError::MissingCode),
"expected MissingCode, got: {err:?}"
);
}
#[test]
fn validate_callback_detects_invalid_url() {
let state = CsrfToken::new("my-state".into());
let err = validate_callback("not a url at all ://", &state).expect_err("should fail");
assert!(
matches!(err, CallbackValidationError::InvalidUrl(_)),
"expected InvalidUrl, got: {err:?}"
);
}
}