Skip to main content

lean_ctx/http_server/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7    extract::Json,
8    extract::Query,
9    extract::State,
10    http::{header, Request, StatusCode},
11    middleware::{self, Next},
12    response::sse::{Event as SseEvent, KeepAlive, Sse},
13    response::{IntoResponse, Response},
14    routing::get,
15    Router,
16};
17use futures::Stream;
18use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
19use serde::Deserialize;
20use serde_json::Value;
21use tokio::sync::broadcast;
22use tokio::time::{Duration, Instant};
23
24use crate::core::context_os::ContextOsMetrics;
25use crate::engine::ContextEngine;
26use crate::tools::LeanCtxServer;
27
28pub mod context_views;
29
30#[cfg(feature = "team-server")]
31pub mod team;
32
33/// Wrapper stream that calls `record_sse_disconnect` on drop.
34use std::pin::Pin;
35
36pub(crate) struct SseDisconnectGuard<I> {
37    pub(crate) inner: Pin<Box<dyn Stream<Item = I> + Send>>,
38    pub(crate) metrics: Arc<ContextOsMetrics>,
39}
40
41impl<I> Stream for SseDisconnectGuard<I> {
42    type Item = I;
43
44    fn poll_next(
45        mut self: Pin<&mut Self>,
46        cx: &mut std::task::Context<'_>,
47    ) -> std::task::Poll<Option<Self::Item>> {
48        self.inner.as_mut().poll_next(cx)
49    }
50}
51
52impl<I> Drop for SseDisconnectGuard<I> {
53    fn drop(&mut self) {
54        self.metrics.record_sse_disconnect();
55    }
56}
57
58const MAX_ID_LEN: usize = 64;
59
60fn sanitize_id(raw: &str) -> String {
61    let trimmed = raw.trim();
62    if trimmed.is_empty() {
63        return "default".to_string();
64    }
65    let cleaned: String = trimmed
66        .chars()
67        .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
68        .take(MAX_ID_LEN)
69        .collect();
70    if cleaned.is_empty() {
71        "default".to_string()
72    } else {
73        cleaned
74    }
75}
76
77#[derive(Clone, Debug)]
78pub struct HttpServerConfig {
79    pub host: String,
80    pub port: u16,
81    pub project_root: PathBuf,
82    pub auth_token: Option<String>,
83    pub stateful_mode: bool,
84    pub json_response: bool,
85    pub disable_host_check: bool,
86    pub allowed_hosts: Vec<String>,
87    pub max_body_bytes: usize,
88    pub max_concurrency: usize,
89    pub max_rps: u32,
90    pub rate_burst: u32,
91    pub request_timeout_ms: u64,
92}
93
94impl Default for HttpServerConfig {
95    fn default() -> Self {
96        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
97        Self {
98            host: "127.0.0.1".to_string(),
99            port: 8080,
100            project_root,
101            auth_token: None,
102            stateful_mode: false,
103            json_response: true,
104            disable_host_check: false,
105            allowed_hosts: Vec::new(),
106            max_body_bytes: 2 * 1024 * 1024,
107            max_concurrency: 32,
108            max_rps: 50,
109            rate_burst: 100,
110            request_timeout_ms: 30_000,
111        }
112    }
113}
114
115impl HttpServerConfig {
116    pub fn validate(&self) -> Result<()> {
117        let host = self.host.trim().to_lowercase();
118        let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
119        if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
120            return Err(anyhow!(
121                "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
122            ));
123        }
124        Ok(())
125    }
126
127    fn mcp_http_config(&self) -> StreamableHttpServerConfig {
128        let mut cfg = StreamableHttpServerConfig::default()
129            .with_stateful_mode(self.stateful_mode)
130            .with_json_response(self.json_response);
131
132        if self.disable_host_check {
133            tracing::warn!(
134                "⚠ --disable-host-check is active: DNS rebinding protection is OFF. \
135                 Do NOT use this in production or on non-loopback interfaces."
136            );
137            cfg = cfg.disable_allowed_hosts();
138            return cfg;
139        }
140
141        if !self.allowed_hosts.is_empty() {
142            cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
143            return cfg;
144        }
145
146        // Keep rmcp's secure loopback defaults; also allow the configured host (if it's loopback).
147        let host = self.host.trim();
148        if host == "127.0.0.1" || host == "localhost" || host == "::1" {
149            cfg.allowed_hosts.push(host.to_string());
150        }
151
152        cfg
153    }
154}
155
156#[derive(Clone)]
157struct AppState {
158    token: Option<String>,
159    concurrency: Arc<tokio::sync::Semaphore>,
160    rate: Arc<RateLimiter>,
161    project_root: String,
162    timeout: Duration,
163}
164
165#[derive(Debug)]
166struct RateLimiter {
167    max_rps: f64,
168    burst: f64,
169    state: tokio::sync::Mutex<RateState>,
170}
171
172#[derive(Debug, Clone, Copy)]
173struct RateState {
174    tokens: f64,
175    last: Instant,
176}
177
178impl RateLimiter {
179    fn new(max_rps: u32, burst: u32) -> Self {
180        let now = Instant::now();
181        Self {
182            max_rps: (max_rps.max(1)) as f64,
183            burst: (burst.max(1)) as f64,
184            state: tokio::sync::Mutex::new(RateState {
185                tokens: (burst.max(1)) as f64,
186                last: now,
187            }),
188        }
189    }
190
191    async fn allow(&self) -> bool {
192        let mut s = self.state.lock().await;
193        let now = Instant::now();
194        let elapsed = now.saturating_duration_since(s.last);
195        let refill = elapsed.as_secs_f64() * self.max_rps;
196        s.tokens = (s.tokens + refill).min(self.burst);
197        s.last = now;
198        if s.tokens >= 1.0 {
199            s.tokens -= 1.0;
200            true
201        } else {
202            false
203        }
204    }
205}
206
207async fn auth_middleware(
208    State(state): State<AppState>,
209    req: Request<axum::body::Body>,
210    next: Next,
211) -> Response {
212    if state.token.is_none() {
213        return next.run(req).await;
214    }
215
216    if req.uri().path() == "/health" {
217        return next.run(req).await;
218    }
219
220    let expected = state.token.as_deref().unwrap_or("");
221    let Some(h) = req.headers().get(header::AUTHORIZATION) else {
222        return StatusCode::UNAUTHORIZED.into_response();
223    };
224    let Ok(s) = h.to_str() else {
225        return StatusCode::UNAUTHORIZED.into_response();
226    };
227    let Some(token) = s
228        .strip_prefix("Bearer ")
229        .or_else(|| s.strip_prefix("bearer "))
230    else {
231        return StatusCode::UNAUTHORIZED.into_response();
232    };
233    if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
234        return StatusCode::UNAUTHORIZED.into_response();
235    }
236
237    next.run(req).await
238}
239
240fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
241    use subtle::ConstantTimeEq;
242    if a.len() != b.len() {
243        return false;
244    }
245    bool::from(a.ct_eq(b))
246}
247
248async fn rate_limit_middleware(
249    State(state): State<AppState>,
250    req: Request<axum::body::Body>,
251    next: Next,
252) -> Response {
253    if !state.rate.allow().await {
254        return StatusCode::TOO_MANY_REQUESTS.into_response();
255    }
256    next.run(req).await
257}
258
259async fn concurrency_middleware(
260    State(state): State<AppState>,
261    req: Request<axum::body::Body>,
262    next: Next,
263) -> Response {
264    let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
265        return StatusCode::TOO_MANY_REQUESTS.into_response();
266    };
267    let resp = next.run(req).await;
268    drop(permit);
269    resp
270}
271
272async fn health() -> impl IntoResponse {
273    (StatusCode::OK, "ok\n")
274}
275
276#[derive(Debug, Deserialize)]
277#[serde(rename_all = "camelCase")]
278struct ToolCallBody {
279    name: String,
280    #[serde(default)]
281    arguments: Option<Value>,
282    #[serde(default)]
283    workspace_id: Option<String>,
284    #[serde(default)]
285    channel_id: Option<String>,
286}
287
288#[derive(Debug, Deserialize)]
289#[serde(rename_all = "camelCase")]
290struct EventsQuery {
291    #[serde(default)]
292    workspace_id: Option<String>,
293    #[serde(default)]
294    channel_id: Option<String>,
295    #[serde(default)]
296    since: Option<i64>,
297    #[serde(default)]
298    limit: Option<usize>,
299    /// Comma-separated event kind filter (e.g. `tool_call,session_start`).
300    /// When set, only matching events are delivered via SSE.
301    #[serde(default)]
302    kind: Option<String>,
303}
304
305async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
306    let _ = state;
307    let v = crate::core::mcp_manifest::manifest_value();
308    (StatusCode::OK, Json(v))
309}
310
311#[derive(Debug, Deserialize)]
312#[serde(rename_all = "camelCase")]
313struct ToolsQuery {
314    #[serde(default)]
315    offset: Option<usize>,
316    #[serde(default)]
317    limit: Option<usize>,
318}
319
320async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
321    let _ = state;
322    let v = crate::core::mcp_manifest::manifest_value();
323    let tools = v
324        .get("tools")
325        .and_then(|t| t.get("granular"))
326        .cloned()
327        .unwrap_or(Value::Array(vec![]));
328
329    let all = tools.as_array().cloned().unwrap_or_default();
330    let total = all.len();
331    let offset = q.offset.unwrap_or(0).min(total);
332    let limit = q.limit.unwrap_or(200).min(500);
333    let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
334
335    (
336        StatusCode::OK,
337        Json(serde_json::json!({
338            "tools": page,
339            "total": total,
340            "offset": offset,
341            "limit": limit,
342        })),
343    )
344}
345
346async fn v1_tool_call(
347    State(state): State<AppState>,
348    Json(body): Json<ToolCallBody>,
349) -> impl IntoResponse {
350    let ws = sanitize_id(body.workspace_id.as_deref().unwrap_or("default"));
351    let ch = sanitize_id(body.channel_id.as_deref().unwrap_or("default"));
352    let ws = ws.as_str();
353    let ch = ch.as_str();
354    let server = LeanCtxServer::new_shared_with_context(&state.project_root, ws, ch);
355    let engine = ContextEngine::from_server(server);
356    match tokio::time::timeout(
357        state.timeout,
358        engine.call_tool_value(&body.name, body.arguments),
359    )
360    .await
361    {
362        Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
363        Ok(Err(e)) => {
364            tracing::warn!("tool call error: {e}");
365            (
366                StatusCode::BAD_REQUEST,
367                Json(serde_json::json!({ "error": "tool_error", "code": "TOOL_ERROR" })),
368            )
369                .into_response()
370        }
371        Err(_) => (
372            StatusCode::GATEWAY_TIMEOUT,
373            Json(serde_json::json!({ "error": "request_timeout" })),
374        )
375            .into_response(),
376    }
377}
378
379async fn v1_events(
380    State(state): State<AppState>,
381    Query(q): Query<EventsQuery>,
382) -> Sse<impl Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
383    use crate::core::context_os::{redact_event_payload, ContextEventV1, RedactionLevel};
384
385    let ws = sanitize_id(&q.workspace_id.unwrap_or_else(|| "default".to_string()));
386    let ch = sanitize_id(&q.channel_id.unwrap_or_else(|| "default".to_string()));
387    let _ = &state.project_root;
388    let since = q.since.unwrap_or(0);
389    let limit = q.limit.unwrap_or(200).min(1000);
390    let redaction = RedactionLevel::RefsOnly;
391
392    let kind_filter: Option<Vec<String>> = q
393        .kind
394        .as_deref()
395        .map(|k| k.split(',').map(|s| s.trim().to_string()).collect());
396
397    let rt = crate::core::context_os::runtime();
398    let replay = rt.bus.read(&ws, &ch, since, limit);
399
400    let replay = if let Some(ref kinds) = kind_filter {
401        replay
402            .into_iter()
403            .filter(|ev| kinds.contains(&ev.kind))
404            .collect()
405    } else {
406        replay
407    };
408
409    let rx = if let Some(ref kinds) = kind_filter {
410        let kind_refs: Vec<&str> = kinds.iter().map(String::as_str).collect();
411        let filter = crate::core::context_os::TopicFilter::kinds(&kind_refs);
412        if let Some(sub) = rt.bus.subscribe_filtered(&ws, &ch, filter) {
413            crate::core::context_os::SubscriptionKind::Filtered(sub)
414        } else {
415            tracing::warn!("SSE subscriber limit reached for {ws}/{ch}");
416            let (_, rx) = broadcast::channel::<ContextEventV1>(1);
417            crate::core::context_os::SubscriptionKind::Unfiltered(rx)
418        }
419    } else if let Some(sub) = rt.bus.subscribe(&ws, &ch) {
420        crate::core::context_os::SubscriptionKind::Unfiltered(sub)
421    } else {
422        tracing::warn!("SSE subscriber limit reached for {ws}/{ch}");
423        let (_, rx) = broadcast::channel::<ContextEventV1>(1);
424        crate::core::context_os::SubscriptionKind::Unfiltered(rx)
425    };
426
427    rt.metrics.record_sse_connect();
428    rt.metrics.record_events_replayed(replay.len() as u64);
429    rt.metrics.record_workspace_active(&ws);
430
431    let bus = rt.bus.clone();
432    let metrics = rt.metrics.clone();
433    let pending: std::collections::VecDeque<ContextEventV1> = replay.into();
434
435    let stream = futures::stream::unfold(
436        (
437            pending,
438            rx,
439            ws.clone(),
440            ch.clone(),
441            since,
442            redaction,
443            bus,
444            metrics,
445        ),
446        |(mut pending, mut rx, ws, ch, mut last_id, redaction, bus, metrics)| async move {
447            if let Some(mut ev) = pending.pop_front() {
448                last_id = ev.id;
449                redact_event_payload(&mut ev, redaction);
450                let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
451                let evt = SseEvent::default()
452                    .id(ev.id.to_string())
453                    .event(ev.kind)
454                    .data(data);
455                return Some((
456                    Ok(evt),
457                    (pending, rx, ws, ch, last_id, redaction, bus, metrics),
458                ));
459            }
460
461            loop {
462                match rx.recv().await {
463                    Ok(mut ev) if ev.id > last_id => {
464                        last_id = ev.id;
465                        redact_event_payload(&mut ev, redaction);
466                        let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
467                        let evt = SseEvent::default()
468                            .id(ev.id.to_string())
469                            .event(ev.kind)
470                            .data(data);
471                        return Some((
472                            Ok(evt),
473                            (pending, rx, ws, ch, last_id, redaction, bus, metrics),
474                        ));
475                    }
476                    Ok(_) => {}
477                    Err(broadcast::error::RecvError::Closed) => return None,
478                    Err(broadcast::error::RecvError::Lagged(skipped)) => {
479                        let missed = bus.read(&ws, &ch, last_id, skipped as usize);
480                        metrics.record_events_replayed(missed.len() as u64);
481                        for ev in missed {
482                            last_id = last_id.max(ev.id);
483                            pending.push_back(ev);
484                        }
485                    }
486                }
487            }
488        },
489    );
490
491    let metrics_ref = rt.metrics.clone();
492    let guarded = SseDisconnectGuard {
493        inner: Box::pin(stream),
494        metrics: metrics_ref,
495    };
496
497    Sse::new(guarded).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
498}
499
500async fn v1_metrics(State(_state): State<AppState>) -> impl IntoResponse {
501    let rt = crate::core::context_os::runtime();
502    let snap = rt.metrics.snapshot();
503    (
504        StatusCode::OK,
505        Json(serde_json::to_value(snap).unwrap_or_default()),
506    )
507}
508
509const MAX_HANDOFF_PAYLOAD_BYTES: usize = 1_000_000;
510const MAX_HANDOFF_FILES: usize = 50;
511
512async fn v1_a2a_handoff(
513    State(state): State<AppState>,
514    Json(body): Json<Value>,
515) -> impl IntoResponse {
516    let envelope = match crate::core::a2a_transport::parse_envelope(
517        &serde_json::to_string(&body).unwrap_or_default(),
518    ) {
519        Ok(env) => env,
520        Err(e) => {
521            tracing::warn!("a2a handoff parse error: {e}");
522            return (
523                StatusCode::BAD_REQUEST,
524                Json(serde_json::json!({"error": "invalid_envelope"})),
525            );
526        }
527    };
528
529    if envelope.payload_json.len() > MAX_HANDOFF_PAYLOAD_BYTES {
530        tracing::warn!(
531            "a2a handoff payload too large: {} bytes (limit {MAX_HANDOFF_PAYLOAD_BYTES})",
532            envelope.payload_json.len()
533        );
534        return (
535            StatusCode::PAYLOAD_TOO_LARGE,
536            Json(serde_json::json!({"error": "payload_too_large"})),
537        );
538    }
539
540    let rt = crate::core::context_os::runtime();
541    rt.bus.append(
542        &state.project_root,
543        "a2a",
544        &crate::core::context_os::ContextEventKindV1::SessionMutated,
545        Some(&envelope.sender.agent_id),
546        serde_json::json!({
547            "type": "handoff_received",
548            "content_type": format!("{:?}", envelope.content_type),
549            "sender": envelope.sender.agent_id,
550            "payload_size": envelope.payload_json.len(),
551        }),
552    );
553
554    match envelope.content_type {
555        crate::core::a2a_transport::TransportContentType::ContextPackage => {
556            let dir = std::path::Path::new(&state.project_root)
557                .join(".lean-ctx")
558                .join("handoffs")
559                .join("packages");
560            let _ = std::fs::create_dir_all(&dir);
561            evict_oldest_files(&dir, MAX_HANDOFF_FILES);
562            let out = dir.join(format!(
563                "ctx-{}.lctxpkg",
564                chrono::Utc::now().format("%Y%m%d_%H%M%S")
565            ));
566            if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
567                tracing::error!("a2a handoff write failed: {e}");
568                return (
569                    StatusCode::INTERNAL_SERVER_ERROR,
570                    Json(serde_json::json!({"error": "write_failed"})),
571                );
572            }
573            (
574                StatusCode::OK,
575                Json(serde_json::json!({
576                    "status": "received",
577                    "content_type": "context_package",
578                })),
579            )
580        }
581        crate::core::a2a_transport::TransportContentType::HandoffBundle => {
582            let dir = std::path::Path::new(&state.project_root)
583                .join(".lean-ctx")
584                .join("handoffs");
585            let _ = std::fs::create_dir_all(&dir);
586            evict_oldest_files(&dir, MAX_HANDOFF_FILES);
587            let out = dir.join(format!(
588                "received-{}.json",
589                chrono::Utc::now().format("%Y%m%d_%H%M%S")
590            ));
591            if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
592                tracing::error!("a2a handoff write failed: {e}");
593                return (
594                    StatusCode::INTERNAL_SERVER_ERROR,
595                    Json(serde_json::json!({"error": "write_failed"})),
596                );
597            }
598            (
599                StatusCode::OK,
600                Json(serde_json::json!({
601                    "status": "received",
602                    "content_type": "handoff_bundle",
603                })),
604            )
605        }
606        _ => (
607            StatusCode::OK,
608            Json(serde_json::json!({
609                "status": "received",
610                "content_type": format!("{:?}", envelope.content_type),
611            })),
612        ),
613    }
614}
615
616fn evict_oldest_files(dir: &std::path::Path, max_files: usize) {
617    let Ok(entries) = std::fs::read_dir(dir) else {
618        return;
619    };
620    let mut files: Vec<(std::time::SystemTime, std::path::PathBuf)> = entries
621        .filter_map(|e| {
622            let e = e.ok()?;
623            let meta = e.metadata().ok()?;
624            if meta.is_file() {
625                Some((meta.modified().unwrap_or(std::time::UNIX_EPOCH), e.path()))
626            } else {
627                None
628            }
629        })
630        .collect();
631
632    if files.len() < max_files {
633        return;
634    }
635    files.sort_by_key(|(mtime, _)| *mtime);
636    let to_remove = files.len().saturating_sub(max_files.saturating_sub(1));
637    for (_, path) in files.into_iter().take(to_remove) {
638        let _ = std::fs::remove_file(path);
639    }
640}
641
642async fn a2a_jsonrpc(Json(body): Json<Value>) -> impl IntoResponse {
643    let req: crate::core::a2a::a2a_compat::JsonRpcRequest = match serde_json::from_value(body) {
644        Ok(r) => r,
645        Err(e) => {
646            tracing::debug!("a2a JSON-RPC parse error: {e}");
647            return (
648                StatusCode::BAD_REQUEST,
649                Json(serde_json::json!({
650                    "jsonrpc": "2.0",
651                    "id": null,
652                    "error": {"code": -32700, "message": "invalid request"}
653                })),
654            );
655        }
656    };
657    let resp = crate::core::a2a::a2a_compat::handle_a2a_jsonrpc(&req);
658    let json = serde_json::to_value(resp).unwrap_or_default();
659    (StatusCode::OK, Json(json))
660}
661
662async fn v1_a2a_agent_card(State(state): State<AppState>) -> impl IntoResponse {
663    let card = crate::core::a2a::agent_card::build_agent_card(&state.project_root);
664    (
665        StatusCode::OK,
666        [(header::CONTENT_TYPE, "application/json")],
667        Json(card),
668    )
669}
670
671pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
672    cfg.validate()?;
673
674    let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
675        .parse()
676        .context("invalid host/port")?;
677
678    let project_root = cfg.project_root.to_string_lossy().to_string();
679    // IMPORTANT: Create a fresh server per MCP session in *shared* mode.
680    // This avoids per-client state clobbering while still sharing the Context OS store.
681    let service_project_root = project_root.clone();
682    let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
683        Ok(LeanCtxServer::new_shared_with_context(
684            &service_project_root,
685            "default",
686            "default",
687        ))
688    };
689    let mcp_http = StreamableHttpService::new(
690        service_factory,
691        Arc::new(
692            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
693        ),
694        cfg.mcp_http_config(),
695    );
696
697    let state = AppState {
698        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
699        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
700        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
701        project_root: project_root.clone(),
702        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
703    };
704
705    let app = Router::new()
706        .route("/health", get(health))
707        .route("/v1/manifest", get(v1_manifest))
708        .route("/v1/tools", get(v1_tools))
709        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
710        .route("/v1/events", get(v1_events))
711        .route(
712            "/v1/context/summary",
713            get(context_views::v1_context_summary),
714        )
715        .route("/v1/events/search", get(context_views::v1_events_search))
716        .route("/v1/events/lineage", get(context_views::v1_event_lineage))
717        .route("/v1/metrics", get(v1_metrics))
718        .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
719        .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
720        .route("/.well-known/agent.json", get(v1_a2a_agent_card))
721        .route("/a2a", axum::routing::post(a2a_jsonrpc))
722        .fallback_service(mcp_http)
723        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
724        .layer(middleware::from_fn_with_state(
725            state.clone(),
726            rate_limit_middleware,
727        ))
728        .layer(middleware::from_fn_with_state(
729            state.clone(),
730            concurrency_middleware,
731        ))
732        .layer(middleware::from_fn_with_state(
733            state.clone(),
734            auth_middleware,
735        ))
736        .with_state(state);
737
738    let listener = tokio::net::TcpListener::bind(addr)
739        .await
740        .with_context(|| format!("bind {addr}"))?;
741
742    tracing::info!(
743        "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
744        cfg.project_root.display()
745    );
746
747    axum::serve(listener, app)
748        .with_graceful_shutdown(async move {
749            let _ = tokio::signal::ctrl_c().await;
750        })
751        .await
752        .context("http server")?;
753    Ok(())
754}
755
756#[cfg(unix)]
757pub async fn serve_uds(cfg: HttpServerConfig, socket_path: PathBuf) -> Result<()> {
758    cfg.validate()?;
759
760    if socket_path.exists() {
761        std::fs::remove_file(&socket_path)
762            .with_context(|| format!("remove stale socket {}", socket_path.display()))?;
763    }
764
765    let project_root = cfg.project_root.to_string_lossy().to_string();
766    let service_project_root = project_root.clone();
767    let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
768        Ok(LeanCtxServer::new_shared_with_context(
769            &service_project_root,
770            "default",
771            "default",
772        ))
773    };
774    let mcp_http = StreamableHttpService::new(
775        service_factory,
776        Arc::new(
777            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
778        ),
779        cfg.mcp_http_config(),
780    );
781
782    let state = AppState {
783        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
784        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
785        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
786        project_root: project_root.clone(),
787        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
788    };
789
790    let app = Router::new()
791        .route("/health", get(health))
792        .route("/v1/manifest", get(v1_manifest))
793        .route("/v1/tools", get(v1_tools))
794        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
795        .route("/v1/events", get(v1_events))
796        .route(
797            "/v1/context/summary",
798            get(context_views::v1_context_summary),
799        )
800        .route("/v1/events/search", get(context_views::v1_events_search))
801        .route("/v1/events/lineage", get(context_views::v1_event_lineage))
802        .route("/v1/metrics", get(v1_metrics))
803        .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
804        .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
805        .route("/.well-known/agent.json", get(v1_a2a_agent_card))
806        .route("/a2a", axum::routing::post(a2a_jsonrpc))
807        .fallback_service(mcp_http)
808        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
809        .layer(middleware::from_fn_with_state(
810            state.clone(),
811            rate_limit_middleware,
812        ))
813        .layer(middleware::from_fn_with_state(
814            state.clone(),
815            concurrency_middleware,
816        ))
817        .layer(middleware::from_fn_with_state(
818            state.clone(),
819            auth_middleware,
820        ))
821        .with_state(state);
822
823    let listener = tokio::net::UnixListener::bind(&socket_path)
824        .with_context(|| format!("bind UDS {}", socket_path.display()))?;
825
826    {
827        use std::os::unix::fs::PermissionsExt;
828        let perms = std::fs::Permissions::from_mode(0o600);
829        std::fs::set_permissions(&socket_path, perms)
830            .with_context(|| format!("chmod 600 UDS {}", socket_path.display()))?;
831    }
832
833    tracing::info!(
834        "lean-ctx daemon listening on {} (project_root={})",
835        socket_path.display(),
836        cfg.project_root.display()
837    );
838
839    axum::serve(listener, app.into_make_service())
840        .with_graceful_shutdown(async move {
841            let _ = tokio::signal::ctrl_c().await;
842        })
843        .await
844        .context("uds server")?;
845    Ok(())
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851    use axum::body::Body;
852    use axum::http::Request;
853    use futures::StreamExt;
854    use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
855    use serde_json::json;
856    use tower::ServiceExt;
857
858    async fn read_first_sse_message(body: Body) -> String {
859        let mut stream = body.into_data_stream();
860        let mut buf: Vec<u8> = Vec::new();
861        for _ in 0..32 {
862            let next = tokio::time::timeout(Duration::from_secs(2), stream.next()).await;
863            let Ok(Some(Ok(bytes))) = next else {
864                break;
865            };
866            buf.extend_from_slice(&bytes);
867            if buf.windows(2).any(|w| w == b"\n\n") {
868                break;
869            }
870        }
871        String::from_utf8_lossy(&buf).to_string()
872    }
873
874    #[tokio::test]
875    async fn auth_token_blocks_requests_without_bearer_header() {
876        let dir = tempfile::tempdir().expect("tempdir");
877        let root_str = dir.path().to_string_lossy().to_string();
878        let service_project_root = root_str.clone();
879        let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
880            Ok(LeanCtxServer::new_shared_with_context(
881                &service_project_root,
882                "default",
883                "default",
884            ))
885        };
886        let cfg = StreamableHttpServerConfig::default()
887            .with_stateful_mode(false)
888            .with_json_response(true);
889
890        let mcp_http = StreamableHttpService::new(
891            service_factory,
892            Arc::new(
893                rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
894            ),
895            cfg,
896        );
897
898        let state = AppState {
899            token: Some("secret".to_string()),
900            concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
901            rate: Arc::new(RateLimiter::new(50, 100)),
902            project_root: root_str.clone(),
903            timeout: Duration::from_millis(30_000),
904        };
905
906        let app = Router::new()
907            .fallback_service(mcp_http)
908            .layer(middleware::from_fn_with_state(
909                state.clone(),
910                auth_middleware,
911            ))
912            .with_state(state);
913
914        let body = json!({
915            "jsonrpc": "2.0",
916            "id": 1,
917            "method": "tools/list",
918            "params": {}
919        })
920        .to_string();
921
922        let req = Request::builder()
923            .method("POST")
924            .uri("/")
925            .header("Host", "localhost")
926            .header("Accept", "application/json, text/event-stream")
927            .header("Content-Type", "application/json")
928            .body(Body::from(body))
929            .expect("request");
930
931        let resp = app.clone().oneshot(req).await.expect("resp");
932        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
933    }
934
935    #[tokio::test]
936    async fn mcp_service_factory_isolates_per_client_state() {
937        let dir = tempfile::tempdir().expect("tempdir");
938        let root_str = dir.path().to_string_lossy().to_string();
939
940        // Mirrors the serve() setup: service_factory must create a fresh server per MCP session.
941        let service_project_root = root_str.clone();
942        let service_factory = move || -> Result<LeanCtxServer, std::convert::Infallible> {
943            Ok(LeanCtxServer::new_shared_with_context(
944                &service_project_root,
945                "default",
946                "default",
947            ))
948        };
949
950        let s1 = service_factory().expect("server 1");
951        let s2 = service_factory().expect("server 2");
952
953        // If the two servers accidentally share the same Arc-backed fields, these writes would
954        // clobber each other. This test stays independent of rmcp's InitializeRequestParams API.
955        *s1.client_name.write().await = "client-a".to_string();
956        *s2.client_name.write().await = "client-b".to_string();
957
958        let a = s1.client_name.read().await.clone();
959        let b = s2.client_name.read().await.clone();
960        assert_eq!(a, "client-a");
961        assert_eq!(b, "client-b");
962    }
963
964    #[tokio::test]
965    async fn rate_limit_returns_429_when_exhausted() {
966        let state = AppState {
967            token: None,
968            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
969            rate: Arc::new(RateLimiter::new(1, 1)),
970            project_root: ".".to_string(),
971            timeout: Duration::from_millis(30_000),
972        };
973
974        let app = Router::new()
975            .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
976            .layer(middleware::from_fn_with_state(
977                state.clone(),
978                rate_limit_middleware,
979            ))
980            .with_state(state);
981
982        let req1 = Request::builder()
983            .method("GET")
984            .uri("/limited")
985            .header("Host", "localhost")
986            .body(Body::empty())
987            .expect("req1");
988        let resp1 = app.clone().oneshot(req1).await.expect("resp1");
989        assert_eq!(resp1.status(), StatusCode::OK);
990
991        let req2 = Request::builder()
992            .method("GET")
993            .uri("/limited")
994            .header("Host", "localhost")
995            .body(Body::empty())
996            .expect("req2");
997        let resp2 = app.clone().oneshot(req2).await.expect("resp2");
998        assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
999    }
1000
1001    #[tokio::test]
1002    async fn events_endpoint_replays_tool_call_event() {
1003        let dir = tempfile::tempdir().expect("tempdir");
1004        std::fs::create_dir_all(dir.path().join(".git")).expect("git marker");
1005        std::fs::write(dir.path().join("a.txt"), "ok").expect("file");
1006        let root_str = dir.path().to_string_lossy().to_string();
1007
1008        let state = AppState {
1009            token: None,
1010            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
1011            rate: Arc::new(RateLimiter::new(50, 100)),
1012            project_root: root_str.clone(),
1013            timeout: Duration::from_millis(30_000),
1014        };
1015
1016        let app = Router::new()
1017            .route("/v1/tools/call", axum::routing::post(v1_tool_call))
1018            .route("/v1/events", get(v1_events))
1019            .with_state(state);
1020
1021        let body = json!({
1022            "name": "ctx_session",
1023            "arguments": { "action": "status" },
1024            "workspaceId": "ws1",
1025            "channelId": "ch1"
1026        })
1027        .to_string();
1028        let req = Request::builder()
1029            .method("POST")
1030            .uri("/v1/tools/call")
1031            .header("Host", "localhost")
1032            .header("Content-Type", "application/json")
1033            .body(Body::from(body))
1034            .expect("req");
1035        let resp = app.clone().oneshot(req).await.expect("call");
1036        assert_eq!(resp.status(), StatusCode::OK);
1037
1038        // Allow async event persistence to complete (Windows CI disk IO is slower).
1039        tokio::time::sleep(Duration::from_millis(250)).await;
1040
1041        // Subscribe with replay semantics; read the first SSE message.
1042        let req = Request::builder()
1043            .method("GET")
1044            .uri("/v1/events?workspaceId=ws1&channelId=ch1&since=0&limit=1")
1045            .header("Host", "localhost")
1046            .header("Accept", "text/event-stream")
1047            .body(Body::empty())
1048            .expect("req");
1049        let resp = app.clone().oneshot(req).await.expect("events");
1050        assert_eq!(resp.status(), StatusCode::OK);
1051
1052        let msg = read_first_sse_message(resp.into_body()).await;
1053        assert!(msg.contains("event: tool_call_recorded"), "msg={msg:?}");
1054        assert!(msg.contains("\"workspaceId\":\"ws1\""), "msg={msg:?}");
1055        assert!(msg.contains("\"channelId\":\"ch1\""), "msg={msg:?}");
1056    }
1057}