use std::collections::{HashMap, HashSet};
use std::io::{self, BufRead, Write};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use serde_json::json;
#[cfg(test)]
use serde_json::Value;
use tracing::{debug, error, info, warn};
use crate::mcp::protocol::{
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, RequestId, RpcError, TaskInfo,
ToolCallResult,
};
use crate::mcp::security::SecurityGuard;
use crate::mcp::tools::AppRegistry;
static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
pub(crate) fn next_task_id() -> String {
let n = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("task-{n:016}")
}
#[derive(Debug, PartialEq, Eq)]
pub(super) enum Phase {
Uninitialized,
Initializing,
Running,
}
pub(crate) struct WorkflowState {
pub steps: Vec<crate::durable_steps::DurableStep>,
pub current_step: usize,
pub results: Vec<crate::durable_steps::WorkflowResult>,
pub completed: bool,
}
pub(crate) struct TaskEntry {
pub info: TaskInfo,
pub result: Option<ToolCallResult>,
}
pub(super) struct Server {
pub(super) registry: Arc<AppRegistry>,
pub(super) phase: Phase,
pub(super) workflows: Arc<Mutex<HashMap<String, WorkflowState>>>,
pub(crate) subscriptions: Arc<Mutex<HashSet<String>>>,
pub(crate) tasks: Arc<Mutex<HashMap<String, TaskEntry>>>,
pub(super) security: SecurityGuard,
pub(crate) client_supports_sampling: bool,
#[cfg(feature = "watch")]
pub(super) watch_state: Arc<crate::mcp::tools_watch::WatchState>,
}
impl Server {
pub(super) fn new() -> Self {
Self {
registry: Arc::new(AppRegistry::default()),
phase: Phase::Uninitialized,
workflows: Arc::new(Mutex::new(HashMap::new())),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
tasks: Arc::new(Mutex::new(HashMap::new())),
security: SecurityGuard::new(),
client_supports_sampling: false,
#[cfg(feature = "watch")]
watch_state: Arc::new(crate::mcp::tools_watch::WatchState::new()),
}
}
pub(super) fn handle<W: Write>(
&mut self,
msg: &JsonRpcRequest,
out: &mut W,
) -> Option<JsonRpcResponse> {
debug!(method = %msg.method, "incoming message");
if msg.id.is_none() {
self.handle_notification(msg);
return None;
}
let id = match msg.id.clone() {
Some(id) => id,
None => {
return Some(JsonRpcResponse::err(
RequestId::Number(0),
RpcError::new(RpcError::INVALID_REQUEST, "Missing request id".to_string()),
));
}
};
match msg.method.as_str() {
"initialize" => Some(self.handle_initialize(id, msg.params.as_ref())),
"ping" => Some(Self::handle_ping(id)),
"tools/list" if self.phase == Phase::Running => Some(self.handle_tools_list(id)),
"tools/call" if self.phase == Phase::Running => {
Some(self.handle_tools_call(id, msg.params.as_ref(), out))
}
"resources/list" if self.phase == Phase::Running => {
Some(Self::handle_resources_list(id))
}
"resources/templates/list" if self.phase == Phase::Running => {
Some(Self::handle_resources_templates_list(id))
}
"resources/read" if self.phase == Phase::Running => {
Some(self.handle_resources_read(id, msg.params.as_ref()))
}
"resources/subscribe" if self.phase == Phase::Running => {
Some(self.handle_resources_subscribe(id, msg.params.as_ref()))
}
"resources/unsubscribe" if self.phase == Phase::Running => {
Some(self.handle_resources_unsubscribe(id, msg.params.as_ref()))
}
"prompts/list" if self.phase == Phase::Running => Some(Self::handle_prompts_list(id)),
"prompts/get" if self.phase == Phase::Running => {
Some(Self::handle_prompts_get(id, msg.params.as_ref()))
}
"tasks/list" if self.phase == Phase::Running => Some(self.handle_tasks_list(id)),
"tasks/result" if self.phase == Phase::Running => {
Some(self.handle_tasks_result(id, msg.params.as_ref()))
}
"tasks/cancel" if self.phase == Phase::Running => {
Some(self.handle_tasks_cancel(id, msg.params.as_ref()))
}
method if self.phase != Phase::Running => {
warn!(method, "request before initialized");
Some(JsonRpcResponse::err(
id,
RpcError::new(RpcError::INVALID_REQUEST, "Server not yet initialized"),
))
}
method => {
warn!(method, "method not found");
Some(JsonRpcResponse::err(
id,
RpcError::new(
RpcError::METHOD_NOT_FOUND,
format!("Method not found: {method}"),
),
))
}
}
}
pub(super) fn handle_notification(&mut self, msg: &JsonRpcRequest) {
match msg.method.as_str() {
"notifications/initialized" => {
if self.phase == Phase::Initializing {
self.phase = Phase::Running;
info!("MCP server ready");
}
}
method => debug!(method, "unhandled notification"),
}
}
}
pub struct ServerHandle(Server);
impl ServerHandle {
#[must_use]
pub fn new() -> Self {
Self(Server::new())
}
pub fn handle<W: Write>(
&mut self,
msg: &JsonRpcRequest,
out: &mut W,
) -> Option<JsonRpcResponse> {
self.0.handle(msg, out)
}
}
impl Default for ServerHandle {
fn default() -> Self {
Self::new()
}
}
pub fn run_stdio() -> anyhow::Result<()> {
info!("axterminator MCP server starting (stdio)");
let stdin = io::stdin();
let stdout = io::stdout();
let mut stdout_lock = stdout.lock();
let mut server = Server::new();
#[cfg(feature = "watch")]
let mut watch_event_rx: Option<tokio::sync::mpsc::Receiver<crate::watch::WatchEvent>> = None;
for line in stdin.lock().lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
debug!(bytes = line.len(), "received line");
let msg: JsonRpcRequest = match serde_json::from_str(&line) {
Ok(m) => m,
Err(e) => {
error!(error = %e, "parse error");
let resp = JsonRpcResponse::err(
RequestId::Number(0),
RpcError::new(RpcError::PARSE_ERROR, format!("Parse error: {e}")),
);
write_response(&mut stdout_lock, &resp)?;
continue;
}
};
#[cfg(feature = "watch")]
drain_watch_events(&mut watch_event_rx, &mut stdout_lock);
if let Some(resp) = server.handle(&msg, &mut stdout_lock) {
#[cfg(feature = "watch")]
maybe_capture_watch_receiver(&server, &mut watch_event_rx, &msg.method);
write_response(&mut stdout_lock, &resp)?;
}
#[cfg(feature = "watch")]
drain_watch_events(&mut watch_event_rx, &mut stdout_lock);
}
info!("stdin closed, shutting down");
Ok(())
}
#[cfg(feature = "watch")]
fn drain_watch_events(
rx: &mut Option<tokio::sync::mpsc::Receiver<crate::watch::WatchEvent>>,
out: &mut impl io::Write,
) {
use crate::mcp::watch_channel::{emit_channel_notification, event_to_channel_notification};
let Some(receiver) = rx else { return };
while let Ok(event) = receiver.try_recv() {
if let Some(params) = event_to_channel_notification(&event) {
let _ = emit_channel_notification(out, params);
}
}
}
#[cfg(feature = "watch")]
fn maybe_capture_watch_receiver(
server: &Server,
rx: &mut Option<tokio::sync::mpsc::Receiver<crate::watch::WatchEvent>>,
method: &str,
) {
if method != "tools/call" {
return;
}
if let Some(new_rx) = server.watch_state.take_pending_receiver() {
*rx = Some(new_rx);
}
}
fn write_response(out: &mut impl Write, resp: &JsonRpcResponse) -> io::Result<()> {
let json = serde_json::to_string(resp).expect("response serialization cannot fail");
debug!(bytes = json.len(), id = ?resp.id, "sending response");
writeln!(out, "{json}")?;
out.flush()
}
pub fn emit_log(out: &mut impl Write, level: &str, message: &str) -> io::Result<()> {
let notif = JsonRpcNotification {
jsonrpc: "2.0",
method: "notifications/message",
params: json!({ "level": level, "data": message }),
};
let json = serde_json::to_string(¬if).expect("notification serialization cannot fail");
writeln!(out, "{json}")?;
out.flush()
}
pub fn notify_resource_changed(out: &mut impl Write, uri: &str) {
let notif = JsonRpcNotification {
jsonrpc: "2.0",
method: "notifications/resources/updated",
params: json!({ "uri": uri }),
};
let json = serde_json::to_string(¬if).expect("notification serialization cannot fail");
let _ = writeln!(out, "{json}");
let _ = out.flush();
}
#[cfg(test)]
#[path = "server_tests.rs"]
mod tests;