Skip to main content

lean_ctx/http_server/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7    extract::Json,
8    extract::Query,
9    extract::State,
10    http::{header, Request, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::get,
14    Router,
15};
16use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
17use serde::Deserialize;
18use serde_json::Value;
19use tokio::time::{Duration, Instant};
20
21use crate::engine::ContextEngine;
22use crate::tools::LeanCtxServer;
23
24#[cfg(feature = "team-server")]
25pub mod team;
26
27#[derive(Clone, Debug)]
28pub struct HttpServerConfig {
29    pub host: String,
30    pub port: u16,
31    pub project_root: PathBuf,
32    pub auth_token: Option<String>,
33    pub stateful_mode: bool,
34    pub json_response: bool,
35    pub disable_host_check: bool,
36    pub allowed_hosts: Vec<String>,
37    pub max_body_bytes: usize,
38    pub max_concurrency: usize,
39    pub max_rps: u32,
40    pub rate_burst: u32,
41    pub request_timeout_ms: u64,
42}
43
44impl Default for HttpServerConfig {
45    fn default() -> Self {
46        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
47        Self {
48            host: "127.0.0.1".to_string(),
49            port: 8080,
50            project_root,
51            auth_token: None,
52            stateful_mode: false,
53            json_response: true,
54            disable_host_check: false,
55            allowed_hosts: Vec::new(),
56            max_body_bytes: 2 * 1024 * 1024,
57            max_concurrency: 32,
58            max_rps: 50,
59            rate_burst: 100,
60            request_timeout_ms: 30_000,
61        }
62    }
63}
64
65impl HttpServerConfig {
66    pub fn validate(&self) -> Result<()> {
67        let host = self.host.trim().to_lowercase();
68        let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
69        if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
70            return Err(anyhow!(
71                "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
72            ));
73        }
74        Ok(())
75    }
76
77    fn mcp_http_config(&self) -> StreamableHttpServerConfig {
78        let mut cfg = StreamableHttpServerConfig::default()
79            .with_stateful_mode(self.stateful_mode)
80            .with_json_response(self.json_response);
81
82        if self.disable_host_check {
83            cfg = cfg.disable_allowed_hosts();
84            return cfg;
85        }
86
87        if !self.allowed_hosts.is_empty() {
88            cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
89            return cfg;
90        }
91
92        // Keep rmcp's secure loopback defaults; also allow the configured host (if it's loopback).
93        let host = self.host.trim();
94        if host == "127.0.0.1" || host == "localhost" || host == "::1" {
95            cfg.allowed_hosts.push(host.to_string());
96        }
97
98        cfg
99    }
100}
101
102#[derive(Clone)]
103struct AppState {
104    token: Option<String>,
105    concurrency: Arc<tokio::sync::Semaphore>,
106    rate: Arc<RateLimiter>,
107    engine: Arc<ContextEngine>,
108    timeout: Duration,
109}
110
111#[derive(Debug)]
112struct RateLimiter {
113    max_rps: f64,
114    burst: f64,
115    state: tokio::sync::Mutex<RateState>,
116}
117
118#[derive(Debug, Clone, Copy)]
119struct RateState {
120    tokens: f64,
121    last: Instant,
122}
123
124impl RateLimiter {
125    fn new(max_rps: u32, burst: u32) -> Self {
126        let now = Instant::now();
127        Self {
128            max_rps: (max_rps.max(1)) as f64,
129            burst: (burst.max(1)) as f64,
130            state: tokio::sync::Mutex::new(RateState {
131                tokens: (burst.max(1)) as f64,
132                last: now,
133            }),
134        }
135    }
136
137    async fn allow(&self) -> bool {
138        let mut s = self.state.lock().await;
139        let now = Instant::now();
140        let elapsed = now.saturating_duration_since(s.last);
141        let refill = elapsed.as_secs_f64() * self.max_rps;
142        s.tokens = (s.tokens + refill).min(self.burst);
143        s.last = now;
144        if s.tokens >= 1.0 {
145            s.tokens -= 1.0;
146            true
147        } else {
148            false
149        }
150    }
151}
152
153async fn auth_middleware(
154    State(state): State<AppState>,
155    req: Request<axum::body::Body>,
156    next: Next,
157) -> Response {
158    if state.token.is_none() {
159        return next.run(req).await;
160    }
161
162    if req.uri().path() == "/health" {
163        return next.run(req).await;
164    }
165
166    let expected = state.token.as_deref().unwrap_or("");
167    let Some(h) = req.headers().get(header::AUTHORIZATION) else {
168        return StatusCode::UNAUTHORIZED.into_response();
169    };
170    let Ok(s) = h.to_str() else {
171        return StatusCode::UNAUTHORIZED.into_response();
172    };
173    let Some(token) = s
174        .strip_prefix("Bearer ")
175        .or_else(|| s.strip_prefix("bearer "))
176    else {
177        return StatusCode::UNAUTHORIZED.into_response();
178    };
179    if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
180        return StatusCode::UNAUTHORIZED.into_response();
181    }
182
183    next.run(req).await
184}
185
186fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
187    if a.len() != b.len() {
188        return false;
189    }
190    a.iter()
191        .zip(b.iter())
192        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
193        == 0
194}
195
196async fn rate_limit_middleware(
197    State(state): State<AppState>,
198    req: Request<axum::body::Body>,
199    next: Next,
200) -> Response {
201    if req.uri().path() == "/health" {
202        return next.run(req).await;
203    }
204    if !state.rate.allow().await {
205        return StatusCode::TOO_MANY_REQUESTS.into_response();
206    }
207    next.run(req).await
208}
209
210async fn concurrency_middleware(
211    State(state): State<AppState>,
212    req: Request<axum::body::Body>,
213    next: Next,
214) -> Response {
215    if req.uri().path() == "/health" {
216        return next.run(req).await;
217    }
218    let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
219        return StatusCode::TOO_MANY_REQUESTS.into_response();
220    };
221    let resp = next.run(req).await;
222    drop(permit);
223    resp
224}
225
226async fn health() -> impl IntoResponse {
227    (StatusCode::OK, "ok\n")
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(rename_all = "camelCase")]
232struct ToolCallBody {
233    name: String,
234    #[serde(default)]
235    arguments: Option<Value>,
236}
237
238async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
239    let v = state.engine.manifest();
240    (StatusCode::OK, Json(v))
241}
242
243#[derive(Debug, Deserialize)]
244#[serde(rename_all = "camelCase")]
245struct ToolsQuery {
246    #[serde(default)]
247    offset: Option<usize>,
248    #[serde(default)]
249    limit: Option<usize>,
250}
251
252async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
253    let v = state.engine.manifest();
254    let tools = v
255        .get("tools")
256        .and_then(|t| t.get("granular"))
257        .cloned()
258        .unwrap_or(Value::Array(vec![]));
259
260    let all = tools.as_array().cloned().unwrap_or_default();
261    let total = all.len();
262    let offset = q.offset.unwrap_or(0).min(total);
263    let limit = q.limit.unwrap_or(200).min(500);
264    let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
265
266    (
267        StatusCode::OK,
268        Json(serde_json::json!({
269            "tools": page,
270            "total": total,
271            "offset": offset,
272            "limit": limit,
273        })),
274    )
275}
276
277async fn v1_tool_call(
278    State(state): State<AppState>,
279    Json(body): Json<ToolCallBody>,
280) -> impl IntoResponse {
281    match tokio::time::timeout(
282        state.timeout,
283        state.engine.call_tool_value(&body.name, body.arguments),
284    )
285    .await
286    {
287        Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
288        Ok(Err(e)) => (
289            StatusCode::BAD_REQUEST,
290            Json(serde_json::json!({ "error": e.to_string() })),
291        )
292            .into_response(),
293        Err(_) => (
294            StatusCode::GATEWAY_TIMEOUT,
295            Json(serde_json::json!({ "error": "request_timeout" })),
296        )
297            .into_response(),
298    }
299}
300
301pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
302    cfg.validate()?;
303
304    let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
305        .parse()
306        .context("invalid host/port")?;
307
308    let project_root = cfg.project_root.to_string_lossy().to_string();
309    let base = LeanCtxServer::new_with_project_root(Some(&project_root));
310    let engine = Arc::new(ContextEngine::from_server(base.clone()));
311
312    let service_factory = move || Ok(base.clone());
313    let mcp_http = StreamableHttpService::new(
314        service_factory,
315        Arc::new(
316            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
317        ),
318        cfg.mcp_http_config(),
319    );
320
321    let state = AppState {
322        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
323        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
324        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
325        engine,
326        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
327    };
328
329    let app = Router::new()
330        .route("/health", get(health))
331        .route("/v1/manifest", get(v1_manifest))
332        .route("/v1/tools", get(v1_tools))
333        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
334        .fallback_service(mcp_http)
335        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
336        .layer(middleware::from_fn_with_state(
337            state.clone(),
338            rate_limit_middleware,
339        ))
340        .layer(middleware::from_fn_with_state(
341            state.clone(),
342            concurrency_middleware,
343        ))
344        .layer(middleware::from_fn_with_state(
345            state.clone(),
346            auth_middleware,
347        ))
348        .with_state(state);
349
350    let listener = tokio::net::TcpListener::bind(addr)
351        .await
352        .with_context(|| format!("bind {addr}"))?;
353
354    tracing::info!(
355        "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
356        cfg.project_root.display()
357    );
358
359    axum::serve(listener, app)
360        .with_graceful_shutdown(async move {
361            let _ = tokio::signal::ctrl_c().await;
362        })
363        .await
364        .context("http server")?;
365    Ok(())
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use axum::body::Body;
372    use axum::http::Request;
373    use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
374    use serde_json::json;
375    use tower::ServiceExt;
376
377    #[tokio::test]
378    async fn auth_token_blocks_requests_without_bearer_header() {
379        let dir = tempfile::tempdir().expect("tempdir");
380        let root_str = dir.path().to_string_lossy().to_string();
381        let base = LeanCtxServer::new_with_project_root(Some(&root_str));
382        let service_factory = move || Ok(base.clone());
383        let cfg = StreamableHttpServerConfig::default()
384            .with_stateful_mode(false)
385            .with_json_response(true);
386
387        let mcp_http = StreamableHttpService::new(
388            service_factory,
389            Arc::new(
390                rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
391            ),
392            cfg,
393        );
394
395        let state = AppState {
396            token: Some("secret".to_string()),
397            concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
398            rate: Arc::new(RateLimiter::new(50, 100)),
399            engine: Arc::new(ContextEngine::from_server(
400                LeanCtxServer::new_with_project_root(Some(&root_str)),
401            )),
402            timeout: Duration::from_millis(30_000),
403        };
404
405        let app = Router::new()
406            .fallback_service(mcp_http)
407            .layer(middleware::from_fn_with_state(
408                state.clone(),
409                auth_middleware,
410            ))
411            .with_state(state);
412
413        let body = json!({
414            "jsonrpc": "2.0",
415            "id": 1,
416            "method": "tools/list",
417            "params": {}
418        })
419        .to_string();
420
421        let req = Request::builder()
422            .method("POST")
423            .uri("/")
424            .header("Host", "localhost")
425            .header("Accept", "application/json, text/event-stream")
426            .header("Content-Type", "application/json")
427            .body(Body::from(body))
428            .expect("request");
429
430        let resp = app.clone().oneshot(req).await.expect("resp");
431        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
432    }
433
434    #[tokio::test]
435    async fn rate_limit_returns_429_when_exhausted() {
436        let state = AppState {
437            token: None,
438            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
439            rate: Arc::new(RateLimiter::new(1, 1)),
440            engine: Arc::new(ContextEngine::new()),
441            timeout: Duration::from_millis(30_000),
442        };
443
444        let app = Router::new()
445            .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
446            .layer(middleware::from_fn_with_state(
447                state.clone(),
448                rate_limit_middleware,
449            ))
450            .with_state(state);
451
452        let req1 = Request::builder()
453            .method("GET")
454            .uri("/limited")
455            .header("Host", "localhost")
456            .body(Body::empty())
457            .expect("req1");
458        let resp1 = app.clone().oneshot(req1).await.expect("resp1");
459        assert_eq!(resp1.status(), StatusCode::OK);
460
461        let req2 = Request::builder()
462            .method("GET")
463            .uri("/limited")
464            .header("Host", "localhost")
465            .body(Body::empty())
466            .expect("req2");
467        let resp2 = app.clone().oneshot(req2).await.expect("resp2");
468        assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
469    }
470}