mcp_host/protocol/
types.rs

1//! Core JSON-RPC 2.0 and MCP protocol types
2//!
3//! Extracted from crimson/dictator implementations and enhanced with patterns
4//! from the Go mcphost implementation.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9/// JSON-RPC 2.0 request
10#[derive(Debug, Clone, Deserialize, Serialize)]
11pub struct JsonRpcRequest {
12    /// JSON-RPC version (always "2.0")
13    pub jsonrpc: String,
14    /// Request ID (null for notifications)
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub id: Option<Value>,
17    /// Method name
18    pub method: String,
19    /// Method parameters
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub params: Option<Value>,
22}
23
24impl JsonRpcRequest {
25    /// Create a new request with ID
26    pub fn new(id: Value, method: impl Into<String>, params: Option<Value>) -> Self {
27        Self {
28            jsonrpc: "2.0".to_string(),
29            id: Some(id),
30            method: method.into(),
31            params,
32        }
33    }
34
35    /// Create a notification (no ID)
36    pub fn notification(method: impl Into<String>, params: Option<Value>) -> Self {
37        Self {
38            jsonrpc: "2.0".to_string(),
39            id: None,
40            method: method.into(),
41            params,
42        }
43    }
44
45    /// Check if this is a notification
46    pub fn is_notification(&self) -> bool {
47        self.id.is_none()
48    }
49}
50
51/// JSON-RPC 2.0 response
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct JsonRpcResponse {
54    /// JSON-RPC version (always "2.0")
55    pub jsonrpc: String,
56    /// Request ID
57    pub id: Value,
58    /// Successful result
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub result: Option<Value>,
61    /// Error object
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub error: Option<JsonRpcError>,
64}
65
66impl JsonRpcResponse {
67    /// Create a successful response
68    pub fn success(id: Value, result: Value) -> Self {
69        Self {
70            jsonrpc: "2.0".to_string(),
71            id,
72            result: Some(result),
73            error: None,
74        }
75    }
76
77    /// Create an error response
78    pub fn error(id: Value, error: JsonRpcError) -> Self {
79        Self {
80            jsonrpc: "2.0".to_string(),
81            id,
82            result: None,
83            error: Some(error),
84        }
85    }
86}
87
88/// JSON-RPC 2.0 error object
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct JsonRpcError {
91    /// Error code
92    pub code: i32,
93    /// Error message
94    pub message: String,
95    /// Additional error data
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub data: Option<Value>,
98}
99
100impl JsonRpcError {
101    /// Create a new error
102    pub fn new(code: i32, message: impl Into<String>) -> Self {
103        Self {
104            code,
105            message: message.into(),
106            data: None,
107        }
108    }
109
110    /// Add additional data to the error
111    pub fn with_data(mut self, data: Value) -> Self {
112        self.data = Some(data);
113        self
114    }
115
116    /// Parse error (-32700)
117    pub fn parse_error(message: impl Into<String>) -> Self {
118        Self::new(-32700, message)
119    }
120
121    /// Invalid request (-32600)
122    pub fn invalid_request(message: impl Into<String>) -> Self {
123        Self::new(-32600, message)
124    }
125
126    /// Method not found (-32601)
127    pub fn method_not_found(message: impl Into<String>) -> Self {
128        Self::new(-32601, message)
129    }
130
131    /// Invalid params (-32602)
132    pub fn invalid_params(message: impl Into<String>) -> Self {
133        Self::new(-32602, message)
134    }
135
136    /// Internal error (-32603)
137    pub fn internal_error(message: impl Into<String>) -> Self {
138        Self::new(-32603, message)
139    }
140}
141
142/// Server/Client implementation information
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct Implementation {
145    /// Implementation name
146    pub name: String,
147    /// Implementation version
148    pub version: String,
149}
150
151impl Implementation {
152    /// Create new implementation info
153    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
154        Self {
155            name: name.into(),
156            version: version.into(),
157        }
158    }
159}
160
161/// Client information with version validation
162#[derive(Debug, Clone, Default)]
163pub struct ClientInfo {
164    /// Client name
165    pub name: String,
166    /// Client version
167    pub version: String,
168}
169
170impl ClientInfo {
171    /// Create new client info
172    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
173        Self {
174            name: name.into(),
175            version: version.into(),
176        }
177    }
178
179    /// Check if client version meets minimum requirements
180    ///
181    /// Uses the configurable MIN_CLIENT_VERSIONS from version module
182    pub fn is_supported(&self, min_versions: &[(&str, &str)]) -> bool {
183        for (name, min_version) in min_versions {
184            if self.name == *name {
185                return version_gte(&self.version, min_version);
186            }
187        }
188        // Unknown clients are allowed by default
189        true
190    }
191}
192
193/// Compare semantic versions (simple: major.minor.patch)
194///
195/// Returns true if version >= min
196pub fn version_gte(version: &str, min: &str) -> bool {
197    let parse = |v: &str| -> (u32, u32, u32) {
198        let parts: Vec<u32> = v.split('.').filter_map(|p| p.parse().ok()).collect();
199        (
200            parts.first().copied().unwrap_or(0),
201            parts.get(1).copied().unwrap_or(0),
202            parts.get(2).copied().unwrap_or(0),
203        )
204    };
205
206    let (v_maj, v_min, v_patch) = parse(version);
207    let (m_maj, m_min, m_patch) = parse(min);
208
209    (v_maj, v_min, v_patch) >= (m_maj, m_min, m_patch)
210}
211
212// =============================================================================
213// ROOTS - Client Context Management
214// =============================================================================
215
216/// Represents a root URI provided by the client
217///
218/// Roots are top-level directories or resources that the client wants
219/// the server to be aware of for context.
220#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
221#[serde(rename_all = "camelCase")]
222pub struct Root {
223    /// URI of the root (e.g., file:// or other protocol)
224    pub uri: String,
225    /// Optional human-readable name for the root
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub name: Option<String>,
228}
229
230/// Result of listing roots
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct ListRootsResult {
233    /// List of roots provided by the client
234    pub roots: Vec<Root>,
235}
236
237/// Notification that the client's root list has changed
238///
239/// Servers should re-request the roots list when receiving this
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct RootsListChangedNotification {}
242
243// =============================================================================
244// SAMPLING - LLM Message Creation
245// =============================================================================
246
247/// Role in a conversation
248#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
249#[serde(rename_all = "lowercase")]
250pub enum Role {
251    /// User message
252    User,
253    /// Assistant message
254    Assistant,
255}
256
257/// A message in a sampling conversation
258#[derive(Debug, Clone, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct SamplingMessage {
261    /// Role of the message sender
262    pub role: Role,
263    /// Message content (reuses Content from content module)
264    pub content: Value, // Will be Content type from content module
265}
266
267/// Specifies how much context should be included in sampling requests
268#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub enum ContextInclusion {
271    /// No context
272    None,
273    /// Context from this server only
274    ThisServer,
275    /// Context from all connected servers
276    AllServers,
277}
278
279/// Model hint for LLM selection
280#[derive(Debug, Clone, Serialize, Deserialize)]
281#[serde(rename_all = "camelCase")]
282pub struct ModelHint {
283    /// Preferred model name (e.g., "claude-3-sonnet", "gpt-4")
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub name: Option<String>,
286}
287
288/// Model selection preferences
289#[derive(Debug, Clone, Serialize, Deserialize)]
290#[serde(rename_all = "camelCase")]
291pub struct ModelPreferences {
292    /// Model name hints
293    #[serde(skip_serializing_if = "Option::is_none")]
294    pub hints: Option<Vec<ModelHint>>,
295    /// Cost priority (0.0-1.0, higher means prefer cheaper models)
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub cost_priority: Option<f64>,
298    /// Speed priority (0.0-1.0, higher means prefer faster models)
299    #[serde(skip_serializing_if = "Option::is_none")]
300    pub speed_priority: Option<f64>,
301    /// Intelligence priority (0.0-1.0, higher means prefer smarter models)
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub intelligence_priority: Option<f64>,
304}
305
306/// Request to create an LLM message/completion
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(rename_all = "camelCase")]
309pub struct CreateMessageRequest {
310    /// Conversation history
311    pub messages: Vec<SamplingMessage>,
312    /// Model selection preferences
313    #[serde(skip_serializing_if = "Option::is_none")]
314    pub model_preferences: Option<ModelPreferences>,
315    /// System prompt to use
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub system_prompt: Option<String>,
318    /// How much context to include
319    #[serde(skip_serializing_if = "Option::is_none")]
320    pub include_context: Option<ContextInclusion>,
321    /// Temperature (0.0-1.0+)
322    #[serde(skip_serializing_if = "Option::is_none")]
323    pub temperature: Option<f64>,
324    /// Maximum tokens to generate
325    pub max_tokens: u32,
326    /// Stop sequences
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub stop_sequences: Option<Vec<String>>,
329    /// Additional metadata
330    #[serde(skip_serializing_if = "Option::is_none")]
331    pub metadata: Option<Value>,
332}
333
334/// Result of creating an LLM message
335#[derive(Debug, Clone, Serialize, Deserialize)]
336#[serde(rename_all = "camelCase")]
337pub struct CreateMessageResult {
338    /// Generated message
339    pub message: SamplingMessage,
340    /// Model used for generation
341    pub model: String,
342    /// Why generation stopped
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub stop_reason: Option<String>, // "endTurn", "stopSequence", "maxTokens"
345}
346
347impl CreateMessageResult {
348    /// Stop reason: natural end of turn
349    pub const STOP_REASON_END_TURN: &'static str = "endTurn";
350    /// Stop reason: hit a stop sequence
351    pub const STOP_REASON_STOP_SEQUENCE: &'static str = "stopSequence";
352    /// Stop reason: reached max tokens
353    pub const STOP_REASON_MAX_TOKENS: &'static str = "maxTokens";
354}
355
356// =============================================================================
357// ELICITATION - Structured User Input
358// =============================================================================
359
360/// Action taken by user in response to elicitation
361#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
362#[serde(rename_all = "lowercase")]
363pub enum ElicitationAction {
364    /// User approved and provided data
365    Accept,
366    /// User explicitly declined
367    Decline,
368    /// User dismissed without decision
369    Cancel,
370}
371
372/// Request to elicit structured input from user
373#[derive(Debug, Clone, Serialize, Deserialize)]
374#[serde(rename_all = "camelCase")]
375pub struct CreateElicitationRequest {
376    /// Message to show user
377    pub message: String,
378    /// Schema for requested data (imported from elicitation module)
379    pub requested_schema: Value, // Will be ElicitationSchema
380}
381
382/// Result of elicitation request
383#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct CreateElicitationResult {
385    /// User's action
386    pub action: ElicitationAction,
387    /// User-provided content (only present if action is Accept)
388    #[serde(skip_serializing_if = "Option::is_none")]
389    pub content: Option<Value>,
390}
391
392// ============================================================================
393// Tasks
394// ============================================================================
395
396/// Task status in lifecycle
397#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
398#[serde(rename_all = "snake_case")]
399pub enum TaskStatus {
400    /// The request is currently being processed
401    Working,
402    /// The task is waiting for input (e.g., elicitation or sampling)
403    InputRequired,
404    /// The request completed successfully and results are available
405    Completed,
406    /// The associated request did not complete successfully
407    Failed,
408    /// The request was cancelled before completion
409    Cancelled,
410}
411
412/// Metadata for augmenting a request with task execution
413#[derive(Debug, Clone, Serialize, Deserialize)]
414#[serde(rename_all = "camelCase")]
415pub struct TaskMetadata {
416    /// Requested duration in milliseconds to retain task from creation
417    #[serde(skip_serializing_if = "Option::is_none")]
418    pub ttl: Option<u64>,
419}
420
421/// Data associated with a task
422#[derive(Debug, Clone, Serialize, Deserialize)]
423#[serde(rename_all = "camelCase")]
424pub struct Task {
425    /// The task identifier
426    pub task_id: String,
427    /// Current task state
428    pub status: TaskStatus,
429    /// Optional human-readable message describing the current task state
430    #[serde(skip_serializing_if = "Option::is_none")]
431    pub status_message: Option<String>,
432    /// ISO 8601 timestamp when the task was created
433    pub created_at: String,
434    /// ISO 8601 timestamp when the task was last updated
435    pub last_updated_at: String,
436    /// Actual retention duration from creation in milliseconds, null for unlimited
437    pub ttl: Option<u64>,
438    /// Suggested polling interval in milliseconds
439    #[serde(skip_serializing_if = "Option::is_none")]
440    pub poll_interval: Option<u64>,
441}
442
443/// A response to a task-augmented request
444#[derive(Debug, Clone, Serialize, Deserialize)]
445#[serde(rename_all = "camelCase")]
446pub struct CreateTaskResult {
447    /// The created task
448    pub task: Task,
449}
450
451/// Request params for tasks/get
452#[derive(Debug, Clone, Serialize, Deserialize)]
453#[serde(rename_all = "camelCase")]
454pub struct GetTaskParams {
455    /// The task identifier to query
456    pub task_id: String,
457}
458
459/// Request params for tasks/cancel
460#[derive(Debug, Clone, Serialize, Deserialize)]
461#[serde(rename_all = "camelCase")]
462pub struct CancelTaskParams {
463    /// The task identifier to cancel
464    pub task_id: String,
465}
466
467/// Tool execution metadata
468#[derive(Debug, Clone, Default, Serialize, Deserialize)]
469#[serde(rename_all = "camelCase")]
470pub struct ToolExecution {
471    /// Indicates whether this tool supports task-augmented execution
472    #[serde(skip_serializing_if = "Option::is_none")]
473    pub task_support: Option<TaskSupport>,
474}
475
476/// Task support level for a tool
477#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
478#[serde(rename_all = "lowercase")]
479pub enum TaskSupport {
480    /// Tool does not support task-augmented execution (default when absent)
481    Forbidden,
482    /// Tool may support task-augmented execution
483    Optional,
484    /// Tool requires task-augmented execution
485    Required,
486}
487
488// ============================================================================
489// Logging Types
490// ============================================================================
491
492/// Request to set the minimum logging level
493#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct SetLevelRequest {
495    /// Minimum log level to receive (debug, info, notice, warning, error, critical, alert, emergency)
496    pub level: String,
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_version_gte_equal() {
505        assert!(version_gte("1.0.0", "1.0.0"));
506        assert!(version_gte("2.0.56", "2.0.56"));
507    }
508
509    #[test]
510    fn test_version_gte_major_greater() {
511        assert!(version_gte("2.0.0", "1.0.0"));
512        assert!(version_gte("3.0.0", "2.9.99"));
513    }
514
515    #[test]
516    fn test_version_gte_minor_greater() {
517        assert!(version_gte("1.2.0", "1.1.0"));
518        assert!(version_gte("1.10.0", "1.9.0"));
519    }
520
521    #[test]
522    fn test_version_gte_patch_greater() {
523        assert!(version_gte("1.0.2", "1.0.1"));
524        assert!(version_gte("2.0.56", "2.0.55"));
525    }
526
527    #[test]
528    fn test_version_gte_less_than() {
529        assert!(!version_gte("1.0.0", "2.0.0"));
530        assert!(!version_gte("2.0.55", "2.0.56"));
531        assert!(!version_gte("1.9.0", "2.0.0"));
532    }
533
534    #[test]
535    fn test_jsonrpc_request_creation() {
536        let req = JsonRpcRequest::new(
537            Value::from(1),
538            "test_method",
539            Some(serde_json::json!({"key": "value"})),
540        );
541        assert_eq!(req.jsonrpc, "2.0");
542        assert_eq!(req.method, "test_method");
543        assert!(req.id.is_some());
544    }
545
546    #[test]
547    fn test_jsonrpc_notification() {
548        let notif = JsonRpcRequest::notification("test_notif", None);
549        assert!(notif.is_notification());
550        assert!(notif.id.is_none());
551    }
552
553    #[test]
554    fn test_jsonrpc_response_success() {
555        let resp = JsonRpcResponse::success(Value::from(1), serde_json::json!({"status": "ok"}));
556        assert!(resp.result.is_some());
557        assert!(resp.error.is_none());
558    }
559
560    #[test]
561    fn test_jsonrpc_response_error() {
562        let error = JsonRpcError::method_not_found("test method not found");
563        let resp = JsonRpcResponse::error(Value::from(1), error);
564        assert!(resp.result.is_none());
565        assert!(resp.error.is_some());
566        assert_eq!(resp.error.unwrap().code, -32601);
567    }
568}