Skip to main content

origin_mcp/
serve.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::extract::Request;
5use axum::http::{HeaderName, Method, StatusCode};
6use axum::middleware::{self, Next};
7use axum::response::IntoResponse;
8use axum::routing::get;
9use axum::Router;
10use rmcp::transport::streamable_http_server::{
11    session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
12};
13use tower_http::cors::CorsLayer;
14
15use crate::auth;
16use crate::client::OriginClient;
17use crate::tools::{OriginMcpServer, TransportMode};
18
19#[derive(Debug, Clone)]
20pub struct ServeConfig {
21    pub port: u16,
22    pub host: String,
23    pub origin_url: String,
24    pub token: Option<String>,
25    pub agent_name: String,
26    pub user_id: Option<String>,
27    pub allowed_origins: Vec<String>,
28}
29
30async fn health() -> impl IntoResponse {
31    axum::Json(serde_json::json!({
32        "status": "ok",
33        "server": "origin-mcp",
34        "version": env!("CARGO_PKG_VERSION"),
35    }))
36}
37
38pub async fn run_serve(config: ServeConfig) -> anyhow::Result<()> {
39    let client = OriginClient::new(config.origin_url.clone());
40    let agent_name = config.agent_name.clone();
41    let user_id = config.user_id.clone();
42    let token = config.token.clone();
43    let allowed_origins = config.allowed_origins.clone();
44
45    let mcp_service = StreamableHttpService::new(
46        move || {
47            Ok(OriginMcpServer::new(
48                client.clone(),
49                TransportMode::Http,
50                agent_name.clone(),
51                user_id.clone(),
52            ))
53        },
54        Arc::new(LocalSessionManager::default()),
55        StreamableHttpServerConfig::default(),
56    );
57
58    let cors = build_cors_layer(&config.allowed_origins);
59
60    let mut router = Router::new()
61        .nest_service("/mcp", mcp_service)
62        .route("/health", get(health))
63        .layer(cors);
64
65    if let Some(ref expected_token) = token {
66        let token_for_middleware = expected_token.clone();
67        let origins_for_middleware = allowed_origins.clone();
68        router = router.layer(middleware::from_fn(move |req: Request, next: Next| {
69            let token = token_for_middleware.clone();
70            let origins = origins_for_middleware.clone();
71            async move { auth_and_origin_middleware(req, next, &token, &origins).await }
72        }));
73    }
74
75    let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
76    let listener = tokio::net::TcpListener::bind(addr).await?;
77    tracing::info!("origin-mcp HTTP server listening on {}", addr);
78
79    if token.is_some() {
80        tracing::info!("Bearer token authentication enabled");
81    } else {
82        tracing::warn!("Running without authentication — only safe on loopback");
83    }
84
85    let shutdown = async {
86        #[cfg(unix)]
87        {
88            use tokio::signal::unix::{signal, SignalKind};
89            let ctrl_c = tokio::signal::ctrl_c();
90            let mut sigterm =
91                signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
92            tokio::select! {
93                _ = ctrl_c => {},
94                _ = sigterm.recv() => {},
95            }
96        }
97        #[cfg(not(unix))]
98        {
99            tokio::signal::ctrl_c().await.ok();
100        }
101        tracing::info!("Shutting down origin-mcp HTTP server");
102    };
103
104    axum::serve(listener, router)
105        .with_graceful_shutdown(shutdown)
106        .await?;
107
108    Ok(())
109}
110
111fn build_cors_layer(allowed_origins: &[String]) -> CorsLayer {
112    let cors = CorsLayer::new()
113        .allow_methods([Method::GET, Method::POST, Method::DELETE, Method::OPTIONS])
114        .allow_headers([
115            http::header::AUTHORIZATION,
116            http::header::CONTENT_TYPE,
117            http::header::ACCEPT,
118            HeaderName::from_static("mcp-session-id"),
119            HeaderName::from_static("mcp-protocol-version"),
120        ]);
121
122    if allowed_origins.iter().any(|o| o == "*") {
123        cors.allow_origin(tower_http::cors::Any)
124    } else {
125        let origins: Vec<http::HeaderValue> = allowed_origins
126            .iter()
127            .filter_map(|o| o.parse().ok())
128            .collect();
129        cors.allow_origin(origins)
130    }
131}
132
133/// Auth middleware: bearer token first (401), then Origin header (403).
134async fn auth_and_origin_middleware(
135    req: Request,
136    next: Next,
137    expected_token: &str,
138    allowed_origins: &[String],
139) -> axum::response::Response {
140    let is_preflight = req.method() == Method::OPTIONS;
141    let is_health = req.uri().path() == "/health";
142    if is_preflight || is_health {
143        return next.run(req).await;
144    }
145
146    // 1. Validate bearer token FIRST
147    let auth_header = req.headers().get(http::header::AUTHORIZATION);
148    match auth_header {
149        Some(value) => {
150            let value_str = value.to_str().unwrap_or("");
151            match auth::extract_bearer_token(value_str) {
152                Some(provided) if auth::verify_token(provided, expected_token) => {}
153                _ => return (StatusCode::UNAUTHORIZED, "Invalid bearer token").into_response(),
154            }
155        }
156        None => return (StatusCode::UNAUTHORIZED, "Authorization header required").into_response(),
157    }
158
159    // 2. Validate Origin header AFTER auth
160    if let Some(origin) = req.headers().get(http::header::ORIGIN) {
161        if let Ok(origin_str) = origin.to_str() {
162            if !auth::is_origin_allowed(origin_str, allowed_origins) {
163                return (StatusCode::FORBIDDEN, "Origin not allowed").into_response();
164            }
165        }
166    }
167
168    next.run(req).await
169}