1use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11use crate::protocol::types::*;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct InitializeParams {
20 #[serde(rename = "clientInfo")]
22 pub client_info: ClientInfo,
23 pub capabilities: ClientCapabilities,
25 #[serde(rename = "protocolVersion")]
27 pub protocol_version: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct InitializeResult {
33 #[serde(rename = "serverInfo")]
35 pub server_info: ServerInfo,
36 pub capabilities: ServerCapabilities,
38 #[serde(rename = "protocolVersion")]
40 pub protocol_version: String,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
49pub struct ListToolsParams {
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub cursor: Option<String>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ListToolsResult {
58 pub tools: Vec<ToolInfo>,
60 #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
62 pub next_cursor: Option<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
67pub struct CallToolParams {
68 pub name: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub arguments: Option<HashMap<String, Value>>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct CallToolResult {
78 pub content: Vec<Content>,
80 #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
82 pub is_error: Option<bool>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
91pub struct ListResourcesParams {
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub cursor: Option<String>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct ListResourcesResult {
100 pub resources: Vec<ResourceInfo>,
102 #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
104 pub next_cursor: Option<String>,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
109pub struct ReadResourceParams {
110 pub uri: String,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
116pub struct ReadResourceResult {
117 pub contents: Vec<ResourceContent>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
123pub struct SubscribeResourceParams {
124 pub uri: String,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub struct SubscribeResourceResult {}
131
132#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
134pub struct UnsubscribeResourceParams {
135 pub uri: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
141pub struct UnsubscribeResourceResult {}
142
143#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
145pub struct ResourceUpdatedParams {
146 pub uri: String,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152pub struct ResourceListChangedParams {}
153
154#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
160pub struct ListPromptsParams {
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub cursor: Option<String>,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
168pub struct ListPromptsResult {
169 pub prompts: Vec<PromptInfo>,
171 #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
173 pub next_cursor: Option<String>,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
178pub struct GetPromptParams {
179 pub name: String,
181 #[serde(skip_serializing_if = "Option::is_none")]
183 pub arguments: Option<HashMap<String, Value>>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188pub struct GetPromptResult {
189 #[serde(skip_serializing_if = "Option::is_none")]
191 pub description: Option<String>,
192 pub messages: Vec<PromptMessage>,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct PromptListChangedParams {}
199
200#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
206pub struct CreateMessageParams {
207 pub messages: Vec<SamplingMessage>,
209 #[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
211 pub model_preferences: Option<ModelPreferences>,
212 #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
214 pub system_prompt: Option<String>,
215 #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
217 pub include_context: Option<String>,
218 #[serde(rename = "maxTokens", skip_serializing_if = "Option::is_none")]
220 pub max_tokens: Option<u32>,
221 #[serde(skip_serializing_if = "Option::is_none")]
223 pub temperature: Option<f32>,
224 #[serde(rename = "topP", skip_serializing_if = "Option::is_none")]
226 pub top_p: Option<f32>,
227 #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
229 pub stop_sequences: Option<Vec<String>>,
230 #[serde(skip_serializing_if = "Option::is_none")]
232 pub metadata: Option<HashMap<String, Value>>,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
237pub struct CreateMessageResult {
238 pub role: String,
240 pub content: SamplingContent,
242 pub model: String,
244 #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
246 pub stop_reason: Option<String>,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
251pub struct SamplingMessage {
252 pub role: String,
254 pub content: SamplingContent,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
260#[serde(untagged)]
261pub enum SamplingContent {
262 Text(String),
264 Complex(Vec<Content>),
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
270pub struct ModelPreferences {
271 #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
273 pub cost_priority: Option<f32>,
274 #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
276 pub speed_priority: Option<f32>,
277 #[serde(rename = "qualityPriority", skip_serializing_if = "Option::is_none")]
279 pub quality_priority: Option<f32>,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
288pub struct ToolListChangedParams {}
289
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
296pub struct PingParams {}
297
298#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
300pub struct PingResult {}
301
302#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
308pub struct SetLoggingLevelParams {
309 pub level: LoggingLevel,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
315pub struct SetLoggingLevelResult {}
316
317#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
319#[serde(rename_all = "lowercase")]
320pub enum LoggingLevel {
321 Debug,
322 Info,
323 Notice,
324 Warning,
325 Error,
326 Critical,
327 Alert,
328 Emergency,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
333pub struct LoggingMessageParams {
334 pub level: LoggingLevel,
336 #[serde(skip_serializing_if = "Option::is_none")]
338 pub logger: Option<String>,
339 pub data: Value,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
349pub struct ProgressParams {
350 #[serde(rename = "progressToken")]
352 pub progress_token: String,
353 pub progress: f32,
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub total: Option<u32>,
358}
359
360impl InitializeParams {
365 pub fn new(
367 client_info: ClientInfo,
368 capabilities: ClientCapabilities,
369 protocol_version: String,
370 ) -> Self {
371 Self {
372 client_info,
373 capabilities,
374 protocol_version,
375 }
376 }
377}
378
379impl InitializeResult {
380 pub fn new(
382 server_info: ServerInfo,
383 capabilities: ServerCapabilities,
384 protocol_version: String,
385 ) -> Self {
386 Self {
387 server_info,
388 capabilities,
389 protocol_version,
390 }
391 }
392}
393
394impl CallToolParams {
395 pub fn new(name: String, arguments: Option<HashMap<String, Value>>) -> Self {
397 Self { name, arguments }
398 }
399}
400
401impl ReadResourceParams {
402 pub fn new(uri: String) -> Self {
404 Self { uri }
405 }
406}
407
408impl GetPromptParams {
409 pub fn new(name: String, arguments: Option<HashMap<String, Value>>) -> Self {
411 Self { name, arguments }
412 }
413}
414
415impl SamplingMessage {
416 pub fn user<S: Into<String>>(content: S) -> Self {
418 Self {
419 role: "user".to_string(),
420 content: SamplingContent::Text(content.into()),
421 }
422 }
423
424 pub fn assistant<S: Into<String>>(content: S) -> Self {
426 Self {
427 role: "assistant".to_string(),
428 content: SamplingContent::Text(content.into()),
429 }
430 }
431
432 pub fn system<S: Into<String>>(content: S) -> Self {
434 Self {
435 role: "system".to_string(),
436 content: SamplingContent::Text(content.into()),
437 }
438 }
439
440 pub fn with_content<S: Into<String>>(role: S, content: Vec<Content>) -> Self {
442 Self {
443 role: role.into(),
444 content: SamplingContent::Complex(content),
445 }
446 }
447}
448
449impl Default for ModelPreferences {
450 fn default() -> Self {
451 Self::new()
452 }
453}
454
455impl ModelPreferences {
456 pub fn new() -> Self {
458 Self {
459 cost_priority: None,
460 speed_priority: None,
461 quality_priority: None,
462 }
463 }
464
465 pub fn with_cost_priority(mut self, priority: f32) -> Self {
467 self.cost_priority = Some(priority);
468 self
469 }
470
471 pub fn with_speed_priority(mut self, priority: f32) -> Self {
473 self.speed_priority = Some(priority);
474 self
475 }
476
477 pub fn with_quality_priority(mut self, priority: f32) -> Self {
479 self.quality_priority = Some(priority);
480 self
481 }
482}
483
484pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
490
491pub mod methods {
493 pub const INITIALIZE: &str = "initialize";
495
496 pub const PING: &str = "ping";
498
499 pub const TOOLS_LIST: &str = "tools/list";
501 pub const TOOLS_CALL: &str = "tools/call";
503 pub const TOOLS_LIST_CHANGED: &str = "tools/list_changed";
505
506 pub const RESOURCES_LIST: &str = "resources/list";
508 pub const RESOURCES_READ: &str = "resources/read";
510 pub const RESOURCES_SUBSCRIBE: &str = "resources/subscribe";
512 pub const RESOURCES_UNSUBSCRIBE: &str = "resources/unsubscribe";
514 pub const RESOURCES_UPDATED: &str = "resources/updated";
516 pub const RESOURCES_LIST_CHANGED: &str = "resources/list_changed";
518
519 pub const PROMPTS_LIST: &str = "prompts/list";
521 pub const PROMPTS_GET: &str = "prompts/get";
523 pub const PROMPTS_LIST_CHANGED: &str = "prompts/list_changed";
525
526 pub const SAMPLING_CREATE_MESSAGE: &str = "sampling/createMessage";
528
529 pub const LOGGING_SET_LEVEL: &str = "logging/setLevel";
531 pub const LOGGING_MESSAGE: &str = "logging/message";
533
534 pub const PROGRESS: &str = "progress";
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use serde_json::json;
542
543 #[test]
544 fn test_initialize_params_serialization() {
545 let params = InitializeParams::new(
546 ClientInfo {
547 name: "test-client".to_string(),
548 version: "1.0.0".to_string(),
549 },
550 ClientCapabilities::default(),
551 MCP_PROTOCOL_VERSION.to_string(),
552 );
553
554 let json = serde_json::to_value(¶ms).unwrap();
555 assert_eq!(json["clientInfo"]["name"], "test-client");
556 assert_eq!(json["protocolVersion"], MCP_PROTOCOL_VERSION);
557 }
558
559 #[test]
560 fn test_call_tool_params() {
561 let mut args = HashMap::new();
562 args.insert("param1".to_string(), json!("value1"));
563 args.insert("param2".to_string(), json!(42));
564
565 let params = CallToolParams::new("test_tool".to_string(), Some(args));
566 let json = serde_json::to_value(¶ms).unwrap();
567
568 assert_eq!(json["name"], "test_tool");
569 assert_eq!(json["arguments"]["param1"], "value1");
570 assert_eq!(json["arguments"]["param2"], 42);
571 }
572
573 #[test]
574 fn test_sampling_message_creation() {
575 let user_msg = SamplingMessage::user("Hello, world!");
576 assert_eq!(user_msg.role, "user");
577
578 if let SamplingContent::Text(text) = user_msg.content {
579 assert_eq!(text, "Hello, world!");
580 } else {
581 panic!("Expected text content");
582 }
583
584 let assistant_msg = SamplingMessage::assistant("Hello back!");
585 assert_eq!(assistant_msg.role, "assistant");
586 }
587
588 #[test]
589 fn test_model_preferences_builder() {
590 let prefs = ModelPreferences::default()
591 .with_cost_priority(0.8)
592 .with_speed_priority(0.6)
593 .with_quality_priority(0.9);
594
595 assert_eq!(prefs.cost_priority, Some(0.8));
596 assert_eq!(prefs.speed_priority, Some(0.6));
597 assert_eq!(prefs.quality_priority, Some(0.9));
598 }
599
600 #[test]
601 fn test_read_resource_params() {
602 let params = ReadResourceParams::new("file:///path/to/file.txt".to_string());
603 let json = serde_json::to_value(¶ms).unwrap();
604 assert_eq!(json["uri"], "file:///path/to/file.txt");
605 }
606
607 #[test]
608 fn test_logging_level_serialization() {
609 let level = LoggingLevel::Warning;
610 let json = serde_json::to_value(&level).unwrap();
611 assert_eq!(json, "warning");
612
613 let deserialized: LoggingLevel = serde_json::from_value(json!("error")).unwrap();
614 assert_eq!(deserialized, LoggingLevel::Error);
615 }
616
617 #[test]
618 fn test_method_constants() {
619 assert_eq!(methods::INITIALIZE, "initialize");
620 assert_eq!(methods::TOOLS_LIST, "tools/list");
621 assert_eq!(methods::TOOLS_CALL, "tools/call");
622 assert_eq!(methods::RESOURCES_READ, "resources/read");
623 assert_eq!(methods::PROMPTS_GET, "prompts/get");
624 assert_eq!(methods::SAMPLING_CREATE_MESSAGE, "sampling/createMessage");
625 }
626}