mcp_daemon/server/sampling/
mod.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use crate::server::error::ServerError;
9
10type Result<T> = std::result::Result<T, ServerError>;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum MessageRole {
16 User,
18 Assistant,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "type", rename_all = "lowercase")]
25pub enum MessageContent {
26 Text {
28 text: String
30 },
31 Image {
33 data: String,
35 mime_type: Option<String>
37 },
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Message {
43 pub role: MessageRole,
45 pub content: MessageContent,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ModelPreferences {
52 pub hints: Option<Vec<ModelHint>>,
54 pub cost_priority: Option<f32>,
56 pub speed_priority: Option<f32>,
58 pub intelligence_priority: Option<f32>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ModelHint {
65 pub name: Option<String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(rename_all = "camelCase")]
72pub enum ContextInclusion {
73 None,
75 ThisServer,
77 AllServers,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SamplingRequest {
84 pub messages: Vec<Message>,
86 pub model_preferences: Option<ModelPreferences>,
88 pub system_prompt: Option<String>,
90 pub include_context: Option<ContextInclusion>,
92 pub temperature: Option<f32>,
94 pub max_tokens: u32,
96 pub stop_sequences: Option<Vec<String>>,
98 pub metadata: Option<HashMap<String, Value>>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(rename_all = "camelCase")]
105pub enum StopReason {
106 EndTurn,
108 StopSequence,
110 MaxTokens,
112 Unknown,
114 #[serde(other)]
115 Other,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct SamplingResult {
122 pub model: String,
124 pub stop_reason: Option<StopReason>,
126 pub role: MessageRole,
128 pub content: MessageContent,
130}
131
132pub trait SamplingCallback: Send + Sync {
134 fn call(
136 &self,
137 request: SamplingRequest,
138 ) -> Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>>;
139}
140
141impl<F, Fut> SamplingCallback for F
142where
143 F: Fn(SamplingRequest) -> Fut + Send + Sync + 'static,
144 Fut: Future<Output = Result<SamplingResult>> + Send + 'static,
145{
146 fn call(
147 &self,
148 request: SamplingRequest,
149 ) -> Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>> {
150 Box::pin(self(request))
151 }
152}
153
154type SamplingFuture = Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>>;
156type SamplingCallbackFunc = Arc<dyn Fn(SamplingRequest) -> SamplingFuture + Send + Sync>;
157
158pub(crate) struct RegisteredSampling {
160 #[allow(dead_code)]
162 pub callback: SamplingCallbackFunc,
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[tokio::test]
170 async fn test_sampling_request() {
171 let request = SamplingRequest {
172 messages: vec![Message {
173 role: MessageRole::User,
174 content: MessageContent::Text {
175 text: "Hello".to_string(),
176 },
177 }],
178 model_preferences: Some(ModelPreferences {
179 hints: Some(vec![ModelHint {
180 name: Some("claude-3".to_string()),
181 }]),
182 cost_priority: Some(0.5),
183 speed_priority: Some(0.8),
184 intelligence_priority: Some(0.9),
185 }),
186 system_prompt: Some("You are a helpful assistant.".to_string()),
187 include_context: Some(ContextInclusion::ThisServer),
188 temperature: Some(0.7),
189 max_tokens: 100,
190 stop_sequences: Some(vec!["END".to_string()]),
191 metadata: None,
192 };
193
194 let callback = |_req: SamplingRequest| {
195 Box::pin(async move {
196 Ok(SamplingResult {
197 model: "claude-3".to_string(),
198 stop_reason: Some(StopReason::EndTurn),
199 role: MessageRole::Assistant,
200 content: MessageContent::Text {
201 text: "Hi there!".to_string(),
202 },
203 })
204 }) as Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send>>
205 };
206
207 let result = callback(request).await.unwrap();
208 assert_eq!(result.model, "claude-3");
209 if let MessageContent::Text { text } = result.content {
210 assert_eq!(text, "Hi there!");
211 } else {
212 panic!("Expected text content");
213 }
214 }
215}