codeprism_mcp/
protocol.rs

1//! MCP Protocol types and JSON-RPC 2.0 implementation
2//!
3//! This module implements the core Model Context Protocol types according to the specification.
4//! All message types follow JSON-RPC 2.0 format as required by MCP.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::Notify;
10use tokio::time::Duration;
11
12/// Supported MCP protocol versions
13pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &[
14    "2024-11-05", // Latest supported version
15    "2024-10-07", // Previous stable version
16    "2024-09-01", // Legacy compatibility
17];
18
19/// Current default protocol version
20pub const DEFAULT_PROTOCOL_VERSION: &str = "2024-11-05";
21
22/// Client types we can detect and optimize for
23#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
24pub enum ClientType {
25    Claude,
26    Cursor,
27    VSCode,
28    Unknown(String),
29}
30
31impl ClientType {
32    /// Detect client type from client info
33    pub fn from_client_info(client_info: &ClientInfo) -> Self {
34        let name_lower = client_info.name.to_lowercase();
35
36        if name_lower.contains("claude") {
37            Self::Claude
38        } else if name_lower.contains("cursor") {
39            Self::Cursor
40        } else if name_lower.contains("vscode") || name_lower.contains("vs code") {
41            Self::VSCode
42        } else {
43            Self::Unknown(client_info.name.clone())
44        }
45    }
46
47    /// Get client-specific optimizations
48    pub fn get_optimizations(&self) -> ClientOptimizations {
49        match self {
50            Self::Claude => ClientOptimizations {
51                max_response_size: 100_000,
52                supports_streaming: true,
53                preferred_timeout: Duration::from_secs(30),
54                batch_size_limit: 10,
55            },
56            Self::Cursor => ClientOptimizations {
57                max_response_size: 50_000,
58                supports_streaming: false,
59                preferred_timeout: Duration::from_secs(15),
60                batch_size_limit: 5,
61            },
62            Self::VSCode => ClientOptimizations {
63                max_response_size: 75_000,
64                supports_streaming: true,
65                preferred_timeout: Duration::from_secs(20),
66                batch_size_limit: 7,
67            },
68            Self::Unknown(_) => ClientOptimizations::default(),
69        }
70    }
71}
72
73/// Client-specific optimization settings
74#[derive(Debug, Clone)]
75pub struct ClientOptimizations {
76    pub max_response_size: usize,
77    pub supports_streaming: bool,
78    pub preferred_timeout: Duration,
79    pub batch_size_limit: usize,
80}
81
82impl Default for ClientOptimizations {
83    fn default() -> Self {
84        Self {
85            max_response_size: 75_000,
86            supports_streaming: false,
87            preferred_timeout: Duration::from_secs(30),
88            batch_size_limit: 5,
89        }
90    }
91}
92
93/// Protocol version negotiation result
94#[derive(Debug, Clone)]
95pub struct VersionNegotiation {
96    pub agreed_version: String,
97    pub client_version: String,
98    pub server_versions: Vec<String>,
99    pub compatibility_level: CompatibilityLevel,
100    pub warnings: Vec<String>,
101}
102
103/// Compatibility level between client and server
104#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
105pub enum CompatibilityLevel {
106    /// Incompatible - connection should be rejected
107    Incompatible,
108    /// Limited compatibility - some features may not work
109    Limited,
110    /// Compatible with minor differences
111    Compatible,
112    /// Full compatibility - same version
113    Full,
114}
115
116impl VersionNegotiation {
117    /// Negotiate protocol version between client and server
118    pub fn negotiate(client_version: &str) -> Self {
119        let server_versions: Vec<String> = SUPPORTED_PROTOCOL_VERSIONS
120            .iter()
121            .map(|v| v.to_string())
122            .collect();
123        let mut warnings = Vec::new();
124
125        // Check if client version is supported
126        let (agreed_version, compatibility_level) =
127            if SUPPORTED_PROTOCOL_VERSIONS.contains(&client_version) {
128                (client_version.to_string(), CompatibilityLevel::Full)
129            } else {
130                // Try to find a compatible version
131                let parsed_client = parse_version(client_version);
132                let mut best_match = None;
133                let mut best_compatibility = CompatibilityLevel::Incompatible;
134
135                for &server_version in SUPPORTED_PROTOCOL_VERSIONS {
136                    let parsed_server = parse_version(server_version);
137                    let compatibility = determine_compatibility(&parsed_client, &parsed_server);
138
139                    if compatibility > best_compatibility {
140                        best_match = Some(server_version.to_string());
141                        best_compatibility = compatibility;
142                    }
143                }
144
145                match best_match {
146                    Some(version) => {
147                        warnings.push(format!(
148                        "Client version {} not directly supported, using {} with {} compatibility",
149                        client_version, version,
150                        match best_compatibility {
151                            CompatibilityLevel::Full => "full",
152                            CompatibilityLevel::Compatible => "high",
153                            CompatibilityLevel::Limited => "limited",
154                            CompatibilityLevel::Incompatible => "no",
155                        }
156                    ));
157                        (version, best_compatibility)
158                    }
159                    None => {
160                        warnings.push(format!(
161                            "Client version {} is incompatible with supported versions: {:?}",
162                            client_version, SUPPORTED_PROTOCOL_VERSIONS
163                        ));
164                        (
165                            DEFAULT_PROTOCOL_VERSION.to_string(),
166                            CompatibilityLevel::Incompatible,
167                        )
168                    }
169                }
170            };
171
172        Self {
173            agreed_version,
174            client_version: client_version.to_string(),
175            server_versions,
176            compatibility_level,
177            warnings,
178        }
179    }
180
181    /// Check if this negotiation allows the connection
182    pub fn is_acceptable(&self) -> bool {
183        self.compatibility_level != CompatibilityLevel::Incompatible
184    }
185}
186
187/// Parsed version components
188#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
189struct ParsedVersion {
190    year: u32,
191    month: u32,
192    day: u32,
193}
194
195/// Parse a version string in YYYY-MM-DD format
196fn parse_version(version: &str) -> ParsedVersion {
197    let parts: Vec<&str> = version.split('-').collect();
198    if parts.len() == 3 {
199        ParsedVersion {
200            year: parts[0].parse().unwrap_or(0),
201            month: parts[1].parse().unwrap_or(0),
202            day: parts[2].parse().unwrap_or(0),
203        }
204    } else {
205        ParsedVersion {
206            year: 0,
207            month: 0,
208            day: 0,
209        }
210    }
211}
212
213/// Determine compatibility level between two versions
214fn determine_compatibility(client: &ParsedVersion, server: &ParsedVersion) -> CompatibilityLevel {
215    if client == server {
216        return CompatibilityLevel::Full;
217    }
218
219    // Same year and month = compatible
220    if client.year == server.year && client.month == server.month {
221        return CompatibilityLevel::Compatible;
222    }
223
224    // Within 6 months = limited compatibility
225    let client_days = client.year * 365 + client.month * 30 + client.day;
226    let server_days = server.year * 365 + server.month * 30 + server.day;
227    let diff_days = (client_days as i32 - server_days as i32).abs();
228
229    if diff_days <= 180 {
230        // ~6 months
231        CompatibilityLevel::Limited
232    } else {
233        CompatibilityLevel::Incompatible
234    }
235}
236
237/// JSON-RPC 2.0 Request message
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct JsonRpcRequest {
240    /// JSON-RPC version, must be "2.0"
241    pub jsonrpc: String,
242    /// Request ID (number or string)
243    pub id: serde_json::Value,
244    /// Method name
245    pub method: String,
246    /// Optional parameters
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub params: Option<serde_json::Value>,
249}
250
251/// JSON-RPC 2.0 Response message
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct JsonRpcResponse {
254    /// JSON-RPC version, must be "2.0"
255    pub jsonrpc: String,
256    /// Request ID matching the original request
257    pub id: serde_json::Value,
258    /// Successful result (mutually exclusive with error)
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub result: Option<serde_json::Value>,
261    /// Error information (mutually exclusive with result)
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub error: Option<JsonRpcError>,
264}
265
266/// JSON-RPC 2.0 Notification message (no response expected)
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct JsonRpcNotification {
269    /// JSON-RPC version, must be "2.0"
270    pub jsonrpc: String,
271    /// Method name
272    pub method: String,
273    /// Optional parameters
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub params: Option<serde_json::Value>,
276}
277
278/// JSON-RPC 2.0 Error object
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct JsonRpcError {
281    /// Error code
282    pub code: i32,
283    /// Error message
284    pub message: String,
285    /// Optional additional error data
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub data: Option<serde_json::Value>,
288}
289
290/// Cancellation notification parameters
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct CancellationParams {
293    /// Request ID being cancelled
294    pub id: serde_json::Value,
295    /// Optional reason for cancellation
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub reason: Option<String>,
298}
299
300/// Cancellation token for request cancellation
301#[derive(Debug, Clone)]
302pub struct CancellationToken {
303    /// Notifier for cancellation
304    notify: Arc<Notify>,
305    /// Whether the token is cancelled
306    cancelled: Arc<std::sync::atomic::AtomicBool>,
307    /// Request ID associated with this token
308    request_id: serde_json::Value,
309}
310
311impl CancellationToken {
312    /// Create a new cancellation token
313    pub fn new(request_id: serde_json::Value) -> Self {
314        Self {
315            notify: Arc::new(Notify::new()),
316            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
317            request_id,
318        }
319    }
320
321    /// Check if cancellation was requested
322    pub fn is_cancelled(&self) -> bool {
323        self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
324    }
325
326    /// Cancel this token
327    pub fn cancel(&self) {
328        self.cancelled
329            .store(true, std::sync::atomic::Ordering::Relaxed);
330        self.notify.notify_waiters();
331    }
332
333    /// Wait for cancellation
334    pub async fn cancelled(&self) {
335        if self.is_cancelled() {
336            return;
337        }
338        self.notify.notified().await;
339    }
340
341    /// Get the request ID
342    pub fn request_id(&self) -> &serde_json::Value {
343        &self.request_id
344    }
345
346    /// Run an operation with timeout and cancellation
347    pub async fn with_timeout<F, T>(
348        &self,
349        timeout: Duration,
350        operation: F,
351    ) -> Result<T, CancellationError>
352    where
353        F: std::future::Future<Output = T>,
354    {
355        tokio::select! {
356            result = operation => Ok(result),
357            _ = self.cancelled() => Err(CancellationError::Cancelled),
358            _ = tokio::time::sleep(timeout) => Err(CancellationError::Timeout),
359        }
360    }
361}
362
363/// Cancellation error types
364#[derive(Debug, Clone, thiserror::Error)]
365pub enum CancellationError {
366    /// Operation was cancelled
367    #[error("Operation was cancelled")]
368    Cancelled,
369    /// Operation timed out
370    #[error("Operation timed out")]
371    Timeout,
372}
373
374/// MCP Initialize request parameters
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct InitializeParams {
377    /// Protocol version supported by client
378    #[serde(rename = "protocolVersion")]
379    pub protocol_version: String,
380    /// Client capabilities
381    pub capabilities: ClientCapabilities,
382    /// Client implementation information
383    #[serde(rename = "clientInfo")]
384    pub client_info: ClientInfo,
385}
386
387/// MCP Initialize response
388#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct InitializeResult {
390    /// Protocol version supported by server
391    #[serde(rename = "protocolVersion")]
392    pub protocol_version: String,
393    /// Server capabilities
394    pub capabilities: ServerCapabilities,
395    /// Server implementation information
396    #[serde(rename = "serverInfo")]
397    pub server_info: ServerInfo,
398}
399
400/// Client capabilities
401#[derive(Debug, Clone, Serialize, Deserialize, Default)]
402pub struct ClientCapabilities {
403    /// Experimental capabilities
404    #[serde(skip_serializing_if = "Option::is_none")]
405    pub experimental: Option<HashMap<String, serde_json::Value>>,
406    /// Sampling capability
407    #[serde(skip_serializing_if = "Option::is_none")]
408    pub sampling: Option<SamplingCapability>,
409}
410
411/// Server capabilities
412#[derive(Debug, Clone, Serialize, Deserialize, Default)]
413pub struct ServerCapabilities {
414    /// Experimental capabilities
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub experimental: Option<HashMap<String, serde_json::Value>>,
417    /// Resources capability
418    #[serde(skip_serializing_if = "Option::is_none")]
419    pub resources: Option<crate::resources::ResourceCapabilities>,
420    /// Tools capability
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub tools: Option<crate::tools::ToolCapabilities>,
423    /// Prompts capability
424    #[serde(skip_serializing_if = "Option::is_none")]
425    pub prompts: Option<crate::prompts::PromptCapabilities>,
426}
427
428/// Sampling capability
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct SamplingCapability {}
431
432/// Client information
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ClientInfo {
435    /// Client name
436    pub name: String,
437    /// Client version
438    pub version: String,
439}
440
441/// Server information
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct ServerInfo {
444    /// Server name
445    pub name: String,
446    /// Server version
447    pub version: String,
448}
449
450impl JsonRpcRequest {
451    /// Create a new JSON-RPC request
452    pub fn new(id: serde_json::Value, method: String, params: Option<serde_json::Value>) -> Self {
453        Self {
454            jsonrpc: "2.0".to_string(),
455            id,
456            method,
457            params,
458        }
459    }
460}
461
462impl JsonRpcResponse {
463    /// Create a successful JSON-RPC response
464    pub fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
465        Self {
466            jsonrpc: "2.0".to_string(),
467            id,
468            result: Some(result),
469            error: None,
470        }
471    }
472
473    /// Create an error JSON-RPC response
474    pub fn error(id: serde_json::Value, error: JsonRpcError) -> Self {
475        Self {
476            jsonrpc: "2.0".to_string(),
477            id,
478            result: None,
479            error: Some(error),
480        }
481    }
482}
483
484impl JsonRpcNotification {
485    /// Create a new JSON-RPC notification
486    pub fn new(method: String, params: Option<serde_json::Value>) -> Self {
487        Self {
488            jsonrpc: "2.0".to_string(),
489            method,
490            params,
491        }
492    }
493}
494
495impl JsonRpcError {
496    /// Standard JSON-RPC error codes
497    pub const PARSE_ERROR: i32 = -32700;
498    pub const INVALID_REQUEST: i32 = -32600;
499    pub const METHOD_NOT_FOUND: i32 = -32601;
500    pub const INVALID_PARAMS: i32 = -32602;
501    pub const INTERNAL_ERROR: i32 = -32603;
502
503    /// Create a new JSON-RPC error
504    pub fn new(code: i32, message: String, data: Option<serde_json::Value>) -> Self {
505        Self {
506            code,
507            message,
508            data,
509        }
510    }
511
512    /// Create a method not found error
513    pub fn method_not_found(method: &str) -> Self {
514        Self::new(
515            Self::METHOD_NOT_FOUND,
516            format!("Method not found: {}", method),
517            None,
518        )
519    }
520
521    /// Create an invalid parameters error
522    pub fn invalid_params(message: String) -> Self {
523        Self::new(Self::INVALID_PARAMS, message, None)
524    }
525
526    /// Create an internal error
527    pub fn internal_error(message: String) -> Self {
528        Self::new(Self::INTERNAL_ERROR, message, None)
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_json_rpc_request_serialization() {
538        let request = JsonRpcRequest::new(
539            serde_json::Value::Number(1.into()),
540            "test_method".to_string(),
541            Some(serde_json::json!({"param": "value"})),
542        );
543
544        let json = serde_json::to_string(&request).unwrap();
545        let deserialized: JsonRpcRequest = serde_json::from_str(&json).unwrap();
546
547        assert_eq!(request.jsonrpc, deserialized.jsonrpc);
548        assert_eq!(request.id, deserialized.id);
549        assert_eq!(request.method, deserialized.method);
550        assert_eq!(request.params, deserialized.params);
551    }
552
553    #[test]
554    fn test_json_rpc_response_success() {
555        let response = JsonRpcResponse::success(
556            serde_json::Value::Number(1.into()),
557            serde_json::json!({"success": true}),
558        );
559
560        assert_eq!(response.jsonrpc, "2.0");
561        assert!(response.result.is_some());
562        assert!(response.error.is_none());
563    }
564
565    #[test]
566    fn test_json_rpc_response_error() {
567        let error = JsonRpcError::method_not_found("unknown_method");
568        let response = JsonRpcResponse::error(serde_json::Value::Number(1.into()), error);
569
570        assert_eq!(response.jsonrpc, "2.0");
571        assert!(response.result.is_none());
572        assert!(response.error.is_some());
573    }
574
575    #[test]
576    fn test_initialize_params() {
577        let params = InitializeParams {
578            protocol_version: "2024-11-05".to_string(),
579            capabilities: ClientCapabilities::default(),
580            client_info: ClientInfo {
581                name: "test-client".to_string(),
582                version: "1.0.0".to_string(),
583            },
584        };
585
586        let json = serde_json::to_string(&params).unwrap();
587        let deserialized: InitializeParams = serde_json::from_str(&json).unwrap();
588
589        assert_eq!(params.protocol_version, deserialized.protocol_version);
590        assert_eq!(params.client_info.name, deserialized.client_info.name);
591    }
592
593    #[test]
594    fn test_client_type_detection() {
595        let claude_client = ClientInfo {
596            name: "Claude Desktop".to_string(),
597            version: "1.0.0".to_string(),
598        };
599        assert_eq!(
600            ClientType::from_client_info(&claude_client),
601            ClientType::Claude
602        );
603
604        let cursor_client = ClientInfo {
605            name: "Cursor Editor".to_string(),
606            version: "2.0.0".to_string(),
607        };
608        assert_eq!(
609            ClientType::from_client_info(&cursor_client),
610            ClientType::Cursor
611        );
612
613        let vscode_client = ClientInfo {
614            name: "VS Code".to_string(),
615            version: "1.80.0".to_string(),
616        };
617        assert_eq!(
618            ClientType::from_client_info(&vscode_client),
619            ClientType::VSCode
620        );
621
622        let unknown_client = ClientInfo {
623            name: "Custom Client".to_string(),
624            version: "1.0.0".to_string(),
625        };
626        assert_eq!(
627            ClientType::from_client_info(&unknown_client),
628            ClientType::Unknown("Custom Client".to_string())
629        );
630    }
631
632    #[test]
633    fn test_client_optimizations() {
634        let claude_opts = ClientType::Claude.get_optimizations();
635        assert_eq!(claude_opts.max_response_size, 100_000);
636        assert!(claude_opts.supports_streaming);
637        assert_eq!(claude_opts.batch_size_limit, 10);
638
639        let cursor_opts = ClientType::Cursor.get_optimizations();
640        assert_eq!(cursor_opts.max_response_size, 50_000);
641        assert!(!cursor_opts.supports_streaming);
642        assert_eq!(cursor_opts.batch_size_limit, 5);
643    }
644
645    #[test]
646    fn test_version_negotiation_exact_match() {
647        let negotiation = VersionNegotiation::negotiate("2024-11-05");
648
649        assert_eq!(negotiation.agreed_version, "2024-11-05");
650        assert_eq!(negotiation.compatibility_level, CompatibilityLevel::Full);
651        assert!(negotiation.warnings.is_empty());
652        assert!(negotiation.is_acceptable());
653    }
654
655    #[test]
656    fn test_version_negotiation_compatible() {
657        let negotiation = VersionNegotiation::negotiate("2024-11-01");
658
659        assert_eq!(
660            negotiation.compatibility_level,
661            CompatibilityLevel::Compatible
662        );
663        assert!(negotiation.is_acceptable());
664        assert!(!negotiation.warnings.is_empty());
665    }
666
667    #[test]
668    fn test_version_negotiation_limited() {
669        let negotiation = VersionNegotiation::negotiate("2024-08-15");
670
671        assert_eq!(negotiation.compatibility_level, CompatibilityLevel::Limited);
672        assert!(negotiation.is_acceptable());
673    }
674
675    #[test]
676    fn test_version_negotiation_incompatible() {
677        let negotiation = VersionNegotiation::negotiate("2023-01-01");
678
679        assert_eq!(
680            negotiation.compatibility_level,
681            CompatibilityLevel::Incompatible
682        );
683        assert!(!negotiation.is_acceptable());
684    }
685
686    #[test]
687    fn test_parse_version() {
688        let parsed = parse_version("2024-11-05");
689        assert_eq!(parsed.year, 2024);
690        assert_eq!(parsed.month, 11);
691        assert_eq!(parsed.day, 5);
692
693        let invalid = parse_version("invalid");
694        assert_eq!(invalid.year, 0);
695        assert_eq!(invalid.month, 0);
696        assert_eq!(invalid.day, 0);
697    }
698
699    #[test]
700    fn test_compatibility_determination() {
701        let v1 = parse_version("2024-11-05");
702        let v2 = parse_version("2024-11-05");
703        assert_eq!(determine_compatibility(&v1, &v2), CompatibilityLevel::Full);
704
705        let v3 = parse_version("2024-11-01");
706        assert_eq!(
707            determine_compatibility(&v1, &v3),
708            CompatibilityLevel::Compatible
709        );
710
711        let v4 = parse_version("2024-08-01");
712        assert_eq!(
713            determine_compatibility(&v1, &v4),
714            CompatibilityLevel::Limited
715        );
716
717        let v5 = parse_version("2023-01-01");
718        assert_eq!(
719            determine_compatibility(&v1, &v5),
720            CompatibilityLevel::Incompatible
721        );
722    }
723
724    #[test]
725    fn test_cancellation_token() {
726        let token = CancellationToken::new(serde_json::Value::Number(1.into()));
727        assert!(!token.is_cancelled());
728
729        token.cancel();
730        assert!(token.is_cancelled());
731    }
732}