Skip to main content

talon_cli/mcp/
protocol.rs

1use serde::{Deserialize, Serialize};
2use serde_json::{Value, json};
3use std::panic::{AssertUnwindSafe, catch_unwind};
4
5use super::tool;
6
7pub const JSONRPC_VERSION: &str = "2.0";
8
9const PARSE_ERROR: i32 = -32700;
10const INVALID_REQUEST: i32 = -32600;
11const METHOD_NOT_FOUND: i32 = -32601;
12
13#[derive(Debug, Deserialize)]
14pub struct JsonRpcRequest {
15    pub jsonrpc: String,
16    #[serde(default)]
17    pub id: Option<Value>,
18    pub method: String,
19    #[serde(default)]
20    pub params: Option<Value>,
21}
22
23impl JsonRpcRequest {
24    #[must_use]
25    pub const fn is_notification(&self) -> bool {
26        self.id.is_none()
27    }
28}
29
30#[derive(Debug, Serialize)]
31pub struct JsonRpcResponse {
32    pub jsonrpc: &'static str,
33    pub id: Value,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub result: Option<Value>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub error: Option<JsonRpcError>,
38}
39
40#[derive(Debug, Serialize)]
41pub struct JsonRpcError {
42    pub code: i32,
43    pub message: String,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub data: Option<Value>,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum MethodDisposition {
50    Continue,
51    Shutdown,
52}
53
54#[must_use]
55pub fn parse_error(data: Value) -> JsonRpcResponse {
56    error_response(Value::Null, PARSE_ERROR, "parse error", Some(data))
57}
58
59#[must_use]
60pub fn handle_request(request: JsonRpcRequest) -> (Option<JsonRpcResponse>, MethodDisposition) {
61    if request.jsonrpc != JSONRPC_VERSION {
62        let id = request.id.unwrap_or(Value::Null);
63        return (
64            Some(error_response(id, INVALID_REQUEST, "invalid request", None)),
65            MethodDisposition::Continue,
66        );
67    }
68
69    match request.method.as_str() {
70        "initialize" => {
71            respond_to_request(request, initialize_result(), MethodDisposition::Continue)
72        }
73        "notifications/initialized" | "initialized" => (None, MethodDisposition::Continue),
74        "tools/list" => respond_to_request(
75            request,
76            tool::tools_list_result(),
77            MethodDisposition::Continue,
78        ),
79        "tools/call" => {
80            let params = request.params.clone();
81            respond_to_request(
82                request,
83                tool::tools_call_result(params),
84                MethodDisposition::Continue,
85            )
86        }
87        "shutdown" => respond_to_request(request, Value::Null, MethodDisposition::Shutdown),
88        _ => {
89            if request.is_notification() {
90                (None, MethodDisposition::Continue)
91            } else {
92                let id = request.id.unwrap_or(Value::Null);
93                (
94                    Some(error_response(
95                        id,
96                        METHOD_NOT_FOUND,
97                        "method not found",
98                        None,
99                    )),
100                    MethodDisposition::Continue,
101                )
102            }
103        }
104    }
105}
106
107/// State-aware variant of [`handle_request`].
108///
109/// For `tools/call` requests, delegates to
110/// [`tool::tools_call_result_with_state`] so that hook tools can access
111/// session state.  All other methods behave identically to [`handle_request`].
112#[must_use]
113pub fn handle_request_with_state(
114    request: JsonRpcRequest,
115    state: &std::sync::Arc<crate::mcp::state::McpServerState>,
116) -> (Option<JsonRpcResponse>, MethodDisposition) {
117    if request.jsonrpc != JSONRPC_VERSION {
118        let id = request.id.unwrap_or(Value::Null);
119        return (
120            Some(error_response(id, INVALID_REQUEST, "invalid request", None)),
121            MethodDisposition::Continue,
122        );
123    }
124
125    match request.method.as_str() {
126        "initialize" => {
127            respond_to_request(request, initialize_result(), MethodDisposition::Continue)
128        }
129        "notifications/initialized" | "initialized" => (None, MethodDisposition::Continue),
130        "tools/list" => respond_to_request(
131            request,
132            tool::tools_list_result(),
133            MethodDisposition::Continue,
134        ),
135        "tools/call" => {
136            let params = request.params.clone();
137            let result = catch_unwind(AssertUnwindSafe(|| {
138                tool::tools_call_result_with_state(params, state)
139            }))
140            .unwrap_or_else(|payload| {
141                crate::mcp::diagnostics::record_caught_panic("tools/call", payload.as_ref());
142                tool::panic_tool_result()
143            });
144            respond_to_request(request, result, MethodDisposition::Continue)
145        }
146        "shutdown" => respond_to_request(request, Value::Null, MethodDisposition::Shutdown),
147        _ => {
148            if request.is_notification() {
149                (None, MethodDisposition::Continue)
150            } else {
151                let id = request.id.unwrap_or(Value::Null);
152                (
153                    Some(error_response(
154                        id,
155                        METHOD_NOT_FOUND,
156                        "method not found",
157                        None,
158                    )),
159                    MethodDisposition::Continue,
160                )
161            }
162        }
163    }
164}
165
166fn respond_to_request(
167    request: JsonRpcRequest,
168    result: Value,
169    disposition: MethodDisposition,
170) -> (Option<JsonRpcResponse>, MethodDisposition) {
171    if request.is_notification() {
172        (None, disposition)
173    } else {
174        (
175            Some(JsonRpcResponse {
176                jsonrpc: JSONRPC_VERSION,
177                id: request.id.unwrap_or(Value::Null),
178                result: Some(result),
179                error: None,
180            }),
181            disposition,
182        )
183    }
184}
185
186fn error_response(id: Value, code: i32, message: &str, data: Option<Value>) -> JsonRpcResponse {
187    JsonRpcResponse {
188        jsonrpc: JSONRPC_VERSION,
189        id,
190        result: None,
191        error: Some(JsonRpcError {
192            code,
193            message: message.to_owned(),
194            data,
195        }),
196    }
197}
198
199fn initialize_result() -> Value {
200    json!({
201        "protocolVersion": "2024-11-05",
202        "capabilities": {
203            "tools": {}
204        },
205        "serverInfo": {
206            "name": "talon",
207            "version": env!("CARGO_PKG_VERSION")
208        }
209    })
210}
211
212#[cfg(test)]
213mod tests {
214    use super::{JsonRpcRequest, MethodDisposition, handle_request};
215    use color_eyre::eyre::Result;
216    use serde_json::{Value, json};
217
218    #[test]
219    fn handle_request_returns_initialize_response_when_request_has_id() -> Result<()> {
220        let request: JsonRpcRequest = serde_json::from_value(json!({
221            "jsonrpc": "2.0",
222            "id": 1,
223            "method": "initialize",
224            "params": {}
225        }))?;
226
227        let (response, disposition) = handle_request(request);
228        let response = serde_json::to_value(response)?;
229
230        assert_eq!(disposition, MethodDisposition::Continue);
231        assert_eq!(response["result"]["serverInfo"]["name"], "talon");
232        assert_eq!(response["id"], 1);
233        Ok(())
234    }
235
236    #[test]
237    fn handle_request_suppresses_response_for_initialized_notification() -> Result<()> {
238        let request: JsonRpcRequest = serde_json::from_value(json!({
239            "jsonrpc": "2.0",
240            "method": "notifications/initialized"
241        }))?;
242
243        let (response, disposition) = handle_request(request);
244
245        assert!(response.is_none());
246        assert_eq!(disposition, MethodDisposition::Continue);
247        Ok(())
248    }
249
250    #[test]
251    fn handle_request_marks_shutdown_after_response() -> Result<()> {
252        let request: JsonRpcRequest = serde_json::from_value(json!({
253            "jsonrpc": "2.0",
254            "id": "stop",
255            "method": "shutdown"
256        }))?;
257
258        let (response, disposition) = handle_request(request);
259        let response = serde_json::to_value(response)?;
260
261        assert_eq!(disposition, MethodDisposition::Shutdown);
262        assert_eq!(response["id"], Value::String("stop".to_owned()));
263        Ok(())
264    }
265
266    #[test]
267    fn handle_request_rejects_generic_talon_tool_call() -> Result<()> {
268        let request: JsonRpcRequest = serde_json::from_value(json!({
269            "jsonrpc": "2.0",
270            "id": "call",
271            "method": "tools/call",
272            "params": {
273                "name": "talon",
274                "arguments": { "action": "status" }
275            }
276        }))?;
277
278        let (response, disposition) = handle_request(request);
279        let response = serde_json::to_value(response)?;
280
281        assert_eq!(disposition, MethodDisposition::Continue);
282        assert_eq!(response["id"], "call");
283        assert_eq!(response["result"]["structuredContent"]["action"], "talon");
284        assert_eq!(response["result"]["structuredContent"]["ok"], false);
285        assert_eq!(response["result"]["isError"], true);
286        Ok(())
287    }
288}