use crate::{
mcp_http::{error_response, types::GenericBody, McpAppState, Middleware, MiddlewareNext},
mcp_server::error::TransportServerResult,
schema::schema_utils::SdkError,
};
use async_trait::async_trait;
use http::{
header::{HOST, ORIGIN},
Request, Response, StatusCode,
};
use std::sync::Arc;
pub(crate) struct DnsRebindProtector {
pub allowed_hosts: Option<Vec<String>>,
pub allowed_origins: Option<Vec<String>>,
}
#[async_trait]
impl Middleware for DnsRebindProtector {
async fn handle<'req>(
&self,
req: Request<&'req str>,
state: Arc<McpAppState>,
next: MiddlewareNext<'req>,
) -> TransportServerResult<Response<GenericBody>> {
if let Err(error) = self.protect_dns_rebinding(req.headers()).await {
return error_response(StatusCode::FORBIDDEN, error);
}
next(req, state).await
}
}
impl DnsRebindProtector {
pub fn new(allowed_hosts: Option<Vec<String>>, allowed_origins: Option<Vec<String>>) -> Self {
Self {
allowed_hosts,
allowed_origins,
}
}
async fn protect_dns_rebinding(&self, headers: &http::HeaderMap) -> Result<(), SdkError> {
if let Some(allowed_hosts) = self.allowed_hosts.as_ref() {
if !allowed_hosts.is_empty() {
let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else {
return Err(
SdkError::bad_request().with_message("Invalid Host header: [unknown] ")
);
};
if !allowed_hosts
.iter()
.any(|allowed| allowed.eq_ignore_ascii_case(host))
{
return Err(SdkError::bad_request()
.with_message(format!("Invalid Host header: \"{host}\" ").as_str()));
}
}
}
if let Some(allowed_origins) = self.allowed_origins.as_ref() {
if !allowed_origins.is_empty() {
let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else {
return Err(
SdkError::bad_request().with_message("Invalid Origin header: [unknown] ")
);
};
if !allowed_origins
.iter()
.any(|allowed| allowed.eq_ignore_ascii_case(origin))
{
return Err(SdkError::bad_request()
.with_message(format!("Invalid Origin header: \"{origin}\" ").as_str()));
}
}
}
Ok(())
}
}