mcp_protocol_sdk/protocol/
messages.rs

1//! MCP protocol message definitions
2//!
3//! This module contains all the MCP-specific message types and their serialization/deserialization
4//! logic. These messages follow the JSON-RPC 2.0 specification and represent the various operations
5//! supported by the Model Context Protocol.
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11use crate::protocol::types::*;
12
13// ============================================================================
14// Initialization Messages
15// ============================================================================
16
17/// Parameters for the initialize request
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct InitializeParams {
20    /// Information about the client
21    #[serde(rename = "clientInfo")]
22    pub client_info: ClientInfo,
23    /// Capabilities advertised by the client
24    pub capabilities: ClientCapabilities,
25    /// Protocol version being used
26    #[serde(rename = "protocolVersion")]
27    pub protocol_version: String,
28}
29
30/// Result of the initialize request
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct InitializeResult {
33    /// Information about the server
34    #[serde(rename = "serverInfo")]
35    pub server_info: ServerInfo,
36    /// Capabilities advertised by the server
37    pub capabilities: ServerCapabilities,
38    /// Protocol version being used
39    #[serde(rename = "protocolVersion")]
40    pub protocol_version: String,
41}
42
43// ============================================================================
44// Tool Messages
45// ============================================================================
46
47/// Parameters for the tools/list request
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
49pub struct ListToolsParams {
50    /// Optional cursor for pagination
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub cursor: Option<String>,
53}
54
55/// Result of the tools/list request
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ListToolsResult {
58    /// Available tools
59    pub tools: Vec<ToolInfo>,
60    /// Cursor for pagination (if more tools are available)
61    #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
62    pub next_cursor: Option<String>,
63}
64
65/// Parameters for the tools/call request
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
67pub struct CallToolParams {
68    /// Name of the tool to call
69    pub name: String,
70    /// Arguments to pass to the tool
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub arguments: Option<HashMap<String, Value>>,
73}
74
75/// Result of the tools/call request
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct CallToolResult {
78    /// Content returned by the tool
79    pub content: Vec<Content>,
80    /// Whether this result represents an error
81    #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
82    pub is_error: Option<bool>,
83}
84
85// ============================================================================
86// Resource Messages
87// ============================================================================
88
89/// Parameters for the resources/list request
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
91pub struct ListResourcesParams {
92    /// Optional cursor for pagination
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub cursor: Option<String>,
95}
96
97/// Result of the resources/list request
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct ListResourcesResult {
100    /// Available resources
101    pub resources: Vec<ResourceInfo>,
102    /// Cursor for pagination (if more resources are available)
103    #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
104    pub next_cursor: Option<String>,
105}
106
107/// Parameters for the resources/read request
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
109pub struct ReadResourceParams {
110    /// URI of the resource to read
111    pub uri: String,
112}
113
114/// Result of the resources/read request
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
116pub struct ReadResourceResult {
117    /// Contents of the resource
118    pub contents: Vec<ResourceContent>,
119}
120
121/// Parameters for the resources/subscribe request
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
123pub struct SubscribeResourceParams {
124    /// URI of the resource to subscribe to
125    pub uri: String,
126}
127
128/// Result of the resources/subscribe request
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub struct SubscribeResourceResult {}
131
132/// Parameters for the resources/unsubscribe request
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
134pub struct UnsubscribeResourceParams {
135    /// URI of the resource to unsubscribe from
136    pub uri: String,
137}
138
139/// Result of the resources/unsubscribe request
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
141pub struct UnsubscribeResourceResult {}
142
143/// Parameters for the resources/updated notification
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
145pub struct ResourceUpdatedParams {
146    /// URI of the resource that was updated
147    pub uri: String,
148}
149
150/// Parameters for the resources/list_changed notification
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152pub struct ResourceListChangedParams {}
153
154// ============================================================================
155// Prompt Messages
156// ============================================================================
157
158/// Parameters for the prompts/list request
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
160pub struct ListPromptsParams {
161    /// Optional cursor for pagination
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub cursor: Option<String>,
164}
165
166/// Result of the prompts/list request
167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
168pub struct ListPromptsResult {
169    /// Available prompts
170    pub prompts: Vec<PromptInfo>,
171    /// Cursor for pagination (if more prompts are available)
172    #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
173    pub next_cursor: Option<String>,
174}
175
176/// Parameters for the prompts/get request
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
178pub struct GetPromptParams {
179    /// Name of the prompt to get
180    pub name: String,
181    /// Arguments to pass to the prompt
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub arguments: Option<HashMap<String, Value>>,
184}
185
186/// Result of the prompts/get request
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188pub struct GetPromptResult {
189    /// Description of the prompt result
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub description: Option<String>,
192    /// Messages generated by the prompt
193    pub messages: Vec<PromptMessage>,
194}
195
196/// Parameters for the prompts/list_changed notification
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
198pub struct PromptListChangedParams {}
199
200// ============================================================================
201// Sampling Messages
202// ============================================================================
203
204/// Parameters for the sampling/createMessage request
205#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
206pub struct CreateMessageParams {
207    /// Messages to include in the conversation
208    pub messages: Vec<SamplingMessage>,
209    /// Model preferences
210    #[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
211    pub model_preferences: Option<ModelPreferences>,
212    /// System prompt
213    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
214    pub system_prompt: Option<String>,
215    /// Whether to include context from the current session
216    #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
217    pub include_context: Option<String>,
218    /// Maximum number of tokens to generate
219    #[serde(rename = "maxTokens", skip_serializing_if = "Option::is_none")]
220    pub max_tokens: Option<u32>,
221    /// Sampling temperature (0.0 to 1.0)
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub temperature: Option<f32>,
224    /// Nucleus sampling parameter (0.0 to 1.0)
225    #[serde(rename = "topP", skip_serializing_if = "Option::is_none")]
226    pub top_p: Option<f32>,
227    /// Stop sequences
228    #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
229    pub stop_sequences: Option<Vec<String>>,
230    /// Metadata to include with the request
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub metadata: Option<HashMap<String, Value>>,
233}
234
235/// Result of the sampling/createMessage request
236#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
237pub struct CreateMessageResult {
238    /// Role of the generated message
239    pub role: String,
240    /// Content of the generated message
241    pub content: SamplingContent,
242    /// Model used for generation
243    pub model: String,
244    /// Stop reason
245    #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
246    pub stop_reason: Option<String>,
247}
248
249/// A message in a sampling conversation
250#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
251pub struct SamplingMessage {
252    /// Role of the message (e.g., "user", "assistant")
253    pub role: String,
254    /// Content of the message
255    pub content: SamplingContent,
256}
257
258/// Content for sampling messages
259#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
260#[serde(untagged)]
261pub enum SamplingContent {
262    /// Simple text content
263    Text(String),
264    /// Complex content with multiple parts
265    Complex(Vec<Content>),
266}
267
268/// Model preferences for sampling
269#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
270pub struct ModelPreferences {
271    /// Hints about cost constraints
272    #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
273    pub cost_priority: Option<f32>,
274    /// Hints about speed constraints
275    #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
276    pub speed_priority: Option<f32>,
277    /// Hints about quality constraints
278    #[serde(rename = "qualityPriority", skip_serializing_if = "Option::is_none")]
279    pub quality_priority: Option<f32>,
280}
281
282// ============================================================================
283// Tool List Changed Notification
284// ============================================================================
285
286/// Parameters for the tools/list_changed notification
287#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
288pub struct ToolListChangedParams {}
289
290// ============================================================================
291// Ping Messages
292// ============================================================================
293
294/// Parameters for the ping request (no parameters)
295#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
296pub struct PingParams {}
297
298/// Result of the ping request (no result)
299#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
300pub struct PingResult {}
301
302// ============================================================================
303// Logging Messages
304// ============================================================================
305
306/// Parameters for the logging/setLevel request
307#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
308pub struct SetLoggingLevelParams {
309    /// The logging level to set
310    pub level: LoggingLevel,
311}
312
313/// Result of the logging/setLevel request
314#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
315pub struct SetLoggingLevelResult {}
316
317/// Logging level enumeration
318#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
319#[serde(rename_all = "lowercase")]
320pub enum LoggingLevel {
321    Debug,
322    Info,
323    Notice,
324    Warning,
325    Error,
326    Critical,
327    Alert,
328    Emergency,
329}
330
331/// Parameters for the logging/message notification
332#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
333pub struct LoggingMessageParams {
334    /// The logging level
335    pub level: LoggingLevel,
336    /// The logger name
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub logger: Option<String>,
339    /// The log message data
340    pub data: Value,
341}
342
343// ============================================================================
344// Progress Messages
345// ============================================================================
346
347/// Parameters for the progress notification
348#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
349pub struct ProgressParams {
350    /// Unique identifier for the progress operation
351    #[serde(rename = "progressToken")]
352    pub progress_token: String,
353    /// Current progress (0.0 to 1.0)
354    pub progress: f32,
355    /// Optional total count
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub total: Option<u32>,
358}
359
360// ============================================================================
361// Message Helpers and Constructors
362// ============================================================================
363
364impl InitializeParams {
365    /// Create new initialize parameters
366    pub fn new(
367        client_info: ClientInfo,
368        capabilities: ClientCapabilities,
369        protocol_version: String,
370    ) -> Self {
371        Self {
372            client_info,
373            capabilities,
374            protocol_version,
375        }
376    }
377}
378
379impl InitializeResult {
380    /// Create new initialize result
381    pub fn new(
382        server_info: ServerInfo,
383        capabilities: ServerCapabilities,
384        protocol_version: String,
385    ) -> Self {
386        Self {
387            server_info,
388            capabilities,
389            protocol_version,
390        }
391    }
392}
393
394impl CallToolParams {
395    /// Create new call tool parameters
396    pub fn new(name: String, arguments: Option<HashMap<String, Value>>) -> Self {
397        Self { name, arguments }
398    }
399}
400
401impl ReadResourceParams {
402    /// Create new read resource parameters
403    pub fn new(uri: String) -> Self {
404        Self { uri }
405    }
406}
407
408impl GetPromptParams {
409    /// Create new get prompt parameters
410    pub fn new(name: String, arguments: Option<HashMap<String, Value>>) -> Self {
411        Self { name, arguments }
412    }
413}
414
415impl SamplingMessage {
416    /// Create a user message
417    pub fn user<S: Into<String>>(content: S) -> Self {
418        Self {
419            role: "user".to_string(),
420            content: SamplingContent::Text(content.into()),
421        }
422    }
423
424    /// Create an assistant message
425    pub fn assistant<S: Into<String>>(content: S) -> Self {
426        Self {
427            role: "assistant".to_string(),
428            content: SamplingContent::Text(content.into()),
429        }
430    }
431
432    /// Create a system message
433    pub fn system<S: Into<String>>(content: S) -> Self {
434        Self {
435            role: "system".to_string(),
436            content: SamplingContent::Text(content.into()),
437        }
438    }
439
440    /// Create a message with complex content
441    pub fn with_content<S: Into<String>>(role: S, content: Vec<Content>) -> Self {
442        Self {
443            role: role.into(),
444            content: SamplingContent::Complex(content),
445        }
446    }
447}
448
449impl ModelPreferences {
450    /// Create model preferences with default values
451    pub fn default() -> Self {
452        Self {
453            cost_priority: None,
454            speed_priority: None,
455            quality_priority: None,
456        }
457    }
458
459    /// Set cost priority
460    pub fn with_cost_priority(mut self, priority: f32) -> Self {
461        self.cost_priority = Some(priority);
462        self
463    }
464
465    /// Set speed priority
466    pub fn with_speed_priority(mut self, priority: f32) -> Self {
467        self.speed_priority = Some(priority);
468        self
469    }
470
471    /// Set quality priority
472    pub fn with_quality_priority(mut self, priority: f32) -> Self {
473        self.quality_priority = Some(priority);
474        self
475    }
476}
477
478// ============================================================================
479// Message Type Constants
480// ============================================================================
481
482/// Constant for MCP protocol version
483pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
484
485/// JSON-RPC method names for MCP messages
486pub mod methods {
487    /// Initialize the connection
488    pub const INITIALIZE: &str = "initialize";
489
490    /// Ping to check connection
491    pub const PING: &str = "ping";
492
493    /// List available tools
494    pub const TOOLS_LIST: &str = "tools/list";
495    /// Call a tool
496    pub const TOOLS_CALL: &str = "tools/call";
497    /// Notification when tool list changes
498    pub const TOOLS_LIST_CHANGED: &str = "tools/list_changed";
499
500    /// List available resources
501    pub const RESOURCES_LIST: &str = "resources/list";
502    /// Read a resource
503    pub const RESOURCES_READ: &str = "resources/read";
504    /// Subscribe to resource updates
505    pub const RESOURCES_SUBSCRIBE: &str = "resources/subscribe";
506    /// Unsubscribe from resource updates
507    pub const RESOURCES_UNSUBSCRIBE: &str = "resources/unsubscribe";
508    /// Notification when a resource is updated
509    pub const RESOURCES_UPDATED: &str = "resources/updated";
510    /// Notification when resource list changes
511    pub const RESOURCES_LIST_CHANGED: &str = "resources/list_changed";
512
513    /// List available prompts
514    pub const PROMPTS_LIST: &str = "prompts/list";
515    /// Get a prompt
516    pub const PROMPTS_GET: &str = "prompts/get";
517    /// Notification when prompt list changes
518    pub const PROMPTS_LIST_CHANGED: &str = "prompts/list_changed";
519
520    /// Create a message using sampling
521    pub const SAMPLING_CREATE_MESSAGE: &str = "sampling/createMessage";
522
523    /// Set logging level
524    pub const LOGGING_SET_LEVEL: &str = "logging/setLevel";
525    /// Log message notification
526    pub const LOGGING_MESSAGE: &str = "logging/message";
527
528    /// Progress notification
529    pub const PROGRESS: &str = "progress";
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use serde_json::json;
536
537    #[test]
538    fn test_initialize_params_serialization() {
539        let params = InitializeParams::new(
540            ClientInfo {
541                name: "test-client".to_string(),
542                version: "1.0.0".to_string(),
543            },
544            ClientCapabilities::default(),
545            MCP_PROTOCOL_VERSION.to_string(),
546        );
547
548        let json = serde_json::to_value(&params).unwrap();
549        assert_eq!(json["clientInfo"]["name"], "test-client");
550        assert_eq!(json["protocolVersion"], MCP_PROTOCOL_VERSION);
551    }
552
553    #[test]
554    fn test_call_tool_params() {
555        let mut args = HashMap::new();
556        args.insert("param1".to_string(), json!("value1"));
557        args.insert("param2".to_string(), json!(42));
558
559        let params = CallToolParams::new("test_tool".to_string(), Some(args));
560        let json = serde_json::to_value(&params).unwrap();
561
562        assert_eq!(json["name"], "test_tool");
563        assert_eq!(json["arguments"]["param1"], "value1");
564        assert_eq!(json["arguments"]["param2"], 42);
565    }
566
567    #[test]
568    fn test_sampling_message_creation() {
569        let user_msg = SamplingMessage::user("Hello, world!");
570        assert_eq!(user_msg.role, "user");
571
572        if let SamplingContent::Text(text) = user_msg.content {
573            assert_eq!(text, "Hello, world!");
574        } else {
575            panic!("Expected text content");
576        }
577
578        let assistant_msg = SamplingMessage::assistant("Hello back!");
579        assert_eq!(assistant_msg.role, "assistant");
580    }
581
582    #[test]
583    fn test_model_preferences_builder() {
584        let prefs = ModelPreferences::default()
585            .with_cost_priority(0.8)
586            .with_speed_priority(0.6)
587            .with_quality_priority(0.9);
588
589        assert_eq!(prefs.cost_priority, Some(0.8));
590        assert_eq!(prefs.speed_priority, Some(0.6));
591        assert_eq!(prefs.quality_priority, Some(0.9));
592    }
593
594    #[test]
595    fn test_read_resource_params() {
596        let params = ReadResourceParams::new("file:///path/to/file.txt".to_string());
597        let json = serde_json::to_value(&params).unwrap();
598        assert_eq!(json["uri"], "file:///path/to/file.txt");
599    }
600
601    #[test]
602    fn test_logging_level_serialization() {
603        let level = LoggingLevel::Warning;
604        let json = serde_json::to_value(&level).unwrap();
605        assert_eq!(json, "warning");
606
607        let deserialized: LoggingLevel = serde_json::from_value(json!("error")).unwrap();
608        assert_eq!(deserialized, LoggingLevel::Error);
609    }
610
611    #[test]
612    fn test_method_constants() {
613        assert_eq!(methods::INITIALIZE, "initialize");
614        assert_eq!(methods::TOOLS_LIST, "tools/list");
615        assert_eq!(methods::TOOLS_CALL, "tools/call");
616        assert_eq!(methods::RESOURCES_READ, "resources/read");
617        assert_eq!(methods::PROMPTS_GET, "prompts/get");
618        assert_eq!(methods::SAMPLING_CREATE_MESSAGE, "sampling/createMessage");
619    }
620}