use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use axum::body::Bytes;
use axum::extract::{Query, State};
use axum::http::HeaderMap;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::Router;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt;
use crate::mcp::protocol::{
JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR,
};
use crate::mcp::tools;
use crate::runtime_events;
pub struct DaemonState {
tools: Arc<[crate::mcp::protocol::Tool]>,
default_profile: Option<String>,
started_at: Instant,
}
pub async fn serve(
host: &str,
port: u16,
default_profile: Option<String>,
cors: bool,
) -> Result<()> {
let state = Arc::new(DaemonState {
tools: tools::build_tools().into(),
default_profile,
started_at: Instant::now(),
});
let mut app = Router::new()
.route("/mcp", post(handle_mcp))
.route("/events", get(handle_events))
.route("/health", get(handle_health))
.with_state(state);
if cors {
use tower_http::cors::CorsLayer;
app = app.layer(CorsLayer::permissive());
}
let listener = tokio::net::TcpListener::bind(format!("{host}:{port}")).await?;
tracing::info!("CCD daemon listening on {host}:{port}");
axum::serve(listener, app).await?;
Ok(())
}
async fn handle_mcp(
State(state): State<Arc<DaemonState>>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let value: Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => {
return axum::Json(JsonRpcResponse::error(
Value::Null,
PARSE_ERROR,
format!("parse error: {e}"),
))
}
};
let raw_id = value.get("id").cloned();
let request: JsonRpcRequest = match serde_json::from_value(value) {
Ok(r) => r,
Err(e) => {
return axum::Json(JsonRpcResponse::error(
raw_id.unwrap_or(Value::Null),
INVALID_REQUEST,
format!("invalid request: {e}"),
))
}
};
if raw_id.is_none() {
return axum::Json(JsonRpcResponse::success(
Value::Null,
Value::Object(Default::default()),
));
}
let id = raw_id.unwrap_or(Value::Null);
if request.jsonrpc != "2.0" {
return axum::Json(JsonRpcResponse::error(
id,
INVALID_REQUEST,
format!("unsupported jsonrpc version: {:?}", request.jsonrpc),
));
}
let profile = headers
.get("x-ccd-profile")
.and_then(|v| v.to_str().ok())
.map(String::from)
.or_else(|| state.default_profile.clone());
let mut params = request.params.clone();
if let Some(ref profile_val) = profile {
if let Some(obj) = params.as_object_mut() {
if request.method == "tools/call" {
if let Some(args) = obj.get_mut("arguments").and_then(|a| a.as_object_mut()) {
if !args.contains_key("profile") {
args.insert("profile".to_owned(), Value::String(profile_val.clone()));
}
}
}
if !obj.contains_key("profile") {
obj.insert("profile".to_owned(), Value::String(profile_val.clone()));
}
}
}
let method = request.method.clone();
let tools = state.tools.clone();
let response = tokio::task::spawn_blocking(move || {
if let Some(ref p) = profile {
std::env::set_var("CCD_PROFILE", p);
}
super::handle_request(&method, ¶ms, id, &tools)
})
.await
.unwrap_or_else(|e| {
JsonRpcResponse::error(
Value::Null,
INTERNAL_ERROR,
format!("dispatch panicked: {e}"),
)
});
axum::Json(response)
}
#[derive(Deserialize, Default)]
struct EventFilter {
profile: Option<String>,
family: Option<String>,
}
async fn handle_events(
State(state): State<Arc<DaemonState>>,
Query(params): Query<EventFilter>,
) -> Sse<impl tokio_stream::Stream<Item = Result<Event, std::convert::Infallible>>> {
let _ = state;
let stream = BroadcastStream::new(runtime_events::subscribe()).filter_map(move |event| {
let Ok(event) = event else {
return None;
};
if let Some(ref profile) = params.profile {
if &event.profile != profile {
return None;
}
}
if let Some(ref family) = params.family {
if &event.family != family {
return None;
}
}
let sse = Event::default()
.id(event.event_id.clone())
.event(&event.family)
.json_data(event.as_ref())
.ok()?;
Some(Ok(sse))
});
Sse::new(stream).keep_alive(KeepAlive::default())
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
uptime_s: u64,
default_profile: Option<String>,
version: &'static str,
}
async fn handle_health(State(state): State<Arc<DaemonState>>) -> impl IntoResponse {
axum::Json(HealthResponse {
status: "healthy",
uptime_s: state.started_at.elapsed().as_secs(),
default_profile: state.default_profile.clone(),
version: env!("CARGO_PKG_VERSION"),
})
}