1use async_trait::async_trait;
32use serde::{Deserialize, Serialize};
33use serde_json::Value;
34
35use crate::error::McpResult;
36
37#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
39pub struct SamplingMessage {
40 pub role: String,
45 pub content: SamplingContent,
48}
49
50#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
53#[serde(tag = "type", rename_all = "lowercase")]
54#[non_exhaustive]
55pub enum SamplingContent {
56 Text {
58 text: String,
60 },
61 Image {
63 data: String,
65 #[serde(rename = "mimeType")]
67 mime_type: String,
68 },
69 Audio {
71 data: String,
73 #[serde(rename = "mimeType")]
75 mime_type: String,
76 },
77}
78
79#[derive(Clone, Debug, Default, PartialEq, Deserialize, Serialize)]
83pub struct ModelPreferences {
84 #[serde(default, skip_serializing_if = "Vec::is_empty")]
89 pub hints: Vec<ModelHint>,
90 #[serde(
93 default,
94 skip_serializing_if = "Option::is_none",
95 rename = "costPriority"
96 )]
97 pub cost_priority: Option<f64>,
98 #[serde(
101 default,
102 skip_serializing_if = "Option::is_none",
103 rename = "speedPriority"
104 )]
105 pub speed_priority: Option<f64>,
106 #[serde(
109 default,
110 skip_serializing_if = "Option::is_none",
111 rename = "intelligencePriority"
112 )]
113 pub intelligence_priority: Option<f64>,
114}
115
116#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
118pub struct ModelHint {
119 pub name: String,
121}
122
123#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
128#[serde(rename_all = "camelCase")]
129#[non_exhaustive]
130pub enum IncludeContext {
131 #[default]
133 None,
134 ThisServer,
136 AllServers,
138}
139
140#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
143pub struct SamplingRequest {
144 pub messages: Vec<SamplingMessage>,
146 #[serde(
148 default,
149 skip_serializing_if = "Option::is_none",
150 rename = "modelPreferences"
151 )]
152 pub model_preferences: Option<ModelPreferences>,
153 #[serde(
155 default,
156 skip_serializing_if = "Option::is_none",
157 rename = "systemPrompt"
158 )]
159 pub system_prompt: Option<String>,
160 #[serde(
162 default,
163 skip_serializing_if = "Option::is_none",
164 rename = "includeContext"
165 )]
166 pub include_context: Option<IncludeContext>,
167 #[serde(default, skip_serializing_if = "Option::is_none")]
170 pub temperature: Option<f64>,
171 #[serde(default, skip_serializing_if = "Option::is_none", rename = "maxTokens")]
175 pub max_tokens: Option<u32>,
176 #[serde(
178 default,
179 skip_serializing_if = "Vec::is_empty",
180 rename = "stopSequences"
181 )]
182 pub stop_sequences: Vec<String>,
183 #[serde(default, skip_serializing_if = "Option::is_none")]
186 pub metadata: Option<Value>,
187}
188
189#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
192pub struct SamplingResponse {
193 pub model: String,
197 #[serde(rename = "stopReason")]
201 pub stop_reason: String,
202 pub role: String,
204 pub content: SamplingContent,
206}
207
208#[async_trait]
219pub trait SamplingProvider: Send + Sync + 'static + std::fmt::Debug {
220 async fn sample(&self, request: SamplingRequest) -> McpResult<SamplingResponse>;
223}
224
225#[derive(Clone, Debug)]
231pub struct StaticSamplingProvider {
232 response: SamplingResponse,
233}
234
235impl StaticSamplingProvider {
236 #[must_use]
238 pub const fn new(response: SamplingResponse) -> Self {
239 Self { response }
240 }
241
242 #[must_use]
244 pub fn text(model: impl Into<String>, text: impl Into<String>) -> Self {
245 Self {
246 response: SamplingResponse {
247 model: model.into(),
248 stop_reason: "endTurn".into(),
249 role: "assistant".into(),
250 content: SamplingContent::Text { text: text.into() },
251 },
252 }
253 }
254}
255
256#[async_trait]
257impl SamplingProvider for StaticSamplingProvider {
258 async fn sample(&self, _request: SamplingRequest) -> McpResult<SamplingResponse> {
259 Ok(self.response.clone())
260 }
261}
262
263#[cfg(test)]
264#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
265mod tests {
266 use super::*;
267 use serde_json::json;
268
269 #[test]
270 fn text_content_serializes_with_type_tag() {
271 let c = SamplingContent::Text {
272 text: "hello".into(),
273 };
274 let s = serde_json::to_value(&c).unwrap();
275 assert_eq!(s, json!({"type": "text", "text": "hello"}));
276 }
277
278 #[test]
279 fn image_content_serializes_with_mime_type() {
280 let c = SamplingContent::Image {
281 data: "AAAA".into(),
282 mime_type: "image/png".into(),
283 };
284 let s = serde_json::to_value(&c).unwrap();
285 assert_eq!(
286 s,
287 json!({"type": "image", "data": "AAAA", "mimeType": "image/png"})
288 );
289 }
290
291 #[test]
292 fn request_deserializes_from_wire_shape_with_optional_fields() {
293 let raw = json!({
294 "messages": [
295 {"role": "user", "content": {"type": "text", "text": "hi"}}
296 ],
297 "modelPreferences": {
298 "hints": [{"name": "claude-3-sonnet"}],
299 "intelligencePriority": 0.9
300 },
301 "systemPrompt": "be concise",
302 "includeContext": "thisServer",
303 "temperature": 0.7,
304 "maxTokens": 256
305 });
306 let parsed: SamplingRequest = serde_json::from_value(raw).unwrap();
307 assert_eq!(parsed.messages.len(), 1);
308 assert_eq!(
309 parsed.messages[0].content,
310 SamplingContent::Text { text: "hi".into() }
311 );
312 let prefs = parsed.model_preferences.as_ref().unwrap();
313 assert_eq!(prefs.hints[0].name, "claude-3-sonnet");
314 assert_eq!(prefs.intelligence_priority, Some(0.9));
315 assert_eq!(parsed.system_prompt.as_deref(), Some("be concise"));
316 assert_eq!(parsed.include_context, Some(IncludeContext::ThisServer));
317 assert_eq!(parsed.max_tokens, Some(256));
318 }
319
320 #[test]
321 fn request_deserializes_minimal_messages_only() {
322 let raw = json!({
323 "messages": [{"role": "user", "content": {"type": "text", "text": "x"}}]
324 });
325 let parsed: SamplingRequest = serde_json::from_value(raw).unwrap();
326 assert!(parsed.model_preferences.is_none());
327 assert!(parsed.system_prompt.is_none());
328 assert!(parsed.include_context.is_none());
329 assert!(parsed.temperature.is_none());
330 assert!(parsed.max_tokens.is_none());
331 assert!(parsed.stop_sequences.is_empty());
332 }
333
334 #[test]
335 fn response_serializes_with_stop_reason_camel_case() {
336 let r = SamplingResponse {
337 model: "claude-3".into(),
338 stop_reason: "endTurn".into(),
339 role: "assistant".into(),
340 content: SamplingContent::Text {
341 text: "done".into(),
342 },
343 };
344 let s = serde_json::to_value(&r).unwrap();
345 assert_eq!(s["model"], "claude-3");
346 assert_eq!(s["stopReason"], "endTurn");
347 assert_eq!(s["content"]["type"], "text");
348 }
349
350 #[tokio::test]
351 async fn static_text_provider_returns_configured_response() {
352 let provider = StaticSamplingProvider::text("claude-3", "ack");
353 let req = SamplingRequest {
354 messages: vec![],
355 model_preferences: None,
356 system_prompt: None,
357 include_context: None,
358 temperature: None,
359 max_tokens: None,
360 stop_sequences: vec![],
361 metadata: None,
362 };
363 let resp = provider.sample(req).await.unwrap();
364 assert_eq!(resp.model, "claude-3");
365 assert_eq!(resp.stop_reason, "endTurn");
366 assert_eq!(resp.role, "assistant");
367 assert_eq!(resp.content, SamplingContent::Text { text: "ack".into() });
368 }
369}