Skip to main content

klieo_mcp_server/
sampling.rs

1//! MCP `sampling/createMessage` wire-format types.
2//!
3//! Sampling is the server-initiated path that asks the MCP client to
4//! run an LLM completion on the server's behalf. The shape of the
5//! request/response envelope is fixed by the MCP spec revision the
6//! crate advertises ([`crate::MCP_PROTOCOL_VERSION`]); this module
7//! captures it as plain Rust types with `serde` renames mapping
8//! the snake_case Rust fields onto the camelCase JSON the wire uses.
9//!
10//! ## Text-content-only scope
11//!
12//! `SamplingContent` carries the `text` variant only. Image, audio, and
13//! resource content blocks are part of the MCP spec but intentionally
14//! out of scope here — a content blob tagged `image` / `audio` /
15//! `resource` will fail to deserialise with serde's "unknown variant"
16//! error, surfacing as an `UnsupportedContent` failure at the call
17//! site rather than silently dropping non-text payloads. See ADR-027
18//! for the rationale.
19//!
20//! ## Wire-format renames
21//!
22//! Field names on the JSON wire follow MCP's camelCase convention
23//! (`maxTokens`, `systemPrompt`, `modelPreferences`, `stopSequences`,
24//! `costPriority`, `speedPriority`, `intelligencePriority`,
25//! `stopReason`). Rust fields use snake_case and rely on
26//! `#[serde(rename = "...")]` to bridge the two. Optional fields use
27//! `skip_serializing_if = "Option::is_none"` so the serialised
28//! payload omits absent values rather than emitting `null`s the spec
29//! does not require.
30
31use serde::{Deserialize, Serialize};
32
33/// One `sampling/createMessage` request envelope sent from a klieo MCP
34/// server to its client. Carries the message history plus optional
35/// model preferences, a system prompt, sampling controls, and an upper
36/// bound on the number of tokens the client may generate.
37#[non_exhaustive]
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct SamplingRequest {
40    /// Conversation history the client should complete. Ordered
41    /// oldest-first; the client's reply continues from the last entry.
42    pub messages: Vec<SamplingMessage>,
43
44    /// Optional hints to the client about which model family / cost
45    /// / latency / intelligence profile to use. Clients MAY ignore
46    /// these and pick their own model.
47    #[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
48    pub model_preferences: Option<ModelPreferences>,
49
50    /// Optional system prompt to prepend to the conversation. When
51    /// `None`, the client uses its own default (or no system prompt).
52    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
53    pub system_prompt: Option<String>,
54
55    /// Upper bound on the number of tokens the client may generate.
56    /// Required by the MCP spec; clients MUST honour it.
57    #[serde(rename = "maxTokens")]
58    pub max_tokens: u32,
59
60    /// Optional sampling temperature. Range and default are
61    /// client-defined; common practice is 0.0–2.0 with 1.0 as default.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub temperature: Option<f32>,
64
65    /// Optional list of stop sequences. Generation halts as soon as
66    /// the model emits any of these strings.
67    #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
68    pub stop_sequences: Option<Vec<String>>,
69}
70
71impl SamplingRequest {
72    /// Constructs a new [`SamplingRequest`] with required fields.
73    /// Optional fields default to `None`.
74    pub fn new(messages: Vec<SamplingMessage>, max_tokens: u32) -> Self {
75        Self {
76            messages,
77            model_preferences: None,
78            system_prompt: None,
79            max_tokens,
80            temperature: None,
81            stop_sequences: None,
82        }
83    }
84}
85
86/// One message in a [`SamplingRequest::messages`] history. The MCP
87/// spec defines `role` as the free-form string `"user"` or
88/// `"assistant"`; richer typing is intentionally deferred to the
89/// transport boundary.
90#[non_exhaustive]
91#[derive(Clone, Debug, Serialize, Deserialize)]
92pub struct SamplingMessage {
93    /// Speaker role. Wire values: `"user"` or `"assistant"`.
94    pub role: String,
95
96    /// Message payload. Only [`SamplingContent::Text`] is supported
97    /// in this cluster; see the module-level note on unsupported
98    /// content variants.
99    pub content: SamplingContent,
100}
101
102impl SamplingMessage {
103    /// Constructs a new [`SamplingMessage`] with the given role and content.
104    ///
105    /// Required for external callers after `#[non_exhaustive]` prevents
106    /// struct-literal construction outside the crate.
107    pub fn new(role: impl Into<String>, content: SamplingContent) -> Self {
108        Self {
109            role: role.into(),
110            content,
111        }
112    }
113}
114
115/// Discriminated content payload for a [`SamplingMessage`] or
116/// [`SamplingResponse`]. Serialised with `{"type":"<variant>", ...}`
117/// per the MCP wire format.
118///
119/// Only the `text` variant is implemented in this cluster. An
120/// incoming blob tagged `"image"`, `"audio"`, or `"resource"` will
121/// fail to deserialise with serde's "unknown variant" error, which
122/// the call site surfaces as an unsupported-content failure. See
123/// ADR-027 for the scope decision.
124#[derive(Clone, Debug, Serialize, Deserialize)]
125#[serde(tag = "type", rename_all = "lowercase")]
126#[non_exhaustive]
127pub enum SamplingContent {
128    /// Plain UTF-8 text payload.
129    Text {
130        /// The text content of the message.
131        text: String,
132    },
133}
134
135/// Optional hints from the server to the client about which model to
136/// use for a [`SamplingRequest`]. All fields are advisory — the
137/// client may pick any model it has access to and ignore these.
138///
139/// Priority fields are normalised to the range `0.0..=1.0`; the
140/// client weighs them against its own catalogue when choosing a model.
141#[non_exhaustive]
142#[derive(Clone, Debug, Default, Serialize, Deserialize)]
143pub struct ModelPreferences {
144    /// Optional ordered list of model-name hints. The first hint
145    /// whose name the client recognises wins.
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub hints: Option<Vec<ModelHint>>,
148
149    /// Optional cost-priority weight (`0.0` = cost-insensitive, `1.0`
150    /// = cheapest model preferred).
151    #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
152    pub cost_priority: Option<f32>,
153
154    /// Optional speed-priority weight (`0.0` = latency-insensitive,
155    /// `1.0` = fastest model preferred).
156    #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
157    pub speed_priority: Option<f32>,
158
159    /// Optional intelligence-priority weight (`0.0` = simplest model
160    /// acceptable, `1.0` = most capable model preferred).
161    #[serde(
162        rename = "intelligencePriority",
163        skip_serializing_if = "Option::is_none"
164    )]
165    pub intelligence_priority: Option<f32>,
166}
167
168/// One model-name hint inside [`ModelPreferences::hints`]. The wire
169/// shape is `{"name": "..."}` per the MCP spec.
170#[non_exhaustive]
171#[derive(Clone, Debug, Serialize, Deserialize)]
172pub struct ModelHint {
173    /// Model family / name string the server is hinting at (e.g.
174    /// `"claude-3-5-sonnet"`, `"gpt-4o"`).
175    pub name: String,
176}
177
178/// The `sampling/createMessage` response envelope returned by an MCP
179/// client. Carries the generated message, the model the client
180/// actually used, and an optional `stopReason` describing why
181/// generation halted.
182#[non_exhaustive]
183#[derive(Clone, Debug, Deserialize)]
184pub struct SamplingResponse {
185    /// Speaker role of the generated message. Typically
186    /// `"assistant"`.
187    pub role: String,
188
189    /// Generated content payload.
190    pub content: SamplingContent,
191
192    /// Identifier of the model the client used to generate
193    /// `content`. May differ from any hint the server supplied.
194    pub model: String,
195
196    /// Optional reason generation stopped (e.g. `"endTurn"`,
197    /// `"maxTokens"`, `"stopSequence"`). Wire field is
198    /// camelCase `stopReason`.
199    #[serde(rename = "stopReason")]
200    pub stop_reason: Option<String>,
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn sampling_request_serialises_camelcase_field_names() {
209        let req = SamplingRequest {
210            messages: vec![SamplingMessage {
211                role: "user".into(),
212                content: SamplingContent::Text {
213                    text: "hello".into(),
214                },
215            }],
216            model_preferences: Some(ModelPreferences {
217                hints: Some(vec![ModelHint {
218                    name: "claude-3-5-sonnet".into(),
219                }]),
220                cost_priority: Some(0.2),
221                speed_priority: Some(0.5),
222                intelligence_priority: Some(0.9),
223            }),
224            system_prompt: Some("be concise".into()),
225            max_tokens: 256,
226            temperature: Some(0.7),
227            stop_sequences: Some(vec!["STOP".into()]),
228        };
229        let json = serde_json::to_value(&req).expect("serialises");
230        assert!(json.get("maxTokens").is_some(), "maxTokens key present");
231        assert!(json.get("max_tokens").is_none(), "no snake_case leak");
232        assert!(
233            json.get("systemPrompt").is_some(),
234            "systemPrompt key present"
235        );
236        assert!(
237            json.get("modelPreferences").is_some(),
238            "modelPreferences key present"
239        );
240        assert!(
241            json.get("stopSequences").is_some(),
242            "stopSequences key present"
243        );
244        let prefs = &json["modelPreferences"];
245        assert!(prefs.get("costPriority").is_some(), "costPriority present");
246        assert!(
247            prefs.get("speedPriority").is_some(),
248            "speedPriority present"
249        );
250        assert!(
251            prefs.get("intelligencePriority").is_some(),
252            "intelligencePriority present"
253        );
254
255        // Absent optional fields are omitted, not serialised as null.
256        let minimal = SamplingRequest {
257            messages: vec![],
258            model_preferences: None,
259            system_prompt: None,
260            max_tokens: 16,
261            temperature: None,
262            stop_sequences: None,
263        };
264        let minimal_json = serde_json::to_value(&minimal).expect("serialises");
265        assert!(minimal_json.get("systemPrompt").is_none());
266        assert!(minimal_json.get("modelPreferences").is_none());
267        assert!(minimal_json.get("stopSequences").is_none());
268        assert!(minimal_json.get("temperature").is_none());
269    }
270
271    #[test]
272    fn sampling_response_deserialises_text_content() {
273        let fixture = r#"{"role":"assistant","content":{"type":"text","text":"hi"},"model":"test","stopReason":"endTurn"}"#;
274        let resp: SamplingResponse = serde_json::from_str(fixture).expect("parses");
275        assert_eq!(resp.role, "assistant");
276        assert_eq!(resp.model, "test");
277        assert_eq!(resp.stop_reason.as_deref(), Some("endTurn"));
278        match resp.content {
279            SamplingContent::Text { text } => assert_eq!(text, "hi"),
280        }
281    }
282
283    #[test]
284    fn sampling_content_image_fails_to_deserialise() {
285        let fixture = r#"{"type":"image","data":"...","mimeType":"image/png"}"#;
286        let result: Result<SamplingContent, _> = serde_json::from_str(fixture);
287        assert!(
288            result.is_err(),
289            "image-tagged content must surface as unknown-variant error"
290        );
291    }
292
293    #[test]
294    fn sampling_request_new_defaults_optional_fields_to_none() {
295        let msg = SamplingMessage::new("user", SamplingContent::Text { text: "hi".into() });
296        let req = SamplingRequest::new(vec![msg], 128);
297        assert_eq!(req.messages.len(), 1);
298        assert_eq!(req.max_tokens, 128);
299        assert!(req.model_preferences.is_none());
300        assert!(req.system_prompt.is_none());
301        assert!(req.temperature.is_none());
302        assert!(req.stop_sequences.is_none());
303    }
304
305    #[test]
306    fn sampling_message_new_accepts_str_and_string_role() {
307        let m1 = SamplingMessage::new("user", SamplingContent::Text { text: "a".into() });
308        let m2 = SamplingMessage::new(
309            String::from("assistant"),
310            SamplingContent::Text { text: "b".into() },
311        );
312        assert_eq!(m1.role, "user");
313        assert_eq!(m2.role, "assistant");
314        assert!(matches!(m1.content, SamplingContent::Text { .. }));
315    }
316}