use crate::auth::oauth::error::OAuthError;
use crate::auth::oauth::hmac::{constant_time_compare, validate_hmac};
use crate::auth::oauth::AuthQuery;
use crate::auth::session::AccessTokenResponse;
use crate::auth::Session;
use crate::config::{ShopDomain, ShopifyConfig};
#[derive(serde::Serialize)]
struct TokenExchangeRequest<'a> {
client_id: &'a str,
client_secret: &'a str,
code: &'a str,
}
pub async fn validate_auth_callback(
config: &ShopifyConfig,
auth_query: &AuthQuery,
expected_state: &str,
) -> Result<Session, OAuthError> {
if !validate_hmac(auth_query, config) {
return Err(OAuthError::InvalidHmac);
}
if !constant_time_compare(&auth_query.state, expected_state) {
return Err(OAuthError::StateMismatch {
expected: expected_state.to_string(),
received: auth_query.state.clone(),
});
}
let shop = ShopDomain::new(&auth_query.shop).map_err(|_| OAuthError::InvalidCallback {
reason: format!("Invalid shop domain: {}", auth_query.shop),
})?;
let token_url = format!("https://{}/admin/oauth/access_token", shop.as_ref());
let request_body = TokenExchangeRequest {
client_id: config.api_key().as_ref(),
client_secret: config.api_secret_key().as_ref(),
code: &auth_query.code,
};
let client = reqwest::Client::new();
let response = client
.post(&token_url)
.json(&request_body)
.send()
.await
.map_err(|e| OAuthError::TokenExchangeFailed {
status: 0,
message: format!("Network error: {e}"),
})?;
let status = response.status().as_u16();
if !response.status().is_success() {
let error_body = response.text().await.unwrap_or_default();
return Err(OAuthError::TokenExchangeFailed {
status,
message: error_body,
});
}
let token_response: AccessTokenResponse =
response
.json()
.await
.map_err(|e| OAuthError::TokenExchangeFailed {
status,
message: format!("Failed to parse token response: {e}"),
})?;
let session = Session::from_access_token_response(shop, &token_response);
Ok(session)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::oauth::hmac::compute_signature;
use crate::config::{ApiKey, ApiSecretKey, HostUrl};
use wiremock::matchers::{body_json, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn create_test_config() -> ShopifyConfig {
ShopifyConfig::builder()
.api_key(ApiKey::new("test-api-key").unwrap())
.api_secret_key(ApiSecretKey::new("test-secret").unwrap())
.host(HostUrl::new("https://myapp.example.com").unwrap())
.build()
.unwrap()
}
fn create_valid_auth_query(secret: &str) -> AuthQuery {
let mut query = AuthQuery::new(
"auth-code-123".to_string(),
"test-shop.myshopify.com".to_string(),
"1700000000".to_string(),
"test-state".to_string(),
"dGVzdC1ob3N0".to_string(),
String::new(),
);
let signable = query.to_signable_string();
query.hmac = compute_signature(&signable, secret);
query
}
#[tokio::test]
async fn test_validate_auth_callback_validates_hmac() {
let config = create_test_config();
let query = AuthQuery::new(
"code".to_string(),
"shop.myshopify.com".to_string(),
"12345".to_string(),
"state".to_string(),
"host".to_string(),
"invalid-hmac".to_string(),
);
let result = validate_auth_callback(&config, &query, "state").await;
assert!(matches!(result, Err(OAuthError::InvalidHmac)));
}
#[tokio::test]
async fn test_validate_auth_callback_rejects_state_mismatch() {
let config = create_test_config();
let query = create_valid_auth_query("test-secret");
let result = validate_auth_callback(&config, &query, "wrong-state").await;
match result {
Err(OAuthError::StateMismatch { expected, received }) => {
assert_eq!(expected, "wrong-state");
assert_eq!(received, "test-state");
}
_ => panic!("Expected StateMismatch error"),
}
}
#[tokio::test]
async fn test_validate_auth_callback_with_invalid_shop() {
let config = create_test_config();
let mut query = AuthQuery::new(
"code".to_string(),
"invalid shop domain".to_string(), "12345".to_string(),
"test-state".to_string(),
"host".to_string(),
String::new(),
);
let signable = query.to_signable_string();
query.hmac = compute_signature(&signable, "test-secret");
let result = validate_auth_callback(&config, &query, "test-state").await;
match result {
Err(OAuthError::InvalidCallback { reason }) => {
assert!(reason.contains("Invalid shop domain"));
}
_ => panic!("Expected InvalidCallback error"),
}
}
#[tokio::test]
async fn test_validate_auth_callback_returns_session_on_success() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/admin/oauth/access_token"))
.and(body_json(serde_json::json!({
"client_id": "test-api-key",
"client_secret": "test-secret",
"code": "auth-code-123"
})))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "new-access-token",
"scope": "read_products,write_orders"
})))
.mount(&mock_server)
.await;
let config = create_test_config();
let query = create_valid_auth_query("test-secret");
let result = validate_auth_callback(&config, &query, "test-state").await;
assert!(matches!(
result,
Err(OAuthError::TokenExchangeFailed { .. })
));
}
#[tokio::test]
async fn test_validate_auth_callback_handles_token_exchange_error() {
let config = create_test_config();
let query = create_valid_auth_query("test-secret");
let result = validate_auth_callback(&config, &query, "test-state").await;
assert!(matches!(
result,
Err(OAuthError::TokenExchangeFailed { .. })
));
}
#[test]
fn test_constant_time_compare_in_state_validation() {
assert!(constant_time_compare("state123", "state123"));
assert!(!constant_time_compare("state123", "state124"));
}
#[tokio::test]
async fn test_validate_hmac_with_old_key_fallback() {
let config = ShopifyConfig::builder()
.api_key(ApiKey::new("test-key").unwrap())
.api_secret_key(ApiSecretKey::new("new-secret").unwrap())
.old_api_secret_key(ApiSecretKey::new("old-secret").unwrap())
.host(HostUrl::new("https://app.example.com").unwrap())
.build()
.unwrap();
let query = create_valid_auth_query("old-secret");
let result = validate_auth_callback(&config, &query, "test-state").await;
assert!(matches!(
result,
Err(OAuthError::TokenExchangeFailed { .. })
));
}
#[tokio::test]
async fn test_validate_auth_callback_with_correct_state() {
let config = create_test_config();
let query = create_valid_auth_query("test-secret");
let result = validate_auth_callback(&config, &query, "test-state").await;
match &result {
Err(OAuthError::StateMismatch { .. }) => {
panic!("Should not fail on state mismatch with correct state")
}
Err(OAuthError::InvalidHmac) => {
panic!("Should not fail on HMAC with valid HMAC")
}
_ => {} }
}
}