1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7 extract::Json,
8 extract::Query,
9 extract::State,
10 http::{header, Request, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::get,
14 Router,
15};
16use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
17use serde::Deserialize;
18use serde_json::Value;
19use tokio::time::{Duration, Instant};
20
21use crate::engine::ContextEngine;
22use crate::tools::LeanCtxServer;
23
24#[cfg(feature = "team-server")]
25pub mod team;
26
27#[derive(Clone, Debug)]
28pub struct HttpServerConfig {
29 pub host: String,
30 pub port: u16,
31 pub project_root: PathBuf,
32 pub auth_token: Option<String>,
33 pub stateful_mode: bool,
34 pub json_response: bool,
35 pub disable_host_check: bool,
36 pub allowed_hosts: Vec<String>,
37 pub max_body_bytes: usize,
38 pub max_concurrency: usize,
39 pub max_rps: u32,
40 pub rate_burst: u32,
41 pub request_timeout_ms: u64,
42}
43
44impl Default for HttpServerConfig {
45 fn default() -> Self {
46 let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
47 Self {
48 host: "127.0.0.1".to_string(),
49 port: 8080,
50 project_root,
51 auth_token: None,
52 stateful_mode: false,
53 json_response: true,
54 disable_host_check: false,
55 allowed_hosts: Vec::new(),
56 max_body_bytes: 2 * 1024 * 1024,
57 max_concurrency: 32,
58 max_rps: 50,
59 rate_burst: 100,
60 request_timeout_ms: 30_000,
61 }
62 }
63}
64
65impl HttpServerConfig {
66 pub fn validate(&self) -> Result<()> {
67 let host = self.host.trim().to_lowercase();
68 let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
69 if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
70 return Err(anyhow!(
71 "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
72 ));
73 }
74 Ok(())
75 }
76
77 fn mcp_http_config(&self) -> StreamableHttpServerConfig {
78 let mut cfg = StreamableHttpServerConfig::default()
79 .with_stateful_mode(self.stateful_mode)
80 .with_json_response(self.json_response);
81
82 if self.disable_host_check {
83 cfg = cfg.disable_allowed_hosts();
84 return cfg;
85 }
86
87 if !self.allowed_hosts.is_empty() {
88 cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
89 return cfg;
90 }
91
92 let host = self.host.trim();
94 if host == "127.0.0.1" || host == "localhost" || host == "::1" {
95 cfg.allowed_hosts.push(host.to_string());
96 }
97
98 cfg
99 }
100}
101
102#[derive(Clone)]
103struct AppState {
104 token: Option<String>,
105 concurrency: Arc<tokio::sync::Semaphore>,
106 rate: Arc<RateLimiter>,
107 engine: Arc<ContextEngine>,
108 timeout: Duration,
109}
110
111#[derive(Debug)]
112struct RateLimiter {
113 max_rps: f64,
114 burst: f64,
115 state: tokio::sync::Mutex<RateState>,
116}
117
118#[derive(Debug, Clone, Copy)]
119struct RateState {
120 tokens: f64,
121 last: Instant,
122}
123
124impl RateLimiter {
125 fn new(max_rps: u32, burst: u32) -> Self {
126 let now = Instant::now();
127 Self {
128 max_rps: (max_rps.max(1)) as f64,
129 burst: (burst.max(1)) as f64,
130 state: tokio::sync::Mutex::new(RateState {
131 tokens: (burst.max(1)) as f64,
132 last: now,
133 }),
134 }
135 }
136
137 async fn allow(&self) -> bool {
138 let mut s = self.state.lock().await;
139 let now = Instant::now();
140 let elapsed = now.saturating_duration_since(s.last);
141 let refill = elapsed.as_secs_f64() * self.max_rps;
142 s.tokens = (s.tokens + refill).min(self.burst);
143 s.last = now;
144 if s.tokens >= 1.0 {
145 s.tokens -= 1.0;
146 true
147 } else {
148 false
149 }
150 }
151}
152
153async fn auth_middleware(
154 State(state): State<AppState>,
155 req: Request<axum::body::Body>,
156 next: Next,
157) -> Response {
158 if state.token.is_none() {
159 return next.run(req).await;
160 }
161
162 if req.uri().path() == "/health" {
163 return next.run(req).await;
164 }
165
166 let expected = state.token.as_deref().unwrap_or("");
167 let Some(h) = req.headers().get(header::AUTHORIZATION) else {
168 return StatusCode::UNAUTHORIZED.into_response();
169 };
170 let Ok(s) = h.to_str() else {
171 return StatusCode::UNAUTHORIZED.into_response();
172 };
173 let Some(token) = s
174 .strip_prefix("Bearer ")
175 .or_else(|| s.strip_prefix("bearer "))
176 else {
177 return StatusCode::UNAUTHORIZED.into_response();
178 };
179 if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
180 return StatusCode::UNAUTHORIZED.into_response();
181 }
182
183 next.run(req).await
184}
185
186fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
187 if a.len() != b.len() {
188 return false;
189 }
190 a.iter()
191 .zip(b.iter())
192 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
193 == 0
194}
195
196async fn rate_limit_middleware(
197 State(state): State<AppState>,
198 req: Request<axum::body::Body>,
199 next: Next,
200) -> Response {
201 if req.uri().path() == "/health" {
202 return next.run(req).await;
203 }
204 if !state.rate.allow().await {
205 return StatusCode::TOO_MANY_REQUESTS.into_response();
206 }
207 next.run(req).await
208}
209
210async fn concurrency_middleware(
211 State(state): State<AppState>,
212 req: Request<axum::body::Body>,
213 next: Next,
214) -> Response {
215 if req.uri().path() == "/health" {
216 return next.run(req).await;
217 }
218 let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
219 return StatusCode::TOO_MANY_REQUESTS.into_response();
220 };
221 let resp = next.run(req).await;
222 drop(permit);
223 resp
224}
225
226async fn health() -> impl IntoResponse {
227 (StatusCode::OK, "ok\n")
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(rename_all = "camelCase")]
232struct ToolCallBody {
233 name: String,
234 #[serde(default)]
235 arguments: Option<Value>,
236}
237
238async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
239 let v = state.engine.manifest();
240 (StatusCode::OK, Json(v))
241}
242
243#[derive(Debug, Deserialize)]
244#[serde(rename_all = "camelCase")]
245struct ToolsQuery {
246 #[serde(default)]
247 offset: Option<usize>,
248 #[serde(default)]
249 limit: Option<usize>,
250}
251
252async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
253 let v = state.engine.manifest();
254 let tools = v
255 .get("tools")
256 .and_then(|t| t.get("granular"))
257 .cloned()
258 .unwrap_or(Value::Array(vec![]));
259
260 let all = tools.as_array().cloned().unwrap_or_default();
261 let total = all.len();
262 let offset = q.offset.unwrap_or(0).min(total);
263 let limit = q.limit.unwrap_or(200).min(500);
264 let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
265
266 (
267 StatusCode::OK,
268 Json(serde_json::json!({
269 "tools": page,
270 "total": total,
271 "offset": offset,
272 "limit": limit,
273 })),
274 )
275}
276
277async fn v1_tool_call(
278 State(state): State<AppState>,
279 Json(body): Json<ToolCallBody>,
280) -> impl IntoResponse {
281 match tokio::time::timeout(
282 state.timeout,
283 state.engine.call_tool_value(&body.name, body.arguments),
284 )
285 .await
286 {
287 Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
288 Ok(Err(e)) => (
289 StatusCode::BAD_REQUEST,
290 Json(serde_json::json!({ "error": e.to_string() })),
291 )
292 .into_response(),
293 Err(_) => (
294 StatusCode::GATEWAY_TIMEOUT,
295 Json(serde_json::json!({ "error": "request_timeout" })),
296 )
297 .into_response(),
298 }
299}
300
301pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
302 cfg.validate()?;
303
304 let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
305 .parse()
306 .context("invalid host/port")?;
307
308 let project_root = cfg.project_root.to_string_lossy().to_string();
309 let base = LeanCtxServer::new_with_project_root(Some(&project_root));
310 let engine = Arc::new(ContextEngine::from_server(base.clone()));
311
312 let service_factory = move || Ok(base.clone());
313 let mcp_http = StreamableHttpService::new(
314 service_factory,
315 Arc::new(
316 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
317 ),
318 cfg.mcp_http_config(),
319 );
320
321 let state = AppState {
322 token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
323 concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
324 rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
325 engine,
326 timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
327 };
328
329 let app = Router::new()
330 .route("/health", get(health))
331 .route("/v1/manifest", get(v1_manifest))
332 .route("/v1/tools", get(v1_tools))
333 .route("/v1/tools/call", axum::routing::post(v1_tool_call))
334 .fallback_service(mcp_http)
335 .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
336 .layer(middleware::from_fn_with_state(
337 state.clone(),
338 rate_limit_middleware,
339 ))
340 .layer(middleware::from_fn_with_state(
341 state.clone(),
342 concurrency_middleware,
343 ))
344 .layer(middleware::from_fn_with_state(
345 state.clone(),
346 auth_middleware,
347 ))
348 .with_state(state);
349
350 let listener = tokio::net::TcpListener::bind(addr)
351 .await
352 .with_context(|| format!("bind {addr}"))?;
353
354 tracing::info!(
355 "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
356 cfg.project_root.display()
357 );
358
359 axum::serve(listener, app)
360 .with_graceful_shutdown(async move {
361 let _ = tokio::signal::ctrl_c().await;
362 })
363 .await
364 .context("http server")?;
365 Ok(())
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use axum::body::Body;
372 use axum::http::Request;
373 use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
374 use serde_json::json;
375 use tower::ServiceExt;
376
377 #[tokio::test]
378 async fn auth_token_blocks_requests_without_bearer_header() {
379 let dir = tempfile::tempdir().expect("tempdir");
380 let root_str = dir.path().to_string_lossy().to_string();
381 let base = LeanCtxServer::new_with_project_root(Some(&root_str));
382 let service_factory = move || Ok(base.clone());
383 let cfg = StreamableHttpServerConfig::default()
384 .with_stateful_mode(false)
385 .with_json_response(true);
386
387 let mcp_http = StreamableHttpService::new(
388 service_factory,
389 Arc::new(
390 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
391 ),
392 cfg,
393 );
394
395 let state = AppState {
396 token: Some("secret".to_string()),
397 concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
398 rate: Arc::new(RateLimiter::new(50, 100)),
399 engine: Arc::new(ContextEngine::from_server(
400 LeanCtxServer::new_with_project_root(Some(&root_str)),
401 )),
402 timeout: Duration::from_millis(30_000),
403 };
404
405 let app = Router::new()
406 .fallback_service(mcp_http)
407 .layer(middleware::from_fn_with_state(
408 state.clone(),
409 auth_middleware,
410 ))
411 .with_state(state);
412
413 let body = json!({
414 "jsonrpc": "2.0",
415 "id": 1,
416 "method": "tools/list",
417 "params": {}
418 })
419 .to_string();
420
421 let req = Request::builder()
422 .method("POST")
423 .uri("/")
424 .header("Host", "localhost")
425 .header("Accept", "application/json, text/event-stream")
426 .header("Content-Type", "application/json")
427 .body(Body::from(body))
428 .expect("request");
429
430 let resp = app.clone().oneshot(req).await.expect("resp");
431 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
432 }
433
434 #[tokio::test]
435 async fn rate_limit_returns_429_when_exhausted() {
436 let state = AppState {
437 token: None,
438 concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
439 rate: Arc::new(RateLimiter::new(1, 1)),
440 engine: Arc::new(ContextEngine::new()),
441 timeout: Duration::from_millis(30_000),
442 };
443
444 let app = Router::new()
445 .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
446 .layer(middleware::from_fn_with_state(
447 state.clone(),
448 rate_limit_middleware,
449 ))
450 .with_state(state);
451
452 let req1 = Request::builder()
453 .method("GET")
454 .uri("/limited")
455 .header("Host", "localhost")
456 .body(Body::empty())
457 .expect("req1");
458 let resp1 = app.clone().oneshot(req1).await.expect("resp1");
459 assert_eq!(resp1.status(), StatusCode::OK);
460
461 let req2 = Request::builder()
462 .method("GET")
463 .uri("/limited")
464 .header("Host", "localhost")
465 .body(Body::empty())
466 .expect("req2");
467 let resp2 = app.clone().oneshot(req2).await.expect("resp2");
468 assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
469 }
470}