1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
//! MCP Sampling Trait
//!
//! This module defines the high-level trait for implementing MCP sampling.
use async_trait::async_trait;
use turul_mcp_builders::prelude::*;
use turul_mcp_protocol::{
McpResult,
sampling::{CreateMessageRequest, CreateMessageResult},
};
/// High-level trait for implementing MCP sampling
///
/// McpSampling extends SamplingDefinition with execution capabilities.
/// All metadata is provided by the SamplingDefinition trait, ensuring
/// consistency between concrete Sampling structs and dynamic implementations.
#[async_trait]
pub trait McpSampling: SamplingDefinition + Send + Sync {
/// Create a message using the sampling model (per MCP spec)
///
/// This method processes the sampling/createMessage request and returns
/// the generated message response.
async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult>;
/// Optional: Check if this sampling handler can handle the given request
///
/// This allows for conditional sampling based on model preferences,
/// context size, or other factors.
fn can_handle(&self, _request: &CreateMessageRequest) -> bool {
true
}
/// Optional: Get sampling priority for request routing
///
/// Higher priority handlers are tried first when multiple handlers
/// can handle the same request.
fn priority(&self) -> u32 {
0
}
/// Optional: Validate the sampling request
///
/// This method can perform additional validation beyond basic parameter checks.
async fn validate_request(&self, request: &CreateMessageRequest) -> McpResult<()> {
// Basic validation - ensure max_tokens is reasonable
if request.params.max_tokens == 0 {
return Err(turul_mcp_protocol::McpError::validation(
"max_tokens must be greater than 0",
));
}
if request.params.max_tokens > 1000000 {
return Err(turul_mcp_protocol::McpError::validation(
"max_tokens exceeds maximum allowed value",
));
}
Ok(())
}
}
/// Convert an McpSampling trait object to CreateMessageParams
///
/// This is a convenience function for converting sampling definitions
/// to protocol parameters.
pub fn sampling_to_params(
sampling: &dyn McpSampling,
) -> turul_mcp_protocol::sampling::CreateMessageParams {
sampling.to_create_params()
}
#[cfg(test)]
mod tests {
use super::*;
use turul_mcp_protocol::sampling::SamplingMessage;
// HasSamplingConfig, HasSamplingContext, etc.
struct TestSampling {
messages: Vec<SamplingMessage>,
max_tokens: u32,
temperature: Option<f64>,
}
// Implement fine-grained traits (MCP spec compliant)
impl HasSamplingConfig for TestSampling {
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> Option<f64> {
self.temperature
}
}
impl HasSamplingContext for TestSampling {
fn messages(&self) -> &[SamplingMessage] {
&self.messages
}
}
impl HasModelPreferences for TestSampling {}
impl HasIcons for TestSampling {}
impl HasSamplingTools for TestSampling {}
// SamplingDefinition automatically implemented via blanket impl!
#[async_trait]
impl McpSampling for TestSampling {
async fn sample(&self, _request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
// Simulate message generation
let response_message = SamplingMessage {
role: turul_mcp_protocol::sampling::Role::Assistant,
content: turul_mcp_protocol::prompts::ContentBlock::Text {
text: "Generated response".to_string(),
annotations: None,
meta: None,
},
};
Ok(CreateMessageResult::new(
response_message.role,
response_message.content,
"test-model",
))
}
}
#[test]
fn test_sampling_trait() {
let sampling = TestSampling {
messages: vec![],
max_tokens: 100,
temperature: Some(0.7),
};
assert_eq!(sampling.max_tokens(), 100);
assert_eq!(sampling.temperature(), Some(0.7));
}
#[tokio::test]
async fn test_sampling_validation() {
let sampling = TestSampling {
messages: vec![],
max_tokens: 0,
temperature: None,
};
let params = sampling.to_create_params();
let request = CreateMessageRequest {
method: "sampling/createMessage".to_string(),
params,
};
let result = sampling.validate_request(&request).await;
assert!(result.is_err());
}
}