Skip to main content

lean_ctx/http_server/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7    extract::Json,
8    extract::Query,
9    extract::State,
10    http::{header, Request, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::get,
14    Router,
15};
16use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
17use serde::Deserialize;
18use serde_json::Value;
19use tokio::time::{Duration, Instant};
20
21use crate::engine::ContextEngine;
22use crate::tools::LeanCtxServer;
23
24#[derive(Clone, Debug)]
25pub struct HttpServerConfig {
26    pub host: String,
27    pub port: u16,
28    pub project_root: PathBuf,
29    pub auth_token: Option<String>,
30    pub stateful_mode: bool,
31    pub json_response: bool,
32    pub disable_host_check: bool,
33    pub allowed_hosts: Vec<String>,
34    pub max_body_bytes: usize,
35    pub max_concurrency: usize,
36    pub max_rps: u32,
37    pub rate_burst: u32,
38    pub request_timeout_ms: u64,
39}
40
41impl Default for HttpServerConfig {
42    fn default() -> Self {
43        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
44        Self {
45            host: "127.0.0.1".to_string(),
46            port: 8080,
47            project_root,
48            auth_token: None,
49            stateful_mode: false,
50            json_response: true,
51            disable_host_check: false,
52            allowed_hosts: Vec::new(),
53            max_body_bytes: 2 * 1024 * 1024,
54            max_concurrency: 32,
55            max_rps: 50,
56            rate_burst: 100,
57            request_timeout_ms: 30_000,
58        }
59    }
60}
61
62impl HttpServerConfig {
63    pub fn validate(&self) -> Result<()> {
64        let host = self.host.trim().to_lowercase();
65        let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
66        if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
67            return Err(anyhow!(
68                "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
69            ));
70        }
71        Ok(())
72    }
73
74    fn mcp_http_config(&self) -> StreamableHttpServerConfig {
75        let mut cfg = StreamableHttpServerConfig::default()
76            .with_stateful_mode(self.stateful_mode)
77            .with_json_response(self.json_response);
78
79        if self.disable_host_check {
80            cfg = cfg.disable_allowed_hosts();
81            return cfg;
82        }
83
84        if !self.allowed_hosts.is_empty() {
85            cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
86            return cfg;
87        }
88
89        // Keep rmcp's secure loopback defaults; also allow the configured host (if it's loopback).
90        let host = self.host.trim();
91        if host == "127.0.0.1" || host == "localhost" || host == "::1" {
92            cfg.allowed_hosts.push(host.to_string());
93        }
94
95        cfg
96    }
97}
98
99#[derive(Clone)]
100struct AppState {
101    token: Option<String>,
102    concurrency: Arc<tokio::sync::Semaphore>,
103    rate: Arc<RateLimiter>,
104    engine: Arc<ContextEngine>,
105    timeout: Duration,
106}
107
108#[derive(Debug)]
109struct RateLimiter {
110    max_rps: f64,
111    burst: f64,
112    state: tokio::sync::Mutex<RateState>,
113}
114
115#[derive(Debug, Clone, Copy)]
116struct RateState {
117    tokens: f64,
118    last: Instant,
119}
120
121impl RateLimiter {
122    fn new(max_rps: u32, burst: u32) -> Self {
123        let now = Instant::now();
124        Self {
125            max_rps: (max_rps.max(1)) as f64,
126            burst: (burst.max(1)) as f64,
127            state: tokio::sync::Mutex::new(RateState {
128                tokens: (burst.max(1)) as f64,
129                last: now,
130            }),
131        }
132    }
133
134    async fn allow(&self) -> bool {
135        let mut s = self.state.lock().await;
136        let now = Instant::now();
137        let elapsed = now.saturating_duration_since(s.last);
138        let refill = elapsed.as_secs_f64() * self.max_rps;
139        s.tokens = (s.tokens + refill).min(self.burst);
140        s.last = now;
141        if s.tokens >= 1.0 {
142            s.tokens -= 1.0;
143            true
144        } else {
145            false
146        }
147    }
148}
149
150async fn auth_middleware(
151    State(state): State<AppState>,
152    req: Request<axum::body::Body>,
153    next: Next,
154) -> Response {
155    if state.token.is_none() {
156        return next.run(req).await;
157    }
158
159    if req.uri().path() == "/health" {
160        return next.run(req).await;
161    }
162
163    let expected = state.token.as_deref().unwrap_or("");
164    let Some(h) = req.headers().get(header::AUTHORIZATION) else {
165        return StatusCode::UNAUTHORIZED.into_response();
166    };
167    let Ok(s) = h.to_str() else {
168        return StatusCode::UNAUTHORIZED.into_response();
169    };
170    let Some(token) = s
171        .strip_prefix("Bearer ")
172        .or_else(|| s.strip_prefix("bearer "))
173    else {
174        return StatusCode::UNAUTHORIZED.into_response();
175    };
176    if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
177        return StatusCode::UNAUTHORIZED.into_response();
178    }
179
180    next.run(req).await
181}
182
183fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
184    if a.len() != b.len() {
185        return false;
186    }
187    a.iter()
188        .zip(b.iter())
189        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
190        == 0
191}
192
193async fn rate_limit_middleware(
194    State(state): State<AppState>,
195    req: Request<axum::body::Body>,
196    next: Next,
197) -> Response {
198    if req.uri().path() == "/health" {
199        return next.run(req).await;
200    }
201    if !state.rate.allow().await {
202        return StatusCode::TOO_MANY_REQUESTS.into_response();
203    }
204    next.run(req).await
205}
206
207async fn concurrency_middleware(
208    State(state): State<AppState>,
209    req: Request<axum::body::Body>,
210    next: Next,
211) -> Response {
212    if req.uri().path() == "/health" {
213        return next.run(req).await;
214    }
215    let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
216        return StatusCode::TOO_MANY_REQUESTS.into_response();
217    };
218    let resp = next.run(req).await;
219    drop(permit);
220    resp
221}
222
223async fn health() -> impl IntoResponse {
224    (StatusCode::OK, "ok\n")
225}
226
227#[derive(Debug, Deserialize)]
228#[serde(rename_all = "camelCase")]
229struct ToolCallBody {
230    name: String,
231    #[serde(default)]
232    arguments: Option<Value>,
233}
234
235async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
236    let v = state.engine.manifest();
237    (StatusCode::OK, Json(v))
238}
239
240#[derive(Debug, Deserialize)]
241#[serde(rename_all = "camelCase")]
242struct ToolsQuery {
243    #[serde(default)]
244    offset: Option<usize>,
245    #[serde(default)]
246    limit: Option<usize>,
247}
248
249async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
250    let v = state.engine.manifest();
251    let tools = v
252        .get("tools")
253        .and_then(|t| t.get("granular"))
254        .cloned()
255        .unwrap_or(Value::Array(vec![]));
256
257    let all = tools.as_array().cloned().unwrap_or_default();
258    let total = all.len();
259    let offset = q.offset.unwrap_or(0).min(total);
260    let limit = q.limit.unwrap_or(200).min(500);
261    let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
262
263    (
264        StatusCode::OK,
265        Json(serde_json::json!({
266            "tools": page,
267            "total": total,
268            "offset": offset,
269            "limit": limit,
270        })),
271    )
272}
273
274async fn v1_tool_call(
275    State(state): State<AppState>,
276    Json(body): Json<ToolCallBody>,
277) -> impl IntoResponse {
278    match tokio::time::timeout(
279        state.timeout,
280        state.engine.call_tool_value(&body.name, body.arguments),
281    )
282    .await
283    {
284        Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
285        Ok(Err(e)) => (
286            StatusCode::BAD_REQUEST,
287            Json(serde_json::json!({ "error": e.to_string() })),
288        )
289            .into_response(),
290        Err(_) => (
291            StatusCode::GATEWAY_TIMEOUT,
292            Json(serde_json::json!({ "error": "request_timeout" })),
293        )
294            .into_response(),
295    }
296}
297
298pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
299    cfg.validate()?;
300
301    let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
302        .parse()
303        .context("invalid host/port")?;
304
305    let project_root = cfg.project_root.to_string_lossy().to_string();
306    let base = LeanCtxServer::new_with_project_root(Some(project_root));
307    let engine = Arc::new(ContextEngine::from_server(base.clone()));
308
309    let service_factory = move || Ok(base.clone());
310    let mcp_http = StreamableHttpService::new(
311        service_factory,
312        Arc::new(
313            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
314        ),
315        cfg.mcp_http_config(),
316    );
317
318    let state = AppState {
319        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
320        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
321        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
322        engine,
323        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
324    };
325
326    let app = Router::new()
327        .route("/health", get(health))
328        .route("/v1/manifest", get(v1_manifest))
329        .route("/v1/tools", get(v1_tools))
330        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
331        .fallback_service(mcp_http)
332        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
333        .layer(middleware::from_fn_with_state(
334            state.clone(),
335            rate_limit_middleware,
336        ))
337        .layer(middleware::from_fn_with_state(
338            state.clone(),
339            concurrency_middleware,
340        ))
341        .layer(middleware::from_fn_with_state(
342            state.clone(),
343            auth_middleware,
344        ))
345        .with_state(state);
346
347    let listener = tokio::net::TcpListener::bind(addr)
348        .await
349        .with_context(|| format!("bind {addr}"))?;
350
351    tracing::info!(
352        "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
353        cfg.project_root.display()
354    );
355
356    axum::serve(listener, app)
357        .with_graceful_shutdown(async move {
358            let _ = tokio::signal::ctrl_c().await;
359        })
360        .await
361        .context("http server")?;
362    Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use axum::body::Body;
369    use axum::http::Request;
370    use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
371    use serde_json::json;
372    use tower::ServiceExt;
373
374    #[tokio::test]
375    async fn auth_token_blocks_requests_without_bearer_header() {
376        let dir = tempfile::tempdir().expect("tempdir");
377        let base =
378            LeanCtxServer::new_with_project_root(Some(dir.path().to_string_lossy().to_string()));
379        let service_factory = move || Ok(base.clone());
380        let cfg = StreamableHttpServerConfig::default()
381            .with_stateful_mode(false)
382            .with_json_response(true);
383
384        let mcp_http = StreamableHttpService::new(
385            service_factory,
386            Arc::new(
387                rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
388            ),
389            cfg,
390        );
391
392        let state = AppState {
393            token: Some("secret".to_string()),
394            concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
395            rate: Arc::new(RateLimiter::new(50, 100)),
396            engine: Arc::new(ContextEngine::from_server(
397                LeanCtxServer::new_with_project_root(Some(
398                    dir.path().to_string_lossy().to_string(),
399                )),
400            )),
401            timeout: Duration::from_millis(30_000),
402        };
403
404        let app = Router::new()
405            .fallback_service(mcp_http)
406            .layer(middleware::from_fn_with_state(
407                state.clone(),
408                auth_middleware,
409            ))
410            .with_state(state);
411
412        let body = json!({
413            "jsonrpc": "2.0",
414            "id": 1,
415            "method": "tools/list",
416            "params": {}
417        })
418        .to_string();
419
420        let req = Request::builder()
421            .method("POST")
422            .uri("/")
423            .header("Host", "localhost")
424            .header("Accept", "application/json, text/event-stream")
425            .header("Content-Type", "application/json")
426            .body(Body::from(body))
427            .expect("request");
428
429        let resp = app.clone().oneshot(req).await.expect("resp");
430        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
431    }
432
433    #[tokio::test]
434    async fn rate_limit_returns_429_when_exhausted() {
435        let state = AppState {
436            token: None,
437            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
438            rate: Arc::new(RateLimiter::new(1, 1)),
439            engine: Arc::new(ContextEngine::new()),
440            timeout: Duration::from_millis(30_000),
441        };
442
443        let app = Router::new()
444            .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
445            .layer(middleware::from_fn_with_state(
446                state.clone(),
447                rate_limit_middleware,
448            ))
449            .with_state(state);
450
451        let req1 = Request::builder()
452            .method("GET")
453            .uri("/limited")
454            .header("Host", "localhost")
455            .body(Body::empty())
456            .expect("req1");
457        let resp1 = app.clone().oneshot(req1).await.expect("resp1");
458        assert_eq!(resp1.status(), StatusCode::OK);
459
460        let req2 = Request::builder()
461            .method("GET")
462            .uri("/limited")
463            .header("Host", "localhost")
464            .body(Body::empty())
465            .expect("req2");
466        let resp2 = app.clone().oneshot(req2).await.expect("resp2");
467        assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
468    }
469}