1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::extract::Request;
5use axum::http::{HeaderName, Method, StatusCode};
6use axum::middleware::{self, Next};
7use axum::response::IntoResponse;
8use axum::routing::get;
9use axum::Router;
10use rmcp::transport::streamable_http_server::{
11 session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
12};
13use tower_http::cors::CorsLayer;
14
15use crate::auth;
16use crate::client::OriginClient;
17use crate::tools::{OriginMcpServer, TransportMode};
18
19#[derive(Debug, Clone)]
20pub struct ServeConfig {
21 pub port: u16,
22 pub host: String,
23 pub origin_url: String,
24 pub token: Option<String>,
25 pub agent_name: String,
26 pub user_id: Option<String>,
27 pub allowed_origins: Vec<String>,
28}
29
30async fn health() -> impl IntoResponse {
31 axum::Json(serde_json::json!({
32 "status": "ok",
33 "server": "origin-mcp",
34 "version": env!("CARGO_PKG_VERSION"),
35 }))
36}
37
38pub async fn run_serve(config: ServeConfig) -> anyhow::Result<()> {
39 let client =
40 OriginClient::new(config.origin_url.clone()).with_agent_name(config.agent_name.clone());
41 let agent_name = config.agent_name.clone();
42 let user_id = config.user_id.clone();
43 let token = config.token.clone();
44 let allowed_origins = config.allowed_origins.clone();
45
46 let mcp_config = StreamableHttpServerConfig::default().disable_allowed_hosts();
60
61 let mcp_service = StreamableHttpService::new(
62 move || {
63 Ok(OriginMcpServer::new(
64 client.clone(),
65 TransportMode::Http,
66 agent_name.clone(),
67 user_id.clone(),
68 ))
69 },
70 Arc::new(LocalSessionManager::default()),
71 mcp_config,
72 );
73
74 let cors = build_cors_layer(&config.allowed_origins);
75
76 let mut router = Router::new()
77 .nest_service("/mcp", mcp_service)
78 .route("/health", get(health))
79 .layer(cors);
80
81 if let Some(ref expected_token) = token {
82 let token_for_middleware = expected_token.clone();
83 let origins_for_middleware = allowed_origins.clone();
84 router = router.layer(middleware::from_fn(move |req: Request, next: Next| {
85 let token = token_for_middleware.clone();
86 let origins = origins_for_middleware.clone();
87 async move { auth_and_origin_middleware(req, next, &token, &origins).await }
88 }));
89 }
90
91 let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
92 let listener = tokio::net::TcpListener::bind(addr).await?;
93 tracing::info!("origin-mcp HTTP server listening on {}", addr);
94
95 if token.is_some() {
96 tracing::info!("Bearer token authentication enabled");
97 } else {
98 tracing::warn!("Running without authentication — only safe on loopback");
99 }
100
101 let shutdown = async {
102 #[cfg(unix)]
103 {
104 use tokio::signal::unix::{signal, SignalKind};
105 let ctrl_c = tokio::signal::ctrl_c();
106 let mut sigterm =
107 signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
108 tokio::select! {
109 _ = ctrl_c => {},
110 _ = sigterm.recv() => {},
111 }
112 }
113 #[cfg(not(unix))]
114 {
115 tokio::signal::ctrl_c().await.ok();
116 }
117 tracing::info!("Shutting down origin-mcp HTTP server");
118 };
119
120 axum::serve(listener, router)
121 .with_graceful_shutdown(shutdown)
122 .await?;
123
124 Ok(())
125}
126
127fn build_cors_layer(allowed_origins: &[String]) -> CorsLayer {
128 let cors = CorsLayer::new()
129 .allow_methods([Method::GET, Method::POST, Method::DELETE, Method::OPTIONS])
130 .allow_headers([
131 http::header::AUTHORIZATION,
132 http::header::CONTENT_TYPE,
133 http::header::ACCEPT,
134 HeaderName::from_static("mcp-session-id"),
135 HeaderName::from_static("mcp-protocol-version"),
136 ]);
137
138 if allowed_origins.iter().any(|o| o == "*") {
139 cors.allow_origin(tower_http::cors::Any)
140 } else {
141 let origins: Vec<http::HeaderValue> = allowed_origins
142 .iter()
143 .filter_map(|o| o.parse().ok())
144 .collect();
145 cors.allow_origin(origins)
146 }
147}
148
149async fn auth_and_origin_middleware(
151 req: Request,
152 next: Next,
153 expected_token: &str,
154 allowed_origins: &[String],
155) -> axum::response::Response {
156 let is_preflight = req.method() == Method::OPTIONS;
157 let is_health = req.uri().path() == "/health";
158 if is_preflight || is_health {
159 return next.run(req).await;
160 }
161
162 let auth_header = req.headers().get(http::header::AUTHORIZATION);
164 match auth_header {
165 Some(value) => {
166 let value_str = value.to_str().unwrap_or("");
167 match auth::extract_bearer_token(value_str) {
168 Some(provided) if auth::verify_token(provided, expected_token) => {}
169 _ => return (StatusCode::UNAUTHORIZED, "Invalid bearer token").into_response(),
170 }
171 }
172 None => return (StatusCode::UNAUTHORIZED, "Authorization header required").into_response(),
173 }
174
175 if let Some(origin) = req.headers().get(http::header::ORIGIN) {
177 if let Ok(origin_str) = origin.to_str() {
178 if !auth::is_origin_allowed(origin_str, allowed_origins) {
179 return (StatusCode::FORBIDDEN, "Origin not allowed").into_response();
180 }
181 }
182 }
183
184 next.run(req).await
185}