mcpkit_server/capability/
sampling.rs1use crate::context::Context;
7use mcpkit_core::error::McpError;
8use mcpkit_core::types::content::Role;
9use mcpkit_core::types::sampling::{
10 CreateMessageRequest, CreateMessageResult, IncludeContext, ModelPreferences,
11 SamplingMessage, StopReason,
12};
13use std::future::Future;
14use std::pin::Pin;
15
16pub type BoxedSamplingFn = Box<
18 dyn for<'a> Fn(
19 CreateMessageRequest,
20 &'a Context<'a>,
21 ) -> Pin<Box<dyn Future<Output = Result<CreateMessageResult, McpError>> + Send + 'a>>
22 + Send
23 + Sync,
24>;
25
26pub struct SamplingService {
30 handler: Option<BoxedSamplingFn>,
31}
32
33impl Default for SamplingService {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl SamplingService {
40 pub fn new() -> Self {
42 Self { handler: None }
43 }
44
45 pub fn with_handler<F, Fut>(mut self, handler: F) -> Self
47 where
48 F: Fn(CreateMessageRequest, &Context<'_>) -> Fut + Send + Sync + 'static,
49 Fut: Future<Output = Result<CreateMessageResult, McpError>> + Send + 'static,
50 {
51 self.handler = Some(Box::new(move |req, ctx| Box::pin(handler(req, ctx))));
52 self
53 }
54
55 pub fn is_supported(&self) -> bool {
57 self.handler.is_some()
58 }
59
60 pub async fn create_message(
62 &self,
63 request: CreateMessageRequest,
64 ctx: &Context<'_>,
65 ) -> Result<CreateMessageResult, McpError> {
66 let handler = self.handler.as_ref().ok_or_else(|| {
67 McpError::invalid_request("Sampling not supported")
68 })?;
69
70 (handler)(request, ctx).await
71 }
72}
73
74pub struct SamplingRequestBuilder {
76 messages: Vec<SamplingMessage>,
77 model_preferences: Option<ModelPreferences>,
78 system_prompt: Option<String>,
79 include_context: Option<IncludeContext>,
80 max_tokens: Option<u32>,
81 temperature: Option<f64>,
82 stop_sequences: Vec<String>,
83}
84
85impl Default for SamplingRequestBuilder {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl SamplingRequestBuilder {
92 pub fn new() -> Self {
94 Self {
95 messages: Vec::new(),
96 model_preferences: None,
97 system_prompt: None,
98 include_context: None,
99 max_tokens: None,
100 temperature: None,
101 stop_sequences: Vec::new(),
102 }
103 }
104
105 pub fn user(mut self, content: impl Into<String>) -> Self {
107 self.messages.push(SamplingMessage::user(content.into()));
108 self
109 }
110
111 pub fn assistant(mut self, content: impl Into<String>) -> Self {
113 self.messages.push(SamplingMessage::assistant(content.into()));
114 self
115 }
116
117 pub fn message(mut self, msg: SamplingMessage) -> Self {
119 self.messages.push(msg);
120 self
121 }
122
123 pub fn model_preferences(mut self, prefs: ModelPreferences) -> Self {
125 self.model_preferences = Some(prefs);
126 self
127 }
128
129 pub fn include_context(mut self, context: IncludeContext) -> Self {
131 self.include_context = Some(context);
132 self
133 }
134
135 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
137 self.system_prompt = Some(prompt.into());
138 self
139 }
140
141 pub fn max_tokens(mut self, tokens: u32) -> Self {
143 self.max_tokens = Some(tokens);
144 self
145 }
146
147 pub fn temperature(mut self, temp: f64) -> Self {
149 self.temperature = Some(temp);
150 self
151 }
152
153 pub fn stop_sequence(mut self, seq: impl Into<String>) -> Self {
155 self.stop_sequences.push(seq.into());
156 self
157 }
158
159 pub fn build(self) -> CreateMessageRequest {
161 CreateMessageRequest {
162 messages: self.messages,
163 model_preferences: self.model_preferences,
164 system_prompt: self.system_prompt,
165 include_context: self.include_context,
166 max_tokens: self.max_tokens.unwrap_or(1024),
167 temperature: self.temperature,
168 stop_sequences: if self.stop_sequences.is_empty() {
169 None
170 } else {
171 Some(self.stop_sequences)
172 },
173 metadata: None,
174 }
175 }
176}
177
178pub struct SamplingResultBuilder {
180 role: Role,
181 content: String,
182 model: String,
183 stop_reason: Option<StopReason>,
184}
185
186impl SamplingResultBuilder {
187 pub fn new(model: impl Into<String>) -> Self {
189 Self {
190 role: Role::Assistant,
191 content: String::new(),
192 model: model.into(),
193 stop_reason: None,
194 }
195 }
196
197 pub fn content(mut self, content: impl Into<String>) -> Self {
199 self.content = content.into();
200 self
201 }
202
203 pub fn stop_reason(mut self, reason: StopReason) -> Self {
205 self.stop_reason = Some(reason);
206 self
207 }
208
209 pub fn end_turn(mut self) -> Self {
211 self.stop_reason = Some(StopReason::EndTurn);
212 self
213 }
214
215 pub fn max_tokens_reached(mut self) -> Self {
217 self.stop_reason = Some(StopReason::MaxTokens);
218 self
219 }
220
221 pub fn build(self) -> CreateMessageResult {
223 CreateMessageResult {
224 role: self.role,
225 content: mcpkit_core::types::content::Content::text(self.content),
226 model: self.model,
227 stop_reason: self.stop_reason,
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_sampling_request_builder() {
238 let request = SamplingRequestBuilder::new()
239 .system_prompt("You are a helpful assistant")
240 .user("Hello!")
241 .max_tokens(100)
242 .temperature(0.7)
243 .build();
244
245 assert_eq!(request.messages.len(), 1);
246 assert_eq!(request.max_tokens, 100);
247 assert_eq!(request.temperature, Some(0.7));
248 assert_eq!(
249 request.system_prompt.as_deref(),
250 Some("You are a helpful assistant")
251 );
252 }
253
254 #[test]
255 fn test_sampling_result_builder() {
256 let result = SamplingResultBuilder::new("gpt-4")
257 .content("Hello! How can I help you?")
258 .end_turn()
259 .build();
260
261 assert_eq!(result.role, Role::Assistant);
262 assert_eq!(result.model, "gpt-4");
263 assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
264 }
265
266 #[test]
267 fn test_sampling_service_default() {
268 let service = SamplingService::new();
269 assert!(!service.is_supported());
270 }
271
272 #[tokio::test]
273 async fn test_sampling_service_with_handler() {
274 let service = SamplingService::new().with_handler(|_req, _ctx| async {
275 Ok(SamplingResultBuilder::new("test-model")
276 .content("Test response")
277 .end_turn()
278 .build())
279 });
280
281 assert!(service.is_supported());
282 }
283}