Skip to main content

aprender_mcp/
types.rs

1//! JSON-RPC 2.0 + MCP protocol types.
2//!
3//! Mirrors `aprender-orchestrate::mcp::types` so downstream tools can depend on
4//! `aprender-mcp` without pulling the entire batuta dependency tree.
5
6#![allow(clippy::disallowed_methods)] // serde_json::json! macro expands to .unwrap() internally
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// JSON-RPC 2.0 request envelope.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct JsonRpcRequest {
14    pub jsonrpc: String,
15    pub id: Option<serde_json::Value>,
16    pub method: String,
17    #[serde(default)]
18    pub params: serde_json::Value,
19}
20
21/// JSON-RPC 2.0 response envelope.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct JsonRpcResponse {
24    pub jsonrpc: String,
25    pub id: Option<serde_json::Value>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub result: Option<serde_json::Value>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub error: Option<JsonRpcError>,
30}
31
32/// JSON-RPC 2.0 notification envelope — server → client, no `id` field.
33///
34/// Used for MCP `notifications/progress` per the 2024-11-05 spec:
35/// <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/progress/>.
36/// A notification MUST NOT carry an `id`; that's the serialization rule the
37/// peer uses to route it as a fire-and-forget message.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct JsonRpcNotification {
40    pub jsonrpc: String,
41    pub method: String,
42    #[serde(default)]
43    pub params: serde_json::Value,
44}
45
46impl JsonRpcNotification {
47    /// Construct a `notifications/progress` payload with the spec-mandated
48    /// `progressToken` echoed back from the originating request's
49    /// `params._meta.progressToken`.
50    ///
51    /// `message` is the opaque JSON payload emitted by the subprocess on one
52    /// line of stdout. The MCP spec's optional `progress`/`total` numeric
53    /// fields are left absent for now — we forward the parsed JSON verbatim
54    /// so higher-level clients can introspect whatever shape the CLI emits.
55    #[must_use]
56    pub fn progress(progress_token: serde_json::Value, message: serde_json::Value) -> Self {
57        Self {
58            jsonrpc: "2.0".to_string(),
59            method: "notifications/progress".to_string(),
60            params: serde_json::json!({
61                "progressToken": progress_token,
62                "message": message,
63            }),
64        }
65    }
66
67    /// Serialize this notification as a single JSON line (no trailing
68    /// newline). Suitable for writing to stdio with `writeln!`.
69    ///
70    /// # Errors
71    /// Returns a `serde_json::Error` if the payload cannot be serialized,
72    /// which in practice only happens for non-finite floats that a well-formed
73    /// subprocess would not produce.
74    pub fn to_json_line(&self) -> Result<String, serde_json::Error> {
75        serde_json::to_string(self)
76    }
77}
78
79/// JSON-RPC 2.0 error object.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct JsonRpcError {
82    pub code: i64,
83    pub message: String,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub data: Option<serde_json::Value>,
86}
87
88impl JsonRpcResponse {
89    #[must_use]
90    pub fn success(id: Option<serde_json::Value>, result: serde_json::Value) -> Self {
91        Self {
92            jsonrpc: "2.0".to_string(),
93            id,
94            result: Some(result),
95            error: None,
96        }
97    }
98
99    pub fn error(id: Option<serde_json::Value>, code: i64, message: impl Into<String>) -> Self {
100        Self {
101            jsonrpc: "2.0".to_string(),
102            id,
103            result: None,
104            error: Some(JsonRpcError {
105                code,
106                message: message.into(),
107                data: None,
108            }),
109        }
110    }
111}
112
113/// `initialize` response capabilities block.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ServerCapabilities {
116    pub tools: ToolsCapability,
117}
118
119/// Tools capability flags.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ToolsCapability {
122    #[serde(rename = "listChanged")]
123    pub list_changed: bool,
124}
125
126/// MCP tool registration record.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ToolDefinition {
129    pub name: String,
130    pub description: String,
131    #[serde(rename = "inputSchema")]
132    pub input_schema: InputSchema,
133}
134
135/// JSON Schema for a tool's input.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct InputSchema {
138    #[serde(rename = "type")]
139    pub schema_type: String,
140    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
141    pub properties: HashMap<String, PropertySchema>,
142    #[serde(default, skip_serializing_if = "Vec::is_empty")]
143    pub required: Vec<String>,
144}
145
146/// JSON Schema for one input property.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct PropertySchema {
149    #[serde(rename = "type")]
150    pub prop_type: String,
151    pub description: String,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub r#enum: Option<Vec<String>>,
154}
155
156/// Result returned by `tools/call`.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ToolCallResult {
159    pub content: Vec<ContentBlock>,
160    #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
161    pub is_error: Option<bool>,
162}
163
164/// One content block in a tool result.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ContentBlock {
167    #[serde(rename = "type")]
168    pub content_type: String,
169    pub text: String,
170}
171
172impl ContentBlock {
173    pub fn text(content: impl Into<String>) -> Self {
174        Self {
175            content_type: "text".to_string(),
176            text: content.into(),
177        }
178    }
179}
180
181impl ToolCallResult {
182    pub fn success(text: impl Into<String>) -> Self {
183        Self {
184            content: vec![ContentBlock::text(text)],
185            is_error: None,
186        }
187    }
188
189    pub fn error(text: impl Into<String>) -> Self {
190        Self {
191            content: vec![ContentBlock::text(text)],
192            is_error: Some(true),
193        }
194    }
195}
196
197#[cfg(test)]
198#[allow(clippy::disallowed_methods)] // serde_json::json! expands to code that hits unwrap()
199mod tests {
200    use super::*;
201
202    #[test]
203    fn json_rpc_response_success_round_trips() {
204        let resp = JsonRpcResponse::success(Some(serde_json::json!(1)), serde_json::json!("ok"));
205        assert!(resp.result.is_some());
206        assert!(resp.error.is_none());
207        assert_eq!(resp.jsonrpc, "2.0");
208    }
209
210    #[test]
211    fn json_rpc_response_error_sets_code() {
212        let resp = JsonRpcResponse::error(Some(serde_json::json!(2)), -32600, "Invalid Request");
213        assert!(resp.result.is_none());
214        let err = resp.error.expect("error present");
215        assert_eq!(err.code, -32600);
216        assert_eq!(err.message, "Invalid Request");
217    }
218
219    #[test]
220    fn tool_call_result_success_has_no_error_flag() {
221        let result = ToolCallResult::success("hello");
222        assert_eq!(result.content.len(), 1);
223        assert_eq!(result.content[0].text, "hello");
224        assert!(result.is_error.is_none());
225    }
226
227    #[test]
228    fn tool_call_result_error_flags_error() {
229        let result = ToolCallResult::error("fail");
230        assert_eq!(result.is_error, Some(true));
231    }
232
233    #[test]
234    fn content_block_text_defaults_type() {
235        let block = ContentBlock::text("test");
236        assert_eq!(block.content_type, "text");
237    }
238
239    #[test]
240    fn json_rpc_request_deserializes_tools_list() {
241        let json = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}"#;
242        let req: JsonRpcRequest = serde_json::from_str(json).expect("deserialize");
243        assert_eq!(req.method, "tools/list");
244    }
245
246    /// FALSIFY-MCP-PROGRESS-001 (wire-format unit): a `notifications/progress`
247    /// serializes without an `id` field and carries `progressToken` +
248    /// `message` inside `params`.
249    #[test]
250    fn json_rpc_notification_progress_has_no_id_field() {
251        let notif = JsonRpcNotification::progress(
252            serde_json::json!("tok-1"),
253            serde_json::json!({"step": 1, "loss": 0.42}),
254        );
255        let json = notif.to_json_line().expect("serialize");
256        assert!(json.contains("\"jsonrpc\":\"2.0\""));
257        assert!(json.contains("\"method\":\"notifications/progress\""));
258        assert!(json.contains("\"progressToken\":\"tok-1\""));
259        assert!(!json.contains("\"id\""), "notifications MUST NOT carry id");
260    }
261
262    #[test]
263    fn json_rpc_notification_accepts_numeric_token() {
264        let notif = JsonRpcNotification::progress(serde_json::json!(7), serde_json::json!("tick"));
265        let json = notif.to_json_line().expect("serialize");
266        assert!(json.contains("\"progressToken\":7"));
267    }
268
269    #[test]
270    fn input_schema_serializes_object_type() {
271        let schema = InputSchema {
272            schema_type: "object".to_string(),
273            properties: HashMap::new(),
274            required: vec![],
275        };
276        let json = serde_json::to_string(&schema).expect("serialize");
277        assert!(json.contains("\"type\":\"object\""));
278    }
279}