1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Context, Result};
6use axum::{
7 extract::Json,
8 extract::Query,
9 extract::State,
10 http::{header, Request, StatusCode},
11 middleware::{self, Next},
12 response::sse::{Event as SseEvent, KeepAlive, Sse},
13 response::{IntoResponse, Response},
14 routing::get,
15 Router,
16};
17use futures::Stream;
18use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
19use serde::Deserialize;
20use serde_json::Value;
21use tokio::sync::broadcast;
22use tokio::time::{Duration, Instant};
23
24use crate::core::context_os::ContextOsMetrics;
25use crate::engine::ContextEngine;
26use crate::tools::LeanCtxServer;
27
28pub mod context_views;
29
30#[cfg(feature = "team-server")]
31pub mod team;
32
33use std::pin::Pin;
35
36pub(crate) struct SseDisconnectGuard<I> {
37 pub(crate) inner: Pin<Box<dyn Stream<Item = I> + Send>>,
38 pub(crate) metrics: Arc<ContextOsMetrics>,
39}
40
41impl<I> Stream for SseDisconnectGuard<I> {
42 type Item = I;
43
44 fn poll_next(
45 mut self: Pin<&mut Self>,
46 cx: &mut std::task::Context<'_>,
47 ) -> std::task::Poll<Option<Self::Item>> {
48 self.inner.as_mut().poll_next(cx)
49 }
50}
51
52impl<I> Drop for SseDisconnectGuard<I> {
53 fn drop(&mut self) {
54 self.metrics.record_sse_disconnect();
55 }
56}
57
58#[derive(Clone, Debug)]
59pub struct HttpServerConfig {
60 pub host: String,
61 pub port: u16,
62 pub project_root: PathBuf,
63 pub auth_token: Option<String>,
64 pub stateful_mode: bool,
65 pub json_response: bool,
66 pub disable_host_check: bool,
67 pub allowed_hosts: Vec<String>,
68 pub max_body_bytes: usize,
69 pub max_concurrency: usize,
70 pub max_rps: u32,
71 pub rate_burst: u32,
72 pub request_timeout_ms: u64,
73}
74
75impl Default for HttpServerConfig {
76 fn default() -> Self {
77 let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
78 Self {
79 host: "127.0.0.1".to_string(),
80 port: 8080,
81 project_root,
82 auth_token: None,
83 stateful_mode: false,
84 json_response: true,
85 disable_host_check: false,
86 allowed_hosts: Vec::new(),
87 max_body_bytes: 2 * 1024 * 1024,
88 max_concurrency: 32,
89 max_rps: 50,
90 rate_burst: 100,
91 request_timeout_ms: 30_000,
92 }
93 }
94}
95
96impl HttpServerConfig {
97 pub fn validate(&self) -> Result<()> {
98 let host = self.host.trim().to_lowercase();
99 let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
100 if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
101 return Err(anyhow!(
102 "Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
103 ));
104 }
105 Ok(())
106 }
107
108 fn mcp_http_config(&self) -> StreamableHttpServerConfig {
109 let mut cfg = StreamableHttpServerConfig::default()
110 .with_stateful_mode(self.stateful_mode)
111 .with_json_response(self.json_response);
112
113 if self.disable_host_check {
114 cfg = cfg.disable_allowed_hosts();
115 return cfg;
116 }
117
118 if !self.allowed_hosts.is_empty() {
119 cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
120 return cfg;
121 }
122
123 let host = self.host.trim();
125 if host == "127.0.0.1" || host == "localhost" || host == "::1" {
126 cfg.allowed_hosts.push(host.to_string());
127 }
128
129 cfg
130 }
131}
132
133#[derive(Clone)]
134struct AppState {
135 token: Option<String>,
136 concurrency: Arc<tokio::sync::Semaphore>,
137 rate: Arc<RateLimiter>,
138 project_root: String,
139 timeout: Duration,
140}
141
142#[derive(Debug)]
143struct RateLimiter {
144 max_rps: f64,
145 burst: f64,
146 state: tokio::sync::Mutex<RateState>,
147}
148
149#[derive(Debug, Clone, Copy)]
150struct RateState {
151 tokens: f64,
152 last: Instant,
153}
154
155impl RateLimiter {
156 fn new(max_rps: u32, burst: u32) -> Self {
157 let now = Instant::now();
158 Self {
159 max_rps: (max_rps.max(1)) as f64,
160 burst: (burst.max(1)) as f64,
161 state: tokio::sync::Mutex::new(RateState {
162 tokens: (burst.max(1)) as f64,
163 last: now,
164 }),
165 }
166 }
167
168 async fn allow(&self) -> bool {
169 let mut s = self.state.lock().await;
170 let now = Instant::now();
171 let elapsed = now.saturating_duration_since(s.last);
172 let refill = elapsed.as_secs_f64() * self.max_rps;
173 s.tokens = (s.tokens + refill).min(self.burst);
174 s.last = now;
175 if s.tokens >= 1.0 {
176 s.tokens -= 1.0;
177 true
178 } else {
179 false
180 }
181 }
182}
183
184async fn auth_middleware(
185 State(state): State<AppState>,
186 req: Request<axum::body::Body>,
187 next: Next,
188) -> Response {
189 if state.token.is_none() {
190 return next.run(req).await;
191 }
192
193 if req.uri().path() == "/health" {
194 return next.run(req).await;
195 }
196
197 let expected = state.token.as_deref().unwrap_or("");
198 let Some(h) = req.headers().get(header::AUTHORIZATION) else {
199 return StatusCode::UNAUTHORIZED.into_response();
200 };
201 let Ok(s) = h.to_str() else {
202 return StatusCode::UNAUTHORIZED.into_response();
203 };
204 let Some(token) = s
205 .strip_prefix("Bearer ")
206 .or_else(|| s.strip_prefix("bearer "))
207 else {
208 return StatusCode::UNAUTHORIZED.into_response();
209 };
210 if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
211 return StatusCode::UNAUTHORIZED.into_response();
212 }
213
214 next.run(req).await
215}
216
217fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
218 if a.len() != b.len() {
219 return false;
220 }
221 a.iter()
222 .zip(b.iter())
223 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
224 == 0
225}
226
227async fn rate_limit_middleware(
228 State(state): State<AppState>,
229 req: Request<axum::body::Body>,
230 next: Next,
231) -> Response {
232 if req.uri().path() == "/health" {
233 return next.run(req).await;
234 }
235 if !state.rate.allow().await {
236 return StatusCode::TOO_MANY_REQUESTS.into_response();
237 }
238 next.run(req).await
239}
240
241async fn concurrency_middleware(
242 State(state): State<AppState>,
243 req: Request<axum::body::Body>,
244 next: Next,
245) -> Response {
246 if req.uri().path() == "/health" {
247 return next.run(req).await;
248 }
249 let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
250 return StatusCode::TOO_MANY_REQUESTS.into_response();
251 };
252 let resp = next.run(req).await;
253 drop(permit);
254 resp
255}
256
257async fn health() -> impl IntoResponse {
258 (StatusCode::OK, "ok\n")
259}
260
261#[derive(Debug, Deserialize)]
262#[serde(rename_all = "camelCase")]
263struct ToolCallBody {
264 name: String,
265 #[serde(default)]
266 arguments: Option<Value>,
267 #[serde(default)]
268 workspace_id: Option<String>,
269 #[serde(default)]
270 channel_id: Option<String>,
271}
272
273#[derive(Debug, Deserialize)]
274#[serde(rename_all = "camelCase")]
275struct EventsQuery {
276 #[serde(default)]
277 workspace_id: Option<String>,
278 #[serde(default)]
279 channel_id: Option<String>,
280 #[serde(default)]
281 since: Option<i64>,
282 #[serde(default)]
283 limit: Option<usize>,
284}
285
286async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
287 let _ = state;
288 let v = crate::core::mcp_manifest::manifest_value();
289 (StatusCode::OK, Json(v))
290}
291
292#[derive(Debug, Deserialize)]
293#[serde(rename_all = "camelCase")]
294struct ToolsQuery {
295 #[serde(default)]
296 offset: Option<usize>,
297 #[serde(default)]
298 limit: Option<usize>,
299}
300
301async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
302 let _ = state;
303 let v = crate::core::mcp_manifest::manifest_value();
304 let tools = v
305 .get("tools")
306 .and_then(|t| t.get("granular"))
307 .cloned()
308 .unwrap_or(Value::Array(vec![]));
309
310 let all = tools.as_array().cloned().unwrap_or_default();
311 let total = all.len();
312 let offset = q.offset.unwrap_or(0).min(total);
313 let limit = q.limit.unwrap_or(200).min(500);
314 let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
315
316 (
317 StatusCode::OK,
318 Json(serde_json::json!({
319 "tools": page,
320 "total": total,
321 "offset": offset,
322 "limit": limit,
323 })),
324 )
325}
326
327async fn v1_tool_call(
328 State(state): State<AppState>,
329 Json(body): Json<ToolCallBody>,
330) -> impl IntoResponse {
331 let ws = body.workspace_id.as_deref().unwrap_or("default");
332 let ch = body.channel_id.as_deref().unwrap_or("default");
333 let server = LeanCtxServer::new_shared_with_context(&state.project_root, ws, ch);
334 let engine = ContextEngine::from_server(server);
335 match tokio::time::timeout(
336 state.timeout,
337 engine.call_tool_value(&body.name, body.arguments),
338 )
339 .await
340 {
341 Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
342 Ok(Err(e)) => {
343 tracing::warn!("tool call error: {e}");
344 (
345 StatusCode::BAD_REQUEST,
346 Json(serde_json::json!({ "error": "tool_error", "code": "TOOL_ERROR" })),
347 )
348 .into_response()
349 }
350 Err(_) => (
351 StatusCode::GATEWAY_TIMEOUT,
352 Json(serde_json::json!({ "error": "request_timeout" })),
353 )
354 .into_response(),
355 }
356}
357
358async fn v1_events(
359 State(_state): State<AppState>,
360 Query(q): Query<EventsQuery>,
361) -> Sse<impl Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
362 use crate::core::context_os::{redact_event_payload, ContextEventV1, RedactionLevel};
363
364 let ws = q.workspace_id.unwrap_or_else(|| "default".to_string());
365 let ch = q.channel_id.unwrap_or_else(|| "default".to_string());
366 let since = q.since.unwrap_or(0);
367 let limit = q.limit.unwrap_or(200).min(1000);
368 let redaction = RedactionLevel::RefsOnly;
369
370 let rt = crate::core::context_os::runtime();
371 let replay = rt.bus.read(&ws, &ch, since, limit);
372 let rx = rt.bus.subscribe(&ws, &ch);
373 rt.metrics.record_sse_connect();
374 rt.metrics.record_events_replayed(replay.len() as u64);
375 rt.metrics.record_workspace_active(&ws);
376
377 let bus = rt.bus.clone();
378 let metrics = rt.metrics.clone();
379 let pending: std::collections::VecDeque<ContextEventV1> = replay.into();
380
381 let stream = futures::stream::unfold(
382 (
383 pending,
384 rx,
385 ws.clone(),
386 ch.clone(),
387 since,
388 redaction,
389 bus,
390 metrics,
391 ),
392 |(mut pending, mut rx, ws, ch, mut last_id, redaction, bus, metrics)| async move {
393 if let Some(mut ev) = pending.pop_front() {
394 last_id = ev.id;
395 redact_event_payload(&mut ev, redaction);
396 let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
397 let evt = SseEvent::default()
398 .id(ev.id.to_string())
399 .event(ev.kind)
400 .data(data);
401 return Some((
402 Ok(evt),
403 (pending, rx, ws, ch, last_id, redaction, bus, metrics),
404 ));
405 }
406
407 loop {
408 match rx.recv().await {
409 Ok(mut ev) if ev.id > last_id => {
410 last_id = ev.id;
411 redact_event_payload(&mut ev, redaction);
412 let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
413 let evt = SseEvent::default()
414 .id(ev.id.to_string())
415 .event(ev.kind)
416 .data(data);
417 return Some((
418 Ok(evt),
419 (pending, rx, ws, ch, last_id, redaction, bus, metrics),
420 ));
421 }
422 Ok(_) => {}
423 Err(broadcast::error::RecvError::Closed) => return None,
424 Err(broadcast::error::RecvError::Lagged(skipped)) => {
425 let missed = bus.read(&ws, &ch, last_id, skipped as usize);
426 metrics.record_events_replayed(missed.len() as u64);
427 for ev in missed {
428 last_id = last_id.max(ev.id);
429 pending.push_back(ev);
430 }
431 }
432 }
433 }
434 },
435 );
436
437 let metrics_ref = rt.metrics.clone();
438 let guarded = SseDisconnectGuard {
439 inner: Box::pin(stream),
440 metrics: metrics_ref,
441 };
442
443 Sse::new(guarded).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
444}
445
446async fn v1_metrics(State(_state): State<AppState>) -> impl IntoResponse {
447 let rt = crate::core::context_os::runtime();
448 let snap = rt.metrics.snapshot();
449 (
450 StatusCode::OK,
451 Json(serde_json::to_value(snap).unwrap_or_default()),
452 )
453}
454
455async fn v1_a2a_handoff(
456 State(state): State<AppState>,
457 Json(body): Json<Value>,
458) -> impl IntoResponse {
459 let envelope = match crate::core::a2a_transport::parse_envelope(
460 &serde_json::to_string(&body).unwrap_or_default(),
461 ) {
462 Ok(env) => env,
463 Err(e) => {
464 return (
465 StatusCode::BAD_REQUEST,
466 Json(serde_json::json!({"error": e})),
467 );
468 }
469 };
470
471 let rt = crate::core::context_os::runtime();
472 rt.bus.append(
473 &state.project_root,
474 "a2a",
475 &crate::core::context_os::ContextEventKindV1::SessionMutated,
476 Some(&envelope.sender.agent_id),
477 serde_json::json!({
478 "type": "handoff_received",
479 "content_type": format!("{:?}", envelope.content_type),
480 "sender": envelope.sender.agent_id,
481 "payload_size": envelope.payload_json.len(),
482 }),
483 );
484
485 match envelope.content_type {
486 crate::core::a2a_transport::TransportContentType::ContextPackage => {
487 let tmp = std::env::temp_dir().join(format!(
488 "lean-ctx-a2a-{}.lctxpkg",
489 chrono::Utc::now().format("%Y%m%d_%H%M%S")
490 ));
491 if let Err(e) = std::fs::write(&tmp, &envelope.payload_json) {
492 return (
493 StatusCode::INTERNAL_SERVER_ERROR,
494 Json(serde_json::json!({"error": format!("write: {e}")})),
495 );
496 }
497 (
498 StatusCode::OK,
499 Json(serde_json::json!({
500 "status": "received",
501 "content_type": "context_package",
502 "stored": tmp.display().to_string(),
503 })),
504 )
505 }
506 crate::core::a2a_transport::TransportContentType::HandoffBundle => {
507 let dir = std::path::Path::new(&state.project_root)
508 .join(".lean-ctx")
509 .join("handoffs");
510 let _ = std::fs::create_dir_all(&dir);
511 let out = dir.join(format!(
512 "received-{}.json",
513 chrono::Utc::now().format("%Y%m%d_%H%M%S")
514 ));
515 if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
516 return (
517 StatusCode::INTERNAL_SERVER_ERROR,
518 Json(serde_json::json!({"error": format!("write: {e}")})),
519 );
520 }
521 (
522 StatusCode::OK,
523 Json(serde_json::json!({
524 "status": "received",
525 "content_type": "handoff_bundle",
526 "stored": out.display().to_string(),
527 })),
528 )
529 }
530 _ => (
531 StatusCode::OK,
532 Json(serde_json::json!({
533 "status": "received",
534 "content_type": format!("{:?}", envelope.content_type),
535 })),
536 ),
537 }
538}
539
540async fn a2a_jsonrpc(Json(body): Json<Value>) -> impl IntoResponse {
541 let req: crate::core::a2a::a2a_compat::JsonRpcRequest = match serde_json::from_value(body) {
542 Ok(r) => r,
543 Err(e) => {
544 return (
545 StatusCode::BAD_REQUEST,
546 Json(serde_json::json!({
547 "jsonrpc": "2.0",
548 "id": null,
549 "error": {"code": -32700, "message": format!("parse error: {e}")}
550 })),
551 );
552 }
553 };
554 let resp = crate::core::a2a::a2a_compat::handle_a2a_jsonrpc(&req);
555 let json = serde_json::to_value(resp).unwrap_or_default();
556 (StatusCode::OK, Json(json))
557}
558
559async fn v1_a2a_agent_card(State(state): State<AppState>) -> impl IntoResponse {
560 let card = crate::core::a2a::agent_card::build_agent_card(&state.project_root);
561 (
562 StatusCode::OK,
563 [(header::CONTENT_TYPE, "application/json")],
564 Json(card),
565 )
566}
567
568pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
569 cfg.validate()?;
570
571 let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
572 .parse()
573 .context("invalid host/port")?;
574
575 let project_root = cfg.project_root.to_string_lossy().to_string();
576 let service_project_root = project_root.clone();
579 let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
580 Ok(LeanCtxServer::new_shared_with_context(
581 &service_project_root,
582 "default",
583 "default",
584 ))
585 };
586 let mcp_http = StreamableHttpService::new(
587 service_factory,
588 Arc::new(
589 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
590 ),
591 cfg.mcp_http_config(),
592 );
593
594 let state = AppState {
595 token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
596 concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
597 rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
598 project_root: project_root.clone(),
599 timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
600 };
601
602 let app = Router::new()
603 .route("/health", get(health))
604 .route("/v1/manifest", get(v1_manifest))
605 .route("/v1/tools", get(v1_tools))
606 .route("/v1/tools/call", axum::routing::post(v1_tool_call))
607 .route("/v1/events", get(v1_events))
608 .route(
609 "/v1/context/summary",
610 get(context_views::v1_context_summary),
611 )
612 .route("/v1/events/search", get(context_views::v1_events_search))
613 .route("/v1/events/lineage", get(context_views::v1_event_lineage))
614 .route("/v1/metrics", get(v1_metrics))
615 .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
616 .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
617 .route("/.well-known/agent.json", get(v1_a2a_agent_card))
618 .route("/a2a", axum::routing::post(a2a_jsonrpc))
619 .fallback_service(mcp_http)
620 .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
621 .layer(middleware::from_fn_with_state(
622 state.clone(),
623 rate_limit_middleware,
624 ))
625 .layer(middleware::from_fn_with_state(
626 state.clone(),
627 concurrency_middleware,
628 ))
629 .layer(middleware::from_fn_with_state(
630 state.clone(),
631 auth_middleware,
632 ))
633 .with_state(state);
634
635 let listener = tokio::net::TcpListener::bind(addr)
636 .await
637 .with_context(|| format!("bind {addr}"))?;
638
639 tracing::info!(
640 "lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
641 cfg.project_root.display()
642 );
643
644 axum::serve(listener, app)
645 .with_graceful_shutdown(async move {
646 let _ = tokio::signal::ctrl_c().await;
647 })
648 .await
649 .context("http server")?;
650 Ok(())
651}
652
653#[cfg(unix)]
654pub async fn serve_uds(cfg: HttpServerConfig, socket_path: PathBuf) -> Result<()> {
655 cfg.validate()?;
656
657 if socket_path.exists() {
658 std::fs::remove_file(&socket_path)
659 .with_context(|| format!("remove stale socket {}", socket_path.display()))?;
660 }
661
662 let project_root = cfg.project_root.to_string_lossy().to_string();
663 let service_project_root = project_root.clone();
664 let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
665 Ok(LeanCtxServer::new_shared_with_context(
666 &service_project_root,
667 "default",
668 "default",
669 ))
670 };
671 let mcp_http = StreamableHttpService::new(
672 service_factory,
673 Arc::new(
674 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
675 ),
676 cfg.mcp_http_config(),
677 );
678
679 let state = AppState {
680 token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
681 concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
682 rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
683 project_root: project_root.clone(),
684 timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
685 };
686
687 let app = Router::new()
688 .route("/health", get(health))
689 .route("/v1/manifest", get(v1_manifest))
690 .route("/v1/tools", get(v1_tools))
691 .route("/v1/tools/call", axum::routing::post(v1_tool_call))
692 .route("/v1/events", get(v1_events))
693 .route(
694 "/v1/context/summary",
695 get(context_views::v1_context_summary),
696 )
697 .route("/v1/events/search", get(context_views::v1_events_search))
698 .route("/v1/events/lineage", get(context_views::v1_event_lineage))
699 .route("/v1/metrics", get(v1_metrics))
700 .route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
701 .route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
702 .route("/.well-known/agent.json", get(v1_a2a_agent_card))
703 .route("/a2a", axum::routing::post(a2a_jsonrpc))
704 .fallback_service(mcp_http)
705 .layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
706 .layer(middleware::from_fn_with_state(
707 state.clone(),
708 rate_limit_middleware,
709 ))
710 .layer(middleware::from_fn_with_state(
711 state.clone(),
712 concurrency_middleware,
713 ))
714 .layer(middleware::from_fn_with_state(
715 state.clone(),
716 auth_middleware,
717 ))
718 .with_state(state);
719
720 let listener = tokio::net::UnixListener::bind(&socket_path)
721 .with_context(|| format!("bind UDS {}", socket_path.display()))?;
722
723 {
724 use std::os::unix::fs::PermissionsExt;
725 let perms = std::fs::Permissions::from_mode(0o600);
726 std::fs::set_permissions(&socket_path, perms)
727 .with_context(|| format!("chmod 600 UDS {}", socket_path.display()))?;
728 }
729
730 tracing::info!(
731 "lean-ctx daemon listening on {} (project_root={})",
732 socket_path.display(),
733 cfg.project_root.display()
734 );
735
736 axum::serve(listener, app.into_make_service())
737 .with_graceful_shutdown(async move {
738 let _ = tokio::signal::ctrl_c().await;
739 })
740 .await
741 .context("uds server")?;
742 Ok(())
743}
744
745#[cfg(test)]
746mod tests {
747 use super::*;
748 use axum::body::Body;
749 use axum::http::Request;
750 use futures::StreamExt;
751 use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
752 use serde_json::json;
753 use tower::ServiceExt;
754
755 async fn read_first_sse_message(body: Body) -> String {
756 let mut stream = body.into_data_stream();
757 let mut buf: Vec<u8> = Vec::new();
758 for _ in 0..32 {
759 let next = tokio::time::timeout(Duration::from_secs(2), stream.next()).await;
760 let Ok(Some(Ok(bytes))) = next else {
761 break;
762 };
763 buf.extend_from_slice(&bytes);
764 if buf.windows(2).any(|w| w == b"\n\n") {
765 break;
766 }
767 }
768 String::from_utf8_lossy(&buf).to_string()
769 }
770
771 #[tokio::test]
772 async fn auth_token_blocks_requests_without_bearer_header() {
773 let dir = tempfile::tempdir().expect("tempdir");
774 let root_str = dir.path().to_string_lossy().to_string();
775 let service_project_root = root_str.clone();
776 let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
777 Ok(LeanCtxServer::new_shared_with_context(
778 &service_project_root,
779 "default",
780 "default",
781 ))
782 };
783 let cfg = StreamableHttpServerConfig::default()
784 .with_stateful_mode(false)
785 .with_json_response(true);
786
787 let mcp_http = StreamableHttpService::new(
788 service_factory,
789 Arc::new(
790 rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
791 ),
792 cfg,
793 );
794
795 let state = AppState {
796 token: Some("secret".to_string()),
797 concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
798 rate: Arc::new(RateLimiter::new(50, 100)),
799 project_root: root_str.clone(),
800 timeout: Duration::from_millis(30_000),
801 };
802
803 let app = Router::new()
804 .fallback_service(mcp_http)
805 .layer(middleware::from_fn_with_state(
806 state.clone(),
807 auth_middleware,
808 ))
809 .with_state(state);
810
811 let body = json!({
812 "jsonrpc": "2.0",
813 "id": 1,
814 "method": "tools/list",
815 "params": {}
816 })
817 .to_string();
818
819 let req = Request::builder()
820 .method("POST")
821 .uri("/")
822 .header("Host", "localhost")
823 .header("Accept", "application/json, text/event-stream")
824 .header("Content-Type", "application/json")
825 .body(Body::from(body))
826 .expect("request");
827
828 let resp = app.clone().oneshot(req).await.expect("resp");
829 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
830 }
831
832 #[tokio::test]
833 async fn mcp_service_factory_isolates_per_client_state() {
834 let dir = tempfile::tempdir().expect("tempdir");
835 let root_str = dir.path().to_string_lossy().to_string();
836
837 let service_project_root = root_str.clone();
839 let service_factory = move || -> Result<LeanCtxServer, std::convert::Infallible> {
840 Ok(LeanCtxServer::new_shared_with_context(
841 &service_project_root,
842 "default",
843 "default",
844 ))
845 };
846
847 let s1 = service_factory().expect("server 1");
848 let s2 = service_factory().expect("server 2");
849
850 *s1.client_name.write().await = "client-a".to_string();
853 *s2.client_name.write().await = "client-b".to_string();
854
855 let a = s1.client_name.read().await.clone();
856 let b = s2.client_name.read().await.clone();
857 assert_eq!(a, "client-a");
858 assert_eq!(b, "client-b");
859 }
860
861 #[tokio::test]
862 async fn rate_limit_returns_429_when_exhausted() {
863 let state = AppState {
864 token: None,
865 concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
866 rate: Arc::new(RateLimiter::new(1, 1)),
867 project_root: ".".to_string(),
868 timeout: Duration::from_millis(30_000),
869 };
870
871 let app = Router::new()
872 .route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
873 .layer(middleware::from_fn_with_state(
874 state.clone(),
875 rate_limit_middleware,
876 ))
877 .with_state(state);
878
879 let req1 = Request::builder()
880 .method("GET")
881 .uri("/limited")
882 .header("Host", "localhost")
883 .body(Body::empty())
884 .expect("req1");
885 let resp1 = app.clone().oneshot(req1).await.expect("resp1");
886 assert_eq!(resp1.status(), StatusCode::OK);
887
888 let req2 = Request::builder()
889 .method("GET")
890 .uri("/limited")
891 .header("Host", "localhost")
892 .body(Body::empty())
893 .expect("req2");
894 let resp2 = app.clone().oneshot(req2).await.expect("resp2");
895 assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
896 }
897
898 #[tokio::test]
899 async fn events_endpoint_replays_tool_call_event() {
900 let dir = tempfile::tempdir().expect("tempdir");
901 std::fs::create_dir_all(dir.path().join(".git")).expect("git marker");
902 std::fs::write(dir.path().join("a.txt"), "ok").expect("file");
903 let root_str = dir.path().to_string_lossy().to_string();
904
905 let state = AppState {
906 token: None,
907 concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
908 rate: Arc::new(RateLimiter::new(50, 100)),
909 project_root: root_str.clone(),
910 timeout: Duration::from_millis(30_000),
911 };
912
913 let app = Router::new()
914 .route("/v1/tools/call", axum::routing::post(v1_tool_call))
915 .route("/v1/events", get(v1_events))
916 .with_state(state);
917
918 let body = json!({
919 "name": "ctx_session",
920 "arguments": { "action": "status" },
921 "workspaceId": "ws1",
922 "channelId": "ch1"
923 })
924 .to_string();
925 let req = Request::builder()
926 .method("POST")
927 .uri("/v1/tools/call")
928 .header("Host", "localhost")
929 .header("Content-Type", "application/json")
930 .body(Body::from(body))
931 .expect("req");
932 let resp = app.clone().oneshot(req).await.expect("call");
933 assert_eq!(resp.status(), StatusCode::OK);
934
935 tokio::time::sleep(Duration::from_millis(250)).await;
937
938 let req = Request::builder()
940 .method("GET")
941 .uri("/v1/events?workspaceId=ws1&channelId=ch1&since=0&limit=1")
942 .header("Host", "localhost")
943 .header("Accept", "text/event-stream")
944 .body(Body::empty())
945 .expect("req");
946 let resp = app.clone().oneshot(req).await.expect("events");
947 assert_eq!(resp.status(), StatusCode::OK);
948
949 let msg = read_first_sse_message(resp.into_body()).await;
950 assert!(msg.contains("event: tool_call_recorded"), "msg={msg:?}");
951 assert!(msg.contains("\"workspaceId\":\"ws1\""), "msg={msg:?}");
952 assert!(msg.contains("\"channelId\":\"ch1\""), "msg={msg:?}");
953 }
954}