turbomcp_server/
sampling.rs

1//! Server-initiated sampling support for TurboMCP
2//!
3//! This module provides helper functions for tools to make sampling requests
4//! to clients, enabling server-initiated LLM interactions.
5
6use crate::{ServerError, ServerResult};
7use turbomcp_protocol::RequestContext;
8use turbomcp_protocol::types::{
9    CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsResult,
10};
11
12/// Extension trait for RequestContext to provide sampling capabilities
13///
14/// Note: We use `async-trait` to address the async fn in trait warning
15#[async_trait::async_trait]
16pub trait SamplingExt {
17    /// Send a sampling/createMessage request to the client
18    async fn create_message(
19        &self,
20        request: CreateMessageRequest,
21    ) -> ServerResult<CreateMessageResult>;
22
23    /// Send an elicitation request to the client for user input
24    async fn elicit(&self, request: ElicitRequest) -> ServerResult<ElicitResult>;
25
26    /// List client's root capabilities
27    async fn list_roots(&self) -> ServerResult<ListRootsResult>;
28}
29
30#[async_trait::async_trait]
31impl SamplingExt for RequestContext {
32    async fn create_message(
33        &self,
34        request: CreateMessageRequest,
35    ) -> ServerResult<CreateMessageResult> {
36        let capabilities = self
37            .server_to_client()
38            .ok_or_else(|| ServerError::Handler {
39                message: "No server capabilities available for sampling requests".to_string(),
40                context: Some("sampling".to_string()),
41            })?;
42
43        // Fully typed - no serialization needed!
44        capabilities
45            .create_message(request, self.clone())
46            .await
47            .map_err(|e| ServerError::Handler {
48                message: format!("Sampling request failed: {}", e),
49                context: Some("sampling".to_string()),
50            })
51    }
52
53    async fn elicit(&self, request: ElicitRequest) -> ServerResult<ElicitResult> {
54        let capabilities = self
55            .server_to_client()
56            .ok_or_else(|| ServerError::Handler {
57                message: "No server capabilities available for elicitation requests".to_string(),
58                context: Some("elicitation".to_string()),
59            })?;
60
61        // Fully typed - no serialization needed!
62        capabilities
63            .elicit(request, self.clone())
64            .await
65            .map_err(|e| ServerError::Handler {
66                message: format!("Elicitation request failed: {}", e),
67                context: Some("elicitation".to_string()),
68            })
69    }
70
71    async fn list_roots(&self) -> ServerResult<ListRootsResult> {
72        let capabilities = self
73            .server_to_client()
74            .ok_or_else(|| ServerError::Handler {
75                message: "No server capabilities available for roots listing".to_string(),
76                context: Some("roots".to_string()),
77            })?;
78
79        // Fully typed - no serialization needed!
80        capabilities
81            .list_roots(self.clone())
82            .await
83            .map_err(|e| ServerError::Handler {
84                message: format!("Roots listing failed: {}", e),
85                context: Some("roots".to_string()),
86            })
87    }
88}