use turbomcp_protocol::{Error as McpError, Result as McpResult};
pub fn validate_canonical_resource_uri(uri: &str) -> McpResult<()> {
use url::Url;
let scheme_end = uri.find("://").ok_or_else(|| {
McpError::invalid_params("Resource URI must have a valid scheme".to_string())
})?;
let scheme = &uri[..scheme_end];
if scheme != scheme.to_lowercase() {
return Err(McpError::invalid_params(
"Resource URI must use canonical form (lowercase scheme and host)".to_string(),
));
}
let parsed = Url::parse(uri)
.map_err(|e| McpError::invalid_params(format!("Invalid resource URI: {e}")))?;
if parsed.scheme() != "https" && parsed.scheme() != "http" {
return Err(McpError::invalid_params(
"Resource URI must use http or https scheme".to_string(),
));
}
if parsed.fragment().is_some() {
return Err(McpError::invalid_params(
"Resource URI must not contain fragment".to_string(),
));
}
if parsed.host_str().is_none() {
return Err(McpError::invalid_params(
"Resource URI must include host".to_string(),
));
}
let _host = parsed.host_str().expect("host validated above");
let host_start = uri.find("://").expect("scheme checked above") + 3;
let host_in_uri = &uri[host_start..];
let host_end = host_in_uri
.find(['/', ':', '?', '#'])
.unwrap_or(host_in_uri.len());
let original_host = &host_in_uri[..host_end];
if original_host != original_host.to_lowercase() {
return Err(McpError::invalid_params(
"Resource URI must use canonical form (lowercase scheme and host)".to_string(),
));
}
Ok(())
}
pub fn validate_oauth_state(expected_state: &str, received_state: &str) -> McpResult<()> {
use subtle::ConstantTimeEq;
if expected_state.is_empty() || received_state.is_empty() {
return Err(McpError::invalid_params(
"OAuth state parameter cannot be empty".to_string(),
));
}
let is_equal = expected_state.as_bytes().ct_eq(received_state.as_bytes());
if bool::from(is_equal) {
Ok(())
} else {
Err(McpError::invalid_params(
"OAuth state parameter mismatch - possible CSRF attack".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_https_uri() {
assert!(validate_canonical_resource_uri("https://example.com/resource").is_ok());
}
#[test]
fn test_valid_http_uri() {
assert!(validate_canonical_resource_uri("http://example.com/resource").is_ok());
}
#[test]
fn test_non_canonical_uppercase_host() {
let result = validate_canonical_resource_uri("https://Example.COM/resource");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("canonical form"));
}
#[test]
fn test_non_canonical_uppercase_scheme() {
let result = validate_canonical_resource_uri("HTTPS://example.com/resource");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("canonical form"));
}
#[test]
fn test_missing_host() {
let result = validate_canonical_resource_uri("file:///etc/passwd");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("http or https scheme")
);
}
#[test]
fn test_fragment_not_allowed() {
let result = validate_canonical_resource_uri("https://example.com/resource#fragment");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("fragment"));
}
#[test]
fn test_invalid_scheme() {
let result = validate_canonical_resource_uri("ftp://example.com/resource");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("http or https scheme")
);
}
#[test]
fn test_oauth_state_validation_success() {
let state = "random-csrf-token-123";
assert!(validate_oauth_state(state, state).is_ok());
}
#[test]
fn test_oauth_state_validation_mismatch() {
let expected = "state-abc123";
let received = "state-xyz789";
let result = validate_oauth_state(expected, received);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("state parameter mismatch")
);
}
#[test]
fn test_oauth_state_validation_empty_expected() {
let result = validate_oauth_state("", "some-state");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_oauth_state_validation_empty_received() {
let result = validate_oauth_state("some-state", "");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_oauth_state_validation_case_sensitive() {
let result = validate_oauth_state("State123", "state123");
assert!(result.is_err());
}
}