use std::sync::Arc;
use axum::{
extract::Request,
http::{HeaderName, HeaderValue, Method, StatusCode, header},
middleware::Next,
response::Response,
};
use tower_http::cors::{AllowOrigin, CorsLayer};
const MCP_SESSION_ID_HEADER: HeaderName = HeaderName::from_static("mcp-session-id");
const LAST_EVENT_ID_HEADER: HeaderName = HeaderName::from_static("last-event-id");
#[derive(Debug, Clone)]
pub struct OriginAllowlist {
inner: Arc<Vec<HeaderValue>>,
}
impl OriginAllowlist {
pub fn new<I, S>(origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let inner = origins
.into_iter()
.filter_map(|o| HeaderValue::from_str(o.as_ref()).ok())
.collect::<Vec<_>>();
Self {
inner: Arc::new(inner),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn header_values(&self) -> impl Iterator<Item = &HeaderValue> {
self.inner.iter()
}
fn matches(&self, candidate: &HeaderValue) -> bool {
self.inner.iter().any(|allowed| allowed == candidate)
}
}
pub async fn origin_guard(
axum::extract::State(allowlist): axum::extract::State<OriginAllowlist>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let Some(origin_value) = request.headers().get(header::ORIGIN) else {
return Ok(next.run(request).await);
};
if origin_value.as_bytes() == b"null" || !allowlist.matches(origin_value) {
tracing::warn!(
origin = %String::from_utf8_lossy(origin_value.as_bytes()),
"Rejecting request with disallowed Origin header"
);
return Err(StatusCode::FORBIDDEN);
}
Ok(next.run(request).await)
}
#[must_use]
pub fn build_cors_layer(allowlist: &OriginAllowlist) -> Option<CorsLayer> {
if allowlist.is_empty() {
return None;
}
let origins = allowlist.header_values().cloned().collect::<Vec<_>>();
Some(
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([
header::CONTENT_TYPE,
header::AUTHORIZATION,
header::ACCEPT,
MCP_SESSION_ID_HEADER,
LAST_EVENT_ID_HEADER,
])
.expose_headers([MCP_SESSION_ID_HEADER]),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_allowlist_is_empty() {
let allowlist = OriginAllowlist::new(Vec::<String>::new());
assert!(allowlist.is_empty());
assert!(build_cors_layer(&allowlist).is_none());
}
#[test]
fn allowlist_matches_exact_origin() {
let allowlist = OriginAllowlist::new(["https://app.example.com", "http://localhost:8080"]);
assert!(!allowlist.is_empty());
let h = HeaderValue::from_static("https://app.example.com");
assert!(allowlist.matches(&h));
let other = HeaderValue::from_static("https://evil.example.com");
assert!(!allowlist.matches(&other));
}
#[test]
fn allowlist_does_not_match_null_origin() {
let allowlist = OriginAllowlist::new(["https://app.example.com"]);
let h = HeaderValue::from_static("null");
assert!(!allowlist.matches(&h));
}
#[test]
fn cors_layer_built_when_allowlist_non_empty() {
let allowlist = OriginAllowlist::new(["https://app.example.com"]);
assert!(build_cors_layer(&allowlist).is_some());
}
}