poem_mcpserver/protocol/
rpc.rs

1//! JSON-RPC protocol types.
2
3use itertools::Either;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::protocol::{
8    initialize::InitializeRequest,
9    prompts::PromptsListRequest,
10    tool::{ToolsCallRequest, ToolsListRequest},
11};
12
13/// A JSON-RPC request id.
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum RequestId {
17    /// A numeric request id.
18    Int(i64),
19    /// A string request id.
20    String(String),
21}
22
23/// A JSON-RPC request body.
24#[derive(Debug, Deserialize)]
25#[serde(tag = "method", rename_all = "camelCase")]
26pub enum Requests {
27    /// Ping.
28    Ping,
29    /// Initialize.
30    Initialize {
31        /// Initialize request parameters.
32        params: InitializeRequest,
33    },
34    /// Initialized notification.
35    #[serde(rename = "notifications/initialized")]
36    Initialized,
37    /// Cancelled notification.
38    #[serde(rename = "notifications/cancelled ")]
39    Cancelled {
40        /// The ID of the request to cancel
41        request_id: RequestId,
42        /// An optional reason string that can be logged or displayed
43        reason: Option<String>,
44    },
45    /// Tools list.
46    #[serde(rename = "tools/list")]
47    ToolsList {
48        /// Tools list request parameters.
49        #[serde(default)]
50        params: ToolsListRequest,
51    },
52    /// Call a tool.
53    #[serde(rename = "tools/call")]
54    ToolsCall {
55        /// Tool call request parameters.
56        params: ToolsCallRequest,
57    },
58    /// Prompts list.
59    #[serde(rename = "prompts/list")]
60    PromptsList {
61        /// Prompts list request parameters.
62        #[serde(default)]
63        params: PromptsListRequest,
64    },
65    /// Resources list.
66    #[serde(rename = "resources/list")]
67    ResourcesList {
68        /// Prompts list request parameters.
69        #[serde(default)]
70        params: PromptsListRequest,
71    },
72}
73
74/// A JSON-RPC batch request.
75#[derive(Debug, Deserialize)]
76#[serde(untagged)]
77pub enum BatchRequest {
78    /// A single request.
79    Single(Request),
80    /// A batch of requests.
81    Batch(Vec<Request>),
82}
83
84impl IntoIterator for BatchRequest {
85    type Item = Request;
86    type IntoIter = Either<std::iter::Once<Self::Item>, std::vec::IntoIter<Self::Item>>;
87
88    fn into_iter(self) -> Self::IntoIter {
89        match self {
90            BatchRequest::Single(request) => Either::Left(std::iter::once(request)),
91            BatchRequest::Batch(requests) => Either::Right(requests.into_iter()),
92        }
93    }
94}
95
96impl BatchRequest {
97    /// Return the number of requests in the batch.
98    pub fn len(&self) -> usize {
99        match self {
100            BatchRequest::Single(_) => 1,
101            BatchRequest::Batch(requests) => requests.len(),
102        }
103    }
104
105    /// Return `true` if the batch is empty.
106    pub fn is_empty(&self) -> bool {
107        self.len() == 0
108    }
109
110    /// Return the requests in the batch.
111    pub fn requests(&self) -> &[Request] {
112        match self {
113            BatchRequest::Single(request) => std::slice::from_ref(request),
114            BatchRequest::Batch(requests) => requests,
115        }
116    }
117}
118
119/// A JSON-RPC request.
120#[derive(Debug, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct Request {
123    /// The JSON-RPC version.
124    pub jsonrpc: String,
125    /// The request id.
126    pub id: Option<RequestId>,
127    /// The request body.
128    #[serde(flatten)]
129    pub body: Requests,
130}
131
132impl Request {
133    #[inline]
134    pub(crate) fn is_initialize(&self) -> bool {
135        matches!(self.body, Requests::Initialize { .. })
136    }
137}
138
139/// A JSON-RPC response.
140#[derive(Debug, Serialize)]
141#[serde(rename_all = "camelCase")]
142pub struct Response<T = ()> {
143    /// The JSON-RPC version.
144    pub jsonrpc: String,
145    /// The request id.
146    pub id: Option<RequestId>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    /// The response result.
149    pub result: Option<T>,
150    /// The response error.
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub error: Option<RpcError>,
153}
154
155impl<T> Response<T>
156where
157    T: Serialize,
158{
159    /// Convert the response body to `serde_json::Value`.
160    #[inline]
161    pub fn map_result_to_value(self) -> Response<Value> {
162        Response {
163            jsonrpc: self.jsonrpc,
164            id: self.id,
165            result: self
166                .result
167                .map(|v| serde_json::to_value(v).expect("serialize result")),
168            error: self.error,
169        }
170    }
171}
172
173/// A JSON-RPC batch response
174#[derive(Debug, Serialize)]
175#[serde(untagged)]
176pub enum BatchResponse<T = ()> {
177    /// A single response.
178    Single(Response<T>),
179    /// A batch of responses.
180    Batch(Vec<Response<T>>),
181}
182
183impl<T> IntoIterator for BatchResponse<T> {
184    type Item = Response<T>;
185    type IntoIter = Either<std::iter::Once<Self::Item>, std::vec::IntoIter<Self::Item>>;
186
187    fn into_iter(self) -> Self::IntoIter {
188        match self {
189            BatchResponse::Single(response) => Either::Left(std::iter::once(response)),
190            BatchResponse::Batch(responses) => Either::Right(responses.into_iter()),
191        }
192    }
193}
194
195const PARSE_ERROR: i32 = -32700;
196const INVALID_REQUEST: i32 = -32600;
197const METHOD_NOT_FOUND: i32 = -32601;
198const INVALID_PARAMS: i32 = -32602;
199const INTERNAL_ERROR: i32 = -32603;
200
201/// A JSON-RPC error.
202#[derive(Debug, Serialize)]
203#[serde(rename_all = "camelCase")]
204pub struct RpcError<E = ()> {
205    code: i32,
206    message: String,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    data: Option<E>,
209}
210
211impl<E> RpcError<E> {
212    /// Create a new JSON-RPC error with the given code and message.
213    #[inline]
214    pub fn new(code: i32, message: impl Into<String>) -> Self {
215        RpcError {
216            code,
217            message: message.into(),
218            data: None,
219        }
220    }
221
222    /// Attach data to the JSON-RPC error.
223    #[inline]
224    pub fn with_data<Q>(self, data: Q) -> RpcError<Q> {
225        RpcError {
226            code: self.code,
227            message: self.message,
228            data: Some(data),
229        }
230    }
231
232    /// Create a JSON-RPC error with code `PARSE_ERROR(-32700)` and the given
233    /// message.
234    #[inline]
235    pub fn parse_error(message: impl Into<String>) -> Self {
236        RpcError::new(PARSE_ERROR, message)
237    }
238
239    /// Create a JSON-RPC error with code `INVALID_REQUEST(-32600)` and the
240    /// given message.
241    #[inline]
242    pub fn invalid_request(message: impl Into<String>) -> Self {
243        RpcError::new(INVALID_REQUEST, message)
244    }
245
246    /// Create a JSON-RPC error with code `METHOD_NOT_FOUND(-32601)` and the
247    /// given message.
248    #[inline]
249    pub fn method_not_found(message: impl Into<String>) -> Self {
250        RpcError::new(METHOD_NOT_FOUND, message)
251    }
252
253    /// Create a JSON-RPC error with code `INVALID_PARAMS(-32602)` and the given
254    /// message.
255    #[inline]
256    pub fn invalid_params(message: impl Into<String>) -> Self {
257        RpcError::new(INVALID_PARAMS, message)
258    }
259
260    /// Create a JSON-RPC error with code `INTERNAL_ERROR(-32603)` and the given
261    /// message.
262    #[inline]
263    pub fn internal_error(message: impl Into<String>) -> Self {
264        RpcError::new(INTERNAL_ERROR, message)
265    }
266}