Skip to main content

lean_ctx/http_server/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7    extract::Json,
8    extract::Query,
9    extract::State,
10    http::{header, Request, StatusCode},
11    middleware::{self, Next},
12    response::sse::{Event as SseEvent, KeepAlive, Sse},
13    response::{IntoResponse, Response},
14    routing::get,
15    Router,
16};
17use futures::Stream;
18use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
19use serde::Deserialize;
20use serde_json::Value;
21use tokio::sync::broadcast;
22use tokio::time::{Duration, Instant};
23
24use crate::core::context_os::ContextOsMetrics;
25use crate::engine::ContextEngine;
26use crate::tools::LeanCtxServer;
27
28pub mod context_views;
29
30#[cfg(feature = "team-server")]
31pub mod team;
32
33/// Wrapper stream that calls `record_sse_disconnect` on drop.
34use std::pin::Pin;
35
36pub(crate) struct SseDisconnectGuard<I> {
37    pub(crate) inner: Pin<Box<dyn Stream<Item = I> + Send>>,
38    pub(crate) metrics: Arc<ContextOsMetrics>,
39}
40
41impl<I> Stream for SseDisconnectGuard<I> {
42    type Item = I;
43
44    fn poll_next(
45        mut self: Pin<&mut Self>,
46        cx: &mut std::task::Context<'_>,
47    ) -> std::task::Poll<Option<Self::Item>> {
48        self.inner.as_mut().poll_next(cx)
49    }
50}
51
52impl<I> Drop for SseDisconnectGuard<I> {
53    fn drop(&mut self) {
54        self.metrics.record_sse_disconnect();
55    }
56}
57
58const MAX_ID_LEN: usize = 64;
59
60fn sanitize_id(raw: &str) -> String {
61    let trimmed = raw.trim();
62    if trimmed.is_empty() {
63        return "default".to_string();
64    }
65    let cleaned: String = trimmed
66        .chars()
67        .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
68        .take(MAX_ID_LEN)
69        .collect();
70    if cleaned.is_empty() {
71        "default".to_string()
72    } else {
73        cleaned
74    }
75}
76
77#[derive(Clone, Debug)]
78pub struct HttpServerConfig {
79    pub host: String,
80    pub port: u16,
81    pub project_root: PathBuf,
82    pub auth_token: Option<String>,
83    pub stateful_mode: bool,
84    pub json_response: bool,
85    pub disable_host_check: bool,
86    pub allowed_hosts: Vec<String>,
87    pub max_body_bytes: usize,
88    pub max_concurrency: usize,
89    pub max_rps: u32,
90    pub rate_burst: u32,
91    pub request_timeout_ms: u64,
92}
93
94impl Default for HttpServerConfig {
95    fn default() -> Self {
96        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
97        Self {
98            host: "127.0.0.1".to_string(),
99            port: 8080,
100            project_root,
101            auth_token: None,
102            stateful_mode: false,
103            json_response: true,
104            disable_host_check: false,
105            allowed_hosts: Vec::new(),
106            max_body_bytes: 2 * 1024 * 1024,
107            max_concurrency: 32,
108            max_rps: 50,
109            rate_burst: 100,
110            request_timeout_ms: 30_000,
111        }
112    }
113}
114
115impl HttpServerConfig {
116    pub fn validate(&self) -> Result<()> {
117        let host = self.host.trim().to_lowercase();
118        let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
119        if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
120            return Err(anyhow!(
121                "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
122            ));
123        }
124        Ok(())
125    }
126
127    fn mcp_http_config(&self) -> StreamableHttpServerConfig {
128        let mut cfg = StreamableHttpServerConfig::default()
129            .with_stateful_mode(self.stateful_mode)
130            .with_json_response(self.json_response);
131
132        if self.disable_host_check {
133            tracing::warn!(
134                "⚠ --disable-host-check is active: DNS rebinding protection is OFF. \
135                 Do NOT use this in production or on non-loopback interfaces."
136            );
137            cfg = cfg.disable_allowed_hosts();
138            return cfg;
139        }
140
141        if !self.allowed_hosts.is_empty() {
142            cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
143            return cfg;
144        }
145
146        // Keep rmcp's secure loopback defaults; also allow the configured host (if it's loopback).
147        let host = self.host.trim();
148        if host == "127.0.0.1" || host == "localhost" || host == "::1" {
149            cfg.allowed_hosts.push(host.to_string());
150        }
151
152        cfg
153    }
154}
155
156#[derive(Clone)]
157struct AppState {
158    token: Option<String>,
159    concurrency: Arc<tokio::sync::Semaphore>,
160    rate: Arc<RateLimiter>,
161    project_root: String,
162    timeout: Duration,
163    server: LeanCtxServer,
164}
165
166#[derive(Debug)]
167struct RateLimiter {
168    max_rps: f64,
169    burst: f64,
170    state: tokio::sync::Mutex<RateState>,
171}
172
173#[derive(Debug, Clone, Copy)]
174struct RateState {
175    tokens: f64,
176    last: Instant,
177}
178
179impl RateLimiter {
180    fn new(max_rps: u32, burst: u32) -> Self {
181        let now = Instant::now();
182        Self {
183            max_rps: (max_rps.max(1)) as f64,
184            burst: (burst.max(1)) as f64,
185            state: tokio::sync::Mutex::new(RateState {
186                tokens: (burst.max(1)) as f64,
187                last: now,
188            }),
189        }
190    }
191
192    async fn allow(&self) -> bool {
193        let mut s = self.state.lock().await;
194        let now = Instant::now();
195        let elapsed = now.saturating_duration_since(s.last);
196        let refill = elapsed.as_secs_f64() * self.max_rps;
197        s.tokens = (s.tokens + refill).min(self.burst);
198        s.last = now;
199        if s.tokens >= 1.0 {
200            s.tokens -= 1.0;
201            true
202        } else {
203            false
204        }
205    }
206}
207
208async fn auth_middleware(
209    State(state): State<AppState>,
210    req: Request<axum::body::Body>,
211    next: Next,
212) -> Response {
213    if state.token.is_none() {
214        return next.run(req).await;
215    }
216
217    if req.uri().path() == "/health" {
218        return next.run(req).await;
219    }
220
221    let expected = state.token.as_deref().unwrap_or("");
222    let Some(h) = req.headers().get(header::AUTHORIZATION) else {
223        return StatusCode::UNAUTHORIZED.into_response();
224    };
225    let Ok(s) = h.to_str() else {
226        return StatusCode::UNAUTHORIZED.into_response();
227    };
228    let Some(token) = s
229        .strip_prefix("Bearer ")
230        .or_else(|| s.strip_prefix("bearer "))
231    else {
232        return StatusCode::UNAUTHORIZED.into_response();
233    };
234    if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
235        return StatusCode::UNAUTHORIZED.into_response();
236    }
237
238    next.run(req).await
239}
240
241fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
242    use subtle::ConstantTimeEq;
243    if a.len() != b.len() {
244        return false;
245    }
246    bool::from(a.ct_eq(b))
247}
248
249async fn rate_limit_middleware(
250    State(state): State<AppState>,
251    req: Request<axum::body::Body>,
252    next: Next,
253) -> Response {
254    if !state.rate.allow().await {
255        return StatusCode::TOO_MANY_REQUESTS.into_response();
256    }
257    next.run(req).await
258}
259
260async fn concurrency_middleware(
261    State(state): State<AppState>,
262    req: Request<axum::body::Body>,
263    next: Next,
264) -> Response {
265    let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
266        return StatusCode::TOO_MANY_REQUESTS.into_response();
267    };
268    let resp = next.run(req).await;
269    drop(permit);
270    resp
271}
272
273async fn health() -> impl IntoResponse {
274    (StatusCode::OK, "ok\n")
275}
276
277#[derive(Debug, Deserialize)]
278#[serde(rename_all = "camelCase")]
279#[allow(dead_code)]
280struct ToolCallBody {
281    name: String,
282    #[serde(default)]
283    arguments: Option<Value>,
284    #[serde(default)]
285    workspace_id: Option<String>,
286    #[serde(default)]
287    channel_id: Option<String>,
288}
289
290#[derive(Debug, Deserialize)]
291#[serde(rename_all = "camelCase")]
292struct EventsQuery {
293    #[serde(default)]
294    workspace_id: Option<String>,
295    #[serde(default)]
296    channel_id: Option<String>,
297    #[serde(default)]
298    since: Option<i64>,
299    #[serde(default)]
300    limit: Option<usize>,
301    /// Comma-separated event kind filter (e.g. `tool_call,session_start`).
302    /// When set, only matching events are delivered via SSE.
303    #[serde(default)]
304    kind: Option<String>,
305}
306
307async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
308    let _ = state;
309    let v = crate::core::mcp_manifest::manifest_value();
310    (StatusCode::OK, Json(v))
311}
312
313#[derive(Debug, Deserialize)]
314#[serde(rename_all = "camelCase")]
315struct ToolsQuery {
316    #[serde(default)]
317    offset: Option<usize>,
318    #[serde(default)]
319    limit: Option<usize>,
320}
321
322async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
323    let _ = state;
324    let v = crate::core::mcp_manifest::manifest_value();
325    let tools = v
326        .get("tools")
327        .and_then(|t| t.get("granular"))
328        .cloned()
329        .unwrap_or(Value::Array(vec![]));
330
331    let all = tools.as_array().cloned().unwrap_or_default();
332    let total = all.len();
333    let offset = q.offset.unwrap_or(0).min(total);
334    let limit = q.limit.unwrap_or(200).min(500);
335    let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
336
337    (
338        StatusCode::OK,
339        Json(serde_json::json!({
340            "tools": page,
341            "total": total,
342            "offset": offset,
343            "limit": limit,
344        })),
345    )
346}
347
348async fn v1_tool_call(
349    State(state): State<AppState>,
350    Json(body): Json<ToolCallBody>,
351) -> impl IntoResponse {
352    let engine = ContextEngine::from_server(state.server.clone());
353    match tokio::time::timeout(
354        state.timeout,
355        engine.call_tool_value(&body.name, body.arguments),
356    )
357    .await
358    {
359        Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
360        Ok(Err(e)) => {
361            tracing::warn!("tool call error: {e}");
362            (
363                StatusCode::BAD_REQUEST,
364                Json(serde_json::json!({ "error": "tool_error", "code": "TOOL_ERROR" })),
365            )
366                .into_response()
367        }
368        Err(_) => (
369            StatusCode::GATEWAY_TIMEOUT,
370            Json(serde_json::json!({ "error": "request_timeout" })),
371        )
372            .into_response(),
373    }
374}
375
376async fn v1_events(
377    State(state): State<AppState>,
378    Query(q): Query<EventsQuery>,
379) -> Sse<impl Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
380    use crate::core::context_os::{redact_event_payload, ContextEventV1, RedactionLevel};
381
382    let ws = sanitize_id(&q.workspace_id.unwrap_or_else(|| "default".to_string()));
383    let ch = sanitize_id(&q.channel_id.unwrap_or_else(|| "default".to_string()));
384    let _ = &state.project_root;
385    let since = q.since.unwrap_or(0);
386    let limit = q.limit.unwrap_or(200).min(1000);
387    let redaction = RedactionLevel::RefsOnly;
388
389    let kind_filter: Option<Vec<String>> = q
390        .kind
391        .as_deref()
392        .map(|k| k.split(',').map(|s| s.trim().to_string()).collect());
393
394    let rt = crate::core::context_os::runtime();
395    let replay = rt.bus.read(&ws, &ch, since, limit);
396
397    let replay = if let Some(ref kinds) = kind_filter {
398        replay
399            .into_iter()
400            .filter(|ev| kinds.contains(&ev.kind))
401            .collect()
402    } else {
403        replay
404    };
405
406    let rx = if let Some(ref kinds) = kind_filter {
407        let kind_refs: Vec<&str> = kinds.iter().map(String::as_str).collect();
408        let filter = crate::core::context_os::TopicFilter::kinds(&kind_refs);
409        if let Some(sub) = rt.bus.subscribe_filtered(&ws, &ch, filter) {
410            crate::core::context_os::SubscriptionKind::Filtered(sub)
411        } else {
412            tracing::warn!("SSE subscriber limit reached for {ws}/{ch}");
413            let (_, rx) = broadcast::channel::<ContextEventV1>(1);
414            crate::core::context_os::SubscriptionKind::Unfiltered(rx)
415        }
416    } else if let Some(sub) = rt.bus.subscribe(&ws, &ch) {
417        crate::core::context_os::SubscriptionKind::Unfiltered(sub)
418    } else {
419        tracing::warn!("SSE subscriber limit reached for {ws}/{ch}");
420        let (_, rx) = broadcast::channel::<ContextEventV1>(1);
421        crate::core::context_os::SubscriptionKind::Unfiltered(rx)
422    };
423
424    rt.metrics.record_sse_connect();
425    rt.metrics.record_events_replayed(replay.len() as u64);
426    rt.metrics.record_workspace_active(&ws);
427
428    let bus = rt.bus.clone();
429    let metrics = rt.metrics.clone();
430    let pending: std::collections::VecDeque<ContextEventV1> = replay.into();
431
432    let stream = futures::stream::unfold(
433        (
434            pending,
435            rx,
436            ws.clone(),
437            ch.clone(),
438            since,
439            redaction,
440            bus,
441            metrics,
442        ),
443        |(mut pending, mut rx, ws, ch, mut last_id, redaction, bus, metrics)| async move {
444            if let Some(mut ev) = pending.pop_front() {
445                last_id = ev.id;
446                redact_event_payload(&mut ev, redaction);
447                let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
448                let evt = SseEvent::default()
449                    .id(ev.id.to_string())
450                    .event(ev.kind)
451                    .data(data);
452                return Some((
453                    Ok(evt),
454                    (pending, rx, ws, ch, last_id, redaction, bus, metrics),
455                ));
456            }
457
458            loop {
459                match rx.recv().await {
460                    Ok(mut ev) if ev.id > last_id => {
461                        last_id = ev.id;
462                        redact_event_payload(&mut ev, redaction);
463                        let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
464                        let evt = SseEvent::default()
465                            .id(ev.id.to_string())
466                            .event(ev.kind)
467                            .data(data);
468                        return Some((
469                            Ok(evt),
470                            (pending, rx, ws, ch, last_id, redaction, bus, metrics),
471                        ));
472                    }
473                    Ok(_) => {}
474                    Err(broadcast::error::RecvError::Closed) => return None,
475                    Err(broadcast::error::RecvError::Lagged(skipped)) => {
476                        let missed = bus.read(&ws, &ch, last_id, skipped as usize);
477                        metrics.record_events_replayed(missed.len() as u64);
478                        for ev in missed {
479                            last_id = last_id.max(ev.id);
480                            pending.push_back(ev);
481                        }
482                    }
483                }
484            }
485        },
486    );
487
488    let metrics_ref = rt.metrics.clone();
489    let guarded = SseDisconnectGuard {
490        inner: Box::pin(stream),
491        metrics: metrics_ref,
492    };
493
494    Sse::new(guarded).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
495}
496
497async fn v1_metrics(State(_state): State<AppState>) -> impl IntoResponse {
498    let rt = crate::core::context_os::runtime();
499    let snap = rt.metrics.snapshot();
500    (
501        StatusCode::OK,
502        Json(serde_json::to_value(snap).unwrap_or_default()),
503    )
504}
505
506const MAX_HANDOFF_PAYLOAD_BYTES: usize = 1_000_000;
507const MAX_HANDOFF_FILES: usize = 50;
508
509async fn v1_a2a_handoff(
510    State(state): State<AppState>,
511    Json(body): Json<Value>,
512) -> impl IntoResponse {
513    let envelope = match crate::core::a2a_transport::parse_envelope(
514        &serde_json::to_string(&body).unwrap_or_default(),
515    ) {
516        Ok(env) => env,
517        Err(e) => {
518            tracing::warn!("a2a handoff parse error: {e}");
519            return (
520                StatusCode::BAD_REQUEST,
521                Json(serde_json::json!({"error": "invalid_envelope"})),
522            );
523        }
524    };
525
526    if envelope.payload_json.len() > MAX_HANDOFF_PAYLOAD_BYTES {
527        tracing::warn!(
528            "a2a handoff payload too large: {} bytes (limit {MAX_HANDOFF_PAYLOAD_BYTES})",
529            envelope.payload_json.len()
530        );
531        return (
532            StatusCode::PAYLOAD_TOO_LARGE,
533            Json(serde_json::json!({"error": "payload_too_large"})),
534        );
535    }
536
537    let rt = crate::core::context_os::runtime();
538    rt.bus.append(
539        &state.project_root,
540        "a2a",
541        &crate::core::context_os::ContextEventKindV1::SessionMutated,
542        Some(&envelope.sender.agent_id),
543        serde_json::json!({
544            "type": "handoff_received",
545            "content_type": format!("{:?}", envelope.content_type),
546            "sender": envelope.sender.agent_id,
547            "payload_size": envelope.payload_json.len(),
548        }),
549    );
550
551    match envelope.content_type {
552        crate::core::a2a_transport::TransportContentType::ContextPackage => {
553            let dir = std::path::Path::new(&state.project_root)
554                .join(".lean-ctx")
555                .join("handoffs")
556                .join("packages");
557            let _ = std::fs::create_dir_all(&dir);
558            evict_oldest_files(&dir, MAX_HANDOFF_FILES);
559            let out = dir.join(format!(
560                "ctx-{}.lctxpkg",
561                chrono::Utc::now().format("%Y%m%d_%H%M%S")
562            ));
563            if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
564                tracing::error!("a2a handoff write failed: {e}");
565                return (
566                    StatusCode::INTERNAL_SERVER_ERROR,
567                    Json(serde_json::json!({"error": "write_failed"})),
568                );
569            }
570            (
571                StatusCode::OK,
572                Json(serde_json::json!({
573                    "status": "received",
574                    "content_type": "context_package",
575                })),
576            )
577        }
578        crate::core::a2a_transport::TransportContentType::HandoffBundle => {
579            let dir = std::path::Path::new(&state.project_root)
580                .join(".lean-ctx")
581                .join("handoffs");
582            let _ = std::fs::create_dir_all(&dir);
583            evict_oldest_files(&dir, MAX_HANDOFF_FILES);
584            let out = dir.join(format!(
585                "received-{}.json",
586                chrono::Utc::now().format("%Y%m%d_%H%M%S")
587            ));
588            if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
589                tracing::error!("a2a handoff write failed: {e}");
590                return (
591                    StatusCode::INTERNAL_SERVER_ERROR,
592                    Json(serde_json::json!({"error": "write_failed"})),
593                );
594            }
595            (
596                StatusCode::OK,
597                Json(serde_json::json!({
598                    "status": "received",
599                    "content_type": "handoff_bundle",
600                })),
601            )
602        }
603        _ => (
604            StatusCode::OK,
605            Json(serde_json::json!({
606                "status": "received",
607                "content_type": format!("{:?}", envelope.content_type),
608            })),
609        ),
610    }
611}
612
613fn evict_oldest_files(dir: &std::path::Path, max_files: usize) {
614    let Ok(entries) = std::fs::read_dir(dir) else {
615        return;
616    };
617    let mut files: Vec<(std::time::SystemTime, std::path::PathBuf)> = entries
618        .filter_map(|e| {
619            let e = e.ok()?;
620            let meta = e.metadata().ok()?;
621            if meta.is_file() {
622                Some((meta.modified().unwrap_or(std::time::UNIX_EPOCH), e.path()))
623            } else {
624                None
625            }
626        })
627        .collect();
628
629    if files.len() < max_files {
630        return;
631    }
632    files.sort_by_key(|(mtime, _)| *mtime);
633    let to_remove = files.len().saturating_sub(max_files.saturating_sub(1));
634    for (_, path) in files.into_iter().take(to_remove) {
635        let _ = std::fs::remove_file(path);
636    }
637}
638
639async fn a2a_jsonrpc(Json(body): Json<Value>) -> impl IntoResponse {
640    let req: crate::core::a2a::a2a_compat::JsonRpcRequest = match serde_json::from_value(body) {
641        Ok(r) => r,
642        Err(e) => {
643            tracing::debug!("a2a JSON-RPC parse error: {e}");
644            return (
645                StatusCode::BAD_REQUEST,
646                Json(serde_json::json!({
647                    "jsonrpc": "2.0",
648                    "id": null,
649                    "error": {"code": -32700, "message": "invalid request"}
650                })),
651            );
652        }
653    };
654    let resp = crate::core::a2a::a2a_compat::handle_a2a_jsonrpc(&req);
655    let json = serde_json::to_value(resp).unwrap_or_default();
656    (StatusCode::OK, Json(json))
657}
658
659async fn v1_a2a_agent_card(State(state): State<AppState>) -> impl IntoResponse {
660    let card = crate::core::a2a::agent_card::build_agent_card(&state.project_root);
661    (
662        StatusCode::OK,
663        [(header::CONTENT_TYPE, "application/json")],
664        Json(card),
665    )
666}
667
668pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
669    cfg.validate()?;
670
671    let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
672        .parse()
673        .context("invalid host/port")?;
674
675    let project_root = cfg.project_root.to_string_lossy().to_string();
676    // MCP sessions still get a fresh server each (per-client state isolation).
677    let service_project_root = project_root.clone();
678    let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
679        Ok(LeanCtxServer::new_shared_with_context(
680            &service_project_root,
681            "default",
682            "default",
683        ))
684    };
685    let mcp_http = StreamableHttpService::new(
686        service_factory,
687        Arc::new(
688            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
689        ),
690        cfg.mcp_http_config(),
691    );
692
693    let rest_server = LeanCtxServer::new_shared_with_context(&project_root, "default", "default");
694
695    let state = AppState {
696        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
697        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
698        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
699        project_root: project_root.clone(),
700        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
701        server: rest_server,
702    };
703
704    let app = Router::new()
705        .route("/health", get(health))
706        .route("/v1/manifest", get(v1_manifest))
707        .route("/v1/tools", get(v1_tools))
708        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
709        .route("/v1/events", get(v1_events))
710        .route(
711            "/v1/context/summary",
712            get(context_views::v1_context_summary),
713        )
714        .route("/v1/events/search", get(context_views::v1_events_search))
715        .route("/v1/events/lineage", get(context_views::v1_event_lineage))
716        .route("/v1/metrics", get(v1_metrics))
717        .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
718        .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
719        .route("/.well-known/agent.json", get(v1_a2a_agent_card))
720        .route("/a2a", axum::routing::post(a2a_jsonrpc))
721        .fallback_service(mcp_http)
722        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
723        .layer(middleware::from_fn_with_state(
724            state.clone(),
725            rate_limit_middleware,
726        ))
727        .layer(middleware::from_fn_with_state(
728            state.clone(),
729            concurrency_middleware,
730        ))
731        .layer(middleware::from_fn_with_state(
732            state.clone(),
733            auth_middleware,
734        ))
735        .with_state(state);
736
737    let listener = tokio::net::TcpListener::bind(addr)
738        .await
739        .with_context(|| format!("bind {addr}"))?;
740
741    tracing::info!(
742        "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
743        cfg.project_root.display()
744    );
745
746    axum::serve(listener, app)
747        .with_graceful_shutdown(async move {
748            let _ = tokio::signal::ctrl_c().await;
749        })
750        .await
751        .context("http server")?;
752    Ok(())
753}
754
755#[cfg(unix)]
756pub async fn serve_uds(cfg: HttpServerConfig, socket_path: PathBuf) -> Result<()> {
757    cfg.validate()?;
758
759    if socket_path.exists() {
760        std::fs::remove_file(&socket_path)
761            .with_context(|| format!("remove stale socket {}", socket_path.display()))?;
762    }
763
764    let project_root = cfg.project_root.to_string_lossy().to_string();
765    let service_project_root = project_root.clone();
766    let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
767        Ok(LeanCtxServer::new_shared_with_context(
768            &service_project_root,
769            "default",
770            "default",
771        ))
772    };
773    let mcp_http = StreamableHttpService::new(
774        service_factory,
775        Arc::new(
776            rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
777        ),
778        cfg.mcp_http_config(),
779    );
780
781    let rest_server = LeanCtxServer::new_shared_with_context(&project_root, "default", "default");
782
783    let state = AppState {
784        token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
785        concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
786        rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
787        project_root: project_root.clone(),
788        timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
789        server: rest_server,
790    };
791
792    let app = Router::new()
793        .route("/health", get(health))
794        .route("/v1/manifest", get(v1_manifest))
795        .route("/v1/tools", get(v1_tools))
796        .route("/v1/tools/call", axum::routing::post(v1_tool_call))
797        .route("/v1/events", get(v1_events))
798        .route(
799            "/v1/context/summary",
800            get(context_views::v1_context_summary),
801        )
802        .route("/v1/events/search", get(context_views::v1_events_search))
803        .route("/v1/events/lineage", get(context_views::v1_event_lineage))
804        .route("/v1/metrics", get(v1_metrics))
805        .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
806        .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
807        .route("/.well-known/agent.json", get(v1_a2a_agent_card))
808        .route("/a2a", axum::routing::post(a2a_jsonrpc))
809        .fallback_service(mcp_http)
810        .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
811        .layer(middleware::from_fn_with_state(
812            state.clone(),
813            rate_limit_middleware,
814        ))
815        .layer(middleware::from_fn_with_state(
816            state.clone(),
817            concurrency_middleware,
818        ))
819        .layer(middleware::from_fn_with_state(
820            state.clone(),
821            auth_middleware,
822        ))
823        .with_state(state);
824
825    let listener = tokio::net::UnixListener::bind(&socket_path)
826        .with_context(|| format!("bind UDS {}", socket_path.display()))?;
827
828    {
829        use std::os::unix::fs::PermissionsExt;
830        let perms = std::fs::Permissions::from_mode(0o600);
831        std::fs::set_permissions(&socket_path, perms)
832            .with_context(|| format!("chmod 600 UDS {}", socket_path.display()))?;
833    }
834
835    tracing::info!(
836        "lean-ctx daemon listening on {} (project_root={})",
837        socket_path.display(),
838        cfg.project_root.display()
839    );
840
841    axum::serve(listener, app.into_make_service())
842        .with_graceful_shutdown(async move {
843            let _ = tokio::signal::ctrl_c().await;
844        })
845        .await
846        .context("uds server")?;
847    Ok(())
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853    use axum::body::Body;
854    use axum::http::Request;
855    use futures::StreamExt;
856    use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
857    use serde_json::json;
858    use tower::ServiceExt;
859
860    async fn read_first_sse_message(body: Body) -> String {
861        let mut stream = body.into_data_stream();
862        let mut buf: Vec<u8> = Vec::new();
863        for _ in 0..32 {
864            let next = tokio::time::timeout(Duration::from_secs(2), stream.next()).await;
865            let Ok(Some(Ok(bytes))) = next else {
866                break;
867            };
868            buf.extend_from_slice(&bytes);
869            if buf.windows(2).any(|w| w == b"\n\n") {
870                break;
871            }
872        }
873        String::from_utf8_lossy(&buf).to_string()
874    }
875
876    #[tokio::test]
877    async fn auth_token_blocks_requests_without_bearer_header() {
878        let dir = tempfile::tempdir().expect("tempdir");
879        let root_str = dir.path().to_string_lossy().to_string();
880        let service_project_root = root_str.clone();
881        let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
882            Ok(LeanCtxServer::new_shared_with_context(
883                &service_project_root,
884                "default",
885                "default",
886            ))
887        };
888        let cfg = StreamableHttpServerConfig::default()
889            .with_stateful_mode(false)
890            .with_json_response(true);
891
892        let mcp_http = StreamableHttpService::new(
893            service_factory,
894            Arc::new(
895                rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
896            ),
897            cfg,
898        );
899
900        let state = AppState {
901            token: Some("secret".to_string()),
902            concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
903            rate: Arc::new(RateLimiter::new(50, 100)),
904            project_root: root_str.clone(),
905            timeout: Duration::from_secs(30),
906            server: LeanCtxServer::new_shared_with_context(&root_str, "default", "default"),
907        };
908
909        let app = Router::new()
910            .fallback_service(mcp_http)
911            .layer(middleware::from_fn_with_state(
912                state.clone(),
913                auth_middleware,
914            ))
915            .with_state(state);
916
917        let body = json!({
918            "jsonrpc": "2.0",
919            "id": 1,
920            "method": "tools/list",
921            "params": {}
922        })
923        .to_string();
924
925        let req = Request::builder()
926            .method("POST")
927            .uri("/")
928            .header("Host", "localhost")
929            .header("Accept", "application/json, text/event-stream")
930            .header("Content-Type", "application/json")
931            .body(Body::from(body))
932            .expect("request");
933
934        let resp = app.clone().oneshot(req).await.expect("resp");
935        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
936    }
937
938    #[tokio::test]
939    async fn mcp_service_factory_isolates_per_client_state() {
940        let dir = tempfile::tempdir().expect("tempdir");
941        let root_str = dir.path().to_string_lossy().to_string();
942
943        // Mirrors the serve() setup: service_factory must create a fresh server per MCP session.
944        let service_project_root = root_str.clone();
945        let service_factory = move || -> Result<LeanCtxServer, std::convert::Infallible> {
946            Ok(LeanCtxServer::new_shared_with_context(
947                &service_project_root,
948                "default",
949                "default",
950            ))
951        };
952
953        let s1 = service_factory().expect("server 1");
954        let s2 = service_factory().expect("server 2");
955
956        // If the two servers accidentally share the same Arc-backed fields, these writes would
957        // clobber each other. This test stays independent of rmcp's InitializeRequestParams API.
958        *s1.client_name.write().await = "client-a".to_string();
959        *s2.client_name.write().await = "client-b".to_string();
960
961        let a = s1.client_name.read().await.clone();
962        let b = s2.client_name.read().await.clone();
963        assert_eq!(a, "client-a");
964        assert_eq!(b, "client-b");
965    }
966
967    #[tokio::test]
968    async fn rate_limit_returns_429_when_exhausted() {
969        let state = AppState {
970            token: None,
971            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
972            rate: Arc::new(RateLimiter::new(1, 1)),
973            project_root: ".".to_string(),
974            timeout: Duration::from_secs(30),
975            server: LeanCtxServer::new_shared_with_context(".", "default", "default"),
976        };
977
978        let app = Router::new()
979            .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
980            .layer(middleware::from_fn_with_state(
981                state.clone(),
982                rate_limit_middleware,
983            ))
984            .with_state(state);
985
986        let req1 = Request::builder()
987            .method("GET")
988            .uri("/limited")
989            .header("Host", "localhost")
990            .body(Body::empty())
991            .expect("req1");
992        let resp1 = app.clone().oneshot(req1).await.expect("resp1");
993        assert_eq!(resp1.status(), StatusCode::OK);
994
995        let req2 = Request::builder()
996            .method("GET")
997            .uri("/limited")
998            .header("Host", "localhost")
999            .body(Body::empty())
1000            .expect("req2");
1001        let resp2 = app.clone().oneshot(req2).await.expect("resp2");
1002        assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
1003    }
1004
1005    #[tokio::test]
1006    async fn events_endpoint_replays_tool_call_event() {
1007        use crate::core::context_os::{self, ContextEventKindV1};
1008
1009        let dir = tempfile::tempdir().expect("tempdir");
1010        std::fs::create_dir_all(dir.path().join(".git")).expect("git marker");
1011        std::fs::write(dir.path().join("a.txt"), "ok").expect("file");
1012        let root_str = dir.path().to_string_lossy().to_string();
1013
1014        let state = AppState {
1015            token: None,
1016            concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
1017            rate: Arc::new(RateLimiter::new(50, 100)),
1018            project_root: root_str.clone(),
1019            timeout: Duration::from_secs(30),
1020            server: LeanCtxServer::new_shared_with_context(&root_str, "default", "default"),
1021        };
1022
1023        let app = Router::new()
1024            .route("/v1/events", get(v1_events))
1025            .with_state(state);
1026
1027        // Directly append an event to the bus — no fire-and-forget timing dependency.
1028        let rt = context_os::runtime();
1029        rt.bus.append(
1030            "ws1",
1031            "ch1",
1032            &ContextEventKindV1::ToolCallRecorded,
1033            Some("test-agent"),
1034            json!({"tool": "ctx_session", "action": "status"}),
1035        );
1036
1037        let req = Request::builder()
1038            .method("GET")
1039            .uri("/v1/events?workspaceId=ws1&channelId=ch1&since=0&limit=1")
1040            .header("Host", "localhost")
1041            .header("Accept", "text/event-stream")
1042            .body(Body::empty())
1043            .expect("req");
1044        let resp = app.clone().oneshot(req).await.expect("events");
1045        assert_eq!(resp.status(), StatusCode::OK);
1046
1047        let msg = read_first_sse_message(resp.into_body()).await;
1048        assert!(msg.contains("event: tool_call_recorded"), "msg={msg:?}");
1049        assert!(msg.contains("\"ws1\""), "msg={msg:?}");
1050        assert!(msg.contains("\"ch1\""), "msg={msg:?}");
1051    }
1052}