mcpkit_server/capability/
sampling.rs1use crate::context::Context;
7use mcpkit_core::error::McpError;
8use mcpkit_core::types::content::Role;
9use mcpkit_core::types::sampling::{
10 CreateMessageRequest, CreateMessageResult, IncludeContext, ModelPreferences, SamplingMessage,
11 StopReason,
12};
13use std::future::Future;
14use std::pin::Pin;
15
16pub type BoxedSamplingFn = Box<
18 dyn for<'a> Fn(
19 CreateMessageRequest,
20 &'a Context<'a>,
21 ) -> Pin<
22 Box<dyn Future<Output = Result<CreateMessageResult, McpError>> + Send + 'a>,
23 > + Send
24 + Sync,
25>;
26
27pub struct SamplingService {
31 handler: Option<BoxedSamplingFn>,
32}
33
34impl Default for SamplingService {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl SamplingService {
41 #[must_use]
43 pub fn new() -> Self {
44 Self { handler: None }
45 }
46
47 pub fn with_handler<F, Fut>(mut self, handler: F) -> Self
49 where
50 F: Fn(CreateMessageRequest, &Context<'_>) -> Fut + Send + Sync + 'static,
51 Fut: Future<Output = Result<CreateMessageResult, McpError>> + Send + 'static,
52 {
53 self.handler = Some(Box::new(move |req, ctx| Box::pin(handler(req, ctx))));
54 self
55 }
56
57 #[must_use]
59 pub fn is_supported(&self) -> bool {
60 self.handler.is_some()
61 }
62
63 pub async fn create_message(
65 &self,
66 request: CreateMessageRequest,
67 ctx: &Context<'_>,
68 ) -> Result<CreateMessageResult, McpError> {
69 let handler = self
70 .handler
71 .as_ref()
72 .ok_or_else(|| McpError::invalid_request("Sampling not supported"))?;
73
74 (handler)(request, ctx).await
75 }
76}
77
78pub struct SamplingRequestBuilder {
80 messages: Vec<SamplingMessage>,
81 model_preferences: Option<ModelPreferences>,
82 system_prompt: Option<String>,
83 include_context: Option<IncludeContext>,
84 max_tokens: Option<u32>,
85 temperature: Option<f64>,
86 stop_sequences: Vec<String>,
87}
88
89impl Default for SamplingRequestBuilder {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl SamplingRequestBuilder {
96 #[must_use]
98 pub const fn new() -> Self {
99 Self {
100 messages: Vec::new(),
101 model_preferences: None,
102 system_prompt: None,
103 include_context: None,
104 max_tokens: None,
105 temperature: None,
106 stop_sequences: Vec::new(),
107 }
108 }
109
110 pub fn user(mut self, content: impl Into<String>) -> Self {
112 self.messages.push(SamplingMessage::user(content.into()));
113 self
114 }
115
116 pub fn assistant(mut self, content: impl Into<String>) -> Self {
118 self.messages
119 .push(SamplingMessage::assistant(content.into()));
120 self
121 }
122
123 #[must_use]
125 pub fn message(mut self, msg: SamplingMessage) -> Self {
126 self.messages.push(msg);
127 self
128 }
129
130 #[must_use]
132 pub fn model_preferences(mut self, prefs: ModelPreferences) -> Self {
133 self.model_preferences = Some(prefs);
134 self
135 }
136
137 #[must_use]
139 pub const fn include_context(mut self, context: IncludeContext) -> Self {
140 self.include_context = Some(context);
141 self
142 }
143
144 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
146 self.system_prompt = Some(prompt.into());
147 self
148 }
149
150 #[must_use]
152 pub const fn max_tokens(mut self, tokens: u32) -> Self {
153 self.max_tokens = Some(tokens);
154 self
155 }
156
157 #[must_use]
159 pub const fn temperature(mut self, temp: f64) -> Self {
160 self.temperature = Some(temp);
161 self
162 }
163
164 pub fn stop_sequence(mut self, seq: impl Into<String>) -> Self {
166 self.stop_sequences.push(seq.into());
167 self
168 }
169
170 #[must_use]
172 pub fn build(self) -> CreateMessageRequest {
173 CreateMessageRequest {
174 messages: self.messages,
175 model_preferences: self.model_preferences,
176 system_prompt: self.system_prompt,
177 include_context: self.include_context,
178 max_tokens: self.max_tokens.unwrap_or(1024),
179 temperature: self.temperature,
180 stop_sequences: if self.stop_sequences.is_empty() {
181 None
182 } else {
183 Some(self.stop_sequences)
184 },
185 metadata: None,
186 }
187 }
188}
189
190pub struct SamplingResultBuilder {
192 role: Role,
193 content: String,
194 model: String,
195 stop_reason: Option<StopReason>,
196}
197
198impl SamplingResultBuilder {
199 pub fn new(model: impl Into<String>) -> Self {
201 Self {
202 role: Role::Assistant,
203 content: String::new(),
204 model: model.into(),
205 stop_reason: None,
206 }
207 }
208
209 pub fn content(mut self, content: impl Into<String>) -> Self {
211 self.content = content.into();
212 self
213 }
214
215 #[must_use]
217 pub const fn stop_reason(mut self, reason: StopReason) -> Self {
218 self.stop_reason = Some(reason);
219 self
220 }
221
222 #[must_use]
224 pub const fn end_turn(mut self) -> Self {
225 self.stop_reason = Some(StopReason::EndTurn);
226 self
227 }
228
229 #[must_use]
231 pub const fn max_tokens_reached(mut self) -> Self {
232 self.stop_reason = Some(StopReason::MaxTokens);
233 self
234 }
235
236 #[must_use]
238 pub fn build(self) -> CreateMessageResult {
239 CreateMessageResult {
240 role: self.role,
241 content: mcpkit_core::types::content::Content::text(self.content),
242 model: self.model,
243 stop_reason: self.stop_reason,
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_sampling_request_builder() {
254 let request = SamplingRequestBuilder::new()
255 .system_prompt("You are a helpful assistant")
256 .user("Hello!")
257 .max_tokens(100)
258 .temperature(0.7)
259 .build();
260
261 assert_eq!(request.messages.len(), 1);
262 assert_eq!(request.max_tokens, 100);
263 assert_eq!(request.temperature, Some(0.7));
264 assert_eq!(
265 request.system_prompt.as_deref(),
266 Some("You are a helpful assistant")
267 );
268 }
269
270 #[test]
271 fn test_sampling_result_builder() {
272 let result = SamplingResultBuilder::new("gpt-4")
273 .content("Hello! How can I help you?")
274 .end_turn()
275 .build();
276
277 assert_eq!(result.role, Role::Assistant);
278 assert_eq!(result.model, "gpt-4");
279 assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
280 }
281
282 #[test]
283 fn test_sampling_service_default() {
284 let service = SamplingService::new();
285 assert!(!service.is_supported());
286 }
287
288 #[tokio::test]
289 async fn test_sampling_service_with_handler() {
290 let service = SamplingService::new().with_handler(|_req, _ctx| async {
291 Ok(SamplingResultBuilder::new("test-model")
292 .content("Test response")
293 .end_turn()
294 .build())
295 });
296
297 assert!(service.is_supported());
298 }
299}