mcp_protocol_sdk/protocol/
messages.rs

1//! MCP protocol message definitions
2//!
3//! This module contains all the MCP-specific message types and their serialization/deserialization
4//! logic. These messages follow the JSON-RPC 2.0 specification and represent the various operations
5//! supported by the Model Context Protocol.
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11use crate::protocol::types::*;
12
13// ============================================================================
14// Initialization Messages
15// ============================================================================
16
17/// Parameters for the initialize request
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct InitializeParams {
20    /// Information about the client
21    #[serde(rename = "clientInfo")]
22    pub client_info: ClientInfo,
23    /// Capabilities advertised by the client
24    pub capabilities: ClientCapabilities,
25    /// Protocol version being used
26    #[serde(rename = "protocolVersion")]
27    pub protocol_version: String,
28}
29
30/// Result of the initialize request
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct InitializeResult {
33    /// Information about the server
34    #[serde(rename = "serverInfo")]
35    pub server_info: ServerInfo,
36    /// Capabilities advertised by the server
37    pub capabilities: ServerCapabilities,
38    /// Protocol version being used
39    #[serde(rename = "protocolVersion")]
40    pub protocol_version: String,
41}
42
43// ============================================================================
44// Tool Messages
45// ============================================================================
46
47/// Parameters for the tools/list request
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
49pub struct ListToolsParams {
50    /// Optional cursor for pagination
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub cursor: Option<String>,
53}
54
55/// Result of the tools/list request
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ListToolsResult {
58    /// Available tools
59    pub tools: Vec<ToolInfo>,
60    /// Cursor for pagination (if more tools are available)
61    #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
62    pub next_cursor: Option<String>,
63}
64
65/// Parameters for the tools/call request
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
67pub struct CallToolParams {
68    /// Name of the tool to call
69    pub name: String,
70    /// Arguments to pass to the tool
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub arguments: Option<HashMap<String, Value>>,
73}
74
75/// Result of the tools/call request
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct CallToolResult {
78    /// Content returned by the tool
79    pub content: Vec<Content>,
80    /// Whether this result represents an error
81    #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
82    pub is_error: Option<bool>,
83}
84
85// ============================================================================
86// Resource Messages
87// ============================================================================
88
89/// Parameters for the resources/list request
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
91pub struct ListResourcesParams {
92    /// Optional cursor for pagination
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub cursor: Option<String>,
95}
96
97/// Result of the resources/list request
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct ListResourcesResult {
100    /// Available resources
101    pub resources: Vec<ResourceInfo>,
102    /// Cursor for pagination (if more resources are available)
103    #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
104    pub next_cursor: Option<String>,
105}
106
107/// Parameters for the resources/read request
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
109pub struct ReadResourceParams {
110    /// URI of the resource to read
111    pub uri: String,
112}
113
114/// Result of the resources/read request
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
116pub struct ReadResourceResult {
117    /// Contents of the resource
118    pub contents: Vec<ResourceContent>,
119}
120
121/// Parameters for the resources/subscribe request
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
123pub struct SubscribeResourceParams {
124    /// URI of the resource to subscribe to
125    pub uri: String,
126}
127
128/// Result of the resources/subscribe request
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub struct SubscribeResourceResult {}
131
132/// Parameters for the resources/unsubscribe request
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
134pub struct UnsubscribeResourceParams {
135    /// URI of the resource to unsubscribe from
136    pub uri: String,
137}
138
139/// Result of the resources/unsubscribe request
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
141pub struct UnsubscribeResourceResult {}
142
143/// Parameters for the resources/updated notification
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
145pub struct ResourceUpdatedParams {
146    /// URI of the resource that was updated
147    pub uri: String,
148}
149
150/// Parameters for the resources/list_changed notification
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152pub struct ResourceListChangedParams {}
153
154// ============================================================================
155// Prompt Messages
156// ============================================================================
157
158/// Parameters for the prompts/list request
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
160pub struct ListPromptsParams {
161    /// Optional cursor for pagination
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub cursor: Option<String>,
164}
165
166/// Result of the prompts/list request
167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
168pub struct ListPromptsResult {
169    /// Available prompts
170    pub prompts: Vec<PromptInfo>,
171    /// Cursor for pagination (if more prompts are available)
172    #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
173    pub next_cursor: Option<String>,
174}
175
176/// Parameters for the prompts/get request
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
178pub struct GetPromptParams {
179    /// Name of the prompt to get
180    pub name: String,
181    /// Arguments to pass to the prompt
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub arguments: Option<HashMap<String, Value>>,
184}
185
186/// Result of the prompts/get request
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188pub struct GetPromptResult {
189    /// Description of the prompt result
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub description: Option<String>,
192    /// Messages generated by the prompt
193    pub messages: Vec<PromptMessage>,
194}
195
196/// Parameters for the prompts/list_changed notification
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct PromptListChangedParams {}
199
200// ============================================================================
201// Sampling Messages
202// ============================================================================
203
204/// Parameters for the sampling/createMessage request
205#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
206pub struct CreateMessageParams {
207    /// Messages to include in the conversation
208    pub messages: Vec<SamplingMessage>,
209    /// Model preferences
210    #[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
211    pub model_preferences: Option<ModelPreferences>,
212    /// System prompt
213    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
214    pub system_prompt: Option<String>,
215    /// Whether to include context from the current session
216    #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
217    pub include_context: Option<String>,
218    /// Maximum number of tokens to generate
219    #[serde(rename = "maxTokens", skip_serializing_if = "Option::is_none")]
220    pub max_tokens: Option<u32>,
221    /// Sampling temperature (0.0 to 1.0)
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub temperature: Option<f32>,
224    /// Nucleus sampling parameter (0.0 to 1.0)
225    #[serde(rename = "topP", skip_serializing_if = "Option::is_none")]
226    pub top_p: Option<f32>,
227    /// Stop sequences
228    #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
229    pub stop_sequences: Option<Vec<String>>,
230    /// Metadata to include with the request
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub metadata: Option<HashMap<String, Value>>,
233}
234
235/// Result of the sampling/createMessage request
236#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
237pub struct CreateMessageResult {
238    /// Role of the generated message
239    pub role: String,
240    /// Content of the generated message
241    pub content: SamplingContent,
242    /// Model used for generation
243    pub model: String,
244    /// Stop reason
245    #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
246    pub stop_reason: Option<String>,
247}
248
249/// A message in a sampling conversation
250#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
251pub struct SamplingMessage {
252    /// Role of the message (e.g., "user", "assistant")
253    pub role: String,
254    /// Content of the message
255    pub content: SamplingContent,
256}
257
258/// Content for sampling messages
259#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
260#[serde(untagged)]
261pub enum SamplingContent {
262    /// Simple text content
263    Text(String),
264    /// Complex content with multiple parts
265    Complex(Vec<Content>),
266}
267
268/// Model preferences for sampling
269#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
270pub struct ModelPreferences {
271    /// Hints about cost constraints
272    #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
273    pub cost_priority: Option<f32>,
274    /// Hints about speed constraints
275    #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
276    pub speed_priority: Option<f32>,
277    /// Hints about quality constraints
278    #[serde(rename = "qualityPriority", skip_serializing_if = "Option::is_none")]
279    pub quality_priority: Option<f32>,
280}
281
282// ============================================================================
283// Tool List Changed Notification
284// ============================================================================
285
286/// Parameters for the tools/list_changed notification
287#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
288pub struct ToolListChangedParams {}
289
290// ============================================================================
291// Ping Messages
292// ============================================================================
293
294/// Parameters for the ping request (no parameters)
295#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
296pub struct PingParams {}
297
298/// Result of the ping request (no result)
299#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
300pub struct PingResult {}
301
302// ============================================================================
303// Logging Messages
304// ============================================================================
305
306/// Parameters for the logging/setLevel request
307#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
308pub struct SetLoggingLevelParams {
309    /// The logging level to set
310    pub level: LoggingLevel,
311}
312
313/// Result of the logging/setLevel request
314#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
315pub struct SetLoggingLevelResult {}
316
317/// Logging level enumeration
318#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
319#[serde(rename_all = "lowercase")]
320pub enum LoggingLevel {
321    Debug,
322    Info,
323    Notice,
324    Warning,
325    Error,
326    Critical,
327    Alert,
328    Emergency,
329}
330
331/// Parameters for the logging/message notification
332#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
333pub struct LoggingMessageParams {
334    /// The logging level
335    pub level: LoggingLevel,
336    /// The logger name
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub logger: Option<String>,
339    /// The log message data
340    pub data: Value,
341}
342
343// ============================================================================
344// Progress Messages
345// ============================================================================
346
347/// Parameters for the progress notification
348#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
349pub struct ProgressParams {
350    /// Unique identifier for the progress operation
351    #[serde(rename = "progressToken")]
352    pub progress_token: String,
353    /// Current progress (0.0 to 1.0)
354    pub progress: f32,
355    /// Optional total count
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub total: Option<u32>,
358}
359
360// ============================================================================
361// Message Helpers and Constructors
362// ============================================================================
363
364impl InitializeParams {
365    /// Create new initialize parameters
366    pub fn new(
367        client_info: ClientInfo,
368        capabilities: ClientCapabilities,
369        protocol_version: String,
370    ) -> Self {
371        Self {
372            client_info,
373            capabilities,
374            protocol_version,
375        }
376    }
377}
378
379impl InitializeResult {
380    /// Create new initialize result
381    pub fn new(
382        server_info: ServerInfo,
383        capabilities: ServerCapabilities,
384        protocol_version: String,
385    ) -> Self {
386        Self {
387            server_info,
388            capabilities,
389            protocol_version,
390        }
391    }
392}
393
394impl CallToolParams {
395    /// Create new call tool parameters
396    pub fn new(name: String, arguments: Option<HashMap<String, Value>>) -> Self {
397        Self { name, arguments }
398    }
399}
400
401impl ReadResourceParams {
402    /// Create new read resource parameters
403    pub fn new(uri: String) -> Self {
404        Self { uri }
405    }
406}
407
408impl GetPromptParams {
409    /// Create new get prompt parameters
410    pub fn new(name: String, arguments: Option<HashMap<String, Value>>) -> Self {
411        Self { name, arguments }
412    }
413}
414
415impl SamplingMessage {
416    /// Create a user message
417    pub fn user<S: Into<String>>(content: S) -> Self {
418        Self {
419            role: "user".to_string(),
420            content: SamplingContent::Text(content.into()),
421        }
422    }
423
424    /// Create an assistant message
425    pub fn assistant<S: Into<String>>(content: S) -> Self {
426        Self {
427            role: "assistant".to_string(),
428            content: SamplingContent::Text(content.into()),
429        }
430    }
431
432    /// Create a system message
433    pub fn system<S: Into<String>>(content: S) -> Self {
434        Self {
435            role: "system".to_string(),
436            content: SamplingContent::Text(content.into()),
437        }
438    }
439
440    /// Create a message with complex content
441    pub fn with_content<S: Into<String>>(role: S, content: Vec<Content>) -> Self {
442        Self {
443            role: role.into(),
444            content: SamplingContent::Complex(content),
445        }
446    }
447}
448
449impl Default for ModelPreferences {
450    fn default() -> Self {
451        Self::new()
452    }
453}
454
455impl ModelPreferences {
456    /// Create model preferences with default values
457    pub fn new() -> Self {
458        Self {
459            cost_priority: None,
460            speed_priority: None,
461            quality_priority: None,
462        }
463    }
464
465    /// Set cost priority
466    pub fn with_cost_priority(mut self, priority: f32) -> Self {
467        self.cost_priority = Some(priority);
468        self
469    }
470
471    /// Set speed priority
472    pub fn with_speed_priority(mut self, priority: f32) -> Self {
473        self.speed_priority = Some(priority);
474        self
475    }
476
477    /// Set quality priority
478    pub fn with_quality_priority(mut self, priority: f32) -> Self {
479        self.quality_priority = Some(priority);
480        self
481    }
482}
483
484// ============================================================================
485// Message Type Constants
486// ============================================================================
487
488/// Constant for MCP protocol version
489pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
490
491/// JSON-RPC method names for MCP messages
492pub mod methods {
493    /// Initialize the connection
494    pub const INITIALIZE: &str = "initialize";
495
496    /// Ping to check connection
497    pub const PING: &str = "ping";
498
499    /// List available tools
500    pub const TOOLS_LIST: &str = "tools/list";
501    /// Call a tool
502    pub const TOOLS_CALL: &str = "tools/call";
503    /// Notification when tool list changes
504    pub const TOOLS_LIST_CHANGED: &str = "tools/list_changed";
505
506    /// List available resources
507    pub const RESOURCES_LIST: &str = "resources/list";
508    /// Read a resource
509    pub const RESOURCES_READ: &str = "resources/read";
510    /// Subscribe to resource updates
511    pub const RESOURCES_SUBSCRIBE: &str = "resources/subscribe";
512    /// Unsubscribe from resource updates
513    pub const RESOURCES_UNSUBSCRIBE: &str = "resources/unsubscribe";
514    /// Notification when a resource is updated
515    pub const RESOURCES_UPDATED: &str = "resources/updated";
516    /// Notification when resource list changes
517    pub const RESOURCES_LIST_CHANGED: &str = "resources/list_changed";
518
519    /// List available prompts
520    pub const PROMPTS_LIST: &str = "prompts/list";
521    /// Get a prompt
522    pub const PROMPTS_GET: &str = "prompts/get";
523    /// Notification when prompt list changes
524    pub const PROMPTS_LIST_CHANGED: &str = "prompts/list_changed";
525
526    /// Create a message using sampling
527    pub const SAMPLING_CREATE_MESSAGE: &str = "sampling/createMessage";
528
529    /// Set logging level
530    pub const LOGGING_SET_LEVEL: &str = "logging/setLevel";
531    /// Log message notification
532    pub const LOGGING_MESSAGE: &str = "logging/message";
533
534    /// Progress notification
535    pub const PROGRESS: &str = "progress";
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use serde_json::json;
542
543    #[test]
544    fn test_initialize_params_serialization() {
545        let params = InitializeParams::new(
546            ClientInfo {
547                name: "test-client".to_string(),
548                version: "1.0.0".to_string(),
549            },
550            ClientCapabilities::default(),
551            MCP_PROTOCOL_VERSION.to_string(),
552        );
553
554        let json = serde_json::to_value(&params).unwrap();
555        assert_eq!(json["clientInfo"]["name"], "test-client");
556        assert_eq!(json["protocolVersion"], MCP_PROTOCOL_VERSION);
557    }
558
559    #[test]
560    fn test_call_tool_params() {
561        let mut args = HashMap::new();
562        args.insert("param1".to_string(), json!("value1"));
563        args.insert("param2".to_string(), json!(42));
564
565        let params = CallToolParams::new("test_tool".to_string(), Some(args));
566        let json = serde_json::to_value(&params).unwrap();
567
568        assert_eq!(json["name"], "test_tool");
569        assert_eq!(json["arguments"]["param1"], "value1");
570        assert_eq!(json["arguments"]["param2"], 42);
571    }
572
573    #[test]
574    fn test_sampling_message_creation() {
575        let user_msg = SamplingMessage::user("Hello, world!");
576        assert_eq!(user_msg.role, "user");
577
578        if let SamplingContent::Text(text) = user_msg.content {
579            assert_eq!(text, "Hello, world!");
580        } else {
581            panic!("Expected text content");
582        }
583
584        let assistant_msg = SamplingMessage::assistant("Hello back!");
585        assert_eq!(assistant_msg.role, "assistant");
586    }
587
588    #[test]
589    fn test_model_preferences_builder() {
590        let prefs = ModelPreferences::default()
591            .with_cost_priority(0.8)
592            .with_speed_priority(0.6)
593            .with_quality_priority(0.9);
594
595        assert_eq!(prefs.cost_priority, Some(0.8));
596        assert_eq!(prefs.speed_priority, Some(0.6));
597        assert_eq!(prefs.quality_priority, Some(0.9));
598    }
599
600    #[test]
601    fn test_read_resource_params() {
602        let params = ReadResourceParams::new("file:///path/to/file.txt".to_string());
603        let json = serde_json::to_value(&params).unwrap();
604        assert_eq!(json["uri"], "file:///path/to/file.txt");
605    }
606
607    #[test]
608    fn test_logging_level_serialization() {
609        let level = LoggingLevel::Warning;
610        let json = serde_json::to_value(&level).unwrap();
611        assert_eq!(json, "warning");
612
613        let deserialized: LoggingLevel = serde_json::from_value(json!("error")).unwrap();
614        assert_eq!(deserialized, LoggingLevel::Error);
615    }
616
617    #[test]
618    fn test_method_constants() {
619        assert_eq!(methods::INITIALIZE, "initialize");
620        assert_eq!(methods::TOOLS_LIST, "tools/list");
621        assert_eq!(methods::TOOLS_CALL, "tools/call");
622        assert_eq!(methods::RESOURCES_READ, "resources/read");
623        assert_eq!(methods::PROMPTS_GET, "prompts/get");
624        assert_eq!(methods::SAMPLING_CREATE_MESSAGE, "sampling/createMessage");
625    }
626}