mcp_kit/server/
sampling.rs1use crate::error::McpResult;
22use crate::protocol::{JsonRpcRequest, RequestId};
23use crate::types::sampling::{
24 CreateMessageRequest, CreateMessageResult, IncludeContext, ModelPreferences, SamplingMessage,
25};
26use serde_json::Value;
27use std::future::Future;
28use std::pin::Pin;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::sync::Arc;
31use tokio::sync::{mpsc, oneshot};
32
33pub trait SamplingClient: Send + Sync {
37 fn create_message(
39 &self,
40 request: CreateMessageRequest,
41 ) -> Pin<Box<dyn Future<Output = McpResult<CreateMessageResult>> + Send + '_>>;
42}
43
44#[derive(Debug, Clone, Default)]
46pub struct SamplingRequestBuilder {
47 messages: Vec<SamplingMessage>,
48 model_preferences: Option<ModelPreferences>,
49 system_prompt: Option<String>,
50 max_tokens: u32,
51 stop_sequences: Option<Vec<String>>,
52 temperature: Option<f64>,
53 metadata: Option<Value>,
54 include_context: Option<IncludeContext>,
55}
56
57impl SamplingRequestBuilder {
58 pub fn new() -> Self {
60 Self {
61 max_tokens: 1000,
62 ..Default::default()
63 }
64 }
65
66 pub fn user_message(mut self, content: impl Into<String>) -> Self {
68 self.messages.push(SamplingMessage::user_text(content));
69 self
70 }
71
72 pub fn assistant_message(mut self, content: impl Into<String>) -> Self {
74 self.messages.push(SamplingMessage::assistant_text(content));
75 self
76 }
77
78 pub fn message(mut self, message: SamplingMessage) -> Self {
80 self.messages.push(message);
81 self
82 }
83
84 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86 self.system_prompt = Some(prompt.into());
87 self
88 }
89
90 pub fn max_tokens(mut self, tokens: u32) -> Self {
92 self.max_tokens = tokens;
93 self
94 }
95
96 pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
98 self.stop_sequences = Some(sequences);
99 self
100 }
101
102 pub fn temperature(mut self, temp: f64) -> Self {
104 self.temperature = Some(temp);
105 self
106 }
107
108 pub fn model_preferences(mut self, prefs: ModelPreferences) -> Self {
110 self.model_preferences = Some(prefs);
111 self
112 }
113
114 pub fn include_context(mut self, include: IncludeContext) -> Self {
116 self.include_context = Some(include);
117 self
118 }
119
120 pub fn metadata(mut self, metadata: Value) -> Self {
122 self.metadata = Some(metadata);
123 self
124 }
125
126 pub fn build(self) -> CreateMessageRequest {
128 CreateMessageRequest {
129 messages: self.messages,
130 model_preferences: self.model_preferences,
131 system_prompt: self.system_prompt,
132 include_context: self.include_context,
133 temperature: self.temperature,
134 max_tokens: self.max_tokens,
135 stop_sequences: self.stop_sequences,
136 metadata: self.metadata,
137 }
138 }
139}
140
141#[derive(Clone)]
145pub struct ChannelSamplingClient {
146 request_tx: mpsc::Sender<(JsonRpcRequest, oneshot::Sender<McpResult<Value>>)>,
147 next_id: Arc<AtomicU64>,
148}
149
150impl ChannelSamplingClient {
151 pub fn new(
153 request_tx: mpsc::Sender<(JsonRpcRequest, oneshot::Sender<McpResult<Value>>)>,
154 ) -> Self {
155 Self {
156 request_tx,
157 next_id: Arc::new(AtomicU64::new(1)),
158 }
159 }
160
161 fn next_request_id(&self) -> RequestId {
162 RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst) as i64)
163 }
164}
165
166impl SamplingClient for ChannelSamplingClient {
167 fn create_message(
168 &self,
169 request: CreateMessageRequest,
170 ) -> Pin<Box<dyn Future<Output = McpResult<CreateMessageResult>> + Send + '_>> {
171 Box::pin(async move {
172 let (response_tx, response_rx) = oneshot::channel();
173
174 let rpc_request = JsonRpcRequest {
175 jsonrpc: "2.0".to_string(),
176 id: self.next_request_id(),
177 method: "sampling/createMessage".to_string(),
178 params: Some(serde_json::to_value(&request)?),
179 };
180
181 self.request_tx
182 .send((rpc_request, response_tx))
183 .await
184 .map_err(|_| {
185 crate::error::McpError::InternalError("Sampling channel closed".to_string())
186 })?;
187
188 let result = response_rx.await.map_err(|_| {
189 crate::error::McpError::InternalError("Response channel closed".to_string())
190 })??;
191
192 serde_json::from_value(result).map_err(|e| {
193 crate::error::McpError::InternalError(format!("Invalid sampling response: {}", e))
194 })
195 })
196 }
197}
198
199#[derive(Clone, Default)]
201pub struct NoOpSamplingClient;
202
203impl SamplingClient for NoOpSamplingClient {
204 fn create_message(
205 &self,
206 _request: CreateMessageRequest,
207 ) -> Pin<Box<dyn Future<Output = McpResult<CreateMessageResult>> + Send + '_>> {
208 Box::pin(async move {
209 Err(crate::error::McpError::InvalidRequest(
210 "Client does not support sampling".to_string(),
211 ))
212 })
213 }
214}