mcpkit_server/capability/
sampling.rs

1//! Sampling capability implementation.
2//!
3//! This module provides support for LLM sampling requests
4//! in MCP servers.
5
6use crate::context::Context;
7use mcpkit_core::error::McpError;
8use mcpkit_core::types::content::Role;
9use mcpkit_core::types::sampling::{
10    CreateMessageRequest, CreateMessageResult, IncludeContext, ModelPreferences,
11    SamplingMessage, StopReason,
12};
13use std::future::Future;
14use std::pin::Pin;
15
16/// A boxed async function for handling sampling requests.
17pub type BoxedSamplingFn = Box<
18    dyn for<'a> Fn(
19            CreateMessageRequest,
20            &'a Context<'a>,
21        ) -> Pin<Box<dyn Future<Output = Result<CreateMessageResult, McpError>> + Send + 'a>>
22        + Send
23        + Sync,
24>;
25
26/// Service for handling sampling requests.
27///
28/// This allows clients to request LLM completions through the server.
29pub struct SamplingService {
30    handler: Option<BoxedSamplingFn>,
31}
32
33impl Default for SamplingService {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl SamplingService {
40    /// Create a new sampling service without a handler.
41    pub fn new() -> Self {
42        Self { handler: None }
43    }
44
45    /// Set the sampling handler.
46    pub fn with_handler<F, Fut>(mut self, handler: F) -> Self
47    where
48        F: Fn(CreateMessageRequest, &Context<'_>) -> Fut + Send + Sync + 'static,
49        Fut: Future<Output = Result<CreateMessageResult, McpError>> + Send + 'static,
50    {
51        self.handler = Some(Box::new(move |req, ctx| Box::pin(handler(req, ctx))));
52        self
53    }
54
55    /// Check if sampling is supported.
56    pub fn is_supported(&self) -> bool {
57        self.handler.is_some()
58    }
59
60    /// Create a message (perform sampling).
61    pub async fn create_message(
62        &self,
63        request: CreateMessageRequest,
64        ctx: &Context<'_>,
65    ) -> Result<CreateMessageResult, McpError> {
66        let handler = self.handler.as_ref().ok_or_else(|| {
67            McpError::invalid_request("Sampling not supported")
68        })?;
69
70        (handler)(request, ctx).await
71    }
72}
73
74/// Builder for creating sampling requests.
75pub struct SamplingRequestBuilder {
76    messages: Vec<SamplingMessage>,
77    model_preferences: Option<ModelPreferences>,
78    system_prompt: Option<String>,
79    include_context: Option<IncludeContext>,
80    max_tokens: Option<u32>,
81    temperature: Option<f64>,
82    stop_sequences: Vec<String>,
83}
84
85impl Default for SamplingRequestBuilder {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl SamplingRequestBuilder {
92    /// Create a new request builder.
93    pub fn new() -> Self {
94        Self {
95            messages: Vec::new(),
96            model_preferences: None,
97            system_prompt: None,
98            include_context: None,
99            max_tokens: None,
100            temperature: None,
101            stop_sequences: Vec::new(),
102        }
103    }
104
105    /// Add a user message.
106    pub fn user(mut self, content: impl Into<String>) -> Self {
107        self.messages.push(SamplingMessage::user(content.into()));
108        self
109    }
110
111    /// Add an assistant message.
112    pub fn assistant(mut self, content: impl Into<String>) -> Self {
113        self.messages.push(SamplingMessage::assistant(content.into()));
114        self
115    }
116
117    /// Add a message.
118    pub fn message(mut self, msg: SamplingMessage) -> Self {
119        self.messages.push(msg);
120        self
121    }
122
123    /// Set model preferences.
124    pub fn model_preferences(mut self, prefs: ModelPreferences) -> Self {
125        self.model_preferences = Some(prefs);
126        self
127    }
128
129    /// Set context inclusion.
130    pub fn include_context(mut self, context: IncludeContext) -> Self {
131        self.include_context = Some(context);
132        self
133    }
134
135    /// Set the system prompt.
136    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
137        self.system_prompt = Some(prompt.into());
138        self
139    }
140
141    /// Set the maximum tokens.
142    pub fn max_tokens(mut self, tokens: u32) -> Self {
143        self.max_tokens = Some(tokens);
144        self
145    }
146
147    /// Set the temperature.
148    pub fn temperature(mut self, temp: f64) -> Self {
149        self.temperature = Some(temp);
150        self
151    }
152
153    /// Add a stop sequence.
154    pub fn stop_sequence(mut self, seq: impl Into<String>) -> Self {
155        self.stop_sequences.push(seq.into());
156        self
157    }
158
159    /// Build the request.
160    pub fn build(self) -> CreateMessageRequest {
161        CreateMessageRequest {
162            messages: self.messages,
163            model_preferences: self.model_preferences,
164            system_prompt: self.system_prompt,
165            include_context: self.include_context,
166            max_tokens: self.max_tokens.unwrap_or(1024),
167            temperature: self.temperature,
168            stop_sequences: if self.stop_sequences.is_empty() {
169                None
170            } else {
171                Some(self.stop_sequences)
172            },
173            metadata: None,
174        }
175    }
176}
177
178/// Builder for creating sampling results.
179pub struct SamplingResultBuilder {
180    role: Role,
181    content: String,
182    model: String,
183    stop_reason: Option<StopReason>,
184}
185
186impl SamplingResultBuilder {
187    /// Create a new result builder.
188    pub fn new(model: impl Into<String>) -> Self {
189        Self {
190            role: Role::Assistant,
191            content: String::new(),
192            model: model.into(),
193            stop_reason: None,
194        }
195    }
196
197    /// Set the content.
198    pub fn content(mut self, content: impl Into<String>) -> Self {
199        self.content = content.into();
200        self
201    }
202
203    /// Set the stop reason.
204    pub fn stop_reason(mut self, reason: StopReason) -> Self {
205        self.stop_reason = Some(reason);
206        self
207    }
208
209    /// Mark as stopped due to end turn.
210    pub fn end_turn(mut self) -> Self {
211        self.stop_reason = Some(StopReason::EndTurn);
212        self
213    }
214
215    /// Mark as stopped due to max tokens.
216    pub fn max_tokens_reached(mut self) -> Self {
217        self.stop_reason = Some(StopReason::MaxTokens);
218        self
219    }
220
221    /// Build the result.
222    pub fn build(self) -> CreateMessageResult {
223        CreateMessageResult {
224            role: self.role,
225            content: mcpkit_core::types::content::Content::text(self.content),
226            model: self.model,
227            stop_reason: self.stop_reason,
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_sampling_request_builder() {
238        let request = SamplingRequestBuilder::new()
239            .system_prompt("You are a helpful assistant")
240            .user("Hello!")
241            .max_tokens(100)
242            .temperature(0.7)
243            .build();
244
245        assert_eq!(request.messages.len(), 1);
246        assert_eq!(request.max_tokens, 100);
247        assert_eq!(request.temperature, Some(0.7));
248        assert_eq!(
249            request.system_prompt.as_deref(),
250            Some("You are a helpful assistant")
251        );
252    }
253
254    #[test]
255    fn test_sampling_result_builder() {
256        let result = SamplingResultBuilder::new("gpt-4")
257            .content("Hello! How can I help you?")
258            .end_turn()
259            .build();
260
261        assert_eq!(result.role, Role::Assistant);
262        assert_eq!(result.model, "gpt-4");
263        assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
264    }
265
266    #[test]
267    fn test_sampling_service_default() {
268        let service = SamplingService::new();
269        assert!(!service.is_supported());
270    }
271
272    #[tokio::test]
273    async fn test_sampling_service_with_handler() {
274        let service = SamplingService::new().with_handler(|_req, _ctx| async {
275            Ok(SamplingResultBuilder::new("test-model")
276                .content("Test response")
277                .end_turn()
278                .build())
279        });
280
281        assert!(service.is_supported());
282    }
283}