Skip to main content

lean_ctx/http_server/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7    extract::Json,
8    extract::Query,
9    extract::State,
10    http::{header, Request, StatusCode},
11    middleware::{self, Next},
12    response::sse::{Event as SseEvent, KeepAlive, Sse},
13    response::{IntoResponse, Response},
14    routing::get,
15    Router,
16};
17use futures::Stream;
18use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
19use serde::Deserialize;
20use serde_json::Value;
21use tokio::sync::broadcast;
22use tokio::time::{Duration, Instant};
23
24use crate::core::context_os::ContextOsMetrics;
25use crate::engine::ContextEngine;
26use crate::tools::LeanCtxServer;
27
28pub mod context_views;
29
30#[cfg(feature = "team-server")]
31pub mod team;
32
33/// Wrapper stream that calls `record_sse_disconnect` on drop.
34use std::pin::Pin;
35
36pub(crate) struct SseDisconnectGuard<I> {
37    pub(crate) inner: Pin<Box<dyn Stream<Item = I> + Send>>,
38    pub(crate) metrics: Arc<ContextOsMetrics>,
39}
40
41impl<I> Stream for SseDisconnectGuard<I> {
42    type Item = I;
43
44    fn poll_next(
45        mut self: Pin<&mut Self>,
46        cx: &mut std::task::Context<'_>,
47    ) -> std::task::Poll<Option<Self::Item>> {
48        self.inner.as_mut().poll_next(cx)
49    }
50}
51
52impl<I> Drop for SseDisconnectGuard<I> {
53    fn drop(&mut self) {
54        self.metrics.record_sse_disconnect();
55    }
56}
57
58#[derive(Clone, Debug)]
59pub struct HttpServerConfig {
60    pub host: String,
61    pub port: u16,
62    pub project_root: PathBuf,
63    pub auth_token: Option<String>,
64    pub stateful_mode: bool,
65    pub json_response: bool,
66    pub disable_host_check: bool,
67    pub allowed_hosts: Vec<String>,
68    pub max_body_bytes: usize,
69    pub max_concurrency: usize,
70    pub max_rps: u32,
71    pub rate_burst: u32,
72    pub request_timeout_ms: u64,
73}
74
75impl Default for HttpServerConfig {
76    fn default() -> Self {
77        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
78        Self {
79            host: "127.0.0.1".to_string(),
80            port: 8080,
81            project_root,
82            auth_token: None,
83            stateful_mode: false,
84            json_response: true,
85            disable_host_check: false,
86            allowed_hosts: Vec::new(),
87            max_body_bytes: 2 * 1024 * 1024,
88            max_concurrency: 32,
89            max_rps: 50,
90            rate_burst: 100,
91            request_timeout_ms: 30_000,
92        }
93    }
94}
95
96impl HttpServerConfig {
97    pub fn validate(&self) -> Result<()> {
98        let host = self.host.trim().to_lowercase();
99        let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
100        if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
101            return Err(anyhow!(
102                "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
103            ));
104        }
105        Ok(())
106    }
107
108    fn mcp_http_config(&self) -> StreamableHttpServerConfig {
109        let mut cfg = StreamableHttpServerConfig::default()
110            .with_stateful_mode(self.stateful_mode)
111            .with_json_response(self.json_response);
112
113        if self.disable_host_check {
114            cfg = cfg.disable_allowed_hosts();
115            return cfg;
116        }
117
118        if !self.allowed_hosts.is_empty() {
119            cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
120            return cfg;
121        }
122
123        // Keep rmcp's secure loopback defaults; also allow the configured host (if it's loopback).
124        let host = self.host.trim();
125        if host == "127.0.0.1" || host == "localhost" || host == "::1" {
126            cfg.allowed_hosts.push(host.to_string());
127        }
128
129        cfg
130    }
131}
132
133#[derive(Clone)]
134struct AppState {
135    token: Option<String>,
136    concurrency: Arc<tokio::sync::Semaphore>,
137    rate: Arc<RateLimiter>,
138    project_root: String,
139    timeout: Duration,
140}
141
142#[derive(Debug)]
143struct RateLimiter {
144    max_rps: f64,
145    burst: f64,
146    state: tokio::sync::Mutex<RateState>,
147}
148
149#[derive(Debug, Clone, Copy)]
150struct RateState {
151    tokens: f64,
152    last: Instant,
153}
154
155impl RateLimiter {
156    fn new(max_rps: u32, burst: u32) -> Self {
157        let now = Instant::now();
158        Self {
159            max_rps: (max_rps.max(1)) as f64,
160            burst: (burst.max(1)) as f64,
161            state: tokio::sync::Mutex::new(RateState {
162                tokens: (burst.max(1)) as f64,
163                last: now,
164            }),
165        }
166    }
167
168    async fn allow(&self) -> bool {
169        let mut s = self.state.lock().await;
170        let now = Instant::now();
171        let elapsed = now.saturating_duration_since(s.last);
172        let refill = elapsed.as_secs_f64() * self.max_rps;
173        s.tokens = (s.tokens + refill).min(self.burst);
174        s.last = now;
175        if s.tokens >= 1.0 {
176            s.tokens -= 1.0;
177            true
178        } else {
179            false
180        }
181    }
182}
183
184async fn auth_middleware(
185    State(state): State<AppState>,
186    req: Request<axum::body::Body>,
187    next: Next,
188) -> Response {
189    if state.token.is_none() {
190        return next.run(req).await;
191    }
192
193    if req.uri().path() == "/health" {
194        return next.run(req).await;
195    }
196
197    let expected = state.token.as_deref().unwrap_or("");
198    let Some(h) = req.headers().get(header::AUTHORIZATION) else {
199        return StatusCode::UNAUTHORIZED.into_response();
200    };
201    let Ok(s) = h.to_str() else {
202        return StatusCode::UNAUTHORIZED.into_response();
203    };
204    let Some(token) = s
205        .strip_prefix("Bearer ")
206        .or_else(|| s.strip_prefix("bearer "))
207    else {
208        return StatusCode::UNAUTHORIZED.into_response();
209    };
210    if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
211        return StatusCode::UNAUTHORIZED.into_response();
212    }
213
214    next.run(req).await
215}
216
217fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
218    if a.len() != b.len() {
219        return false;
220    }
221    a.iter()
222        .zip(b.iter())
223        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
224        == 0
225}
226
227async fn rate_limit_middleware(
228    State(state): State<AppState>,
229    req: Request<axum::body::Body>,
230    next: Next,
231) -> Response {
232    if req.uri().path() == "/health" {
233        return next.run(req).await;
234    }
235    if !state.rate.allow().await {
236        return StatusCode::TOO_MANY_REQUESTS.into_response();
237    }
238    next.run(req).await
239}
240
241async fn concurrency_middleware(
242    State(state): State<AppState>,
243    req: Request<axum::body::Body>,
244    next: Next,
245) -> Response {
246    if req.uri().path() == "/health" {
247        return next.run(req).await;
248    }
249    let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
250        return StatusCode::TOO_MANY_REQUESTS.into_response();
251    };
252    let resp = next.run(req).await;
253    drop(permit);
254    resp
255}
256
257async fn health() -> impl IntoResponse {
258    (StatusCode::OK, "ok\n")
259}
260
261#[derive(Debug, Deserialize)]
262#[serde(rename_all = "camelCase")]
263struct ToolCallBody {
264    name: String,
265    #[serde(default)]
266    arguments: Option<Value>,
267    #[serde(default)]
268    workspace_id: Option<String>,
269    #[serde(default)]
270    channel_id: Option<String>,
271}
272
273#[derive(Debug, Deserialize)]
274#[serde(rename_all = "camelCase")]
275struct EventsQuery {
276    #[serde(default)]
277    workspace_id: Option<String>,
278    #[serde(default)]
279    channel_id: Option<String>,
280    #[serde(default)]
281    since: Option<i64>,
282    #[serde(default)]
283    limit: Option<usize>,
284    /// Comma-separated event kind filter (e.g. `tool_call,session_start`).
285    /// When set, only matching events are delivered via SSE.
286    #[serde(default)]
287    kind: Option<String>,
288}
289
290async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
291    let _ = state;
292    let v = crate::core::mcp_manifest::manifest_value();
293    (StatusCode::OK, Json(v))
294}
295
296#[derive(Debug, Deserialize)]
297#[serde(rename_all = "camelCase")]
298struct ToolsQuery {
299    #[serde(default)]
300    offset: Option<usize>,
301    #[serde(default)]
302    limit: Option<usize>,
303}
304
305async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
306    let _ = state;
307    let v = crate::core::mcp_manifest::manifest_value();
308    let tools = v
309        .get("tools")
310        .and_then(|t| t.get("granular"))
311        .cloned()
312        .unwrap_or(Value::Array(vec![]));
313
314    let all = tools.as_array().cloned().unwrap_or_default();
315    let total = all.len();
316    let offset = q.offset.unwrap_or(0).min(total);
317    let limit = q.limit.unwrap_or(200).min(500);
318    let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
319
320    (
321        StatusCode::OK,
322        Json(serde_json::json!({
323            "tools": page,
324            "total": total,
325            "offset": offset,
326            "limit": limit,
327        })),
328    )
329}
330
331async fn v1_tool_call(
332    State(state): State<AppState>,
333    Json(body): Json<ToolCallBody>,
334) -> impl IntoResponse {
335    let ws = body.workspace_id.as_deref().unwrap_or("default");
336    let ch = body.channel_id.as_deref().unwrap_or("default");
337    let server = LeanCtxServer::new_shared_with_context(&state.project_root, ws, ch);
338    let engine = ContextEngine::from_server(server);
339    match tokio::time::timeout(
340        state.timeout,
341        engine.call_tool_value(&body.name, body.arguments),
342    )
343    .await
344    {
345        Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
346        Ok(Err(e)) => {
347            tracing::warn!("tool call error: {e}");
348            (
349                StatusCode::BAD_REQUEST,
350                Json(serde_json::json!({ "error": "tool_error", "code": "TOOL_ERROR" })),
351            )
352                .into_response()
353        }
354        Err(_) => (
355            StatusCode::GATEWAY_TIMEOUT,
356            Json(serde_json::json!({ "error": "request_timeout" })),
357        )
358            .into_response(),
359    }
360}
361
362async fn v1_events(
363    State(_state): State<AppState>,
364    Query(q): Query<EventsQuery>,
365) -> Sse<impl Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
366    use crate::core::context_os::{redact_event_payload, ContextEventV1, RedactionLevel};
367
368    let ws = q.workspace_id.unwrap_or_else(|| "default".to_string());
369    let ch = q.channel_id.unwrap_or_else(|| "default".to_string());
370    let since = q.since.unwrap_or(0);
371    let limit = q.limit.unwrap_or(200).min(1000);
372    let redaction = RedactionLevel::RefsOnly;
373
374    let kind_filter: Option<Vec<String>> = q
375        .kind
376        .as_deref()
377        .map(|k| k.split(',').map(|s| s.trim().to_string()).collect());
378
379    let rt = crate::core::context_os::runtime();
380    let replay = rt.bus.read(&ws, &ch, since, limit);
381
382    let replay = if let Some(ref kinds) = kind_filter {
383        replay
384            .into_iter()
385            .filter(|ev| kinds.contains(&ev.kind))
386            .collect()
387    } else {
388        replay
389    };
390
391    let rx = if let Some(ref kinds) = kind_filter {
392        let kind_refs: Vec<&str> = kinds.iter().map(String::as_str).collect();
393        let filter = crate::core::context_os::TopicFilter::kinds(&kind_refs);
394        crate::core::context_os::SubscriptionKind::Filtered(
395            rt.bus.subscribe_filtered(&ws, &ch, filter),
396        )
397    } else {
398        crate::core::context_os::SubscriptionKind::Unfiltered(rt.bus.subscribe(&ws, &ch))
399    };
400
401    rt.metrics.record_sse_connect();
402    rt.metrics.record_events_replayed(replay.len() as u64);
403    rt.metrics.record_workspace_active(&ws);
404
405    let bus = rt.bus.clone();
406    let metrics = rt.metrics.clone();
407    let pending: std::collections::VecDeque<ContextEventV1> = replay.into();
408
409    let stream = futures::stream::unfold(
410        (
411            pending,
412            rx,
413            ws.clone(),
414            ch.clone(),
415            since,
416            redaction,
417            bus,
418            metrics,
419        ),
420        |(mut pending, mut rx, ws, ch, mut last_id, redaction, bus, metrics)| async move {
421            if let Some(mut ev) = pending.pop_front() {
422                last_id = ev.id;
423                redact_event_payload(&mut ev, redaction);
424                let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
425                let evt = SseEvent::default()
426                    .id(ev.id.to_string())
427                    .event(ev.kind)
428                    .data(data);
429                return Some((
430                    Ok(evt),
431                    (pending, rx, ws, ch, last_id, redaction, bus, metrics),
432                ));
433            }
434
435            loop {
436                match rx.recv().await {
437                    Ok(mut ev) if ev.id > last_id => {
438                        last_id = ev.id;
439                        redact_event_payload(&mut ev, redaction);
440                        let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
441                        let evt = SseEvent::default()
442                            .id(ev.id.to_string())
443                            .event(ev.kind)
444                            .data(data);
445                        return Some((
446                            Ok(evt),
447                            (pending, rx, ws, ch, last_id, redaction, bus, metrics),
448                        ));
449                    }
450                    Ok(_) => {}
451                    Err(broadcast::error::RecvError::Closed) => return None,
452                    Err(broadcast::error::RecvError::Lagged(skipped)) => {
453                        let missed = bus.read(&ws, &ch, last_id, skipped as usize);
454                        metrics.record_events_replayed(missed.len() as u64);
455                        for ev in missed {
456                            last_id = last_id.max(ev.id);
457                            pending.push_back(ev);
458                        }
459                    }
460                }
461            }
462        },
463    );
464
465    let metrics_ref = rt.metrics.clone();
466    let guarded = SseDisconnectGuard {
467        inner: Box::pin(stream),
468        metrics: metrics_ref,
469    };
470
471    Sse::new(guarded).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
472}
473
474async fn v1_metrics(State(_state): State<AppState>) -> impl IntoResponse {
475    let rt = crate::core::context_os::runtime();
476    let snap = rt.metrics.snapshot();
477    (
478        StatusCode::OK,
479        Json(serde_json::to_value(snap).unwrap_or_default()),
480    )
481}
482
483async fn v1_a2a_handoff(
484    State(state): State<AppState>,
485    Json(body): Json<Value>,
486) -> impl IntoResponse {
487    let envelope = match crate::core::a2a_transport::parse_envelope(
488        &serde_json::to_string(&body).unwrap_or_default(),
489    ) {
490        Ok(env) => env,
491        Err(e) => {
492            return (
493                StatusCode::BAD_REQUEST,
494                Json(serde_json::json!({"error": e})),
495            );
496        }
497    };
498
499    let rt = crate::core::context_os::runtime();
500    rt.bus.append(
501        &state.project_root,
502        "a2a",
503        &crate::core::context_os::ContextEventKindV1::SessionMutated,
504        Some(&envelope.sender.agent_id),
505        serde_json::json!({
506            "type": "handoff_received",
507            "content_type": format!("{:?}", envelope.content_type),
508            "sender": envelope.sender.agent_id,
509            "payload_size": envelope.payload_json.len(),
510        }),
511    );
512
513    match envelope.content_type {
514        crate::core::a2a_transport::TransportContentType::ContextPackage => {
515            let tmp = std::env::temp_dir().join(format!(
516                "lean-ctx-a2a-{}.lctxpkg",
517                chrono::Utc::now().format("%Y%m%d_%H%M%S")
518            ));
519            if let Err(e) = std::fs::write(&tmp, &envelope.payload_json) {
520                return (
521                    StatusCode::INTERNAL_SERVER_ERROR,
522                    Json(serde_json::json!({"error": format!("write: {e}")})),
523                );
524            }
525            (
526                StatusCode::OK,
527                Json(serde_json::json!({
528                    "status": "received",
529                    "content_type": "context_package",
530                    "stored": tmp.display().to_string(),
531                })),
532            )
533        }
534        crate::core::a2a_transport::TransportContentType::HandoffBundle => {
535            let dir = std::path::Path::new(&state.project_root)
536                .join(".lean-ctx")
537                .join("handoffs");
538            let _ = std::fs::create_dir_all(&dir);
539            let out = dir.join(format!(
540                "received-{}.json",
541                chrono::Utc::now().format("%Y%m%d_%H%M%S")
542            ));
543            if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
544                return (
545                    StatusCode::INTERNAL_SERVER_ERROR,
546                    Json(serde_json::json!({"error": format!("write: {e}")})),
547                );
548            }
549            (
550                StatusCode::OK,
551                Json(serde_json::json!({
552                    "status": "received",
553                    "content_type": "handoff_bundle",
554                    "stored": out.display().to_string(),
555                })),
556            )
557        }
558        _ => (
559            StatusCode::OK,
560            Json(serde_json::json!({
561                "status": "received",
562                "content_type": format!("{:?}", envelope.content_type),
563            })),
564        ),
565    }
566}
567
568async fn a2a_jsonrpc(Json(body): Json<Value>) -> impl IntoResponse {
569    let req: crate::core::a2a::a2a_compat::JsonRpcRequest = match serde_json::from_value(body) {
570        Ok(r) => r,
571        Err(e) => {
572            return (
573                StatusCode::BAD_REQUEST,
574                Json(serde_json::json!({
575                    "jsonrpc": "2.0",
576                    "id": null,
577                    "error": {"code": -32700, "message": format!("parse error: {e}")}
578                })),
579            );
580        }
581    };
582    let resp = crate::core::a2a::a2a_compat::handle_a2a_jsonrpc(&req);
583    let json = serde_json::to_value(resp).unwrap_or_default();
584    (StatusCode::OK, Json(json))
585}
586
587async fn v1_a2a_agent_card(State(state): State<AppState>) -> impl IntoResponse {
588    let card = crate::core::a2a::agent_card::build_agent_card(&state.project_root);
589    (
590        StatusCode::OK,
591        [(header::CONTENT_TYPE, "application/json")],
592        Json(card),
593    )
594}
595
596pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
597    cfg.validate()?;
598
599    let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
600        .parse()
601        .context("invalid host/port")?;
602
603    let project_root = cfg.project_root.to_string_lossy().to_string();
604    // IMPORTANT: Create a fresh server per MCP session in *shared* mode.
605    // This avoids per-client state clobbering while still sharing the Context OS store.
606    let service_project_root = project_root.clone();
607    let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
608        Ok(LeanCtxServer::new_shared_with_context(
609            &service_project_root,
610            "default",
611            "default",
612        ))
613    };
614    let mcp_http = StreamableHttpService::new(
615        service_factory,
616        Arc::new(
617            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
618        ),
619        cfg.mcp_http_config(),
620    );
621
622    let state = AppState {
623        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
624        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
625        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
626        project_root: project_root.clone(),
627        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
628    };
629
630    let app = Router::new()
631        .route("/health", get(health))
632        .route("/v1/manifest", get(v1_manifest))
633        .route("/v1/tools", get(v1_tools))
634        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
635        .route("/v1/events", get(v1_events))
636        .route(
637            "/v1/context/summary",
638            get(context_views::v1_context_summary),
639        )
640        .route("/v1/events/search", get(context_views::v1_events_search))
641        .route("/v1/events/lineage", get(context_views::v1_event_lineage))
642        .route("/v1/metrics", get(v1_metrics))
643        .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
644        .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
645        .route("/.well-known/agent.json", get(v1_a2a_agent_card))
646        .route("/a2a", axum::routing::post(a2a_jsonrpc))
647        .fallback_service(mcp_http)
648        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
649        .layer(middleware::from_fn_with_state(
650            state.clone(),
651            rate_limit_middleware,
652        ))
653        .layer(middleware::from_fn_with_state(
654            state.clone(),
655            concurrency_middleware,
656        ))
657        .layer(middleware::from_fn_with_state(
658            state.clone(),
659            auth_middleware,
660        ))
661        .with_state(state);
662
663    let listener = tokio::net::TcpListener::bind(addr)
664        .await
665        .with_context(|| format!("bind {addr}"))?;
666
667    tracing::info!(
668        "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
669        cfg.project_root.display()
670    );
671
672    axum::serve(listener, app)
673        .with_graceful_shutdown(async move {
674            let _ = tokio::signal::ctrl_c().await;
675        })
676        .await
677        .context("http server")?;
678    Ok(())
679}
680
681#[cfg(unix)]
682pub async fn serve_uds(cfg: HttpServerConfig, socket_path: PathBuf) -> Result<()> {
683    cfg.validate()?;
684
685    if socket_path.exists() {
686        std::fs::remove_file(&socket_path)
687            .with_context(|| format!("remove stale socket {}", socket_path.display()))?;
688    }
689
690    let project_root = cfg.project_root.to_string_lossy().to_string();
691    let service_project_root = project_root.clone();
692    let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
693        Ok(LeanCtxServer::new_shared_with_context(
694            &service_project_root,
695            "default",
696            "default",
697        ))
698    };
699    let mcp_http = StreamableHttpService::new(
700        service_factory,
701        Arc::new(
702            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
703        ),
704        cfg.mcp_http_config(),
705    );
706
707    let state = AppState {
708        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
709        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
710        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
711        project_root: project_root.clone(),
712        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
713    };
714
715    let app = Router::new()
716        .route("/health", get(health))
717        .route("/v1/manifest", get(v1_manifest))
718        .route("/v1/tools", get(v1_tools))
719        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
720        .route("/v1/events", get(v1_events))
721        .route(
722            "/v1/context/summary",
723            get(context_views::v1_context_summary),
724        )
725        .route("/v1/events/search", get(context_views::v1_events_search))
726        .route("/v1/events/lineage", get(context_views::v1_event_lineage))
727        .route("/v1/metrics", get(v1_metrics))
728        .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
729        .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
730        .route("/.well-known/agent.json", get(v1_a2a_agent_card))
731        .route("/a2a", axum::routing::post(a2a_jsonrpc))
732        .fallback_service(mcp_http)
733        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
734        .layer(middleware::from_fn_with_state(
735            state.clone(),
736            rate_limit_middleware,
737        ))
738        .layer(middleware::from_fn_with_state(
739            state.clone(),
740            concurrency_middleware,
741        ))
742        .layer(middleware::from_fn_with_state(
743            state.clone(),
744            auth_middleware,
745        ))
746        .with_state(state);
747
748    let listener = tokio::net::UnixListener::bind(&socket_path)
749        .with_context(|| format!("bind UDS {}", socket_path.display()))?;
750
751    {
752        use std::os::unix::fs::PermissionsExt;
753        let perms = std::fs::Permissions::from_mode(0o600);
754        std::fs::set_permissions(&socket_path, perms)
755            .with_context(|| format!("chmod 600 UDS {}", socket_path.display()))?;
756    }
757
758    tracing::info!(
759        "lean-ctx daemon listening on {} (project_root={})",
760        socket_path.display(),
761        cfg.project_root.display()
762    );
763
764    axum::serve(listener, app.into_make_service())
765        .with_graceful_shutdown(async move {
766            let _ = tokio::signal::ctrl_c().await;
767        })
768        .await
769        .context("uds server")?;
770    Ok(())
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use axum::body::Body;
777    use axum::http::Request;
778    use futures::StreamExt;
779    use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
780    use serde_json::json;
781    use tower::ServiceExt;
782
783    async fn read_first_sse_message(body: Body) -> String {
784        let mut stream = body.into_data_stream();
785        let mut buf: Vec<u8> = Vec::new();
786        for _ in 0..32 {
787            let next = tokio::time::timeout(Duration::from_secs(2), stream.next()).await;
788            let Ok(Some(Ok(bytes))) = next else {
789                break;
790            };
791            buf.extend_from_slice(&bytes);
792            if buf.windows(2).any(|w| w == b"\n\n") {
793                break;
794            }
795        }
796        String::from_utf8_lossy(&buf).to_string()
797    }
798
799    #[tokio::test]
800    async fn auth_token_blocks_requests_without_bearer_header() {
801        let dir = tempfile::tempdir().expect("tempdir");
802        let root_str = dir.path().to_string_lossy().to_string();
803        let service_project_root = root_str.clone();
804        let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
805            Ok(LeanCtxServer::new_shared_with_context(
806                &service_project_root,
807                "default",
808                "default",
809            ))
810        };
811        let cfg = StreamableHttpServerConfig::default()
812            .with_stateful_mode(false)
813            .with_json_response(true);
814
815        let mcp_http = StreamableHttpService::new(
816            service_factory,
817            Arc::new(
818                rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
819            ),
820            cfg,
821        );
822
823        let state = AppState {
824            token: Some("secret".to_string()),
825            concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
826            rate: Arc::new(RateLimiter::new(50, 100)),
827            project_root: root_str.clone(),
828            timeout: Duration::from_millis(30_000),
829        };
830
831        let app = Router::new()
832            .fallback_service(mcp_http)
833            .layer(middleware::from_fn_with_state(
834                state.clone(),
835                auth_middleware,
836            ))
837            .with_state(state);
838
839        let body = json!({
840            "jsonrpc": "2.0",
841            "id": 1,
842            "method": "tools/list",
843            "params": {}
844        })
845        .to_string();
846
847        let req = Request::builder()
848            .method("POST")
849            .uri("/")
850            .header("Host", "localhost")
851            .header("Accept", "application/json, text/event-stream")
852            .header("Content-Type", "application/json")
853            .body(Body::from(body))
854            .expect("request");
855
856        let resp = app.clone().oneshot(req).await.expect("resp");
857        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
858    }
859
860    #[tokio::test]
861    async fn mcp_service_factory_isolates_per_client_state() {
862        let dir = tempfile::tempdir().expect("tempdir");
863        let root_str = dir.path().to_string_lossy().to_string();
864
865        // Mirrors the serve() setup: service_factory must create a fresh server per MCP session.
866        let service_project_root = root_str.clone();
867        let service_factory = move || -> Result<LeanCtxServer, std::convert::Infallible> {
868            Ok(LeanCtxServer::new_shared_with_context(
869                &service_project_root,
870                "default",
871                "default",
872            ))
873        };
874
875        let s1 = service_factory().expect("server 1");
876        let s2 = service_factory().expect("server 2");
877
878        // If the two servers accidentally share the same Arc-backed fields, these writes would
879        // clobber each other. This test stays independent of rmcp's InitializeRequestParams API.
880        *s1.client_name.write().await = "client-a".to_string();
881        *s2.client_name.write().await = "client-b".to_string();
882
883        let a = s1.client_name.read().await.clone();
884        let b = s2.client_name.read().await.clone();
885        assert_eq!(a, "client-a");
886        assert_eq!(b, "client-b");
887    }
888
889    #[tokio::test]
890    async fn rate_limit_returns_429_when_exhausted() {
891        let state = AppState {
892            token: None,
893            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
894            rate: Arc::new(RateLimiter::new(1, 1)),
895            project_root: ".".to_string(),
896            timeout: Duration::from_millis(30_000),
897        };
898
899        let app = Router::new()
900            .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
901            .layer(middleware::from_fn_with_state(
902                state.clone(),
903                rate_limit_middleware,
904            ))
905            .with_state(state);
906
907        let req1 = Request::builder()
908            .method("GET")
909            .uri("/limited")
910            .header("Host", "localhost")
911            .body(Body::empty())
912            .expect("req1");
913        let resp1 = app.clone().oneshot(req1).await.expect("resp1");
914        assert_eq!(resp1.status(), StatusCode::OK);
915
916        let req2 = Request::builder()
917            .method("GET")
918            .uri("/limited")
919            .header("Host", "localhost")
920            .body(Body::empty())
921            .expect("req2");
922        let resp2 = app.clone().oneshot(req2).await.expect("resp2");
923        assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
924    }
925
926    #[tokio::test]
927    async fn events_endpoint_replays_tool_call_event() {
928        let dir = tempfile::tempdir().expect("tempdir");
929        std::fs::create_dir_all(dir.path().join(".git")).expect("git marker");
930        std::fs::write(dir.path().join("a.txt"), "ok").expect("file");
931        let root_str = dir.path().to_string_lossy().to_string();
932
933        let state = AppState {
934            token: None,
935            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
936            rate: Arc::new(RateLimiter::new(50, 100)),
937            project_root: root_str.clone(),
938            timeout: Duration::from_millis(30_000),
939        };
940
941        let app = Router::new()
942            .route("/v1/tools/call", axum::routing::post(v1_tool_call))
943            .route("/v1/events", get(v1_events))
944            .with_state(state);
945
946        let body = json!({
947            "name": "ctx_session",
948            "arguments": { "action": "status" },
949            "workspaceId": "ws1",
950            "channelId": "ch1"
951        })
952        .to_string();
953        let req = Request::builder()
954            .method("POST")
955            .uri("/v1/tools/call")
956            .header("Host", "localhost")
957            .header("Content-Type", "application/json")
958            .body(Body::from(body))
959            .expect("req");
960        let resp = app.clone().oneshot(req).await.expect("call");
961        assert_eq!(resp.status(), StatusCode::OK);
962
963        // Allow async event persistence to complete (Windows CI disk IO is slower).
964        tokio::time::sleep(Duration::from_millis(250)).await;
965
966        // Subscribe with replay semantics; read the first SSE message.
967        let req = Request::builder()
968            .method("GET")
969            .uri("/v1/events?workspaceId=ws1&channelId=ch1&since=0&limit=1")
970            .header("Host", "localhost")
971            .header("Accept", "text/event-stream")
972            .body(Body::empty())
973            .expect("req");
974        let resp = app.clone().oneshot(req).await.expect("events");
975        assert_eq!(resp.status(), StatusCode::OK);
976
977        let msg = read_first_sse_message(resp.into_body()).await;
978        assert!(msg.contains("event: tool_call_recorded"), "msg={msg:?}");
979        assert!(msg.contains("\"workspaceId\":\"ws1\""), "msg={msg:?}");
980        assert!(msg.contains("\"channelId\":\"ch1\""), "msg={msg:?}");
981    }
982}