1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::Notify;
10use tokio::time::Duration;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct JsonRpcRequest {
15 pub jsonrpc: String,
17 pub id: serde_json::Value,
19 pub method: String,
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub params: Option<serde_json::Value>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct JsonRpcResponse {
29 pub jsonrpc: String,
31 pub id: serde_json::Value,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub result: Option<serde_json::Value>,
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub error: Option<JsonRpcError>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct JsonRpcNotification {
44 pub jsonrpc: String,
46 pub method: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub params: Option<serde_json::Value>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct JsonRpcError {
56 pub code: i32,
58 pub message: String,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub data: Option<serde_json::Value>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CancellationParams {
68 pub id: serde_json::Value,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub reason: Option<String>,
73}
74
75#[derive(Debug, Clone)]
77pub struct CancellationToken {
78 notify: Arc<Notify>,
80 cancelled: Arc<std::sync::atomic::AtomicBool>,
82 request_id: serde_json::Value,
84}
85
86impl CancellationToken {
87 pub fn new(request_id: serde_json::Value) -> Self {
89 Self {
90 notify: Arc::new(Notify::new()),
91 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
92 request_id,
93 }
94 }
95
96 pub fn is_cancelled(&self) -> bool {
98 self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
99 }
100
101 pub fn cancel(&self) {
103 self.cancelled
104 .store(true, std::sync::atomic::Ordering::Relaxed);
105 self.notify.notify_waiters();
106 }
107
108 pub async fn cancelled(&self) {
110 if self.is_cancelled() {
111 return;
112 }
113 self.notify.notified().await;
114 }
115
116 pub fn request_id(&self) -> &serde_json::Value {
118 &self.request_id
119 }
120
121 pub async fn with_timeout<F, T>(
123 &self,
124 timeout: Duration,
125 operation: F,
126 ) -> Result<T, CancellationError>
127 where
128 F: std::future::Future<Output = T>,
129 {
130 tokio::select! {
131 result = operation => Ok(result),
132 _ = self.cancelled() => Err(CancellationError::Cancelled),
133 _ = tokio::time::sleep(timeout) => Err(CancellationError::Timeout),
134 }
135 }
136}
137
138#[derive(Debug, Clone, thiserror::Error)]
140pub enum CancellationError {
141 #[error("Operation was cancelled")]
143 Cancelled,
144 #[error("Operation timed out")]
146 Timeout,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct InitializeParams {
152 #[serde(rename = "protocolVersion")]
154 pub protocol_version: String,
155 pub capabilities: ClientCapabilities,
157 #[serde(rename = "clientInfo")]
159 pub client_info: ClientInfo,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct InitializeResult {
165 #[serde(rename = "protocolVersion")]
167 pub protocol_version: String,
168 pub capabilities: ServerCapabilities,
170 #[serde(rename = "serverInfo")]
172 pub server_info: ServerInfo,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize, Default)]
177pub struct ClientCapabilities {
178 #[serde(skip_serializing_if = "Option::is_none")]
180 pub experimental: Option<HashMap<String, serde_json::Value>>,
181 #[serde(skip_serializing_if = "Option::is_none")]
183 pub sampling: Option<SamplingCapability>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Default)]
188pub struct ServerCapabilities {
189 #[serde(skip_serializing_if = "Option::is_none")]
191 pub experimental: Option<HashMap<String, serde_json::Value>>,
192 #[serde(skip_serializing_if = "Option::is_none")]
194 pub resources: Option<crate::resources::ResourceCapabilities>,
195 #[serde(skip_serializing_if = "Option::is_none")]
197 pub tools: Option<crate::tools::ToolCapabilities>,
198 #[serde(skip_serializing_if = "Option::is_none")]
200 pub prompts: Option<crate::prompts::PromptCapabilities>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SamplingCapability {}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ClientInfo {
210 pub name: String,
212 pub version: String,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ServerInfo {
219 pub name: String,
221 pub version: String,
223}
224
225impl JsonRpcRequest {
226 pub fn new(id: serde_json::Value, method: String, params: Option<serde_json::Value>) -> Self {
228 Self {
229 jsonrpc: "2.0".to_string(),
230 id,
231 method,
232 params,
233 }
234 }
235}
236
237impl JsonRpcResponse {
238 pub fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
240 Self {
241 jsonrpc: "2.0".to_string(),
242 id,
243 result: Some(result),
244 error: None,
245 }
246 }
247
248 pub fn error(id: serde_json::Value, error: JsonRpcError) -> Self {
250 Self {
251 jsonrpc: "2.0".to_string(),
252 id,
253 result: None,
254 error: Some(error),
255 }
256 }
257}
258
259impl JsonRpcNotification {
260 pub fn new(method: String, params: Option<serde_json::Value>) -> Self {
262 Self {
263 jsonrpc: "2.0".to_string(),
264 method,
265 params,
266 }
267 }
268}
269
270impl JsonRpcError {
271 pub const PARSE_ERROR: i32 = -32700;
273 pub const INVALID_REQUEST: i32 = -32600;
274 pub const METHOD_NOT_FOUND: i32 = -32601;
275 pub const INVALID_PARAMS: i32 = -32602;
276 pub const INTERNAL_ERROR: i32 = -32603;
277
278 pub fn new(code: i32, message: String, data: Option<serde_json::Value>) -> Self {
280 Self {
281 code,
282 message,
283 data,
284 }
285 }
286
287 pub fn method_not_found(method: &str) -> Self {
289 Self::new(
290 Self::METHOD_NOT_FOUND,
291 format!("Method not found: {}", method),
292 None,
293 )
294 }
295
296 pub fn invalid_params(message: String) -> Self {
298 Self::new(Self::INVALID_PARAMS, message, None)
299 }
300
301 pub fn internal_error(message: String) -> Self {
303 Self::new(Self::INTERNAL_ERROR, message, None)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_json_rpc_request_serialization() {
313 let request = JsonRpcRequest::new(
314 serde_json::Value::Number(1.into()),
315 "test_method".to_string(),
316 Some(serde_json::json!({"param": "value"})),
317 );
318
319 let json = serde_json::to_string(&request).unwrap();
320 let deserialized: JsonRpcRequest = serde_json::from_str(&json).unwrap();
321
322 assert_eq!(request.jsonrpc, deserialized.jsonrpc);
323 assert_eq!(request.id, deserialized.id);
324 assert_eq!(request.method, deserialized.method);
325 assert_eq!(request.params, deserialized.params);
326 }
327
328 #[test]
329 fn test_json_rpc_response_success() {
330 let response = JsonRpcResponse::success(
331 serde_json::Value::Number(1.into()),
332 serde_json::json!({"success": true}),
333 );
334
335 assert_eq!(response.jsonrpc, "2.0");
336 assert!(response.result.is_some());
337 assert!(response.error.is_none());
338 }
339
340 #[test]
341 fn test_json_rpc_response_error() {
342 let error = JsonRpcError::method_not_found("unknown_method");
343 let response = JsonRpcResponse::error(serde_json::Value::Number(1.into()), error);
344
345 assert_eq!(response.jsonrpc, "2.0");
346 assert!(response.result.is_none());
347 assert!(response.error.is_some());
348 }
349
350 #[test]
351 fn test_initialize_params() {
352 let params = InitializeParams {
353 protocol_version: "2024-11-05".to_string(),
354 capabilities: ClientCapabilities::default(),
355 client_info: ClientInfo {
356 name: "test-client".to_string(),
357 version: "1.0.0".to_string(),
358 },
359 };
360
361 let json = serde_json::to_string(¶ms).unwrap();
362 let deserialized: InitializeParams = serde_json::from_str(&json).unwrap();
363
364 assert_eq!(params.protocol_version, deserialized.protocol_version);
365 assert_eq!(params.client_info.name, deserialized.client_info.name);
366 }
367}