1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7 extract::Json,
8 extract::Query,
9 extract::State,
10 http::{header, Request, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::get,
14 Router,
15};
16use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
17use serde::Deserialize;
18use serde_json::Value;
19use tokio::time::{Duration, Instant};
20
21use crate::engine::ContextEngine;
22use crate::tools::LeanCtxServer;
23
24#[derive(Clone, Debug)]
25pub struct HttpServerConfig {
26 pub host: String,
27 pub port: u16,
28 pub project_root: PathBuf,
29 pub auth_token: Option<String>,
30 pub stateful_mode: bool,
31 pub json_response: bool,
32 pub disable_host_check: bool,
33 pub allowed_hosts: Vec<String>,
34 pub max_body_bytes: usize,
35 pub max_concurrency: usize,
36 pub max_rps: u32,
37 pub rate_burst: u32,
38 pub request_timeout_ms: u64,
39}
40
41impl Default for HttpServerConfig {
42 fn default() -> Self {
43 let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
44 Self {
45 host: "127.0.0.1".to_string(),
46 port: 8080,
47 project_root,
48 auth_token: None,
49 stateful_mode: false,
50 json_response: true,
51 disable_host_check: false,
52 allowed_hosts: Vec::new(),
53 max_body_bytes: 2 * 1024 * 1024,
54 max_concurrency: 32,
55 max_rps: 50,
56 rate_burst: 100,
57 request_timeout_ms: 30_000,
58 }
59 }
60}
61
62impl HttpServerConfig {
63 pub fn validate(&self) -> Result<()> {
64 let host = self.host.trim().to_lowercase();
65 let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
66 if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
67 return Err(anyhow!(
68 "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
69 ));
70 }
71 Ok(())
72 }
73
74 fn mcp_http_config(&self) -> StreamableHttpServerConfig {
75 let mut cfg = StreamableHttpServerConfig::default()
76 .with_stateful_mode(self.stateful_mode)
77 .with_json_response(self.json_response);
78
79 if self.disable_host_check {
80 cfg = cfg.disable_allowed_hosts();
81 return cfg;
82 }
83
84 if !self.allowed_hosts.is_empty() {
85 cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
86 return cfg;
87 }
88
89 let host = self.host.trim();
91 if host == "127.0.0.1" || host == "localhost" || host == "::1" {
92 cfg.allowed_hosts.push(host.to_string());
93 }
94
95 cfg
96 }
97}
98
99#[derive(Clone)]
100struct AppState {
101 token: Option<String>,
102 concurrency: Arc<tokio::sync::Semaphore>,
103 rate: Arc<RateLimiter>,
104 engine: Arc<ContextEngine>,
105 timeout: Duration,
106}
107
108#[derive(Debug)]
109struct RateLimiter {
110 max_rps: f64,
111 burst: f64,
112 state: tokio::sync::Mutex<RateState>,
113}
114
115#[derive(Debug, Clone, Copy)]
116struct RateState {
117 tokens: f64,
118 last: Instant,
119}
120
121impl RateLimiter {
122 fn new(max_rps: u32, burst: u32) -> Self {
123 let now = Instant::now();
124 Self {
125 max_rps: (max_rps.max(1)) as f64,
126 burst: (burst.max(1)) as f64,
127 state: tokio::sync::Mutex::new(RateState {
128 tokens: (burst.max(1)) as f64,
129 last: now,
130 }),
131 }
132 }
133
134 async fn allow(&self) -> bool {
135 let mut s = self.state.lock().await;
136 let now = Instant::now();
137 let elapsed = now.saturating_duration_since(s.last);
138 let refill = elapsed.as_secs_f64() * self.max_rps;
139 s.tokens = (s.tokens + refill).min(self.burst);
140 s.last = now;
141 if s.tokens >= 1.0 {
142 s.tokens -= 1.0;
143 true
144 } else {
145 false
146 }
147 }
148}
149
150async fn auth_middleware(
151 State(state): State<AppState>,
152 req: Request<axum::body::Body>,
153 next: Next,
154) -> Response {
155 if state.token.is_none() {
156 return next.run(req).await;
157 }
158
159 if req.uri().path() == "/health" {
160 return next.run(req).await;
161 }
162
163 let expected = state.token.as_deref().unwrap_or("");
164 let Some(h) = req.headers().get(header::AUTHORIZATION) else {
165 return StatusCode::UNAUTHORIZED.into_response();
166 };
167 let Ok(s) = h.to_str() else {
168 return StatusCode::UNAUTHORIZED.into_response();
169 };
170 let Some(token) = s
171 .strip_prefix("Bearer ")
172 .or_else(|| s.strip_prefix("bearer "))
173 else {
174 return StatusCode::UNAUTHORIZED.into_response();
175 };
176 if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
177 return StatusCode::UNAUTHORIZED.into_response();
178 }
179
180 next.run(req).await
181}
182
183fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
184 if a.len() != b.len() {
185 return false;
186 }
187 a.iter()
188 .zip(b.iter())
189 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
190 == 0
191}
192
193async fn rate_limit_middleware(
194 State(state): State<AppState>,
195 req: Request<axum::body::Body>,
196 next: Next,
197) -> Response {
198 if req.uri().path() == "/health" {
199 return next.run(req).await;
200 }
201 if !state.rate.allow().await {
202 return StatusCode::TOO_MANY_REQUESTS.into_response();
203 }
204 next.run(req).await
205}
206
207async fn concurrency_middleware(
208 State(state): State<AppState>,
209 req: Request<axum::body::Body>,
210 next: Next,
211) -> Response {
212 if req.uri().path() == "/health" {
213 return next.run(req).await;
214 }
215 let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
216 return StatusCode::TOO_MANY_REQUESTS.into_response();
217 };
218 let resp = next.run(req).await;
219 drop(permit);
220 resp
221}
222
223async fn health() -> impl IntoResponse {
224 (StatusCode::OK, "ok\n")
225}
226
227#[derive(Debug, Deserialize)]
228#[serde(rename_all = "camelCase")]
229struct ToolCallBody {
230 name: String,
231 #[serde(default)]
232 arguments: Option<Value>,
233}
234
235async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
236 let v = state.engine.manifest();
237 (StatusCode::OK, Json(v))
238}
239
240#[derive(Debug, Deserialize)]
241#[serde(rename_all = "camelCase")]
242struct ToolsQuery {
243 #[serde(default)]
244 offset: Option<usize>,
245 #[serde(default)]
246 limit: Option<usize>,
247}
248
249async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
250 let v = state.engine.manifest();
251 let tools = v
252 .get("tools")
253 .and_then(|t| t.get("granular"))
254 .cloned()
255 .unwrap_or(Value::Array(vec![]));
256
257 let all = tools.as_array().cloned().unwrap_or_default();
258 let total = all.len();
259 let offset = q.offset.unwrap_or(0).min(total);
260 let limit = q.limit.unwrap_or(200).min(500);
261 let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
262
263 (
264 StatusCode::OK,
265 Json(serde_json::json!({
266 "tools": page,
267 "total": total,
268 "offset": offset,
269 "limit": limit,
270 })),
271 )
272}
273
274async fn v1_tool_call(
275 State(state): State<AppState>,
276 Json(body): Json<ToolCallBody>,
277) -> impl IntoResponse {
278 match tokio::time::timeout(
279 state.timeout,
280 state.engine.call_tool_value(&body.name, body.arguments),
281 )
282 .await
283 {
284 Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
285 Ok(Err(e)) => (
286 StatusCode::BAD_REQUEST,
287 Json(serde_json::json!({ "error": e.to_string() })),
288 )
289 .into_response(),
290 Err(_) => (
291 StatusCode::GATEWAY_TIMEOUT,
292 Json(serde_json::json!({ "error": "request_timeout" })),
293 )
294 .into_response(),
295 }
296}
297
298pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
299 cfg.validate()?;
300
301 let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
302 .parse()
303 .context("invalid host/port")?;
304
305 let project_root = cfg.project_root.to_string_lossy().to_string();
306 let base = LeanCtxServer::new_with_project_root(Some(project_root));
307 let engine = Arc::new(ContextEngine::from_server(base.clone()));
308
309 let service_factory = move || Ok(base.clone());
310 let mcp_http = StreamableHttpService::new(
311 service_factory,
312 Arc::new(
313 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
314 ),
315 cfg.mcp_http_config(),
316 );
317
318 let state = AppState {
319 token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
320 concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
321 rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
322 engine,
323 timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
324 };
325
326 let app = Router::new()
327 .route("/health", get(health))
328 .route("/v1/manifest", get(v1_manifest))
329 .route("/v1/tools", get(v1_tools))
330 .route("/v1/tools/call", axum::routing::post(v1_tool_call))
331 .fallback_service(mcp_http)
332 .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
333 .layer(middleware::from_fn_with_state(
334 state.clone(),
335 rate_limit_middleware,
336 ))
337 .layer(middleware::from_fn_with_state(
338 state.clone(),
339 concurrency_middleware,
340 ))
341 .layer(middleware::from_fn_with_state(
342 state.clone(),
343 auth_middleware,
344 ))
345 .with_state(state);
346
347 let listener = tokio::net::TcpListener::bind(addr)
348 .await
349 .with_context(|| format!("bind {addr}"))?;
350
351 tracing::info!(
352 "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
353 cfg.project_root.display()
354 );
355
356 axum::serve(listener, app)
357 .with_graceful_shutdown(async move {
358 let _ = tokio::signal::ctrl_c().await;
359 })
360 .await
361 .context("http server")?;
362 Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use axum::body::Body;
369 use axum::http::Request;
370 use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
371 use serde_json::json;
372 use tower::ServiceExt;
373
374 #[tokio::test]
375 async fn auth_token_blocks_requests_without_bearer_header() {
376 let dir = tempfile::tempdir().expect("tempdir");
377 let base =
378 LeanCtxServer::new_with_project_root(Some(dir.path().to_string_lossy().to_string()));
379 let service_factory = move || Ok(base.clone());
380 let cfg = StreamableHttpServerConfig::default()
381 .with_stateful_mode(false)
382 .with_json_response(true);
383
384 let mcp_http = StreamableHttpService::new(
385 service_factory,
386 Arc::new(
387 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
388 ),
389 cfg,
390 );
391
392 let state = AppState {
393 token: Some("secret".to_string()),
394 concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
395 rate: Arc::new(RateLimiter::new(50, 100)),
396 engine: Arc::new(ContextEngine::from_server(
397 LeanCtxServer::new_with_project_root(Some(
398 dir.path().to_string_lossy().to_string(),
399 )),
400 )),
401 timeout: Duration::from_millis(30_000),
402 };
403
404 let app = Router::new()
405 .fallback_service(mcp_http)
406 .layer(middleware::from_fn_with_state(
407 state.clone(),
408 auth_middleware,
409 ))
410 .with_state(state);
411
412 let body = json!({
413 "jsonrpc": "2.0",
414 "id": 1,
415 "method": "tools/list",
416 "params": {}
417 })
418 .to_string();
419
420 let req = Request::builder()
421 .method("POST")
422 .uri("/")
423 .header("Host", "localhost")
424 .header("Accept", "application/json, text/event-stream")
425 .header("Content-Type", "application/json")
426 .body(Body::from(body))
427 .expect("request");
428
429 let resp = app.clone().oneshot(req).await.expect("resp");
430 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
431 }
432
433 #[tokio::test]
434 async fn rate_limit_returns_429_when_exhausted() {
435 let state = AppState {
436 token: None,
437 concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
438 rate: Arc::new(RateLimiter::new(1, 1)),
439 engine: Arc::new(ContextEngine::new()),
440 timeout: Duration::from_millis(30_000),
441 };
442
443 let app = Router::new()
444 .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
445 .layer(middleware::from_fn_with_state(
446 state.clone(),
447 rate_limit_middleware,
448 ))
449 .with_state(state);
450
451 let req1 = Request::builder()
452 .method("GET")
453 .uri("/limited")
454 .header("Host", "localhost")
455 .body(Body::empty())
456 .expect("req1");
457 let resp1 = app.clone().oneshot(req1).await.expect("resp1");
458 assert_eq!(resp1.status(), StatusCode::OK);
459
460 let req2 = Request::builder()
461 .method("GET")
462 .uri("/limited")
463 .header("Host", "localhost")
464 .body(Body::empty())
465 .expect("req2");
466 let resp2 = app.clone().oneshot(req2).await.expect("resp2");
467 assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
468 }
469}