Skip to main content

mcp_kit/server/
sampling.rs

1//! Sampling API for server-initiated LLM requests.
2//!
3//! This module allows MCP servers to request LLM completions from clients
4//! that support the sampling capability. This enables agentic workflows
5//! where the server can leverage the client's LLM for complex tasks.
6//!
7//! # Example
8//! ```rust,ignore
9//! use mcp_kit::server::sampling::{SamplingRequest, SamplingClient};
10//!
11//! async fn agentic_tool(client: impl SamplingClient) -> Result<String, Error> {
12//!     let request = SamplingRequest::new()
13//!         .add_user_message("Analyze this data and suggest improvements")
14//!         .max_tokens(1000);
15//!     
16//!     let response = client.create_message(request).await?;
17//!     Ok(response.content.text().unwrap_or_default())
18//! }
19//! ```
20
21use crate::error::McpResult;
22use crate::protocol::{JsonRpcRequest, RequestId};
23use crate::types::sampling::{
24    CreateMessageRequest, CreateMessageResult, IncludeContext, ModelPreferences, SamplingMessage,
25};
26use serde_json::Value;
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::sync::Arc;
31use tokio::sync::{mpsc, oneshot};
32
33/// A client for making sampling requests to the MCP client.
34///
35/// Implementations handle the actual communication with the client.
36pub trait SamplingClient: Send + Sync {
37    /// Send a sampling/createMessage request to the client.
38    fn create_message(
39        &self,
40        request: CreateMessageRequest,
41    ) -> Pin<Box<dyn Future<Output = McpResult<CreateMessageResult>> + Send + '_>>;
42}
43
44/// Builder for creating sampling requests.
45#[derive(Debug, Clone, Default)]
46pub struct SamplingRequestBuilder {
47    messages: Vec<SamplingMessage>,
48    model_preferences: Option<ModelPreferences>,
49    system_prompt: Option<String>,
50    max_tokens: u32,
51    stop_sequences: Option<Vec<String>>,
52    temperature: Option<f64>,
53    metadata: Option<Value>,
54    include_context: Option<IncludeContext>,
55}
56
57impl SamplingRequestBuilder {
58    /// Create a new sampling request builder.
59    pub fn new() -> Self {
60        Self {
61            max_tokens: 1000,
62            ..Default::default()
63        }
64    }
65
66    /// Add a user message.
67    pub fn user_message(mut self, content: impl Into<String>) -> Self {
68        self.messages.push(SamplingMessage::user_text(content));
69        self
70    }
71
72    /// Add an assistant message (for multi-turn conversations).
73    pub fn assistant_message(mut self, content: impl Into<String>) -> Self {
74        self.messages.push(SamplingMessage::assistant_text(content));
75        self
76    }
77
78    /// Add a raw message with role and content.
79    pub fn message(mut self, message: SamplingMessage) -> Self {
80        self.messages.push(message);
81        self
82    }
83
84    /// Set the system prompt.
85    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86        self.system_prompt = Some(prompt.into());
87        self
88    }
89
90    /// Set maximum tokens for the response.
91    pub fn max_tokens(mut self, tokens: u32) -> Self {
92        self.max_tokens = tokens;
93        self
94    }
95
96    /// Set stop sequences.
97    pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
98        self.stop_sequences = Some(sequences);
99        self
100    }
101
102    /// Set temperature for sampling.
103    pub fn temperature(mut self, temp: f64) -> Self {
104        self.temperature = Some(temp);
105        self
106    }
107
108    /// Set model preferences.
109    pub fn model_preferences(mut self, prefs: ModelPreferences) -> Self {
110        self.model_preferences = Some(prefs);
111        self
112    }
113
114    /// Set whether to include MCP context.
115    pub fn include_context(mut self, include: IncludeContext) -> Self {
116        self.include_context = Some(include);
117        self
118    }
119
120    /// Set custom metadata.
121    pub fn metadata(mut self, metadata: Value) -> Self {
122        self.metadata = Some(metadata);
123        self
124    }
125
126    /// Build the request.
127    pub fn build(self) -> CreateMessageRequest {
128        CreateMessageRequest {
129            messages: self.messages,
130            model_preferences: self.model_preferences,
131            system_prompt: self.system_prompt,
132            include_context: self.include_context,
133            temperature: self.temperature,
134            max_tokens: self.max_tokens,
135            stop_sequences: self.stop_sequences,
136            metadata: self.metadata,
137        }
138    }
139}
140
141/// Channel-based sampling client implementation.
142///
143/// Sends requests through a channel to be forwarded to the MCP client.
144#[derive(Clone)]
145pub struct ChannelSamplingClient {
146    request_tx: mpsc::Sender<(JsonRpcRequest, oneshot::Sender<McpResult<Value>>)>,
147    next_id: Arc<AtomicU64>,
148}
149
150impl ChannelSamplingClient {
151    /// Create a new channel-based sampling client.
152    pub fn new(
153        request_tx: mpsc::Sender<(JsonRpcRequest, oneshot::Sender<McpResult<Value>>)>,
154    ) -> Self {
155        Self {
156            request_tx,
157            next_id: Arc::new(AtomicU64::new(1)),
158        }
159    }
160
161    fn next_request_id(&self) -> RequestId {
162        RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst) as i64)
163    }
164}
165
166impl SamplingClient for ChannelSamplingClient {
167    fn create_message(
168        &self,
169        request: CreateMessageRequest,
170    ) -> Pin<Box<dyn Future<Output = McpResult<CreateMessageResult>> + Send + '_>> {
171        Box::pin(async move {
172            let (response_tx, response_rx) = oneshot::channel();
173
174            let rpc_request = JsonRpcRequest {
175                jsonrpc: "2.0".to_string(),
176                id: self.next_request_id(),
177                method: "sampling/createMessage".to_string(),
178                params: Some(serde_json::to_value(&request)?),
179            };
180
181            self.request_tx
182                .send((rpc_request, response_tx))
183                .await
184                .map_err(|_| {
185                    crate::error::McpError::InternalError("Sampling channel closed".to_string())
186                })?;
187
188            let result = response_rx.await.map_err(|_| {
189                crate::error::McpError::InternalError("Response channel closed".to_string())
190            })??;
191
192            serde_json::from_value(result).map_err(|e| {
193                crate::error::McpError::InternalError(format!("Invalid sampling response: {}", e))
194            })
195        })
196    }
197}
198
199/// No-op sampling client for when client doesn't support sampling.
200#[derive(Clone, Default)]
201pub struct NoOpSamplingClient;
202
203impl SamplingClient for NoOpSamplingClient {
204    fn create_message(
205        &self,
206        _request: CreateMessageRequest,
207    ) -> Pin<Box<dyn Future<Output = McpResult<CreateMessageResult>> + Send + '_>> {
208        Box::pin(async move {
209            Err(crate::error::McpError::InvalidRequest(
210                "Client does not support sampling".to_string(),
211            ))
212        })
213    }
214}