mcp_daemon/server/sampling/
mod.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use crate::server::error::ServerError;
9
10type Result<T> = std::result::Result<T, ServerError>;
11
12/// Message role in a sampling conversation
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum MessageRole {
16    /// Role representing the user in a conversation
17    User,
18    /// Role representing the AI assistant in a conversation
19    Assistant,
20}
21
22/// Content type for a message
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "type", rename_all = "lowercase")]
25pub enum MessageContent {
26    /// Text content with a string payload
27    Text {
28        /// The text content
29        text: String
30    },
31    /// Image content with base64-encoded data and optional MIME type
32    Image {
33        /// The image data (e.g., base64 encoded)
34        data: String,
35        /// The optional MIME type of the image
36        mime_type: Option<String>
37    },
38}
39
40/// A message in a sampling conversation
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Message {
43    /// The role of the message sender (user or assistant)
44    pub role: MessageRole,
45    /// The content of the message (text or image)
46    pub content: MessageContent,
47}
48
49/// Model selection preferences
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ModelPreferences {
52    /// Optional hints for model selection
53    pub hints: Option<Vec<ModelHint>>,
54    /// Priority for cost optimization (0.0 to 1.0)
55    pub cost_priority: Option<f32>,
56    /// Priority for speed optimization (0.0 to 1.0)
57    pub speed_priority: Option<f32>,
58    /// Priority for intelligence/quality optimization (0.0 to 1.0)
59    pub intelligence_priority: Option<f32>,
60}
61
62/// A hint for model selection
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ModelHint {
65    /// Optional specific model name to use
66    pub name: Option<String>,
67}
68
69/// Context inclusion level for sampling
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(rename_all = "camelCase")]
72pub enum ContextInclusion {
73    /// No model preference
74    None,
75    /// Use the model running on this server
76    ThisServer,
77    /// Use models from all available servers
78    AllServers,
79}
80
81/// Parameters for a sampling request
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SamplingRequest {
84    /// List of messages in the conversation
85    pub messages: Vec<Message>,
86    /// Optional model preferences for sampling
87    pub model_preferences: Option<ModelPreferences>,
88    /// Optional system prompt
89    pub system_prompt: Option<String>,
90    /// Optional context inclusion settings
91    pub include_context: Option<ContextInclusion>,
92    /// Optional temperature parameter for sampling
93    pub temperature: Option<f32>,
94    /// Maximum number of tokens to generate
95    pub max_tokens: u32,
96    /// Optional sequences that will stop generation
97    pub stop_sequences: Option<Vec<String>>,
98    /// Optional metadata for the sampling request
99    pub metadata: Option<HashMap<String, Value>>,
100}
101
102/// Stop reason for a sampling completion
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(rename_all = "camelCase")]
105pub enum StopReason {
106    /// Generation stopped due to end of turn
107    EndTurn,
108    /// Generation stopped due to stop sequence
109    StopSequence,
110    /// Generation stopped due to max tokens
111    MaxTokens,
112    /// Generation stopped for unknown reason
113    Unknown,
114    #[serde(other)]
115    /// Generation stopped for other reason
116    Other,
117}
118
119/// Result of a sampling request
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct SamplingResult {
122    /// The model to use for sampling
123    pub model: String,
124    /// The reason why generation stopped
125    pub stop_reason: Option<StopReason>,
126    /// The role of the message sender (user or assistant)
127    pub role: MessageRole,
128    /// The content of the message (text or image)
129    pub content: MessageContent,
130}
131
132/// A callback that can handle sampling requests
133pub trait SamplingCallback: Send + Sync {
134    /// Calls the sampling function with the given request
135    fn call(
136        &self,
137        request: SamplingRequest,
138    ) -> Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>>;
139}
140
141impl<F, Fut> SamplingCallback for F
142where
143    F: Fn(SamplingRequest) -> Fut + Send + Sync + 'static,
144    Fut: Future<Output = Result<SamplingResult>> + Send + 'static,
145{
146    fn call(
147        &self,
148        request: SamplingRequest,
149    ) -> Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>> {
150        Box::pin(self(request))
151    }
152}
153
154// Type aliases for complex future and callback types
155type SamplingFuture = Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>>;
156type SamplingCallbackFunc = Arc<dyn Fn(SamplingRequest) -> SamplingFuture + Send + Sync>;
157
158/// A registered sampling handler
159pub(crate) struct RegisteredSampling {
160    /// The callback to handle sampling requests
161    #[allow(dead_code)]
162    pub callback: SamplingCallbackFunc,
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[tokio::test]
170    async fn test_sampling_request() {
171        let request = SamplingRequest {
172            messages: vec![Message {
173                role: MessageRole::User,
174                content: MessageContent::Text {
175                    text: "Hello".to_string(),
176                },
177            }],
178            model_preferences: Some(ModelPreferences {
179                hints: Some(vec![ModelHint {
180                    name: Some("claude-3".to_string()),
181                }]),
182                cost_priority: Some(0.5),
183                speed_priority: Some(0.8),
184                intelligence_priority: Some(0.9),
185            }),
186            system_prompt: Some("You are a helpful assistant.".to_string()),
187            include_context: Some(ContextInclusion::ThisServer),
188            temperature: Some(0.7),
189            max_tokens: 100,
190            stop_sequences: Some(vec!["END".to_string()]),
191            metadata: None,
192        };
193
194        let callback = |_req: SamplingRequest| {
195            Box::pin(async move {
196                Ok(SamplingResult {
197                    model: "claude-3".to_string(),
198                    stop_reason: Some(StopReason::EndTurn),
199                    role: MessageRole::Assistant,
200                    content: MessageContent::Text {
201                        text: "Hi there!".to_string(),
202                    },
203                })
204            }) as Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send>>
205        };
206
207        let result = callback(request).await.unwrap();
208        assert_eq!(result.model, "claude-3");
209        if let MessageContent::Text { text } = result.content {
210            assert_eq!(text, "Hi there!");
211        } else {
212            panic!("Expected text content");
213        }
214    }
215}