codeprism_mcp/
protocol.rs

1//! MCP Protocol types and JSON-RPC 2.0 implementation
2//!
3//! This module implements the core Model Context Protocol types according to the specification.
4//! All message types follow JSON-RPC 2.0 format as required by MCP.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::Notify;
10use tokio::time::Duration;
11
12/// JSON-RPC 2.0 Request message
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct JsonRpcRequest {
15    /// JSON-RPC version, must be "2.0"
16    pub jsonrpc: String,
17    /// Request ID (number or string)
18    pub id: serde_json::Value,
19    /// Method name
20    pub method: String,
21    /// Optional parameters
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub params: Option<serde_json::Value>,
24}
25
26/// JSON-RPC 2.0 Response message
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct JsonRpcResponse {
29    /// JSON-RPC version, must be "2.0"
30    pub jsonrpc: String,
31    /// Request ID matching the original request
32    pub id: serde_json::Value,
33    /// Successful result (mutually exclusive with error)
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub result: Option<serde_json::Value>,
36    /// Error information (mutually exclusive with result)
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub error: Option<JsonRpcError>,
39}
40
41/// JSON-RPC 2.0 Notification message (no response expected)
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct JsonRpcNotification {
44    /// JSON-RPC version, must be "2.0"
45    pub jsonrpc: String,
46    /// Method name
47    pub method: String,
48    /// Optional parameters
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub params: Option<serde_json::Value>,
51}
52
53/// JSON-RPC 2.0 Error object
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct JsonRpcError {
56    /// Error code
57    pub code: i32,
58    /// Error message
59    pub message: String,
60    /// Optional additional error data
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub data: Option<serde_json::Value>,
63}
64
65/// Cancellation notification parameters
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CancellationParams {
68    /// Request ID being cancelled
69    pub id: serde_json::Value,
70    /// Optional reason for cancellation
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub reason: Option<String>,
73}
74
75/// Cancellation token for request cancellation
76#[derive(Debug, Clone)]
77pub struct CancellationToken {
78    /// Notifier for cancellation
79    notify: Arc<Notify>,
80    /// Whether the token is cancelled
81    cancelled: Arc<std::sync::atomic::AtomicBool>,
82    /// Request ID associated with this token
83    request_id: serde_json::Value,
84}
85
86impl CancellationToken {
87    /// Create a new cancellation token
88    pub fn new(request_id: serde_json::Value) -> Self {
89        Self {
90            notify: Arc::new(Notify::new()),
91            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
92            request_id,
93        }
94    }
95
96    /// Check if cancellation was requested
97    pub fn is_cancelled(&self) -> bool {
98        self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
99    }
100
101    /// Cancel this token
102    pub fn cancel(&self) {
103        self.cancelled
104            .store(true, std::sync::atomic::Ordering::Relaxed);
105        self.notify.notify_waiters();
106    }
107
108    /// Wait for cancellation
109    pub async fn cancelled(&self) {
110        if self.is_cancelled() {
111            return;
112        }
113        self.notify.notified().await;
114    }
115
116    /// Get the request ID
117    pub fn request_id(&self) -> &serde_json::Value {
118        &self.request_id
119    }
120
121    /// Run an operation with timeout and cancellation
122    pub async fn with_timeout<F, T>(
123        &self,
124        timeout: Duration,
125        operation: F,
126    ) -> Result<T, CancellationError>
127    where
128        F: std::future::Future<Output = T>,
129    {
130        tokio::select! {
131            result = operation => Ok(result),
132            _ = self.cancelled() => Err(CancellationError::Cancelled),
133            _ = tokio::time::sleep(timeout) => Err(CancellationError::Timeout),
134        }
135    }
136}
137
138/// Cancellation error types
139#[derive(Debug, Clone, thiserror::Error)]
140pub enum CancellationError {
141    /// Operation was cancelled
142    #[error("Operation was cancelled")]
143    Cancelled,
144    /// Operation timed out
145    #[error("Operation timed out")]
146    Timeout,
147}
148
149/// MCP Initialize request parameters
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct InitializeParams {
152    /// Protocol version supported by client
153    #[serde(rename = "protocolVersion")]
154    pub protocol_version: String,
155    /// Client capabilities
156    pub capabilities: ClientCapabilities,
157    /// Client implementation information
158    #[serde(rename = "clientInfo")]
159    pub client_info: ClientInfo,
160}
161
162/// MCP Initialize response
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct InitializeResult {
165    /// Protocol version supported by server
166    #[serde(rename = "protocolVersion")]
167    pub protocol_version: String,
168    /// Server capabilities
169    pub capabilities: ServerCapabilities,
170    /// Server implementation information
171    #[serde(rename = "serverInfo")]
172    pub server_info: ServerInfo,
173}
174
175/// Client capabilities
176#[derive(Debug, Clone, Serialize, Deserialize, Default)]
177pub struct ClientCapabilities {
178    /// Experimental capabilities
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub experimental: Option<HashMap<String, serde_json::Value>>,
181    /// Sampling capability
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub sampling: Option<SamplingCapability>,
184}
185
186/// Server capabilities
187#[derive(Debug, Clone, Serialize, Deserialize, Default)]
188pub struct ServerCapabilities {
189    /// Experimental capabilities
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub experimental: Option<HashMap<String, serde_json::Value>>,
192    /// Resources capability
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub resources: Option<crate::resources::ResourceCapabilities>,
195    /// Tools capability
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub tools: Option<crate::tools::ToolCapabilities>,
198    /// Prompts capability
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub prompts: Option<crate::prompts::PromptCapabilities>,
201}
202
203/// Sampling capability
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SamplingCapability {}
206
207/// Client information
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ClientInfo {
210    /// Client name
211    pub name: String,
212    /// Client version
213    pub version: String,
214}
215
216/// Server information
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ServerInfo {
219    /// Server name
220    pub name: String,
221    /// Server version
222    pub version: String,
223}
224
225impl JsonRpcRequest {
226    /// Create a new JSON-RPC request
227    pub fn new(id: serde_json::Value, method: String, params: Option<serde_json::Value>) -> Self {
228        Self {
229            jsonrpc: "2.0".to_string(),
230            id,
231            method,
232            params,
233        }
234    }
235}
236
237impl JsonRpcResponse {
238    /// Create a successful JSON-RPC response
239    pub fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
240        Self {
241            jsonrpc: "2.0".to_string(),
242            id,
243            result: Some(result),
244            error: None,
245        }
246    }
247
248    /// Create an error JSON-RPC response
249    pub fn error(id: serde_json::Value, error: JsonRpcError) -> Self {
250        Self {
251            jsonrpc: "2.0".to_string(),
252            id,
253            result: None,
254            error: Some(error),
255        }
256    }
257}
258
259impl JsonRpcNotification {
260    /// Create a new JSON-RPC notification
261    pub fn new(method: String, params: Option<serde_json::Value>) -> Self {
262        Self {
263            jsonrpc: "2.0".to_string(),
264            method,
265            params,
266        }
267    }
268}
269
270impl JsonRpcError {
271    /// Standard JSON-RPC error codes
272    pub const PARSE_ERROR: i32 = -32700;
273    pub const INVALID_REQUEST: i32 = -32600;
274    pub const METHOD_NOT_FOUND: i32 = -32601;
275    pub const INVALID_PARAMS: i32 = -32602;
276    pub const INTERNAL_ERROR: i32 = -32603;
277
278    /// Create a new JSON-RPC error
279    pub fn new(code: i32, message: String, data: Option<serde_json::Value>) -> Self {
280        Self {
281            code,
282            message,
283            data,
284        }
285    }
286
287    /// Create a method not found error
288    pub fn method_not_found(method: &str) -> Self {
289        Self::new(
290            Self::METHOD_NOT_FOUND,
291            format!("Method not found: {}", method),
292            None,
293        )
294    }
295
296    /// Create an invalid parameters error
297    pub fn invalid_params(message: String) -> Self {
298        Self::new(Self::INVALID_PARAMS, message, None)
299    }
300
301    /// Create an internal error
302    pub fn internal_error(message: String) -> Self {
303        Self::new(Self::INTERNAL_ERROR, message, None)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_json_rpc_request_serialization() {
313        let request = JsonRpcRequest::new(
314            serde_json::Value::Number(1.into()),
315            "test_method".to_string(),
316            Some(serde_json::json!({"param": "value"})),
317        );
318
319        let json = serde_json::to_string(&request).unwrap();
320        let deserialized: JsonRpcRequest = serde_json::from_str(&json).unwrap();
321
322        assert_eq!(request.jsonrpc, deserialized.jsonrpc);
323        assert_eq!(request.id, deserialized.id);
324        assert_eq!(request.method, deserialized.method);
325        assert_eq!(request.params, deserialized.params);
326    }
327
328    #[test]
329    fn test_json_rpc_response_success() {
330        let response = JsonRpcResponse::success(
331            serde_json::Value::Number(1.into()),
332            serde_json::json!({"success": true}),
333        );
334
335        assert_eq!(response.jsonrpc, "2.0");
336        assert!(response.result.is_some());
337        assert!(response.error.is_none());
338    }
339
340    #[test]
341    fn test_json_rpc_response_error() {
342        let error = JsonRpcError::method_not_found("unknown_method");
343        let response = JsonRpcResponse::error(serde_json::Value::Number(1.into()), error);
344
345        assert_eq!(response.jsonrpc, "2.0");
346        assert!(response.result.is_none());
347        assert!(response.error.is_some());
348    }
349
350    #[test]
351    fn test_initialize_params() {
352        let params = InitializeParams {
353            protocol_version: "2024-11-05".to_string(),
354            capabilities: ClientCapabilities::default(),
355            client_info: ClientInfo {
356                name: "test-client".to_string(),
357                version: "1.0.0".to_string(),
358            },
359        };
360
361        let json = serde_json::to_string(&params).unwrap();
362        let deserialized: InitializeParams = serde_json::from_str(&json).unwrap();
363
364        assert_eq!(params.protocol_version, deserialized.protocol_version);
365        assert_eq!(params.client_info.name, deserialized.client_info.name);
366    }
367}