mcpkit_server/capability/
sampling.rs

1//! Sampling capability implementation.
2//!
3//! This module provides support for LLM sampling requests
4//! in MCP servers.
5
6use crate::context::Context;
7use mcpkit_core::error::McpError;
8use mcpkit_core::types::content::Role;
9use mcpkit_core::types::sampling::{
10    CreateMessageRequest, CreateMessageResult, IncludeContext, ModelPreferences, SamplingMessage,
11    StopReason,
12};
13use std::future::Future;
14use std::pin::Pin;
15
16/// A boxed async function for handling sampling requests.
17pub type BoxedSamplingFn = Box<
18    dyn for<'a> Fn(
19            CreateMessageRequest,
20            &'a Context<'a>,
21        ) -> Pin<
22            Box<dyn Future<Output = Result<CreateMessageResult, McpError>> + Send + 'a>,
23        > + Send
24        + Sync,
25>;
26
27/// Service for handling sampling requests.
28///
29/// This allows clients to request LLM completions through the server.
30pub struct SamplingService {
31    handler: Option<BoxedSamplingFn>,
32}
33
34impl Default for SamplingService {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl SamplingService {
41    /// Create a new sampling service without a handler.
42    #[must_use]
43    pub fn new() -> Self {
44        Self { handler: None }
45    }
46
47    /// Set the sampling handler.
48    pub fn with_handler<F, Fut>(mut self, handler: F) -> Self
49    where
50        F: Fn(CreateMessageRequest, &Context<'_>) -> Fut + Send + Sync + 'static,
51        Fut: Future<Output = Result<CreateMessageResult, McpError>> + Send + 'static,
52    {
53        self.handler = Some(Box::new(move |req, ctx| Box::pin(handler(req, ctx))));
54        self
55    }
56
57    /// Check if sampling is supported.
58    #[must_use]
59    pub fn is_supported(&self) -> bool {
60        self.handler.is_some()
61    }
62
63    /// Create a message (perform sampling).
64    pub async fn create_message(
65        &self,
66        request: CreateMessageRequest,
67        ctx: &Context<'_>,
68    ) -> Result<CreateMessageResult, McpError> {
69        let handler = self
70            .handler
71            .as_ref()
72            .ok_or_else(|| McpError::invalid_request("Sampling not supported"))?;
73
74        (handler)(request, ctx).await
75    }
76}
77
78/// Builder for creating sampling requests.
79pub struct SamplingRequestBuilder {
80    messages: Vec<SamplingMessage>,
81    model_preferences: Option<ModelPreferences>,
82    system_prompt: Option<String>,
83    include_context: Option<IncludeContext>,
84    max_tokens: Option<u32>,
85    temperature: Option<f64>,
86    stop_sequences: Vec<String>,
87}
88
89impl Default for SamplingRequestBuilder {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl SamplingRequestBuilder {
96    /// Create a new request builder.
97    #[must_use]
98    pub const fn new() -> Self {
99        Self {
100            messages: Vec::new(),
101            model_preferences: None,
102            system_prompt: None,
103            include_context: None,
104            max_tokens: None,
105            temperature: None,
106            stop_sequences: Vec::new(),
107        }
108    }
109
110    /// Add a user message.
111    pub fn user(mut self, content: impl Into<String>) -> Self {
112        self.messages.push(SamplingMessage::user(content.into()));
113        self
114    }
115
116    /// Add an assistant message.
117    pub fn assistant(mut self, content: impl Into<String>) -> Self {
118        self.messages
119            .push(SamplingMessage::assistant(content.into()));
120        self
121    }
122
123    /// Add a message.
124    #[must_use]
125    pub fn message(mut self, msg: SamplingMessage) -> Self {
126        self.messages.push(msg);
127        self
128    }
129
130    /// Set model preferences.
131    #[must_use]
132    pub fn model_preferences(mut self, prefs: ModelPreferences) -> Self {
133        self.model_preferences = Some(prefs);
134        self
135    }
136
137    /// Set context inclusion.
138    #[must_use]
139    pub const fn include_context(mut self, context: IncludeContext) -> Self {
140        self.include_context = Some(context);
141        self
142    }
143
144    /// Set the system prompt.
145    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
146        self.system_prompt = Some(prompt.into());
147        self
148    }
149
150    /// Set the maximum tokens.
151    #[must_use]
152    pub const fn max_tokens(mut self, tokens: u32) -> Self {
153        self.max_tokens = Some(tokens);
154        self
155    }
156
157    /// Set the temperature.
158    #[must_use]
159    pub const fn temperature(mut self, temp: f64) -> Self {
160        self.temperature = Some(temp);
161        self
162    }
163
164    /// Add a stop sequence.
165    pub fn stop_sequence(mut self, seq: impl Into<String>) -> Self {
166        self.stop_sequences.push(seq.into());
167        self
168    }
169
170    /// Build the request.
171    #[must_use]
172    pub fn build(self) -> CreateMessageRequest {
173        CreateMessageRequest {
174            messages: self.messages,
175            model_preferences: self.model_preferences,
176            system_prompt: self.system_prompt,
177            include_context: self.include_context,
178            max_tokens: self.max_tokens.unwrap_or(1024),
179            temperature: self.temperature,
180            stop_sequences: if self.stop_sequences.is_empty() {
181                None
182            } else {
183                Some(self.stop_sequences)
184            },
185            metadata: None,
186        }
187    }
188}
189
190/// Builder for creating sampling results.
191pub struct SamplingResultBuilder {
192    role: Role,
193    content: String,
194    model: String,
195    stop_reason: Option<StopReason>,
196}
197
198impl SamplingResultBuilder {
199    /// Create a new result builder.
200    pub fn new(model: impl Into<String>) -> Self {
201        Self {
202            role: Role::Assistant,
203            content: String::new(),
204            model: model.into(),
205            stop_reason: None,
206        }
207    }
208
209    /// Set the content.
210    pub fn content(mut self, content: impl Into<String>) -> Self {
211        self.content = content.into();
212        self
213    }
214
215    /// Set the stop reason.
216    #[must_use]
217    pub const fn stop_reason(mut self, reason: StopReason) -> Self {
218        self.stop_reason = Some(reason);
219        self
220    }
221
222    /// Mark as stopped due to end turn.
223    #[must_use]
224    pub const fn end_turn(mut self) -> Self {
225        self.stop_reason = Some(StopReason::EndTurn);
226        self
227    }
228
229    /// Mark as stopped due to max tokens.
230    #[must_use]
231    pub const fn max_tokens_reached(mut self) -> Self {
232        self.stop_reason = Some(StopReason::MaxTokens);
233        self
234    }
235
236    /// Build the result.
237    #[must_use]
238    pub fn build(self) -> CreateMessageResult {
239        CreateMessageResult {
240            role: self.role,
241            content: mcpkit_core::types::content::Content::text(self.content),
242            model: self.model,
243            stop_reason: self.stop_reason,
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_sampling_request_builder() {
254        let request = SamplingRequestBuilder::new()
255            .system_prompt("You are a helpful assistant")
256            .user("Hello!")
257            .max_tokens(100)
258            .temperature(0.7)
259            .build();
260
261        assert_eq!(request.messages.len(), 1);
262        assert_eq!(request.max_tokens, 100);
263        assert_eq!(request.temperature, Some(0.7));
264        assert_eq!(
265            request.system_prompt.as_deref(),
266            Some("You are a helpful assistant")
267        );
268    }
269
270    #[test]
271    fn test_sampling_result_builder() {
272        let result = SamplingResultBuilder::new("gpt-4")
273            .content("Hello! How can I help you?")
274            .end_turn()
275            .build();
276
277        assert_eq!(result.role, Role::Assistant);
278        assert_eq!(result.model, "gpt-4");
279        assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
280    }
281
282    #[test]
283    fn test_sampling_service_default() {
284        let service = SamplingService::new();
285        assert!(!service.is_supported());
286    }
287
288    #[tokio::test]
289    async fn test_sampling_service_with_handler() {
290        let service = SamplingService::new().with_handler(|_req, _ctx| async {
291            Ok(SamplingResultBuilder::new("test-model")
292                .content("Test response")
293                .end_turn()
294                .build())
295        });
296
297        assert!(service.is_supported());
298    }
299}