use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use axum::{
extract::Json,
extract::Query,
extract::State,
http::{header, Request, StatusCode},
middleware::{self, Next},
response::sse::{Event as SseEvent, KeepAlive, Sse},
response::{IntoResponse, Response},
routing::get,
Router,
};
use futures::Stream;
use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
use serde::Deserialize;
use serde_json::Value;
use tokio::sync::broadcast;
use tokio::time::{Duration, Instant};
use crate::core::context_os::ContextOsMetrics;
use crate::engine::ContextEngine;
use crate::tools::LeanCtxServer;
pub mod context_views;
#[cfg(feature = "team-server")]
pub mod team;
use std::pin::Pin;
pub(crate) struct SseDisconnectGuard<I> {
pub(crate) inner: Pin<Box<dyn Stream<Item = I> + Send>>,
pub(crate) metrics: Arc<ContextOsMetrics>,
}
impl<I> Stream for SseDisconnectGuard<I> {
type Item = I;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}
impl<I> Drop for SseDisconnectGuard<I> {
fn drop(&mut self) {
self.metrics.record_sse_disconnect();
}
}
const MAX_ID_LEN: usize = 64;
fn sanitize_id(raw: &str) -> String {
let trimmed = raw.trim();
if trimmed.is_empty() {
return "default".to_string();
}
let cleaned: String = trimmed
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
.take(MAX_ID_LEN)
.collect();
if cleaned.is_empty() {
"default".to_string()
} else {
cleaned
}
}
#[derive(Clone, Debug)]
pub struct HttpServerConfig {
pub host: String,
pub port: u16,
pub project_root: PathBuf,
pub auth_token: Option<String>,
pub stateful_mode: bool,
pub json_response: bool,
pub disable_host_check: bool,
pub allowed_hosts: Vec<String>,
pub max_body_bytes: usize,
pub max_concurrency: usize,
pub max_rps: u32,
pub rate_burst: u32,
pub request_timeout_ms: u64,
}
impl Default for HttpServerConfig {
fn default() -> Self {
let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
Self {
host: "127.0.0.1".to_string(),
port: 8080,
project_root,
auth_token: None,
stateful_mode: false,
json_response: true,
disable_host_check: false,
allowed_hosts: Vec::new(),
max_body_bytes: 2 * 1024 * 1024,
max_concurrency: 32,
max_rps: 50,
rate_burst: 100,
request_timeout_ms: 30_000,
}
}
}
impl HttpServerConfig {
pub fn validate(&self) -> Result<()> {
let host = self.host.trim().to_lowercase();
let is_loopback = host == "127.0.0.1" || host == "localhost" || host == "::1";
if !is_loopback && self.auth_token.as_deref().unwrap_or("").is_empty() {
return Err(anyhow!(
"Refusing to bind to host='{host}' without auth. Provide --auth-token (or bind to 127.0.0.1)."
));
}
Ok(())
}
fn mcp_http_config(&self) -> StreamableHttpServerConfig {
let mut cfg = StreamableHttpServerConfig::default()
.with_stateful_mode(self.stateful_mode)
.with_json_response(self.json_response);
if self.disable_host_check {
tracing::warn!(
"⚠ --disable-host-check is active: DNS rebinding protection is OFF. \
Do NOT use this in production or on non-loopback interfaces."
);
cfg = cfg.disable_allowed_hosts();
return cfg;
}
if !self.allowed_hosts.is_empty() {
cfg = cfg.with_allowed_hosts(self.allowed_hosts.clone());
return cfg;
}
let host = self.host.trim();
if host == "127.0.0.1" || host == "localhost" || host == "::1" {
cfg.allowed_hosts.push(host.to_string());
}
cfg
}
}
#[derive(Clone)]
struct AppState {
token: Option<String>,
concurrency: Arc<tokio::sync::Semaphore>,
rate: Arc<RateLimiter>,
project_root: String,
timeout: Duration,
server: LeanCtxServer,
}
#[derive(Debug)]
struct RateLimiter {
max_rps: f64,
burst: f64,
state: tokio::sync::Mutex<RateState>,
}
#[derive(Debug, Clone, Copy)]
struct RateState {
tokens: f64,
last: Instant,
}
impl RateLimiter {
fn new(max_rps: u32, burst: u32) -> Self {
let now = Instant::now();
Self {
max_rps: (max_rps.max(1)) as f64,
burst: (burst.max(1)) as f64,
state: tokio::sync::Mutex::new(RateState {
tokens: (burst.max(1)) as f64,
last: now,
}),
}
}
async fn allow(&self) -> bool {
let mut s = self.state.lock().await;
let now = Instant::now();
let elapsed = now.saturating_duration_since(s.last);
let refill = elapsed.as_secs_f64() * self.max_rps;
s.tokens = (s.tokens + refill).min(self.burst);
s.last = now;
if s.tokens >= 1.0 {
s.tokens -= 1.0;
true
} else {
false
}
}
}
async fn auth_middleware(
State(state): State<AppState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
if state.token.is_none() {
return next.run(req).await;
}
if req.uri().path() == "/health" {
return next.run(req).await;
}
let expected = state.token.as_deref().unwrap_or("");
let Some(h) = req.headers().get(header::AUTHORIZATION) else {
return StatusCode::UNAUTHORIZED.into_response();
};
let Ok(s) = h.to_str() else {
return StatusCode::UNAUTHORIZED.into_response();
};
let Some(token) = s
.strip_prefix("Bearer ")
.or_else(|| s.strip_prefix("bearer "))
else {
return StatusCode::UNAUTHORIZED.into_response();
};
if !constant_time_eq(token.as_bytes(), expected.as_bytes()) {
return StatusCode::UNAUTHORIZED.into_response();
}
next.run(req).await
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
if a.len() != b.len() {
return false;
}
bool::from(a.ct_eq(b))
}
async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
if !state.rate.allow().await {
return StatusCode::TOO_MANY_REQUESTS.into_response();
}
next.run(req).await
}
async fn concurrency_middleware(
State(state): State<AppState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
let Ok(permit) = state.concurrency.clone().try_acquire_owned() else {
return StatusCode::TOO_MANY_REQUESTS.into_response();
};
let resp = next.run(req).await;
drop(permit);
resp
}
async fn health() -> impl IntoResponse {
(StatusCode::OK, "ok\n")
}
async fn v1_shutdown() -> impl IntoResponse {
tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(100)).await;
std::process::exit(0);
});
(StatusCode::OK, "shutting down\n")
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
struct ToolCallBody {
name: String,
#[serde(default)]
arguments: Option<Value>,
#[serde(default)]
workspace_id: Option<String>,
#[serde(default)]
channel_id: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct EventsQuery {
#[serde(default)]
workspace_id: Option<String>,
#[serde(default)]
channel_id: Option<String>,
#[serde(default)]
since: Option<i64>,
#[serde(default)]
limit: Option<usize>,
#[serde(default)]
kind: Option<String>,
}
async fn v1_manifest(State(state): State<AppState>) -> impl IntoResponse {
let _ = state;
let v = crate::core::mcp_manifest::manifest_value();
(StatusCode::OK, Json(v))
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ToolsQuery {
#[serde(default)]
offset: Option<usize>,
#[serde(default)]
limit: Option<usize>,
}
async fn v1_tools(State(state): State<AppState>, Query(q): Query<ToolsQuery>) -> impl IntoResponse {
let _ = state;
let v = crate::core::mcp_manifest::manifest_value();
let tools = v
.get("tools")
.and_then(|t| t.get("granular"))
.cloned()
.unwrap_or(Value::Array(vec![]));
let all = tools.as_array().cloned().unwrap_or_default();
let total = all.len();
let offset = q.offset.unwrap_or(0).min(total);
let limit = q.limit.unwrap_or(200).min(500);
let page = all.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
(
StatusCode::OK,
Json(serde_json::json!({
"tools": page,
"total": total,
"offset": offset,
"limit": limit,
})),
)
}
async fn v1_tool_call(
State(state): State<AppState>,
Json(body): Json<ToolCallBody>,
) -> impl IntoResponse {
let engine = ContextEngine::from_server(state.server.clone());
match tokio::time::timeout(
state.timeout,
engine.call_tool_value(&body.name, body.arguments),
)
.await
{
Ok(Ok(v)) => (StatusCode::OK, Json(serde_json::json!({ "result": v }))).into_response(),
Ok(Err(e)) => {
tracing::warn!("tool call error: {e}");
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "tool_error", "code": "TOOL_ERROR" })),
)
.into_response()
}
Err(_) => (
StatusCode::GATEWAY_TIMEOUT,
Json(serde_json::json!({ "error": "request_timeout" })),
)
.into_response(),
}
}
async fn v1_events(
State(state): State<AppState>,
Query(q): Query<EventsQuery>,
) -> Sse<impl Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
use crate::core::context_os::{redact_event_payload, ContextEventV1, RedactionLevel};
let ws = sanitize_id(&q.workspace_id.unwrap_or_else(|| "default".to_string()));
let ch = sanitize_id(&q.channel_id.unwrap_or_else(|| "default".to_string()));
let _ = &state.project_root;
let since = q.since.unwrap_or(0);
let limit = q.limit.unwrap_or(200).min(1000);
let redaction = RedactionLevel::RefsOnly;
let kind_filter: Option<Vec<String>> = q
.kind
.as_deref()
.map(|k| k.split(',').map(|s| s.trim().to_string()).collect());
let rt = crate::core::context_os::runtime();
let replay = rt.bus.read(&ws, &ch, since, limit);
let replay = if let Some(ref kinds) = kind_filter {
replay
.into_iter()
.filter(|ev| kinds.contains(&ev.kind))
.collect()
} else {
replay
};
let rx = if let Some(ref kinds) = kind_filter {
let kind_refs: Vec<&str> = kinds.iter().map(String::as_str).collect();
let filter = crate::core::context_os::TopicFilter::kinds(&kind_refs);
if let Some(sub) = rt.bus.subscribe_filtered(&ws, &ch, filter) {
crate::core::context_os::SubscriptionKind::Filtered(sub)
} else {
tracing::warn!("SSE subscriber limit reached for {ws}/{ch}");
let (_, rx) = broadcast::channel::<ContextEventV1>(1);
crate::core::context_os::SubscriptionKind::Unfiltered(rx)
}
} else if let Some(sub) = rt.bus.subscribe(&ws, &ch) {
crate::core::context_os::SubscriptionKind::Unfiltered(sub)
} else {
tracing::warn!("SSE subscriber limit reached for {ws}/{ch}");
let (_, rx) = broadcast::channel::<ContextEventV1>(1);
crate::core::context_os::SubscriptionKind::Unfiltered(rx)
};
rt.metrics.record_sse_connect();
rt.metrics.record_events_replayed(replay.len() as u64);
rt.metrics.record_workspace_active(&ws);
let bus = rt.bus.clone();
let metrics = rt.metrics.clone();
let pending: std::collections::VecDeque<ContextEventV1> = replay.into();
let stream = futures::stream::unfold(
(
pending,
rx,
ws.clone(),
ch.clone(),
since,
redaction,
bus,
metrics,
),
|(mut pending, mut rx, ws, ch, mut last_id, redaction, bus, metrics)| async move {
if let Some(mut ev) = pending.pop_front() {
last_id = ev.id;
redact_event_payload(&mut ev, redaction);
let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
let evt = SseEvent::default()
.id(ev.id.to_string())
.event(ev.kind)
.data(data);
return Some((
Ok(evt),
(pending, rx, ws, ch, last_id, redaction, bus, metrics),
));
}
loop {
match rx.recv().await {
Ok(mut ev) if ev.id > last_id => {
last_id = ev.id;
redact_event_payload(&mut ev, redaction);
let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".to_string());
let evt = SseEvent::default()
.id(ev.id.to_string())
.event(ev.kind)
.data(data);
return Some((
Ok(evt),
(pending, rx, ws, ch, last_id, redaction, bus, metrics),
));
}
Ok(_) => {}
Err(broadcast::error::RecvError::Closed) => return None,
Err(broadcast::error::RecvError::Lagged(skipped)) => {
let missed = bus.read(&ws, &ch, last_id, skipped as usize);
metrics.record_events_replayed(missed.len() as u64);
for ev in missed {
last_id = last_id.max(ev.id);
pending.push_back(ev);
}
}
}
}
},
);
let metrics_ref = rt.metrics.clone();
let guarded = SseDisconnectGuard {
inner: Box::pin(stream),
metrics: metrics_ref,
};
Sse::new(guarded).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
}
async fn v1_metrics(State(_state): State<AppState>) -> impl IntoResponse {
let rt = crate::core::context_os::runtime();
let snap = rt.metrics.snapshot();
(
StatusCode::OK,
Json(serde_json::to_value(snap).unwrap_or_default()),
)
}
const MAX_HANDOFF_PAYLOAD_BYTES: usize = 1_000_000;
const MAX_HANDOFF_FILES: usize = 50;
async fn v1_a2a_handoff(
State(state): State<AppState>,
Json(body): Json<Value>,
) -> impl IntoResponse {
let envelope = match crate::core::a2a_transport::parse_envelope(
&serde_json::to_string(&body).unwrap_or_default(),
) {
Ok(env) => env,
Err(e) => {
tracing::warn!("a2a handoff parse error: {e}");
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "invalid_envelope"})),
);
}
};
if envelope.payload_json.len() > MAX_HANDOFF_PAYLOAD_BYTES {
tracing::warn!(
"a2a handoff payload too large: {} bytes (limit {MAX_HANDOFF_PAYLOAD_BYTES})",
envelope.payload_json.len()
);
return (
StatusCode::PAYLOAD_TOO_LARGE,
Json(serde_json::json!({"error": "payload_too_large"})),
);
}
let rt = crate::core::context_os::runtime();
rt.bus.append(
&state.project_root,
"a2a",
&crate::core::context_os::ContextEventKindV1::SessionMutated,
Some(&envelope.sender.agent_id),
serde_json::json!({
"type": "handoff_received",
"content_type": format!("{:?}", envelope.content_type),
"sender": envelope.sender.agent_id,
"payload_size": envelope.payload_json.len(),
}),
);
match envelope.content_type {
crate::core::a2a_transport::TransportContentType::ContextPackage => {
let dir = std::path::Path::new(&state.project_root)
.join(".lean-ctx")
.join("handoffs")
.join("packages");
let _ = std::fs::create_dir_all(&dir);
evict_oldest_files(&dir, MAX_HANDOFF_FILES);
let out = dir.join(format!(
"ctx-{}.lctxpkg",
chrono::Utc::now().format("%Y%m%d_%H%M%S")
));
if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
tracing::error!("a2a handoff write failed: {e}");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "write_failed"})),
);
}
(
StatusCode::OK,
Json(serde_json::json!({
"status": "received",
"content_type": "context_package",
})),
)
}
crate::core::a2a_transport::TransportContentType::HandoffBundle => {
let dir = std::path::Path::new(&state.project_root)
.join(".lean-ctx")
.join("handoffs");
let _ = std::fs::create_dir_all(&dir);
evict_oldest_files(&dir, MAX_HANDOFF_FILES);
let out = dir.join(format!(
"received-{}.json",
chrono::Utc::now().format("%Y%m%d_%H%M%S")
));
if let Err(e) = std::fs::write(&out, &envelope.payload_json) {
tracing::error!("a2a handoff write failed: {e}");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "write_failed"})),
);
}
(
StatusCode::OK,
Json(serde_json::json!({
"status": "received",
"content_type": "handoff_bundle",
})),
)
}
_ => (
StatusCode::OK,
Json(serde_json::json!({
"status": "received",
"content_type": format!("{:?}", envelope.content_type),
})),
),
}
}
fn evict_oldest_files(dir: &std::path::Path, max_files: usize) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
let mut files: Vec<(std::time::SystemTime, std::path::PathBuf)> = entries
.filter_map(|e| {
let e = e.ok()?;
let meta = e.metadata().ok()?;
if meta.is_file() {
Some((meta.modified().unwrap_or(std::time::UNIX_EPOCH), e.path()))
} else {
None
}
})
.collect();
if files.len() < max_files {
return;
}
files.sort_by_key(|(mtime, _)| *mtime);
let to_remove = files.len().saturating_sub(max_files.saturating_sub(1));
for (_, path) in files.into_iter().take(to_remove) {
let _ = std::fs::remove_file(path);
}
}
async fn a2a_jsonrpc(Json(body): Json<Value>) -> impl IntoResponse {
let req: crate::core::a2a::a2a_compat::JsonRpcRequest = match serde_json::from_value(body) {
Ok(r) => r,
Err(e) => {
tracing::debug!("a2a JSON-RPC parse error: {e}");
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"jsonrpc": "2.0",
"id": null,
"error": {"code": -32700, "message": "invalid request"}
})),
);
}
};
let resp = crate::core::a2a::a2a_compat::handle_a2a_jsonrpc(&req);
let json = serde_json::to_value(resp).unwrap_or_default();
(StatusCode::OK, Json(json))
}
async fn v1_a2a_agent_card(State(state): State<AppState>) -> impl IntoResponse {
let card = crate::core::a2a::agent_card::build_agent_card(&state.project_root);
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
Json(card),
)
}
fn build_app_router(cfg: &HttpServerConfig) -> Router {
let project_root = cfg.project_root.to_string_lossy().to_string();
let service_project_root = project_root.clone();
let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
Ok(LeanCtxServer::new_shared_with_context(
&service_project_root,
"default",
"default",
))
};
let mcp_http = StreamableHttpService::new(
service_factory,
Arc::new(
rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
),
cfg.mcp_http_config(),
);
let rest_server = LeanCtxServer::new_shared_with_context(&project_root, "default", "default");
let state = AppState {
token: cfg.auth_token.clone().filter(|t| !t.is_empty()),
concurrency: Arc::new(tokio::sync::Semaphore::new(cfg.max_concurrency.max(1))),
rate: Arc::new(RateLimiter::new(cfg.max_rps, cfg.rate_burst)),
project_root,
timeout: Duration::from_millis(cfg.request_timeout_ms.max(1)),
server: rest_server,
};
Router::new()
.route("/health", get(health))
.route("/v1/shutdown", axum::routing::post(v1_shutdown))
.route("/v1/manifest", get(v1_manifest))
.route("/v1/tools", get(v1_tools))
.route("/v1/tools/call", axum::routing::post(v1_tool_call))
.route("/v1/events", get(v1_events))
.route(
"/v1/context/summary",
get(context_views::v1_context_summary),
)
.route("/v1/events/search", get(context_views::v1_events_search))
.route("/v1/events/lineage", get(context_views::v1_event_lineage))
.route("/v1/metrics", get(v1_metrics))
.route("/v1/a2a/handoff", axum::routing::post(v1_a2a_handoff))
.route("/v1/a2a/agent-card", get(v1_a2a_agent_card))
.route("/.well-known/agent.json", get(v1_a2a_agent_card))
.route("/a2a", axum::routing::post(a2a_jsonrpc))
.fallback_service(mcp_http)
.layer(axum::extract::DefaultBodyLimit::max(cfg.max_body_bytes))
.layer(middleware::from_fn_with_state(
state.clone(),
rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
concurrency_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state)
}
pub async fn serve(cfg: HttpServerConfig) -> Result<()> {
crate::core::protocol::set_mcp_context(true);
cfg.validate()?;
let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
.parse()
.context("invalid host/port")?;
let app = build_app_router(&cfg);
let listener = tokio::net::TcpListener::bind(addr)
.await
.with_context(|| format!("bind {addr}"))?;
tracing::info!(
"lean-ctx Streamable HTTP server listening on http://{addr} (project_root={})",
cfg.project_root.display()
);
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = tokio::signal::ctrl_c().await;
})
.await
.context("http server")?;
Ok(())
}
pub async fn serve_ipc(cfg: HttpServerConfig, addr: crate::ipc::DaemonAddr) -> Result<()> {
cfg.validate()?;
match addr {
#[cfg(unix)]
crate::ipc::DaemonAddr::Unix(ref path) => {
let app = build_app_router(&cfg);
let listener = crate::ipc::bind_listener(&addr)?;
tracing::info!(
"lean-ctx daemon listening on {} (project_root={})",
path.display(),
cfg.project_root.display()
);
axum::serve(listener, app.into_make_service())
.with_graceful_shutdown(async move {
let _ = tokio::signal::ctrl_c().await;
})
.await
.context("ipc server")?;
Ok(())
}
#[cfg(windows)]
crate::ipc::DaemonAddr::NamedPipe(ref _name) => {
anyhow::bail!("Named pipe server not yet supported — use TCP mode on Windows for now");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
use futures::StreamExt;
use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
use serde_json::json;
use tower::ServiceExt;
async fn read_first_sse_message(body: Body) -> String {
let mut stream = body.into_data_stream();
let mut buf: Vec<u8> = Vec::new();
for _ in 0..32 {
let next = tokio::time::timeout(Duration::from_secs(2), stream.next()).await;
let Ok(Some(Ok(bytes))) = next else {
break;
};
buf.extend_from_slice(&bytes);
if buf.windows(2).any(|w| w == b"\n\n") {
break;
}
}
String::from_utf8_lossy(&buf).to_string()
}
#[tokio::test]
async fn auth_token_blocks_requests_without_bearer_header() {
let dir = tempfile::tempdir().expect("tempdir");
let root_str = dir.path().to_string_lossy().to_string();
let service_project_root = root_str.clone();
let service_factory = move || -> Result<LeanCtxServer, std::io::Error> {
Ok(LeanCtxServer::new_shared_with_context(
&service_project_root,
"default",
"default",
))
};
let cfg = StreamableHttpServerConfig::default()
.with_stateful_mode(false)
.with_json_response(true);
let mcp_http = StreamableHttpService::new(
service_factory,
Arc::new(
rmcp::transport::streamable_http_server::session::local::LocalSessionManager::default(),
),
cfg,
);
let state = AppState {
token: Some("secret".to_string()),
concurrency: Arc::new(tokio::sync::Semaphore::new(4)),
rate: Arc::new(RateLimiter::new(50, 100)),
project_root: root_str.clone(),
timeout: Duration::from_secs(30),
server: LeanCtxServer::new_shared_with_context(&root_str, "default", "default"),
};
let app = Router::new()
.fallback_service(mcp_http)
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state);
let body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
})
.to_string();
let req = Request::builder()
.method("POST")
.uri("/")
.header("Host", "localhost")
.header("Accept", "application/json, text/event-stream")
.header("Content-Type", "application/json")
.body(Body::from(body))
.expect("request");
let resp = app.clone().oneshot(req).await.expect("resp");
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn mcp_service_factory_isolates_per_client_state() {
let dir = tempfile::tempdir().expect("tempdir");
let root_str = dir.path().to_string_lossy().to_string();
let service_project_root = root_str.clone();
let service_factory = move || -> Result<LeanCtxServer, std::convert::Infallible> {
Ok(LeanCtxServer::new_shared_with_context(
&service_project_root,
"default",
"default",
))
};
let s1 = service_factory().expect("server 1");
let s2 = service_factory().expect("server 2");
*s1.client_name.write().await = "client-a".to_string();
*s2.client_name.write().await = "client-b".to_string();
let a = s1.client_name.read().await.clone();
let b = s2.client_name.read().await.clone();
assert_eq!(a, "client-a");
assert_eq!(b, "client-b");
}
#[tokio::test]
async fn rate_limit_returns_429_when_exhausted() {
let state = AppState {
token: None,
concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
rate: Arc::new(RateLimiter::new(1, 1)),
project_root: ".".to_string(),
timeout: Duration::from_secs(30),
server: LeanCtxServer::new_shared_with_context(".", "default", "default"),
};
let app = Router::new()
.route("/limited", get(|| async { (StatusCode::OK, "ok\n") }))
.layer(middleware::from_fn_with_state(
state.clone(),
rate_limit_middleware,
))
.with_state(state);
let req1 = Request::builder()
.method("GET")
.uri("/limited")
.header("Host", "localhost")
.body(Body::empty())
.expect("req1");
let resp1 = app.clone().oneshot(req1).await.expect("resp1");
assert_eq!(resp1.status(), StatusCode::OK);
let req2 = Request::builder()
.method("GET")
.uri("/limited")
.header("Host", "localhost")
.body(Body::empty())
.expect("req2");
let resp2 = app.clone().oneshot(req2).await.expect("resp2");
assert_eq!(resp2.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn events_endpoint_replays_tool_call_event() {
use crate::core::context_os::{self, ContextEventKindV1};
let dir = tempfile::tempdir().expect("tempdir");
std::fs::create_dir_all(dir.path().join(".git")).expect("git marker");
std::fs::write(dir.path().join("a.txt"), "ok").expect("file");
let root_str = dir.path().to_string_lossy().to_string();
let state = AppState {
token: None,
concurrency: Arc::new(tokio::sync::Semaphore::new(16)),
rate: Arc::new(RateLimiter::new(50, 100)),
project_root: root_str.clone(),
timeout: Duration::from_secs(30),
server: LeanCtxServer::new_shared_with_context(&root_str, "default", "default"),
};
let app = Router::new()
.route("/v1/events", get(v1_events))
.with_state(state);
let rt = context_os::runtime();
rt.bus.append(
"ws1",
"ch1",
&ContextEventKindV1::ToolCallRecorded,
Some("test-agent"),
json!({"tool": "ctx_session", "action": "status"}),
);
let req = Request::builder()
.method("GET")
.uri("/v1/events?workspaceId=ws1&channelId=ch1&since=0&limit=1")
.header("Host", "localhost")
.header("Accept", "text/event-stream")
.body(Body::empty())
.expect("req");
let resp = app.clone().oneshot(req).await.expect("events");
assert_eq!(resp.status(), StatusCode::OK);
let msg = read_first_sse_message(resp.into_body()).await;
assert!(msg.contains("event: tool_call_recorded"), "msg={msg:?}");
assert!(msg.contains("\"ws1\""), "msg={msg:?}");
assert!(msg.contains("\"ch1\""), "msg={msg:?}");
}
}