use crate::activity::Activity;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{Mutex, mpsc};
pub const PROTOCOL_VERSION: &str = "2024-11-05";
#[derive(Debug, Clone, Serialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
#[derive(Debug, Clone, Serialize)]
pub struct ResourceDef {
pub uri: String,
pub name: String,
pub description: String,
#[serde(rename = "mimeType")]
pub mime_type: String,
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub name: String,
pub version: String,
pub instructions: String,
pub tools: Vec<ToolDef>,
pub resources: Vec<ResourceDef>,
}
#[derive(Debug)]
pub struct ResourceContent {
pub uri: String,
pub mime_type: String,
pub text: String,
}
pub type ResourceError = ToolError;
#[derive(Debug)]
pub enum ToolOutcome {
Text(String),
Json(Value),
}
#[derive(Debug)]
pub struct ToolError {
pub code: i32,
pub message: String,
pub data: Option<Value>,
}
impl ToolError {
pub fn invalid_params(msg: impl Into<String>) -> Self {
Self {
code: INVALID_PARAMS,
message: msg.into(),
data: None,
}
}
pub fn internal(msg: impl Into<String>) -> Self {
Self {
code: INTERNAL_ERROR,
message: msg.into(),
data: None,
}
}
pub fn with_data(mut self, data: Value) -> Self {
self.data = Some(data);
self
}
}
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
#[derive(Clone)]
pub struct Notifier {
out_tx: mpsc::UnboundedSender<OutboundMessage>,
}
impl Notifier {
pub fn channel(&self, content: impl Into<String>, meta: Value) {
let coerced = match meta {
Value::Object(map) => {
let stringified: serde_json::Map<String, Value> = map
.into_iter()
.map(|(k, v)| {
let s = match v {
Value::String(s) => s,
other => other.to_string(),
};
(k, Value::String(s))
})
.collect();
Value::Object(stringified)
}
_ => Value::Object(serde_json::Map::new()),
};
let kind = coerced
.get("kind")
.and_then(|v| v.as_str())
.unwrap_or("?")
.to_string();
let params = serde_json::json!({
"content": content.into(),
"meta": coerced,
});
let serialized_len = params
.get("content")
.and_then(|v| v.as_str())
.map(|s| s.len())
.unwrap_or(0);
log::info!(
"[trace-pushpipe] site=Notifier::channel kind={kind} content_len={serialized_len}",
);
self.send_raw("notifications/claude/channel", params);
}
pub fn log(&self, level: &str, data: Value) {
let params = serde_json::json!({
"level": level,
"logger": "marshal",
"data": data,
});
self.send_raw("notifications/message", params);
}
pub fn send_raw(&self, method: &str, params: Value) {
let _ = self.out_tx.send(OutboundMessage::Notification {
method: method.to_string(),
params,
});
}
}
#[derive(Debug, Deserialize)]
struct JsonRpcRequest {
#[serde(default)]
#[allow(dead_code)]
jsonrpc: Option<String>,
#[serde(default)]
id: Option<Value>,
method: String,
#[serde(default)]
params: Value,
}
#[derive(Debug, Serialize)]
struct JsonRpcSuccess<'a> {
jsonrpc: &'static str,
id: &'a Value,
result: Value,
}
#[derive(Debug, Serialize)]
struct JsonRpcErrorReply<'a> {
jsonrpc: &'static str,
id: &'a Value,
error: JsonRpcErrorBody,
}
#[derive(Debug, Serialize)]
struct JsonRpcErrorBody {
code: i32,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
}
#[derive(Debug, Serialize)]
struct JsonRpcNotification {
jsonrpc: &'static str,
method: String,
params: Value,
}
#[derive(Debug)]
enum OutboundMessage {
Reply { id: Value, result: Value },
Error { id: Value, body: JsonRpcErrorBody },
Notification { method: String, params: Value },
}
pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutcome, ToolError>> + Send + 'a>>;
pub type ResourceFuture<'a> =
Pin<Box<dyn Future<Output = Result<ResourceContent, ResourceError>> + Send + 'a>>;
pub trait Handler: Send + Sync + 'static {
fn call_tool<'a>(
&'a self,
name: &'a str,
args: &'a Value,
notifier: &'a Notifier,
) -> ToolFuture<'a>;
fn read_resource<'a>(&'a self, uri: &'a str) -> ResourceFuture<'a> {
let uri = uri.to_string();
Box::pin(async move {
Err(ResourceError {
code: METHOD_NOT_FOUND,
message: format!("no resource at '{uri}'"),
data: None,
})
})
}
}
pub use Handler as ToolHandler;
pub async fn serve_stdio<H, F>(
config: ServerConfig,
handler: Arc<H>,
activity: Arc<Activity>,
on_ready: F,
) -> Result<()>
where
H: ToolHandler,
F: FnOnce(Notifier) + Send + 'static,
{
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
serve(config, handler, activity, on_ready, stdin, stdout).await
}
pub async fn serve<H, F, R, W>(
config: ServerConfig,
handler: Arc<H>,
activity: Arc<Activity>,
on_ready: F,
reader: R,
writer: W,
) -> Result<()>
where
H: ToolHandler,
F: FnOnce(Notifier) + Send + 'static,
R: tokio::io::AsyncRead + Unpin + Send + 'static,
W: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (out_tx, mut out_rx) = mpsc::unbounded_channel::<OutboundMessage>();
let notifier = Notifier { out_tx };
let writer = Arc::new(Mutex::new(writer));
let writer_task = {
let writer = Arc::clone(&writer);
tokio::spawn(async move {
while let Some(msg) = out_rx.recv().await {
let line = match &msg {
OutboundMessage::Reply { id, result } => {
let r = JsonRpcSuccess {
jsonrpc: "2.0",
id,
result: result.clone(),
};
match serde_json::to_string(&r) {
Ok(s) => s,
Err(e) => {
log::warn!("serialize reply: {e}");
continue;
}
}
}
OutboundMessage::Error { id, body } => {
let r = JsonRpcErrorReply {
jsonrpc: "2.0",
id,
error: JsonRpcErrorBody {
code: body.code,
message: body.message.clone(),
data: body.data.clone(),
},
};
match serde_json::to_string(&r) {
Ok(s) => s,
Err(e) => {
log::warn!("serialize error: {e}");
continue;
}
}
}
OutboundMessage::Notification { method, params } => {
let n = JsonRpcNotification {
jsonrpc: "2.0",
method: method.clone(),
params: params.clone(),
};
match serde_json::to_string(&n) {
Ok(s) => s,
Err(e) => {
log::warn!("serialize notification: {e}");
continue;
}
}
}
};
let is_notification = matches!(&msg, OutboundMessage::Notification { .. });
let notification_kind = match &msg {
OutboundMessage::Notification { method, params } => {
let kind = params
.get("meta")
.and_then(|m| m.get("kind"))
.and_then(|v| v.as_str())
.unwrap_or("?");
Some((method.clone(), kind.to_string()))
}
_ => None,
};
let line_bytes = line.len();
let mut w = writer.lock().await;
if w.write_all(line.as_bytes()).await.is_err() {
break;
}
if w.write_all(b"\n").await.is_err() {
break;
}
if w.flush().await.is_err() {
break;
}
if is_notification && let Some((method, kind)) = notification_kind {
log::info!(
"[trace-pushpipe] site=writer_task wrote method={method} kind={kind} bytes={line_bytes}",
);
}
}
})
};
on_ready(notifier.clone());
let mut reader = BufReader::new(reader);
let mut line = String::new();
loop {
line.clear();
let n = match reader.read_line(&mut line).await {
Ok(n) => n,
Err(e) => {
log::warn!("stdin read error: {e}");
break;
}
};
if n == 0 {
break; }
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let req: JsonRpcRequest = match serde_json::from_str(trimmed) {
Ok(r) => r,
Err(e) => {
log::warn!("malformed json-rpc: {e} | line: {trimmed}");
continue;
}
};
match req.id {
Some(id) => dispatch_request(
req.method,
req.params,
id,
&config,
Arc::clone(&handler),
Arc::clone(&activity),
notifier.clone(),
),
None => {
log::debug!("ignoring notification: {}", req.method);
}
}
}
drop(notifier); let _ = writer_task.await;
Ok(())
}
fn dispatch_request<H>(
method: String,
params: Value,
id: Value,
config: &ServerConfig,
handler: Arc<H>,
activity: Arc<Activity>,
notifier: Notifier,
) where
H: ToolHandler,
{
activity.bump();
match method.as_str() {
"initialize" => {
let result = serde_json::json!({
"protocolVersion": PROTOCOL_VERSION,
"capabilities": {
"tools": {},
"resources": { "subscribe": false, "listChanged": false },
"logging": {},
"experimental": {
"claude/channel": {}
}
},
"serverInfo": {
"name": config.name,
"version": config.version,
},
"instructions": config.instructions,
});
let _ = notifier.out_tx.send(OutboundMessage::Reply { id, result });
}
"ping" => {
let _ = notifier.out_tx.send(OutboundMessage::Reply {
id,
result: serde_json::json!({}),
});
}
"tools/list" => {
let result = serde_json::json!({ "tools": &config.tools });
let _ = notifier.out_tx.send(OutboundMessage::Reply { id, result });
}
"resources/list" => {
let result = serde_json::json!({ "resources": &config.resources });
let _ = notifier.out_tx.send(OutboundMessage::Reply { id, result });
}
"resources/read" => {
activity.start();
tokio::spawn(async move {
let uri = params
.get("uri")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let result = handler.read_resource(&uri).await;
let msg = match result {
Ok(content) => OutboundMessage::Reply {
id,
result: serde_json::json!({
"contents": [{
"uri": content.uri,
"mimeType": content.mime_type,
"text": content.text,
}]
}),
},
Err(e) => OutboundMessage::Error {
id,
body: JsonRpcErrorBody {
code: e.code,
message: e.message,
data: e.data,
},
},
};
let _ = notifier.out_tx.send(msg);
activity.end();
});
}
"tools/call" => {
activity.start();
tokio::spawn(async move {
let name = params
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let args = params.get("arguments").cloned().unwrap_or(Value::Null);
if !name.is_empty() {
activity.record_tool(&name);
}
let result = handler.call_tool(&name, &args, ¬ifier).await;
let msg = match result {
Ok(outcome) => OutboundMessage::Reply {
id,
result: tool_outcome_to_result(outcome),
},
Err(e) => OutboundMessage::Error {
id,
body: JsonRpcErrorBody {
code: e.code,
message: e.message,
data: e.data,
},
},
};
let _ = notifier.out_tx.send(msg);
activity.end();
});
}
other => {
let body = JsonRpcErrorBody {
code: METHOD_NOT_FOUND,
message: format!("method '{other}' not implemented"),
data: None,
};
let _ = notifier.out_tx.send(OutboundMessage::Error { id, body });
}
}
}
fn tool_outcome_to_result(outcome: ToolOutcome) -> Value {
let text = match outcome {
ToolOutcome::Text(s) => s,
ToolOutcome::Json(v) => serde_json::to_string_pretty(&v).unwrap_or_else(|_| v.to_string()),
};
serde_json::json!({
"content": [
{ "type": "text", "text": text }
],
"isError": false,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
struct EchoHandler;
impl ToolHandler for EchoHandler {
fn call_tool<'a>(
&'a self,
name: &'a str,
args: &'a Value,
_notifier: &'a Notifier,
) -> ToolFuture<'a> {
Box::pin(async move {
Ok(ToolOutcome::Json(serde_json::json!({
"tool": name,
"args": args,
})))
})
}
}
fn make_config() -> ServerConfig {
ServerConfig {
name: "test".into(),
version: "0.0.0".into(),
instructions: "test instructions".into(),
tools: vec![ToolDef {
name: "echo".into(),
description: "echo".into(),
input_schema: serde_json::json!({"type": "object"}),
}],
resources: vec![],
}
}
async fn read_line(reader: &mut (impl tokio::io::AsyncRead + Unpin)) -> String {
let mut line = String::new();
let mut buf = BufReader::new(reader);
buf.read_line(&mut line).await.unwrap();
line
}
#[tokio::test]
async fn initialize_returns_capabilities() {
let (client_w, server_r) = duplex(64 * 1024);
let (server_w, client_r) = duplex(64 * 1024);
let server = tokio::spawn(serve(
make_config(),
Arc::new(EchoHandler),
Arc::new(Activity::new()),
|_| {},
server_r,
server_w,
));
let mut client_w = client_w;
let req = serde_json::json!({
"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}
});
client_w
.write_all(format!("{}\n", req).as_bytes())
.await
.unwrap();
let mut client_r = client_r;
let line = read_line(&mut client_r).await;
let resp: Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(resp["id"], 1);
assert_eq!(resp["result"]["protocolVersion"], PROTOCOL_VERSION);
assert_eq!(resp["result"]["serverInfo"]["name"], "test");
assert!(resp["result"]["capabilities"]["experimental"]["claude/channel"].is_object());
drop(client_w); let _ = server.await;
}
#[tokio::test]
async fn tools_list_returns_definitions() {
let (client_w, server_r) = duplex(64 * 1024);
let (server_w, client_r) = duplex(64 * 1024);
let server = tokio::spawn(serve(
make_config(),
Arc::new(EchoHandler),
Arc::new(Activity::new()),
|_| {},
server_r,
server_w,
));
let mut client_w = client_w;
client_w
.write_all(b"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"tools/list\"}\n")
.await
.unwrap();
let mut client_r = client_r;
let line = read_line(&mut client_r).await;
let resp: Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(resp["id"], 2);
assert_eq!(resp["result"]["tools"][0]["name"], "echo");
drop(client_w);
let _ = server.await;
}
#[tokio::test]
async fn tools_call_dispatches_and_serializes() {
let (client_w, server_r) = duplex(64 * 1024);
let (server_w, client_r) = duplex(64 * 1024);
let server = tokio::spawn(serve(
make_config(),
Arc::new(EchoHandler),
Arc::new(Activity::new()),
|_| {},
server_r,
server_w,
));
let mut client_w = client_w;
let req = serde_json::json!({
"jsonrpc": "2.0", "id": 3, "method": "tools/call",
"params": { "name": "echo", "arguments": {"hello": "world"} }
});
client_w
.write_all(format!("{}\n", req).as_bytes())
.await
.unwrap();
let mut client_r = client_r;
let line = read_line(&mut client_r).await;
let resp: Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(resp["id"], 3);
let text = resp["result"]["content"][0]["text"]
.as_str()
.unwrap()
.to_string();
let inner: Value = serde_json::from_str(&text).unwrap();
assert_eq!(inner["tool"], "echo");
assert_eq!(inner["args"]["hello"], "world");
drop(client_w);
let _ = server.await;
}
#[tokio::test]
async fn on_ready_callback_fires_before_initialize_handshake() {
let (client_w, server_r) = duplex(64 * 1024);
let (server_w, client_r) = duplex(64 * 1024);
let server = tokio::spawn(serve(
make_config(),
Arc::new(EchoHandler),
Arc::new(Activity::new()),
|notifier| {
notifier.channel("hello", serde_json::json!({"source": "test"}));
},
server_r,
server_w,
));
let mut client_r = client_r;
let line = read_line(&mut client_r).await;
let n: Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(n["method"], "notifications/claude/channel");
assert_eq!(n["params"]["content"], "hello");
assert_eq!(n["params"]["meta"]["source"], "test");
drop(client_w);
let _ = server.await;
}
}