1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::Notify;
10use tokio::time::Duration;
11
12pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &[
14 "2024-11-05", "2024-10-07", "2024-09-01", ];
18
19pub const DEFAULT_PROTOCOL_VERSION: &str = "2024-11-05";
21
22#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
24pub enum ClientType {
25 Claude,
26 Cursor,
27 VSCode,
28 Unknown(String),
29}
30
31impl ClientType {
32 pub fn from_client_info(client_info: &ClientInfo) -> Self {
34 let name_lower = client_info.name.to_lowercase();
35
36 if name_lower.contains("claude") {
37 Self::Claude
38 } else if name_lower.contains("cursor") {
39 Self::Cursor
40 } else if name_lower.contains("vscode") || name_lower.contains("vs code") {
41 Self::VSCode
42 } else {
43 Self::Unknown(client_info.name.clone())
44 }
45 }
46
47 pub fn get_optimizations(&self) -> ClientOptimizations {
49 match self {
50 Self::Claude => ClientOptimizations {
51 max_response_size: 100_000,
52 supports_streaming: true,
53 preferred_timeout: Duration::from_secs(30),
54 batch_size_limit: 10,
55 },
56 Self::Cursor => ClientOptimizations {
57 max_response_size: 50_000,
58 supports_streaming: false,
59 preferred_timeout: Duration::from_secs(15),
60 batch_size_limit: 5,
61 },
62 Self::VSCode => ClientOptimizations {
63 max_response_size: 75_000,
64 supports_streaming: true,
65 preferred_timeout: Duration::from_secs(20),
66 batch_size_limit: 7,
67 },
68 Self::Unknown(_) => ClientOptimizations::default(),
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ClientOptimizations {
76 pub max_response_size: usize,
77 pub supports_streaming: bool,
78 pub preferred_timeout: Duration,
79 pub batch_size_limit: usize,
80}
81
82impl Default for ClientOptimizations {
83 fn default() -> Self {
84 Self {
85 max_response_size: 75_000,
86 supports_streaming: false,
87 preferred_timeout: Duration::from_secs(30),
88 batch_size_limit: 5,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct VersionNegotiation {
96 pub agreed_version: String,
97 pub client_version: String,
98 pub server_versions: Vec<String>,
99 pub compatibility_level: CompatibilityLevel,
100 pub warnings: Vec<String>,
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
105pub enum CompatibilityLevel {
106 Incompatible,
108 Limited,
110 Compatible,
112 Full,
114}
115
116impl VersionNegotiation {
117 pub fn negotiate(client_version: &str) -> Self {
119 let server_versions: Vec<String> = SUPPORTED_PROTOCOL_VERSIONS
120 .iter()
121 .map(|v| v.to_string())
122 .collect();
123 let mut warnings = Vec::new();
124
125 let (agreed_version, compatibility_level) =
127 if SUPPORTED_PROTOCOL_VERSIONS.contains(&client_version) {
128 (client_version.to_string(), CompatibilityLevel::Full)
129 } else {
130 let parsed_client = parse_version(client_version);
132 let mut best_match = None;
133 let mut best_compatibility = CompatibilityLevel::Incompatible;
134
135 for &server_version in SUPPORTED_PROTOCOL_VERSIONS {
136 let parsed_server = parse_version(server_version);
137 let compatibility = determine_compatibility(&parsed_client, &parsed_server);
138
139 if compatibility > best_compatibility {
140 best_match = Some(server_version.to_string());
141 best_compatibility = compatibility;
142 }
143 }
144
145 match best_match {
146 Some(version) => {
147 warnings.push(format!(
148 "Client version {} not directly supported, using {} with {} compatibility",
149 client_version, version,
150 match best_compatibility {
151 CompatibilityLevel::Full => "full",
152 CompatibilityLevel::Compatible => "high",
153 CompatibilityLevel::Limited => "limited",
154 CompatibilityLevel::Incompatible => "no",
155 }
156 ));
157 (version, best_compatibility)
158 }
159 None => {
160 warnings.push(format!(
161 "Client version {} is incompatible with supported versions: {:?}",
162 client_version, SUPPORTED_PROTOCOL_VERSIONS
163 ));
164 (
165 DEFAULT_PROTOCOL_VERSION.to_string(),
166 CompatibilityLevel::Incompatible,
167 )
168 }
169 }
170 };
171
172 Self {
173 agreed_version,
174 client_version: client_version.to_string(),
175 server_versions,
176 compatibility_level,
177 warnings,
178 }
179 }
180
181 pub fn is_acceptable(&self) -> bool {
183 self.compatibility_level != CompatibilityLevel::Incompatible
184 }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
189struct ParsedVersion {
190 year: u32,
191 month: u32,
192 day: u32,
193}
194
195fn parse_version(version: &str) -> ParsedVersion {
197 let parts: Vec<&str> = version.split('-').collect();
198 if parts.len() == 3 {
199 ParsedVersion {
200 year: parts[0].parse().unwrap_or(0),
201 month: parts[1].parse().unwrap_or(0),
202 day: parts[2].parse().unwrap_or(0),
203 }
204 } else {
205 ParsedVersion {
206 year: 0,
207 month: 0,
208 day: 0,
209 }
210 }
211}
212
213fn determine_compatibility(client: &ParsedVersion, server: &ParsedVersion) -> CompatibilityLevel {
215 if client == server {
216 return CompatibilityLevel::Full;
217 }
218
219 if client.year == server.year && client.month == server.month {
221 return CompatibilityLevel::Compatible;
222 }
223
224 let client_days = client.year * 365 + client.month * 30 + client.day;
226 let server_days = server.year * 365 + server.month * 30 + server.day;
227 let diff_days = (client_days as i32 - server_days as i32).abs();
228
229 if diff_days <= 180 {
230 CompatibilityLevel::Limited
232 } else {
233 CompatibilityLevel::Incompatible
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct JsonRpcRequest {
240 pub jsonrpc: String,
242 pub id: serde_json::Value,
244 pub method: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
248 pub params: Option<serde_json::Value>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct JsonRpcResponse {
254 pub jsonrpc: String,
256 pub id: serde_json::Value,
258 #[serde(skip_serializing_if = "Option::is_none")]
260 pub result: Option<serde_json::Value>,
261 #[serde(skip_serializing_if = "Option::is_none")]
263 pub error: Option<JsonRpcError>,
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct JsonRpcNotification {
269 pub jsonrpc: String,
271 pub method: String,
273 #[serde(skip_serializing_if = "Option::is_none")]
275 pub params: Option<serde_json::Value>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct JsonRpcError {
281 pub code: i32,
283 pub message: String,
285 #[serde(skip_serializing_if = "Option::is_none")]
287 pub data: Option<serde_json::Value>,
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct CancellationParams {
293 pub id: serde_json::Value,
295 #[serde(skip_serializing_if = "Option::is_none")]
297 pub reason: Option<String>,
298}
299
300#[derive(Debug, Clone)]
302pub struct CancellationToken {
303 notify: Arc<Notify>,
305 cancelled: Arc<std::sync::atomic::AtomicBool>,
307 request_id: serde_json::Value,
309}
310
311impl CancellationToken {
312 pub fn new(request_id: serde_json::Value) -> Self {
314 Self {
315 notify: Arc::new(Notify::new()),
316 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
317 request_id,
318 }
319 }
320
321 pub fn is_cancelled(&self) -> bool {
323 self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
324 }
325
326 pub fn cancel(&self) {
328 self.cancelled
329 .store(true, std::sync::atomic::Ordering::Relaxed);
330 self.notify.notify_waiters();
331 }
332
333 pub async fn cancelled(&self) {
335 if self.is_cancelled() {
336 return;
337 }
338 self.notify.notified().await;
339 }
340
341 pub fn request_id(&self) -> &serde_json::Value {
343 &self.request_id
344 }
345
346 pub async fn with_timeout<F, T>(
348 &self,
349 timeout: Duration,
350 operation: F,
351 ) -> Result<T, CancellationError>
352 where
353 F: std::future::Future<Output = T>,
354 {
355 tokio::select! {
356 result = operation => Ok(result),
357 _ = self.cancelled() => Err(CancellationError::Cancelled),
358 _ = tokio::time::sleep(timeout) => Err(CancellationError::Timeout),
359 }
360 }
361}
362
363#[derive(Debug, Clone, thiserror::Error)]
365pub enum CancellationError {
366 #[error("Operation was cancelled")]
368 Cancelled,
369 #[error("Operation timed out")]
371 Timeout,
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct InitializeParams {
377 #[serde(rename = "protocolVersion")]
379 pub protocol_version: String,
380 pub capabilities: ClientCapabilities,
382 #[serde(rename = "clientInfo")]
384 pub client_info: ClientInfo,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct InitializeResult {
390 #[serde(rename = "protocolVersion")]
392 pub protocol_version: String,
393 pub capabilities: ServerCapabilities,
395 #[serde(rename = "serverInfo")]
397 pub server_info: ServerInfo,
398}
399
400#[derive(Debug, Clone, Serialize, Deserialize, Default)]
402pub struct ClientCapabilities {
403 #[serde(skip_serializing_if = "Option::is_none")]
405 pub experimental: Option<HashMap<String, serde_json::Value>>,
406 #[serde(skip_serializing_if = "Option::is_none")]
408 pub sampling: Option<SamplingCapability>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize, Default)]
413pub struct ServerCapabilities {
414 #[serde(skip_serializing_if = "Option::is_none")]
416 pub experimental: Option<HashMap<String, serde_json::Value>>,
417 #[serde(skip_serializing_if = "Option::is_none")]
419 pub resources: Option<crate::resources::ResourceCapabilities>,
420 #[serde(skip_serializing_if = "Option::is_none")]
422 pub tools: Option<crate::tools::ToolCapabilities>,
423 #[serde(skip_serializing_if = "Option::is_none")]
425 pub prompts: Option<crate::prompts::PromptCapabilities>,
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct SamplingCapability {}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ClientInfo {
435 pub name: String,
437 pub version: String,
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct ServerInfo {
444 pub name: String,
446 pub version: String,
448}
449
450impl JsonRpcRequest {
451 pub fn new(id: serde_json::Value, method: String, params: Option<serde_json::Value>) -> Self {
453 Self {
454 jsonrpc: "2.0".to_string(),
455 id,
456 method,
457 params,
458 }
459 }
460}
461
462impl JsonRpcResponse {
463 pub fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
465 Self {
466 jsonrpc: "2.0".to_string(),
467 id,
468 result: Some(result),
469 error: None,
470 }
471 }
472
473 pub fn error(id: serde_json::Value, error: JsonRpcError) -> Self {
475 Self {
476 jsonrpc: "2.0".to_string(),
477 id,
478 result: None,
479 error: Some(error),
480 }
481 }
482}
483
484impl JsonRpcNotification {
485 pub fn new(method: String, params: Option<serde_json::Value>) -> Self {
487 Self {
488 jsonrpc: "2.0".to_string(),
489 method,
490 params,
491 }
492 }
493}
494
495impl JsonRpcError {
496 pub const PARSE_ERROR: i32 = -32700;
498 pub const INVALID_REQUEST: i32 = -32600;
499 pub const METHOD_NOT_FOUND: i32 = -32601;
500 pub const INVALID_PARAMS: i32 = -32602;
501 pub const INTERNAL_ERROR: i32 = -32603;
502
503 pub fn new(code: i32, message: String, data: Option<serde_json::Value>) -> Self {
505 Self {
506 code,
507 message,
508 data,
509 }
510 }
511
512 pub fn method_not_found(method: &str) -> Self {
514 Self::new(
515 Self::METHOD_NOT_FOUND,
516 format!("Method not found: {}", method),
517 None,
518 )
519 }
520
521 pub fn invalid_params(message: String) -> Self {
523 Self::new(Self::INVALID_PARAMS, message, None)
524 }
525
526 pub fn internal_error(message: String) -> Self {
528 Self::new(Self::INTERNAL_ERROR, message, None)
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_json_rpc_request_serialization() {
538 let request = JsonRpcRequest::new(
539 serde_json::Value::Number(1.into()),
540 "test_method".to_string(),
541 Some(serde_json::json!({"param": "value"})),
542 );
543
544 let json = serde_json::to_string(&request).unwrap();
545 let deserialized: JsonRpcRequest = serde_json::from_str(&json).unwrap();
546
547 assert_eq!(request.jsonrpc, deserialized.jsonrpc);
548 assert_eq!(request.id, deserialized.id);
549 assert_eq!(request.method, deserialized.method);
550 assert_eq!(request.params, deserialized.params);
551 }
552
553 #[test]
554 fn test_json_rpc_response_success() {
555 let response = JsonRpcResponse::success(
556 serde_json::Value::Number(1.into()),
557 serde_json::json!({"success": true}),
558 );
559
560 assert_eq!(response.jsonrpc, "2.0");
561 assert!(response.result.is_some());
562 assert!(response.error.is_none());
563 }
564
565 #[test]
566 fn test_json_rpc_response_error() {
567 let error = JsonRpcError::method_not_found("unknown_method");
568 let response = JsonRpcResponse::error(serde_json::Value::Number(1.into()), error);
569
570 assert_eq!(response.jsonrpc, "2.0");
571 assert!(response.result.is_none());
572 assert!(response.error.is_some());
573 }
574
575 #[test]
576 fn test_initialize_params() {
577 let params = InitializeParams {
578 protocol_version: "2024-11-05".to_string(),
579 capabilities: ClientCapabilities::default(),
580 client_info: ClientInfo {
581 name: "test-client".to_string(),
582 version: "1.0.0".to_string(),
583 },
584 };
585
586 let json = serde_json::to_string(¶ms).unwrap();
587 let deserialized: InitializeParams = serde_json::from_str(&json).unwrap();
588
589 assert_eq!(params.protocol_version, deserialized.protocol_version);
590 assert_eq!(params.client_info.name, deserialized.client_info.name);
591 }
592
593 #[test]
594 fn test_client_type_detection() {
595 let claude_client = ClientInfo {
596 name: "Claude Desktop".to_string(),
597 version: "1.0.0".to_string(),
598 };
599 assert_eq!(
600 ClientType::from_client_info(&claude_client),
601 ClientType::Claude
602 );
603
604 let cursor_client = ClientInfo {
605 name: "Cursor Editor".to_string(),
606 version: "2.0.0".to_string(),
607 };
608 assert_eq!(
609 ClientType::from_client_info(&cursor_client),
610 ClientType::Cursor
611 );
612
613 let vscode_client = ClientInfo {
614 name: "VS Code".to_string(),
615 version: "1.80.0".to_string(),
616 };
617 assert_eq!(
618 ClientType::from_client_info(&vscode_client),
619 ClientType::VSCode
620 );
621
622 let unknown_client = ClientInfo {
623 name: "Custom Client".to_string(),
624 version: "1.0.0".to_string(),
625 };
626 assert_eq!(
627 ClientType::from_client_info(&unknown_client),
628 ClientType::Unknown("Custom Client".to_string())
629 );
630 }
631
632 #[test]
633 fn test_client_optimizations() {
634 let claude_opts = ClientType::Claude.get_optimizations();
635 assert_eq!(claude_opts.max_response_size, 100_000);
636 assert!(claude_opts.supports_streaming);
637 assert_eq!(claude_opts.batch_size_limit, 10);
638
639 let cursor_opts = ClientType::Cursor.get_optimizations();
640 assert_eq!(cursor_opts.max_response_size, 50_000);
641 assert!(!cursor_opts.supports_streaming);
642 assert_eq!(cursor_opts.batch_size_limit, 5);
643 }
644
645 #[test]
646 fn test_version_negotiation_exact_match() {
647 let negotiation = VersionNegotiation::negotiate("2024-11-05");
648
649 assert_eq!(negotiation.agreed_version, "2024-11-05");
650 assert_eq!(negotiation.compatibility_level, CompatibilityLevel::Full);
651 assert!(negotiation.warnings.is_empty());
652 assert!(negotiation.is_acceptable());
653 }
654
655 #[test]
656 fn test_version_negotiation_compatible() {
657 let negotiation = VersionNegotiation::negotiate("2024-11-01");
658
659 assert_eq!(
660 negotiation.compatibility_level,
661 CompatibilityLevel::Compatible
662 );
663 assert!(negotiation.is_acceptable());
664 assert!(!negotiation.warnings.is_empty());
665 }
666
667 #[test]
668 fn test_version_negotiation_limited() {
669 let negotiation = VersionNegotiation::negotiate("2024-08-15");
670
671 assert_eq!(negotiation.compatibility_level, CompatibilityLevel::Limited);
672 assert!(negotiation.is_acceptable());
673 }
674
675 #[test]
676 fn test_version_negotiation_incompatible() {
677 let negotiation = VersionNegotiation::negotiate("2023-01-01");
678
679 assert_eq!(
680 negotiation.compatibility_level,
681 CompatibilityLevel::Incompatible
682 );
683 assert!(!negotiation.is_acceptable());
684 }
685
686 #[test]
687 fn test_parse_version() {
688 let parsed = parse_version("2024-11-05");
689 assert_eq!(parsed.year, 2024);
690 assert_eq!(parsed.month, 11);
691 assert_eq!(parsed.day, 5);
692
693 let invalid = parse_version("invalid");
694 assert_eq!(invalid.year, 0);
695 assert_eq!(invalid.month, 0);
696 assert_eq!(invalid.day, 0);
697 }
698
699 #[test]
700 fn test_compatibility_determination() {
701 let v1 = parse_version("2024-11-05");
702 let v2 = parse_version("2024-11-05");
703 assert_eq!(determine_compatibility(&v1, &v2), CompatibilityLevel::Full);
704
705 let v3 = parse_version("2024-11-01");
706 assert_eq!(
707 determine_compatibility(&v1, &v3),
708 CompatibilityLevel::Compatible
709 );
710
711 let v4 = parse_version("2024-08-01");
712 assert_eq!(
713 determine_compatibility(&v1, &v4),
714 CompatibilityLevel::Limited
715 );
716
717 let v5 = parse_version("2023-01-01");
718 assert_eq!(
719 determine_compatibility(&v1, &v5),
720 CompatibilityLevel::Incompatible
721 );
722 }
723
724 #[test]
725 fn test_cancellation_token() {
726 let token = CancellationToken::new(serde_json::Value::Number(1.into()));
727 assert!(!token.is_cancelled());
728
729 token.cancel();
730 assert!(token.is_cancelled());
731 }
732}