Skip to main content

entelix_mcp/
sampling.rs

1//! `SamplingProvider` + request/response shapes — client-side
2//! answer to the server-initiated `sampling/createMessage`
3//! request (MCP 2024-11-05 §"Sampling").
4//!
5//! Sampling lets an MCP server ask the client to run an LLM
6//! completion on its behalf — typically when the server needs
7//! reasoning capability it doesn't own (e.g., a server
8//! orchestrating tool dispatch wants the agent's LLM to choose
9//! the next tool). The server provides the conversation
10//! prefix, optional sampling parameters, and gets back a
11//! finalized assistant message.
12//!
13//! ## Why a trait, not a `ChatModel` adapter shipped here
14//!
15//! A "wire `ChatModel` directly" adapter would force this
16//! crate to depend on the `ChatModel` surface. Instead the
17//! trait stays minimal and operators write a 20-line wrapper
18//! that converts MCP messages → `entelix_core::ir::Message` →
19//! `ChatModel::invoke` → MCP response. The conversion is
20//! deployment-specific (which model, which prompt envelope,
21//! which IR translation) and doesn't generalise cleanly into
22//! the trait surface.
23//!
24//! ## No `ExecutionContext` parameter
25//!
26//! Mirrors [`crate::RootsProvider`] and
27//! [`crate::ElicitationProvider`]: server-initiated requests
28//! arrive on a background SSE listener, not in the middle of
29//! a client-driven call. No honest context to thread.
30
31use async_trait::async_trait;
32use serde::{Deserialize, Serialize};
33use serde_json::Value;
34
35use crate::error::McpResult;
36
37/// One conversation message in a sampling request.
38#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
39pub struct SamplingMessage {
40    /// Speaker role — either `user` or `assistant` per MCP
41    /// spec. The variants are stringly-typed at the wire level
42    /// because MCP doesn't enumerate them — the agent is free
43    /// to pass through whatever role the server requested.
44    pub role: String,
45    /// Message body. Text is by far the common case; image /
46    /// audio variants exist for multimodal servers.
47    pub content: SamplingContent,
48}
49
50/// Body of one [`SamplingMessage`]. Tagged by `type` field on
51/// the wire — matches the MCP spec's content-block shape.
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
53#[serde(tag = "type", rename_all = "lowercase")]
54#[non_exhaustive]
55pub enum SamplingContent {
56    /// Plain text content.
57    Text {
58        /// UTF-8 text body.
59        text: String,
60    },
61    /// Image content (base64-encoded data + MIME type).
62    Image {
63        /// Base64-encoded image bytes.
64        data: String,
65        /// MIME type (e.g., `image/png`, `image/jpeg`).
66        #[serde(rename = "mimeType")]
67        mime_type: String,
68    },
69    /// Audio content (base64-encoded data + MIME type).
70    Audio {
71        /// Base64-encoded audio bytes.
72        data: String,
73        /// MIME type (e.g., `audio/wav`).
74        #[serde(rename = "mimeType")]
75        mime_type: String,
76    },
77}
78
79/// Operator hints + priorities the server passes to bias
80/// model selection. All fields optional — the provider
81/// chooses how to honour them (or ignores them entirely).
82#[derive(Clone, Debug, Default, PartialEq, Deserialize, Serialize)]
83pub struct ModelPreferences {
84    /// Suggested model names, best-fit-first. Provider may
85    /// match against any of them; spec encourages substring
86    /// matching (e.g., hint `"claude-3-sonnet"` matches
87    /// `"claude-3-sonnet-20240229"`).
88    #[serde(default, skip_serializing_if = "Vec::is_empty")]
89    pub hints: Vec<ModelHint>,
90    /// Cost-vs-quality preference in `[0.0, 1.0]`. Higher =
91    /// prefer cheaper models.
92    #[serde(
93        default,
94        skip_serializing_if = "Option::is_none",
95        rename = "costPriority"
96    )]
97    pub cost_priority: Option<f64>,
98    /// Speed-vs-quality preference in `[0.0, 1.0]`. Higher =
99    /// prefer faster models.
100    #[serde(
101        default,
102        skip_serializing_if = "Option::is_none",
103        rename = "speedPriority"
104    )]
105    pub speed_priority: Option<f64>,
106    /// Intelligence-vs-cost preference in `[0.0, 1.0]`. Higher
107    /// = prefer more capable models.
108    #[serde(
109        default,
110        skip_serializing_if = "Option::is_none",
111        rename = "intelligencePriority"
112    )]
113    pub intelligence_priority: Option<f64>,
114}
115
116/// One model hint (a name string the server suggests).
117#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
118pub struct ModelHint {
119    /// Hint string (e.g., `"claude-3-sonnet"`).
120    pub name: String,
121}
122
123/// How much surrounding context the server wants the client
124/// to include in the sampling call. The client decides how to
125/// honour this (the spec is intentionally vague — it's a
126/// hint, not a contract).
127#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
128#[serde(rename_all = "camelCase")]
129#[non_exhaustive]
130pub enum IncludeContext {
131    /// No context outside the supplied messages.
132    #[default]
133    None,
134    /// Include context from this MCP server only.
135    ThisServer,
136    /// Include context from every MCP server the client knows.
137    AllServers,
138}
139
140/// Sampling request as it arrives from the server. Mirrors
141/// the spec's `sampling/createMessage` params block.
142#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
143pub struct SamplingRequest {
144    /// Conversation prefix the server wants completed.
145    pub messages: Vec<SamplingMessage>,
146    /// Optional model selection hints + priorities.
147    #[serde(
148        default,
149        skip_serializing_if = "Option::is_none",
150        rename = "modelPreferences"
151    )]
152    pub model_preferences: Option<ModelPreferences>,
153    /// Optional system prompt to prepend.
154    #[serde(
155        default,
156        skip_serializing_if = "Option::is_none",
157        rename = "systemPrompt"
158    )]
159    pub system_prompt: Option<String>,
160    /// Whether to include surrounding context.
161    #[serde(
162        default,
163        skip_serializing_if = "Option::is_none",
164        rename = "includeContext"
165    )]
166    pub include_context: Option<IncludeContext>,
167    /// Sampling temperature (vendor-defined range — typically
168    /// `[0.0, 1.0]` or `[0.0, 2.0]`).
169    #[serde(default, skip_serializing_if = "Option::is_none")]
170    pub temperature: Option<f64>,
171    /// Token cap for the completion. Servers SHOULD set this;
172    /// providers that pass through to a vendor demanding the
173    /// field (Anthropic) reject the request when missing.
174    #[serde(default, skip_serializing_if = "Option::is_none", rename = "maxTokens")]
175    pub max_tokens: Option<u32>,
176    /// Stop sequences (model halts as soon as one is generated).
177    #[serde(
178        default,
179        skip_serializing_if = "Vec::is_empty",
180        rename = "stopSequences"
181    )]
182    pub stop_sequences: Vec<String>,
183    /// Vendor-opaque metadata the server attached; passed
184    /// through to the provider verbatim.
185    #[serde(default, skip_serializing_if = "Option::is_none")]
186    pub metadata: Option<Value>,
187}
188
189/// Sampling response the client sends back. Mirrors the spec's
190/// `sampling/createMessage` result block.
191#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
192pub struct SamplingResponse {
193    /// Model identifier the provider used (e.g.,
194    /// `"claude-3-sonnet-20240229"`). Surfaced for the
195    /// server's audit / cost accounting.
196    pub model: String,
197    /// Why generation stopped. MCP spec uses string tokens
198    /// (`endTurn`, `stopSequence`, `maxTokens`); the field
199    /// stays stringly-typed to mirror the wire shape.
200    #[serde(rename = "stopReason")]
201    pub stop_reason: String,
202    /// Role of the produced message — typically `assistant`.
203    pub role: String,
204    /// Generated content. Same shape as request messages.
205    pub content: SamplingContent,
206}
207
208/// Async source-of-truth for sampling completions. Mirrors
209/// the `*Provider` taxonomy — async, single-purpose,
210/// replaceable.
211///
212/// Operators wire one provider per server through
213/// [`crate::McpServerConfig::with_sampling_provider`]. Most
214/// production providers wrap a `ChatModel` from
215/// `entelix_core` — convert MCP messages to IR, dispatch, map
216/// the response back. The trait stays minimal so the
217/// conversion choices stay operator-side.
218#[async_trait]
219pub trait SamplingProvider: Send + Sync + 'static + std::fmt::Debug {
220    /// Resolve one server-initiated sampling request — call
221    /// the underlying LLM (or stub it) and return the result.
222    async fn sample(&self, request: SamplingRequest) -> McpResult<SamplingResponse>;
223}
224
225/// In-memory [`SamplingProvider`] returning a fixed response.
226///
227/// Useful for tests and for deployments that want a
228/// deterministic stub (e.g., during local development before
229/// a real LLM is wired).
230#[derive(Clone, Debug)]
231pub struct StaticSamplingProvider {
232    response: SamplingResponse,
233}
234
235impl StaticSamplingProvider {
236    /// Wrap a fixed response.
237    #[must_use]
238    pub const fn new(response: SamplingResponse) -> Self {
239        Self { response }
240    }
241
242    /// Convenience: text-only response with `endTurn` stop reason.
243    #[must_use]
244    pub fn text(model: impl Into<String>, text: impl Into<String>) -> Self {
245        Self {
246            response: SamplingResponse {
247                model: model.into(),
248                stop_reason: "endTurn".into(),
249                role: "assistant".into(),
250                content: SamplingContent::Text { text: text.into() },
251            },
252        }
253    }
254}
255
256#[async_trait]
257impl SamplingProvider for StaticSamplingProvider {
258    async fn sample(&self, _request: SamplingRequest) -> McpResult<SamplingResponse> {
259        Ok(self.response.clone())
260    }
261}
262
263#[cfg(test)]
264#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
265mod tests {
266    use super::*;
267    use serde_json::json;
268
269    #[test]
270    fn text_content_serializes_with_type_tag() {
271        let c = SamplingContent::Text {
272            text: "hello".into(),
273        };
274        let s = serde_json::to_value(&c).unwrap();
275        assert_eq!(s, json!({"type": "text", "text": "hello"}));
276    }
277
278    #[test]
279    fn image_content_serializes_with_mime_type() {
280        let c = SamplingContent::Image {
281            data: "AAAA".into(),
282            mime_type: "image/png".into(),
283        };
284        let s = serde_json::to_value(&c).unwrap();
285        assert_eq!(
286            s,
287            json!({"type": "image", "data": "AAAA", "mimeType": "image/png"})
288        );
289    }
290
291    #[test]
292    fn request_deserializes_from_wire_shape_with_optional_fields() {
293        let raw = json!({
294            "messages": [
295                {"role": "user", "content": {"type": "text", "text": "hi"}}
296            ],
297            "modelPreferences": {
298                "hints": [{"name": "claude-3-sonnet"}],
299                "intelligencePriority": 0.9
300            },
301            "systemPrompt": "be concise",
302            "includeContext": "thisServer",
303            "temperature": 0.7,
304            "maxTokens": 256
305        });
306        let parsed: SamplingRequest = serde_json::from_value(raw).unwrap();
307        assert_eq!(parsed.messages.len(), 1);
308        assert_eq!(
309            parsed.messages[0].content,
310            SamplingContent::Text { text: "hi".into() }
311        );
312        let prefs = parsed.model_preferences.as_ref().unwrap();
313        assert_eq!(prefs.hints[0].name, "claude-3-sonnet");
314        assert_eq!(prefs.intelligence_priority, Some(0.9));
315        assert_eq!(parsed.system_prompt.as_deref(), Some("be concise"));
316        assert_eq!(parsed.include_context, Some(IncludeContext::ThisServer));
317        assert_eq!(parsed.max_tokens, Some(256));
318    }
319
320    #[test]
321    fn request_deserializes_minimal_messages_only() {
322        let raw = json!({
323            "messages": [{"role": "user", "content": {"type": "text", "text": "x"}}]
324        });
325        let parsed: SamplingRequest = serde_json::from_value(raw).unwrap();
326        assert!(parsed.model_preferences.is_none());
327        assert!(parsed.system_prompt.is_none());
328        assert!(parsed.include_context.is_none());
329        assert!(parsed.temperature.is_none());
330        assert!(parsed.max_tokens.is_none());
331        assert!(parsed.stop_sequences.is_empty());
332    }
333
334    #[test]
335    fn response_serializes_with_stop_reason_camel_case() {
336        let r = SamplingResponse {
337            model: "claude-3".into(),
338            stop_reason: "endTurn".into(),
339            role: "assistant".into(),
340            content: SamplingContent::Text {
341                text: "done".into(),
342            },
343        };
344        let s = serde_json::to_value(&r).unwrap();
345        assert_eq!(s["model"], "claude-3");
346        assert_eq!(s["stopReason"], "endTurn");
347        assert_eq!(s["content"]["type"], "text");
348    }
349
350    #[tokio::test]
351    async fn static_text_provider_returns_configured_response() {
352        let provider = StaticSamplingProvider::text("claude-3", "ack");
353        let req = SamplingRequest {
354            messages: vec![],
355            model_preferences: None,
356            system_prompt: None,
357            include_context: None,
358            temperature: None,
359            max_tokens: None,
360            stop_sequences: vec![],
361            metadata: None,
362        };
363        let resp = provider.sample(req).await.unwrap();
364        assert_eq!(resp.model, "claude-3");
365        assert_eq!(resp.stop_reason, "endTurn");
366        assert_eq!(resp.role, "assistant");
367        assert_eq!(resp.content, SamplingContent::Text { text: "ack".into() });
368    }
369}