modelcontextprotocol_server/
sampling.rs

1// mcp-server/src/sampling.rs
2use anyhow::{anyhow, Result};
3use mcp_protocol::types::sampling::{CreateMessageParams, CreateMessageResult};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7/// Callback type for the sampling create message
8pub type CreateMessageCallback = Box<dyn Fn(&CreateMessageParams) -> Result<CreateMessageResult> + Send + Sync>;
9
10/// Sampling manager that handles requests for LLM sampling
11pub struct SamplingManager {
12    create_message_callback: Arc<Mutex<Option<CreateMessageCallback>>>,
13}
14
15impl SamplingManager {
16    /// Create a new sampling manager
17    pub fn new() -> Self {
18        Self {
19            create_message_callback: Arc::new(Mutex::new(None)),
20        }
21    }
22    
23    /// Register a create message callback
24    pub fn register_create_message_callback(&self, callback: CreateMessageCallback) {
25        let mut cb = self.create_message_callback.blocking_lock();
26        *cb = Some(callback);
27    }
28    
29    /// Create a message using the registered callback
30    pub async fn create_message(&self, params: &CreateMessageParams) -> Result<CreateMessageResult> {
31        // Get the callback and invoke it with the lock
32        let cb = self.create_message_callback.lock().await;
33        if cb.is_none() {
34            return Err(anyhow!("No create message callback registered"));
35        }
36        
37        // We can't clone the Box<dyn Fn...>, so we'll invoke it while we have the lock
38        let callback_ref = cb.as_ref().unwrap();
39        callback_ref(params)
40    }
41}
42
43impl Default for SamplingManager {
44    fn default() -> Self {
45        Self::new()
46    }
47}