poem_mcpserver/protocol/
rpc.rs

1//! JSON-RPC protocol types.
2
3use itertools::Either;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::protocol::{
8    initialize::InitializeRequest,
9    prompts::PromptsListRequest,
10    tool::{ToolsCallRequest, ToolsListRequest},
11};
12
13/// A JSON-RPC request id.
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum RequestId {
17    /// A numeric request id.
18    Int(i64),
19    /// A string request id.
20    String(String),
21}
22
23/// A JSON-RPC request body.
24#[derive(Debug, Deserialize)]
25#[serde(tag = "method", rename_all = "camelCase")]
26pub enum Requests {
27    /// Ping.
28    Ping,
29    /// Initialize.
30    Initialize {
31        /// Initialize request parameters.
32        params: InitializeRequest,
33    },
34    /// Initialized notification.
35    #[serde(rename = "notifications/initialized")]
36    Initialized,
37    /// Cancelled notification.
38    #[serde(rename = "notifications/cancelled")]
39    Cancelled {
40        /// The ID of the request to cancel
41        request_id: RequestId,
42        /// An optional reason string that can be logged or displayed
43        reason: Option<String>,
44    },
45    /// Tools list.
46    #[serde(rename = "tools/list")]
47    ToolsList {
48        /// Tools list request parameters.
49        #[serde(default)]
50        params: ToolsListRequest,
51    },
52    /// Call a tool.
53    #[serde(rename = "tools/call")]
54    ToolsCall {
55        /// Tool call request parameters.
56        params: ToolsCallRequest,
57    },
58    /// Prompts list.
59    #[serde(rename = "prompts/list")]
60    PromptsList {
61        /// Prompts list request parameters.
62        #[serde(default)]
63        params: PromptsListRequest,
64    },
65    /// Resources list.
66    #[serde(rename = "resources/list")]
67    ResourcesList {
68        /// Prompts list request parameters.
69        #[serde(default)]
70        params: PromptsListRequest,
71    },
72}
73
74/// A JSON-RPC batch request.
75#[derive(Debug, Deserialize)]
76#[serde(untagged)]
77pub enum BatchRequest {
78    /// A single request.
79    Single(Request),
80    /// A batch of requests.
81    Batch(Vec<Request>),
82}
83
84impl IntoIterator for BatchRequest {
85    type Item = Request;
86    type IntoIter = Either<std::iter::Once<Self::Item>, std::vec::IntoIter<Self::Item>>;
87
88    fn into_iter(self) -> Self::IntoIter {
89        match self {
90            BatchRequest::Single(request) => Either::Left(std::iter::once(request)),
91            BatchRequest::Batch(requests) => Either::Right(requests.into_iter()),
92        }
93    }
94}
95
96impl BatchRequest {
97    /// Return the number of requests in the batch.
98    pub fn len(&self) -> usize {
99        match self {
100            BatchRequest::Single(_) => 1,
101            BatchRequest::Batch(requests) => requests.len(),
102        }
103    }
104
105    /// Return `true` if the batch is empty.
106    pub fn is_empty(&self) -> bool {
107        self.len() == 0
108    }
109
110    /// Return the requests in the batch.
111    pub fn requests(&self) -> &[Request] {
112        match self {
113            BatchRequest::Single(request) => std::slice::from_ref(request),
114            BatchRequest::Batch(requests) => requests,
115        }
116    }
117}
118
119/// A JSON-RPC request.
120#[derive(Debug, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct Request {
123    /// The JSON-RPC version.
124    pub jsonrpc: String,
125    /// The request id.
126    pub id: Option<RequestId>,
127    /// The request body.
128    #[serde(flatten)]
129    pub body: Requests,
130}
131
132impl Request {
133    #[allow(dead_code)]
134    #[inline]
135    pub(crate) fn is_initialize(&self) -> bool {
136        matches!(self.body, Requests::Initialize { .. })
137    }
138}
139
140/// A JSON-RPC response.
141#[derive(Debug, Serialize)]
142#[serde(rename_all = "camelCase")]
143pub struct Response<T = ()> {
144    /// The JSON-RPC version.
145    pub jsonrpc: String,
146    /// The request id.
147    pub id: Option<RequestId>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    /// The response result.
150    pub result: Option<T>,
151    /// The response error.
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub error: Option<RpcError>,
154}
155
156impl<T> Response<T>
157where
158    T: Serialize,
159{
160    /// Convert the response body to `serde_json::Value`.
161    #[inline]
162    pub fn map_result_to_value(self) -> Response<Value> {
163        Response {
164            jsonrpc: self.jsonrpc,
165            id: self.id,
166            result: self
167                .result
168                .map(|v| serde_json::to_value(v).expect("serialize result")),
169            error: self.error,
170        }
171    }
172}
173
174/// A JSON-RPC batch response
175#[derive(Debug, Serialize)]
176#[serde(untagged)]
177pub enum BatchResponse<T = ()> {
178    /// A single response.
179    Single(Response<T>),
180    /// A batch of responses.
181    Batch(Vec<Response<T>>),
182}
183
184impl<T> IntoIterator for BatchResponse<T> {
185    type Item = Response<T>;
186    type IntoIter = Either<std::iter::Once<Self::Item>, std::vec::IntoIter<Self::Item>>;
187
188    fn into_iter(self) -> Self::IntoIter {
189        match self {
190            BatchResponse::Single(response) => Either::Left(std::iter::once(response)),
191            BatchResponse::Batch(responses) => Either::Right(responses.into_iter()),
192        }
193    }
194}
195
196const PARSE_ERROR: i32 = -32700;
197const INVALID_REQUEST: i32 = -32600;
198const METHOD_NOT_FOUND: i32 = -32601;
199const INVALID_PARAMS: i32 = -32602;
200const INTERNAL_ERROR: i32 = -32603;
201
202/// A JSON-RPC error.
203#[derive(Debug, Serialize)]
204#[serde(rename_all = "camelCase")]
205pub struct RpcError<E = ()> {
206    code: i32,
207    message: String,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    data: Option<E>,
210}
211
212impl<E> RpcError<E> {
213    /// Create a new JSON-RPC error with the given code and message.
214    #[inline]
215    pub fn new(code: i32, message: impl Into<String>) -> Self {
216        RpcError {
217            code,
218            message: message.into(),
219            data: None,
220        }
221    }
222
223    /// Attach data to the JSON-RPC error.
224    #[inline]
225    pub fn with_data<Q>(self, data: Q) -> RpcError<Q> {
226        RpcError {
227            code: self.code,
228            message: self.message,
229            data: Some(data),
230        }
231    }
232
233    /// Create a JSON-RPC error with code `PARSE_ERROR(-32700)` and the given
234    /// message.
235    #[inline]
236    pub fn parse_error(message: impl Into<String>) -> Self {
237        RpcError::new(PARSE_ERROR, message)
238    }
239
240    /// Create a JSON-RPC error with code `INVALID_REQUEST(-32600)` and the
241    /// given message.
242    #[inline]
243    pub fn invalid_request(message: impl Into<String>) -> Self {
244        RpcError::new(INVALID_REQUEST, message)
245    }
246
247    /// Create a JSON-RPC error with code `METHOD_NOT_FOUND(-32601)` and the
248    /// given message.
249    #[inline]
250    pub fn method_not_found(message: impl Into<String>) -> Self {
251        RpcError::new(METHOD_NOT_FOUND, message)
252    }
253
254    /// Create a JSON-RPC error with code `INVALID_PARAMS(-32602)` and the given
255    /// message.
256    #[inline]
257    pub fn invalid_params(message: impl Into<String>) -> Self {
258        RpcError::new(INVALID_PARAMS, message)
259    }
260
261    /// Create a JSON-RPC error with code `INTERNAL_ERROR(-32603)` and the given
262    /// message.
263    #[inline]
264    pub fn internal_error(message: impl Into<String>) -> Self {
265        RpcError::new(INTERNAL_ERROR, message)
266    }
267}