modelcontextprotocol_server/
sampling.rs1use anyhow::{anyhow, Result};
3use mcp_protocol::types::sampling::{CreateMessageParams, CreateMessageResult};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7pub type CreateMessageCallback = Box<dyn Fn(&CreateMessageParams) -> Result<CreateMessageResult> + Send + Sync>;
9
10pub struct SamplingManager {
12 create_message_callback: Arc<Mutex<Option<CreateMessageCallback>>>,
13}
14
15impl SamplingManager {
16 pub fn new() -> Self {
18 Self {
19 create_message_callback: Arc::new(Mutex::new(None)),
20 }
21 }
22
23 pub fn register_create_message_callback(&self, callback: CreateMessageCallback) {
25 let mut cb = self.create_message_callback.blocking_lock();
26 *cb = Some(callback);
27 }
28
29 pub async fn create_message(&self, params: &CreateMessageParams) -> Result<CreateMessageResult> {
31 let cb = self.create_message_callback.lock().await;
33 if cb.is_none() {
34 return Err(anyhow!("No create message callback registered"));
35 }
36
37 let callback_ref = cb.as_ref().unwrap();
39 callback_ref(params)
40 }
41}
42
43impl Default for SamplingManager {
44 fn default() -> Self {
45 Self::new()
46 }
47}