use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware::{self, Next},
response::Response,
};
use ipnetwork::IpNetwork;
use std::{net::IpAddr, str::FromStr, sync::Arc};
const PROTOCOL_HEADER: &str = "MCP-Protocol-Version";
const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2025-11-25", "2025-06-18", "2025-03-26"];
const DEFAULT_PROTOCOL_VERSION: &str = "2025-03-26";
#[derive(Clone)]
pub struct OriginConfig {
allowed_networks: Arc<Vec<IpNetwork>>,
allow_missing_origin: bool,
}
impl OriginConfig {
pub fn new(allowed_cidrs: Vec<String>) -> Result<Self, String> {
let mut networks = Vec::new();
for cidr in allowed_cidrs {
let network = IpNetwork::from_str(&cidr).map_err(|e| format!("Invalid CIDR notation '{}': {}", cidr, e))?;
networks.push(network);
}
if networks.is_empty() {
return Err("At least one allowed network must be specified".to_string());
}
Ok(Self {
allowed_networks: Arc::new(networks),
allow_missing_origin: true,
})
}
pub fn localhost_only() -> Self {
Self {
allowed_networks: Arc::new(vec![
IpNetwork::from_str("127.0.0.0/8").unwrap(), IpNetwork::from_str("::1/128").unwrap(), ]),
allow_missing_origin: true,
}
}
#[allow(dead_code)]
pub fn allow_missing_origin(mut self, allow: bool) -> Self {
self.allow_missing_origin = allow;
self
}
pub fn is_allowed_origin(&self, origin: &str) -> bool {
let host = match Self::extract_host_from_origin(origin) {
Some(h) => h,
None => return false,
};
if let Ok(ip_addr) = IpAddr::from_str(host) {
return self.is_ip_allowed(&ip_addr);
}
if host == "localhost" {
let localhost_v4 = IpAddr::from_str("127.0.0.1").unwrap();
let localhost_v6 = IpAddr::from_str("::1").unwrap();
return self.is_ip_allowed(&localhost_v4) || self.is_ip_allowed(&localhost_v6);
}
false
}
fn is_ip_allowed(&self, ip: &IpAddr) -> bool {
self.allowed_networks.iter().any(|network| network.contains(*ip))
}
fn extract_host_from_origin(origin: &str) -> Option<&str> {
let without_protocol = origin
.strip_prefix("https://")
.or_else(|| origin.strip_prefix("http://"))
.unwrap_or(origin);
if let Some(start) = without_protocol.find('[')
&& let Some(end) = without_protocol.find(']')
{
return Some(&without_protocol[start + 1..end]);
}
let host_end = without_protocol
.find(':')
.or_else(|| without_protocol.find('/'))
.unwrap_or(without_protocol.len());
Some(&without_protocol[..host_end])
}
}
#[allow(dead_code)]
pub fn is_valid_protocol_version(version: &str) -> bool {
SUPPORTED_PROTOCOL_VERSIONS.contains(&version)
}
pub fn with_guards(router: Router, origin_config: OriginConfig) -> Router {
router
.layer(middleware::from_fn(protocol_version_guard))
.layer(middleware::from_fn(move |req, next| {
origin_guard(req, next, origin_config.clone())
}))
}
pub async fn protocol_version_guard(req: Request<Body>, next: Next) -> Response {
if let Some(v) = req.headers().get(PROTOCOL_HEADER) {
match v.to_str() {
Ok(s) if SUPPORTED_PROTOCOL_VERSIONS.contains(&s) => {
tracing::debug!("Accepted MCP-Protocol-Version: {}", s);
},
Ok(s) => {
tracing::warn!("Unsupported MCP-Protocol-Version: {}", s);
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(
format!(
"unsupported MCP-Protocol-Version: {}. Supported versions: {:?}",
s, SUPPORTED_PROTOCOL_VERSIONS
)
.into(),
)
.unwrap();
},
Err(_) => {
tracing::warn!("Invalid MCP-Protocol-Version header encoding");
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("invalid MCP-Protocol-Version header encoding".into())
.unwrap();
},
}
} else {
tracing::debug!("No MCP-Protocol-Version header; assuming {}", DEFAULT_PROTOCOL_VERSION);
}
next.run(req).await
}
pub async fn origin_guard(req: Request<Body>, next: Next, config: OriginConfig) -> Response {
if let Some(origin) = req.headers().get("Origin") {
match origin.to_str() {
Ok(origin_str) if config.is_allowed_origin(origin_str) => {
tracing::debug!("Accepted Origin header: {}", origin_str);
},
Ok(origin_str) => {
tracing::warn!("Rejected request due to non-allowed Origin header: {}", origin_str);
return Response::builder()
.status(StatusCode::FORBIDDEN)
.body("origin not allowed - origin IP not in allowed networks".into())
.unwrap();
},
Err(_) => {
tracing::warn!("Rejected request due to invalid Origin header encoding");
return Response::builder()
.status(StatusCode::FORBIDDEN)
.body("invalid Origin header encoding".into())
.unwrap();
},
}
} else if !config.allow_missing_origin {
tracing::warn!("Rejected request due to missing Origin header");
return Response::builder()
.status(StatusCode::FORBIDDEN)
.body("missing Origin header".into())
.unwrap();
} else {
tracing::trace!("No Origin header present (non-browser client)");
}
next.run(req).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accepts_supported_protocol_version() {
let supported_versions = vec!["2025-11-25", "2025-06-18", "2025-03-26"];
for version in supported_versions {
assert!(
is_valid_protocol_version(version),
"Version {} should be valid",
version
);
}
}
#[test]
fn test_rejects_unsupported_protocol_version() {
let unsupported_versions = vec!["2020-01-01", "1.0", "invalid", "", "2024-01-01", "2025-01-01"];
for version in unsupported_versions {
assert!(
!is_valid_protocol_version(version),
"Version {} should be invalid",
version
);
}
}
#[test]
fn test_localhost_config() {
let config = OriginConfig::localhost_only();
assert!(config.is_allowed_origin("http://localhost"));
assert!(config.is_allowed_origin("https://localhost:8080"));
assert!(config.is_allowed_origin("http://127.0.0.1"));
assert!(config.is_allowed_origin("https://127.0.0.1:3000"));
assert!(config.is_allowed_origin("http://[::1]"));
assert!(config.is_allowed_origin("https://[::1]:8080"));
assert!(!config.is_allowed_origin("http://192.168.1.1"));
assert!(!config.is_allowed_origin("https://example.com"));
}
#[test]
fn test_custom_networks() {
let config = OriginConfig::new(vec!["192.168.1.0/24".to_string(), "10.0.0.0/8".to_string()]).unwrap();
assert!(config.is_allowed_origin("http://192.168.1.1"));
assert!(config.is_allowed_origin("https://192.168.1.254"));
assert!(config.is_allowed_origin("http://10.0.0.1"));
assert!(config.is_allowed_origin("http://10.255.255.255"));
assert!(!config.is_allowed_origin("http://192.168.2.1"));
assert!(!config.is_allowed_origin("http://127.0.0.1"));
assert!(!config.is_allowed_origin("http://localhost"));
}
}