1use serde_json::Value;
7use std::collections::HashMap;
8
9use turul_mcp_protocol::prompts::ContentBlock;
11use turul_mcp_protocol::sampling::{
12 CreateMessageParams, CreateMessageRequest, ModelHint, ModelPreferences, Role, SamplingMessage,
13};
14
15pub struct MessageBuilder {
17 messages: Vec<SamplingMessage>,
18 model_preferences: Option<ModelPreferences>,
19 system_prompt: Option<String>,
20 include_context: Option<String>,
21 temperature: Option<f64>,
22 max_tokens: u32,
23 stop_sequences: Option<Vec<String>>,
24 metadata: Option<Value>,
25 meta: Option<HashMap<String, Value>>,
26}
27
28impl MessageBuilder {
29 pub fn new() -> Self {
31 Self {
32 messages: Vec::new(),
33 model_preferences: None,
34 system_prompt: None,
35 include_context: None,
36 temperature: None,
37 max_tokens: 1000, stop_sequences: None,
39 metadata: None,
40 meta: None,
41 }
42 }
43
44 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
46 self.max_tokens = max_tokens;
47 self
48 }
49
50 pub fn message(mut self, message: SamplingMessage) -> Self {
52 self.messages.push(message);
53 self
54 }
55
56 pub fn system(mut self, content: impl Into<String>) -> Self {
58 self.messages.push(SamplingMessage {
59 role: Role::System,
60 content: ContentBlock::text(content),
61 });
62 self
63 }
64
65 pub fn user_text(mut self, content: impl Into<String>) -> Self {
67 self.messages.push(SamplingMessage {
68 role: Role::User,
69 content: ContentBlock::text(content),
70 });
71 self
72 }
73
74 pub fn user_image(mut self, data: impl Into<String>, mime_type: impl Into<String>) -> Self {
76 self.messages.push(SamplingMessage {
77 role: Role::User,
78 content: ContentBlock::image(data, mime_type),
79 });
80 self
81 }
82
83 pub fn assistant_text(mut self, content: impl Into<String>) -> Self {
85 self.messages.push(SamplingMessage {
86 role: Role::Assistant,
87 content: ContentBlock::text(content),
88 });
89 self
90 }
91
92 pub fn model_preferences(mut self, preferences: ModelPreferences) -> Self {
94 self.model_preferences = Some(preferences);
95 self
96 }
97
98 pub fn with_model_preferences<F>(mut self, f: F) -> Self
100 where
101 F: FnOnce(ModelPreferencesBuilder) -> ModelPreferencesBuilder,
102 {
103 let builder = ModelPreferencesBuilder::new();
104 let preferences = f(builder).build();
105 self.model_preferences = Some(preferences);
106 self
107 }
108
109 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
111 self.system_prompt = Some(prompt.into());
112 self
113 }
114
115 pub fn include_context(mut self, context: impl Into<String>) -> Self {
117 self.include_context = Some(context.into());
118 self
119 }
120
121 pub fn temperature(mut self, temperature: f64) -> Self {
123 self.temperature = Some(temperature.clamp(0.0, 2.0));
124 self
125 }
126
127 pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
129 self.stop_sequences = Some(sequences);
130 self
131 }
132
133 pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
135 if let Some(ref mut sequences) = self.stop_sequences {
136 sequences.push(sequence.into());
137 } else {
138 self.stop_sequences = Some(vec![sequence.into()]);
139 }
140 self
141 }
142
143 pub fn metadata(mut self, metadata: Value) -> Self {
145 self.metadata = Some(metadata);
146 self
147 }
148
149 pub fn meta(mut self, meta: HashMap<String, Value>) -> Self {
151 self.meta = Some(meta);
152 self
153 }
154
155 pub fn build_params(self) -> CreateMessageParams {
157 let mut params = CreateMessageParams::new(self.messages, self.max_tokens);
158
159 if let Some(preferences) = self.model_preferences {
160 params = params.with_model_preferences(preferences);
161 }
162 if let Some(prompt) = self.system_prompt {
163 params = params.with_system_prompt(prompt);
164 }
165 if let Some(temp) = self.temperature {
166 params = params.with_temperature(temp);
167 }
168 if let Some(sequences) = self.stop_sequences {
169 params = params.with_stop_sequences(sequences);
170 }
171 if let Some(meta) = self.meta {
172 params = params.with_meta(meta);
173 }
174
175 params.include_context = self.include_context;
177 params.metadata = self.metadata;
178
179 params
180 }
181
182 pub fn build_request(self) -> CreateMessageRequest {
184 CreateMessageRequest {
185 method: "sampling/createMessage".to_string(),
186 params: self.build_params(),
187 }
188 }
189}
190
191impl Default for MessageBuilder {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197pub struct ModelPreferencesBuilder {
199 hints: Vec<ModelHint>,
200 cost_priority: Option<f64>,
201 speed_priority: Option<f64>,
202 intelligence_priority: Option<f64>,
203}
204
205impl ModelPreferencesBuilder {
206 pub fn new() -> Self {
207 Self {
208 hints: Vec::new(),
209 cost_priority: None,
210 speed_priority: None,
211 intelligence_priority: None,
212 }
213 }
214
215 pub fn hint(mut self, hint: ModelHint) -> Self {
217 self.hints.push(hint);
218 self
219 }
220
221 pub fn prefer_claude_sonnet(self) -> Self {
223 self.hint(ModelHint::Claude35Sonnet20241022)
224 }
225
226 pub fn prefer_claude_haiku(self) -> Self {
228 self.hint(ModelHint::Claude35Haiku20241022)
229 }
230
231 pub fn prefer_gpt4o(self) -> Self {
233 self.hint(ModelHint::Gpt4o)
234 }
235
236 pub fn prefer_gpt4o_mini(self) -> Self {
238 self.hint(ModelHint::Gpt4oMini)
239 }
240
241 pub fn prefer_fast(self) -> Self {
243 self.hint(ModelHint::Claude35Haiku20241022)
244 .hint(ModelHint::Gpt4oMini)
245 }
246
247 pub fn prefer_quality(self) -> Self {
249 self.hint(ModelHint::Claude35Sonnet20241022)
250 .hint(ModelHint::Gpt4o)
251 }
252
253 pub fn cost_priority(mut self, priority: f64) -> Self {
255 self.cost_priority = Some(priority.clamp(0.0, 1.0));
256 self
257 }
258
259 pub fn speed_priority(mut self, priority: f64) -> Self {
261 self.speed_priority = Some(priority.clamp(0.0, 1.0));
262 self
263 }
264
265 pub fn intelligence_priority(mut self, priority: f64) -> Self {
267 self.intelligence_priority = Some(priority.clamp(0.0, 1.0));
268 self
269 }
270
271 pub fn build(self) -> ModelPreferences {
273 ModelPreferences {
274 hints: if self.hints.is_empty() {
275 None
276 } else {
277 Some(self.hints)
278 },
279 cost_priority: self.cost_priority,
280 speed_priority: self.speed_priority,
281 intelligence_priority: self.intelligence_priority,
282 }
283 }
284}
285
286impl Default for ModelPreferencesBuilder {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292pub trait SamplingMessageExt {
294 fn system(content: impl Into<String>) -> SamplingMessage;
296 fn user_text(content: impl Into<String>) -> SamplingMessage;
298 fn user_image(data: impl Into<String>, mime_type: impl Into<String>) -> SamplingMessage;
300 fn assistant_text(content: impl Into<String>) -> SamplingMessage;
302}
303
304impl SamplingMessageExt for SamplingMessage {
305 fn system(content: impl Into<String>) -> Self {
306 Self {
307 role: Role::System,
308 content: ContentBlock::text(content),
309 }
310 }
311
312 fn user_text(content: impl Into<String>) -> Self {
313 Self {
314 role: Role::User,
315 content: ContentBlock::text(content),
316 }
317 }
318
319 fn user_image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
320 Self {
321 role: Role::User,
322 content: ContentBlock::image(data, mime_type),
323 }
324 }
325
326 fn assistant_text(content: impl Into<String>) -> Self {
327 Self {
328 role: Role::Assistant,
329 content: ContentBlock::text(content),
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use serde_json::json;
338
339 #[test]
340 fn test_message_builder_basic() {
341 let params = MessageBuilder::new()
342 .max_tokens(2000)
343 .system("You are a helpful assistant.")
344 .user_text("Hello, how are you?")
345 .assistant_text("I'm doing well, thank you!")
346 .temperature(0.7)
347 .build_params();
348
349 assert_eq!(params.messages.len(), 3);
350 assert_eq!(params.max_tokens, 2000);
351 assert_eq!(params.temperature, Some(0.7));
352
353 assert_eq!(params.messages[0].role, Role::System);
355 if let ContentBlock::Text { text, .. } = ¶ms.messages[0].content {
356 assert_eq!(text, "You are a helpful assistant.");
357 } else {
358 panic!("Expected text content");
359 }
360 }
361
362 #[test]
363 fn test_message_builder_model_preferences() {
364 let params = MessageBuilder::new()
365 .user_text("Test message")
366 .with_model_preferences(|prefs| {
367 prefs
368 .prefer_claude_sonnet()
369 .cost_priority(0.8)
370 .speed_priority(0.6)
371 .intelligence_priority(0.9)
372 })
373 .build_params();
374
375 let preferences = params
376 .model_preferences
377 .expect("Expected model preferences");
378 assert_eq!(preferences.hints.as_ref().unwrap().len(), 1);
379 assert_eq!(
380 preferences.hints.as_ref().unwrap()[0],
381 ModelHint::Claude35Sonnet20241022
382 );
383 assert_eq!(preferences.cost_priority, Some(0.8));
384 assert_eq!(preferences.speed_priority, Some(0.6));
385 assert_eq!(preferences.intelligence_priority, Some(0.9));
386 }
387
388 #[test]
389 fn test_message_builder_stop_sequences() {
390 let params = MessageBuilder::new()
391 .user_text("Generate some code")
392 .stop_sequence("```")
393 .stop_sequence("\n\n")
394 .build_params();
395
396 let sequences = params.stop_sequences.expect("Expected stop sequences");
397 assert_eq!(sequences.len(), 2);
398 assert_eq!(sequences[0], "```");
399 assert_eq!(sequences[1], "\n\n");
400 }
401
402 #[test]
403 fn test_message_builder_complete_request() {
404 let request = MessageBuilder::new()
405 .system_prompt("You are a coding assistant")
406 .user_text("Write a function to calculate fibonacci numbers")
407 .temperature(0.3)
408 .max_tokens(500)
409 .metadata(json!({"request_id": "12345"}))
410 .build_request();
411
412 assert_eq!(request.method, "sampling/createMessage");
413 assert_eq!(request.params.max_tokens, 500);
414 assert_eq!(request.params.temperature, Some(0.3));
415 assert_eq!(
416 request.params.system_prompt,
417 Some("You are a coding assistant".to_string())
418 );
419 assert!(request.params.metadata.is_some());
420 }
421
422 #[test]
423 fn test_model_preferences_builder() {
424 let preferences = ModelPreferencesBuilder::new()
425 .prefer_fast()
426 .cost_priority(0.9)
427 .speed_priority(0.8)
428 .build();
429
430 let hints = preferences.hints.expect("Expected hints");
431 assert_eq!(hints.len(), 2);
432 assert!(hints.contains(&ModelHint::Claude35Haiku20241022));
433 assert!(hints.contains(&ModelHint::Gpt4oMini));
434 assert_eq!(preferences.cost_priority, Some(0.9));
435 assert_eq!(preferences.speed_priority, Some(0.8));
436 }
437
438 #[test]
439 fn test_sampling_message_convenience_methods() {
440 let system_msg = SamplingMessage::system("System prompt");
441 assert_eq!(system_msg.role, Role::System);
442
443 let user_msg = SamplingMessage::user_text("User input");
444 assert_eq!(user_msg.role, Role::User);
445
446 let assistant_msg = SamplingMessage::assistant_text("Assistant response");
447 assert_eq!(assistant_msg.role, Role::Assistant);
448
449 let image_msg = SamplingMessage::user_image("base64data", "image/png");
450 assert_eq!(image_msg.role, Role::User);
451 if let ContentBlock::Image {
452 data, mime_type, ..
453 } = &image_msg.content
454 {
455 assert_eq!(data, "base64data");
456 assert_eq!(mime_type, "image/png");
457 } else {
458 panic!("Expected image content");
459 }
460 }
461
462 #[test]
463 fn test_temperature_clamping() {
464 let params = MessageBuilder::new()
465 .user_text("Test")
466 .temperature(5.0) .build_params();
468
469 assert_eq!(params.temperature, Some(2.0));
470
471 let params2 = MessageBuilder::new()
472 .user_text("Test")
473 .temperature(-1.0) .build_params();
475
476 assert_eq!(params2.temperature, Some(0.0));
477 }
478}