mcp_probe_core/messages/
sampling.rs

1//! Sampling-related message types for MCP LLM completion requests.
2//!
3//! This module provides types for:
4//! - Server-to-client LLM completion requests
5//! - Completion parameters (temperature, max tokens, etc.)
6//! - Completion responses with generated content
7//! - Model selection and configuration
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12
13/// Request from server to client for LLM completion.
14///
15/// This allows MCP servers to request LLM completions from the client,
16/// enabling servers to leverage the client's LLM capabilities.
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct CompleteRequest {
19    /// The completion argument
20    pub argument: CompletionArgument,
21}
22
23/// Arguments for a completion request.
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct CompletionArgument {
26    /// Messages for the completion
27    pub messages: Vec<SamplingMessage>,
28
29    /// Optional model selection
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub model_preferences: Option<ModelPreferences>,
32
33    /// System prompt for the completion
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub system_prompt: Option<String>,
36
37    /// Include context about tools available to the model
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub include_context: Option<String>,
40
41    /// Temperature for sampling (0.0 to 1.0)
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub temperature: Option<f64>,
44
45    /// Maximum number of tokens to generate
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub max_tokens: Option<i32>,
48
49    /// Stop sequences for completion
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub stop_sequences: Option<Vec<String>>,
52
53    /// Additional metadata for the request
54    #[serde(flatten)]
55    pub metadata: HashMap<String, Value>,
56}
57
58impl CompletionArgument {
59    /// Create a new completion argument with messages.
60    pub fn new(messages: Vec<SamplingMessage>) -> Self {
61        Self {
62            messages,
63            model_preferences: None,
64            system_prompt: None,
65            include_context: None,
66            temperature: None,
67            max_tokens: None,
68            stop_sequences: None,
69            metadata: HashMap::new(),
70        }
71    }
72
73    /// Set model preferences.
74    pub fn with_model_preferences(mut self, preferences: ModelPreferences) -> Self {
75        self.model_preferences = Some(preferences);
76        self
77    }
78
79    /// Set system prompt.
80    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81        self.system_prompt = Some(prompt.into());
82        self
83    }
84
85    /// Set temperature.
86    pub fn with_temperature(mut self, temperature: f64) -> Self {
87        self.temperature = Some(temperature);
88        self
89    }
90
91    /// Set maximum tokens.
92    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
93        self.max_tokens = Some(max_tokens);
94        self
95    }
96
97    /// Add stop sequences.
98    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
99        self.stop_sequences = Some(sequences);
100        self
101    }
102
103    /// Add metadata.
104    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
105        self.metadata.insert(key.into(), value);
106        self
107    }
108}
109
110/// Model preferences for completion requests.
111#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub struct ModelPreferences {
113    /// Preferred model names in order of preference
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub models: Option<Vec<String>>,
116
117    /// Minimum cost tier acceptable
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub cost_priority: Option<CostPriority>,
120
121    /// Minimum speed tier acceptable
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub speed_priority: Option<SpeedPriority>,
124
125    /// Minimum intelligence tier acceptable
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub intelligence_priority: Option<IntelligencePriority>,
128}
129
130impl ModelPreferences {
131    /// Create new model preferences.
132    pub fn new() -> Self {
133        Self {
134            models: None,
135            cost_priority: None,
136            speed_priority: None,
137            intelligence_priority: None,
138        }
139    }
140
141    /// Set preferred models.
142    pub fn with_models(mut self, models: Vec<String>) -> Self {
143        self.models = Some(models);
144        self
145    }
146
147    /// Set cost priority.
148    pub fn with_cost_priority(mut self, priority: CostPriority) -> Self {
149        self.cost_priority = Some(priority);
150        self
151    }
152
153    /// Set speed priority.
154    pub fn with_speed_priority(mut self, priority: SpeedPriority) -> Self {
155        self.speed_priority = Some(priority);
156        self
157    }
158
159    /// Set intelligence priority.
160    pub fn with_intelligence_priority(mut self, priority: IntelligencePriority) -> Self {
161        self.intelligence_priority = Some(priority);
162        self
163    }
164}
165
166impl Default for ModelPreferences {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172/// Cost priority levels for model selection.
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175pub enum CostPriority {
176    /// Lowest cost models
177    Low,
178    /// Medium cost models
179    Medium,
180    /// High cost models acceptable
181    High,
182}
183
184/// Speed priority levels for model selection.
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186#[serde(rename_all = "lowercase")]
187pub enum SpeedPriority {
188    /// Slowest acceptable speed
189    Low,
190    /// Medium speed required
191    Medium,
192    /// High speed required
193    High,
194}
195
196/// Intelligence priority levels for model selection.
197#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
198#[serde(rename_all = "lowercase")]
199pub enum IntelligencePriority {
200    /// Basic intelligence level
201    Low,
202    /// Medium intelligence level
203    Medium,
204    /// High intelligence level required
205    High,
206}
207
208/// A message in a sampling request.
209#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
210pub struct SamplingMessage {
211    /// Role of the message
212    pub role: MessageRole,
213
214    /// Content of the message
215    pub content: SamplingContent,
216}
217
218impl SamplingMessage {
219    /// Create a new sampling message.
220    pub fn new(role: MessageRole, content: SamplingContent) -> Self {
221        Self { role, content }
222    }
223
224    /// Create a system message.
225    pub fn system(content: impl Into<String>) -> Self {
226        Self::new(MessageRole::System, SamplingContent::text(content))
227    }
228
229    /// Create a user message.
230    pub fn user(content: impl Into<String>) -> Self {
231        Self::new(MessageRole::User, SamplingContent::text(content))
232    }
233
234    /// Create an assistant message.
235    pub fn assistant(content: impl Into<String>) -> Self {
236        Self::new(MessageRole::Assistant, SamplingContent::text(content))
237    }
238}
239
240/// Role of a message in sampling.
241#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
242#[serde(rename_all = "lowercase")]
243pub enum MessageRole {
244    /// System message
245    System,
246    /// User message
247    User,
248    /// Assistant message
249    Assistant,
250}
251
252/// Content of a sampling message.
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
254#[serde(tag = "type")]
255pub enum SamplingContent {
256    /// Text content
257    #[serde(rename = "text")]
258    Text {
259        /// The text content
260        text: String,
261    },
262
263    /// Image content
264    #[serde(rename = "image")]
265    Image {
266        /// Image data (base64 or URL)
267        data: String,
268
269        /// MIME type of the image
270        #[serde(rename = "mimeType")]
271        mime_type: String,
272    },
273}
274
275impl SamplingContent {
276    /// Create text content.
277    pub fn text(text: impl Into<String>) -> Self {
278        Self::Text { text: text.into() }
279    }
280
281    /// Create image content.
282    pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
283        Self::Image {
284            data: data.into(),
285            mime_type: mime_type.into(),
286        }
287    }
288}
289
290/// Response to a completion request.
291#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292pub struct CompleteResponse {
293    /// Completion result
294    pub completion: CompletionResult,
295
296    /// Model used for the completion
297    #[serde(skip_serializing_if = "Option::is_none")]
298    pub model: Option<String>,
299
300    /// Stop reason for the completion
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub stop_reason: Option<StopReason>,
303}
304
305/// Result of a completion request.
306#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
307#[serde(tag = "type")]
308pub enum CompletionResult {
309    /// Text completion result
310    #[serde(rename = "text")]
311    Text {
312        /// Generated text
313        text: String,
314    },
315}
316
317impl CompletionResult {
318    /// Create a text completion result.
319    pub fn text(text: impl Into<String>) -> Self {
320        Self::Text { text: text.into() }
321    }
322}
323
324/// Reason why completion stopped.
325#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
326#[serde(rename_all = "snake_case")]
327pub enum StopReason {
328    /// Reached end of sequence naturally
329    EndTurn,
330    /// Hit maximum token limit
331    MaxTokens,
332    /// Encountered stop sequence
333    StopSequence,
334    /// Tool call was made
335    ToolUse,
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use serde_json::json;
342
343    #[test]
344    fn test_completion_argument_creation() {
345        let messages = vec![
346            SamplingMessage::system("You are a helpful assistant"),
347            SamplingMessage::user("Hello, how are you?"),
348        ];
349
350        let arg = CompletionArgument::new(messages)
351            .with_temperature(0.7)
352            .with_max_tokens(1000)
353            .with_system_prompt("Be helpful")
354            .with_metadata("priority", json!("high"));
355
356        assert_eq!(arg.temperature, Some(0.7));
357        assert_eq!(arg.max_tokens, Some(1000));
358        assert_eq!(arg.system_prompt, Some("Be helpful".to_string()));
359        assert_eq!(arg.metadata.get("priority"), Some(&json!("high")));
360    }
361
362    #[test]
363    fn test_model_preferences() {
364        let prefs = ModelPreferences::new()
365            .with_models(vec!["gpt-4".to_string(), "claude-3".to_string()])
366            .with_cost_priority(CostPriority::Medium)
367            .with_speed_priority(SpeedPriority::High)
368            .with_intelligence_priority(IntelligencePriority::High);
369
370        assert_eq!(
371            prefs.models,
372            Some(vec!["gpt-4".to_string(), "claude-3".to_string()])
373        );
374        assert_eq!(prefs.cost_priority, Some(CostPriority::Medium));
375        assert_eq!(prefs.speed_priority, Some(SpeedPriority::High));
376        assert_eq!(
377            prefs.intelligence_priority,
378            Some(IntelligencePriority::High)
379        );
380    }
381
382    #[test]
383    fn test_sampling_message_creation() {
384        let system_msg = SamplingMessage::system("You are helpful");
385        let user_msg = SamplingMessage::user("Hello");
386        let assistant_msg = SamplingMessage::assistant("Hi there!");
387
388        assert_eq!(system_msg.role, MessageRole::System);
389        assert_eq!(user_msg.role, MessageRole::User);
390        assert_eq!(assistant_msg.role, MessageRole::Assistant);
391    }
392
393    #[test]
394    fn test_sampling_content_text() {
395        let content = SamplingContent::text("Hello world");
396        let json = serde_json::to_value(&content).unwrap();
397        assert_eq!(json["type"], "text");
398        assert_eq!(json["text"], "Hello world");
399    }
400
401    #[test]
402    fn test_sampling_content_image() {
403        let content = SamplingContent::image("base64data", "image/png");
404        let json = serde_json::to_value(&content).unwrap();
405        assert_eq!(json["type"], "image");
406        assert_eq!(json["data"], "base64data");
407        assert_eq!(json["mimeType"], "image/png");
408    }
409
410    #[test]
411    fn test_completion_result() {
412        let result = CompletionResult::text("Generated response");
413        let json = serde_json::to_value(&result).unwrap();
414        assert_eq!(json["type"], "text");
415        assert_eq!(json["text"], "Generated response");
416    }
417
418    #[test]
419    fn test_priority_serialization() {
420        let cost = CostPriority::Low;
421        let speed = SpeedPriority::Medium;
422        let intel = IntelligencePriority::High;
423
424        assert_eq!(serde_json::to_string(&cost).unwrap(), "\"low\"");
425        assert_eq!(serde_json::to_string(&speed).unwrap(), "\"medium\"");
426        assert_eq!(serde_json::to_string(&intel).unwrap(), "\"high\"");
427    }
428
429    #[test]
430    fn test_stop_reason_serialization() {
431        let reasons = [
432            StopReason::EndTurn,
433            StopReason::MaxTokens,
434            StopReason::StopSequence,
435            StopReason::ToolUse,
436        ];
437
438        let expected = [
439            "\"end_turn\"",
440            "\"max_tokens\"",
441            "\"stop_sequence\"",
442            "\"tool_use\"",
443        ];
444
445        for (reason, expected) in reasons.iter().zip(expected.iter()) {
446            assert_eq!(serde_json::to_string(reason).unwrap(), *expected);
447        }
448    }
449}