1use std::sync::Arc;
7
8use async_trait::async_trait;
9use awaken_runtime_contract::AgentSpec;
10use awaken_runtime_contract::contract::content::ContentBlock;
11use awaken_runtime_contract::contract::executor::{InferenceRequest, LlmExecutor};
12use awaken_runtime_contract::contract::message::Message;
13use mcp::transport::McpTransportError;
14use mcp::{CreateMessageParams, CreateMessageResult, SamplingContent};
15
16#[async_trait]
21pub trait SamplingHandler: Send + Sync {
22 async fn handle_create_message(
23 &self,
24 params: CreateMessageParams,
25 ) -> Result<CreateMessageResult, McpTransportError>;
26}
27
28#[async_trait]
55pub trait SamplingHandlerFactory: Send + Sync {
56 async fn for_agent(&self, agent_spec: &AgentSpec) -> Option<Arc<dyn SamplingHandler>>;
57}
58
59pub struct FixedSamplingHandlerFactory {
66 handler: Arc<dyn SamplingHandler>,
67}
68
69impl FixedSamplingHandlerFactory {
70 pub fn new(handler: Arc<dyn SamplingHandler>) -> Self {
71 Self { handler }
72 }
73}
74
75#[async_trait]
76impl SamplingHandlerFactory for FixedSamplingHandlerFactory {
77 async fn for_agent(&self, _agent_spec: &AgentSpec) -> Option<Arc<dyn SamplingHandler>> {
78 Some(self.handler.clone())
79 }
80}
81
82pub struct DefaultSamplingHandler {
86 executor: Arc<dyn LlmExecutor>,
87 upstream_model: String,
88}
89
90impl DefaultSamplingHandler {
91 pub fn new(executor: Arc<dyn LlmExecutor>, upstream_model: impl Into<String>) -> Self {
95 Self {
96 executor,
97 upstream_model: upstream_model.into(),
98 }
99 }
100
101 fn convert_messages(params: &CreateMessageParams) -> Result<Vec<Message>, McpTransportError> {
122 let mut out = Vec::with_capacity(params.messages.len());
123 for msg in ¶ms.messages {
124 let mut text_parts: Vec<&str> = Vec::with_capacity(msg.content.len());
125 for block in &msg.content {
126 match block {
127 SamplingContent::Text { text: t, .. } => text_parts.push(t.as_str()),
128 other => {
129 return Err(McpTransportError::TransportError(format!(
130 "sampling request contains unsupported content kind: {} \
131 (awaken's sampling handler only supports text — server should \
132 retry with a text-only message)",
133 sampling_content_kind(other)
134 )));
135 }
136 }
137 }
138 let joined = text_parts.join("\n\n");
139 out.push(match msg.role {
140 mcp::Role::User => Message::user(joined),
141 mcp::Role::Assistant => Message::assistant(joined),
142 });
143 }
144 Ok(out)
145 }
146
147 fn system_blocks(params: &CreateMessageParams) -> Vec<ContentBlock> {
149 match ¶ms.system_prompt {
150 Some(prompt) if !prompt.is_empty() => vec![ContentBlock::text(prompt.clone())],
151 _ => vec![],
152 }
153 }
154
155 fn convert_result(
157 result: &awaken_runtime_contract::contract::inference::StreamResult,
158 model: &str,
159 ) -> CreateMessageResult {
160 let text = result.text();
161 let content = vec![SamplingContent::Text {
162 text,
163 annotations: None,
164 meta: None,
165 }];
166
167 let stop_reason = result.stop_reason.map(|sr| match sr {
168 awaken_runtime_contract::contract::inference::StopReason::EndTurn => {
169 "endTurn".to_string()
170 }
171 awaken_runtime_contract::contract::inference::StopReason::MaxTokens => {
172 "maxTokens".to_string()
173 }
174 awaken_runtime_contract::contract::inference::StopReason::ToolUse => {
175 "toolUse".to_string()
176 }
177 awaken_runtime_contract::contract::inference::StopReason::StopSequence => {
178 "stopSequence".to_string()
179 }
180 });
181
182 CreateMessageResult {
183 role: mcp::Role::Assistant,
184 content,
185 model: model.to_string(),
186 stop_reason,
187 meta: None,
188 }
189 }
190}
191
192fn sampling_content_kind(content: &SamplingContent) -> &'static str {
196 match content {
197 SamplingContent::Text { .. } => "text",
198 SamplingContent::Image { .. } => "image",
199 SamplingContent::Audio { .. } => "audio",
200 SamplingContent::ToolUse { .. } => "tool_use",
201 SamplingContent::ToolResult { .. } => "tool_result",
202 }
203}
204
205fn reject_unsupported_sampling_fields(
223 params: &CreateMessageParams,
224) -> Result<(), McpTransportError> {
225 let mut unsupported: Vec<&'static str> = Vec::new();
226 if params
227 .stop_sequences
228 .as_ref()
229 .is_some_and(|s| !s.is_empty())
230 {
231 unsupported.push("stopSequences");
232 }
233 if params.include_context.is_some() {
234 unsupported.push("includeContext");
235 }
236 if params.tools.as_ref().is_some_and(|t| !t.is_empty()) {
237 unsupported.push("tools");
238 }
239 if params.tool_choice.is_some() {
240 unsupported.push("toolChoice");
241 }
242 if !unsupported.is_empty() {
243 return Err(McpTransportError::TransportError(format!(
244 "sampling request sets unsupported field(s): {} \
245 (awaken's DefaultSamplingHandler maps systemPrompt, \
246 temperature, maxTokens only; honouring others silently \
247 would change the LLM's reply away from what the server \
248 requested)",
249 unsupported.join(", ")
250 )));
251 }
252 Ok(())
253}
254
255#[async_trait]
256impl SamplingHandler for DefaultSamplingHandler {
257 async fn handle_create_message(
258 &self,
259 params: CreateMessageParams,
260 ) -> Result<CreateMessageResult, McpTransportError> {
261 reject_unsupported_sampling_fields(¶ms)?;
264
265 let messages = Self::convert_messages(¶ms)?;
266 if messages.is_empty() {
267 return Err(McpTransportError::TransportError(
268 "sampling request contained no messages".to_string(),
269 ));
270 }
271
272 let system = Self::system_blocks(¶ms);
273
274 let overrides = {
275 let mut ovr =
276 awaken_runtime_contract::contract::inference::InferenceOverride::default();
277 if let Some(temp) = params.temperature {
278 ovr.temperature = Some(temp);
279 }
280 ovr.max_tokens = Some(params.max_tokens);
281 if ovr.temperature.is_none() && ovr.max_tokens.is_none() {
282 None
283 } else {
284 Some(ovr)
285 }
286 };
287
288 let request = InferenceRequest {
289 upstream_model: self.upstream_model.clone(),
290 routing_key: None,
291 messages,
292 tools: vec![],
293 system,
294 overrides,
295 enable_prompt_cache: false,
296 };
297
298 let result =
299 self.executor.execute(request).await.map_err(|e| {
300 McpTransportError::TransportError(format!("LLM execution failed: {e}"))
301 })?;
302
303 Ok(Self::convert_result(&result, &self.upstream_model))
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use awaken_runtime_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
311 use awaken_runtime_contract::contract::message::Role;
312 use mcp::SamplingMessage;
313
314 struct MockLlm {
315 response_text: String,
316 }
317
318 #[async_trait]
319 impl LlmExecutor for MockLlm {
320 async fn execute(
321 &self,
322 _request: InferenceRequest,
323 ) -> Result<
324 StreamResult,
325 awaken_runtime_contract::contract::executor::InferenceExecutionError,
326 > {
327 Ok(StreamResult {
328 content: vec![ContentBlock::text(self.response_text.clone())],
329 tool_calls: vec![],
330 usage: Some(TokenUsage {
331 prompt_tokens: Some(10),
332 completion_tokens: Some(5),
333 total_tokens: Some(15),
334 ..Default::default()
335 }),
336 stop_reason: Some(StopReason::EndTurn),
337 has_incomplete_tool_calls: false,
338 })
339 }
340
341 fn name(&self) -> &str {
342 "mock"
343 }
344 }
345
346 fn make_params(text: &str) -> CreateMessageParams {
347 CreateMessageParams {
348 messages: vec![SamplingMessage {
349 role: mcp::Role::User,
350 content: vec![SamplingContent::Text {
351 text: text.to_string(),
352 annotations: None,
353 meta: None,
354 }],
355 meta: None,
356 }],
357 model_preferences: None,
358 system_prompt: None,
359 include_context: None,
360 temperature: None,
361 max_tokens: 1024,
362 stop_sequences: None,
363 metadata: None,
364 tools: None,
365 tool_choice: None,
366 task: None,
367 meta: None,
368 }
369 }
370
371 #[test]
372 fn convert_messages_maps_roles() {
373 let params = CreateMessageParams {
374 messages: vec![
375 SamplingMessage {
376 role: mcp::Role::User,
377 content: vec![SamplingContent::Text {
378 text: "hello".into(),
379 annotations: None,
380 meta: None,
381 }],
382 meta: None,
383 },
384 SamplingMessage {
385 role: mcp::Role::Assistant,
386 content: vec![SamplingContent::Text {
387 text: "hi there".into(),
388 annotations: None,
389 meta: None,
390 }],
391 meta: None,
392 },
393 ],
394 model_preferences: None,
395 system_prompt: None,
396 include_context: None,
397 temperature: None,
398 max_tokens: 1024,
399 stop_sequences: None,
400 metadata: None,
401 tools: None,
402 tool_choice: None,
403 task: None,
404 meta: None,
405 };
406 let msgs =
407 DefaultSamplingHandler::convert_messages(¶ms).expect("text-only converts cleanly");
408 assert_eq!(msgs.len(), 2);
409 assert_eq!(msgs[0].role, Role::User);
410 assert_eq!(msgs[0].text(), "hello");
411 assert_eq!(msgs[1].role, Role::Assistant);
412 assert_eq!(msgs[1].text(), "hi there");
413 }
414
415 #[test]
416 fn convert_messages_rejects_image_content() {
417 let params = CreateMessageParams {
425 messages: vec![SamplingMessage {
426 role: mcp::Role::User,
427 content: vec![
428 SamplingContent::Text {
429 text: "describe this:".into(),
430 annotations: None,
431 meta: None,
432 },
433 SamplingContent::Image {
434 data: "base64-blob".into(),
435 mime_type: "image/png".into(),
436 annotations: None,
437 meta: None,
438 },
439 ],
440 meta: None,
441 }],
442 model_preferences: None,
443 system_prompt: None,
444 include_context: None,
445 temperature: None,
446 max_tokens: 1024,
447 stop_sequences: None,
448 metadata: None,
449 tools: None,
450 tool_choice: None,
451 task: None,
452 meta: None,
453 };
454 let err =
455 DefaultSamplingHandler::convert_messages(¶ms).expect_err("image must be rejected");
456 let msg = format!("{err}");
457 assert!(
458 msg.contains("image"),
459 "error should identify the offending content kind, got: {msg}"
460 );
461 }
462
463 #[test]
464 fn convert_messages_rejects_audio_content() {
465 let params = CreateMessageParams {
466 messages: vec![SamplingMessage {
467 role: mcp::Role::User,
468 content: vec![SamplingContent::Audio {
469 data: "base64-blob".into(),
470 mime_type: "audio/wav".into(),
471 annotations: None,
472 meta: None,
473 }],
474 meta: None,
475 }],
476 model_preferences: None,
477 system_prompt: None,
478 include_context: None,
479 temperature: None,
480 max_tokens: 1024,
481 stop_sequences: None,
482 metadata: None,
483 tools: None,
484 tool_choice: None,
485 task: None,
486 meta: None,
487 };
488 let err =
489 DefaultSamplingHandler::convert_messages(¶ms).expect_err("audio must be rejected");
490 assert!(format!("{err}").contains("audio"));
491 }
492
493 #[test]
494 fn sampling_content_kind_names_each_variant() {
495 assert_eq!(
498 sampling_content_kind(&SamplingContent::Text {
499 text: "x".into(),
500 annotations: None,
501 meta: None,
502 }),
503 "text"
504 );
505 assert_eq!(
506 sampling_content_kind(&SamplingContent::Image {
507 data: "x".into(),
508 mime_type: "image/png".into(),
509 annotations: None,
510 meta: None,
511 }),
512 "image"
513 );
514 assert_eq!(
515 sampling_content_kind(&SamplingContent::Audio {
516 data: "x".into(),
517 mime_type: "audio/wav".into(),
518 annotations: None,
519 meta: None,
520 }),
521 "audio"
522 );
523 }
524
525 #[test]
526 fn system_blocks_from_params() {
527 let mut params = make_params("test");
528 assert!(DefaultSamplingHandler::system_blocks(¶ms).is_empty());
529
530 params.system_prompt = Some("Be helpful".into());
531 let blocks = DefaultSamplingHandler::system_blocks(¶ms);
532 assert_eq!(blocks.len(), 1);
533 match &blocks[0] {
534 ContentBlock::Text { text } => assert_eq!(text, "Be helpful"),
535 _ => panic!("expected text block"),
536 }
537 }
538
539 #[test]
540 fn convert_result_maps_stop_reasons() {
541 let result = StreamResult {
542 content: vec![ContentBlock::text("response")],
543 tool_calls: vec![],
544 usage: None,
545 stop_reason: Some(StopReason::EndTurn),
546 has_incomplete_tool_calls: false,
547 };
548 let mcp_result = DefaultSamplingHandler::convert_result(&result, "test-model");
549 assert_eq!(mcp_result.model, "test-model");
550 assert_eq!(mcp_result.stop_reason.as_deref(), Some("endTurn"));
551 assert!(matches!(mcp_result.role, mcp::Role::Assistant));
552 assert_eq!(mcp_result.content.len(), 1);
553 }
554
555 #[test]
556 fn convert_messages_joins_multi_text_with_blank_line() {
557 let params = CreateMessageParams {
562 messages: vec![SamplingMessage {
563 role: mcp::Role::User,
564 content: vec![
565 SamplingContent::Text {
566 text: "hello".into(),
567 annotations: None,
568 meta: None,
569 },
570 SamplingContent::Text {
571 text: "world".into(),
572 annotations: None,
573 meta: None,
574 },
575 ],
576 meta: None,
577 }],
578 model_preferences: None,
579 system_prompt: None,
580 include_context: None,
581 temperature: None,
582 max_tokens: 1024,
583 stop_sequences: None,
584 metadata: None,
585 tools: None,
586 tool_choice: None,
587 task: None,
588 meta: None,
589 };
590 let msgs = DefaultSamplingHandler::convert_messages(¶ms).unwrap();
591 assert_eq!(msgs.len(), 1);
592 assert_eq!(msgs[0].text(), "hello\n\nworld");
593 }
594
595 #[tokio::test]
596 async fn handle_create_message_rejects_stop_sequences() {
597 let executor = Arc::new(MockLlm {
598 response_text: "ignored".into(),
599 });
600 let handler = DefaultSamplingHandler::new(executor, "m");
601 let mut params = make_params("hi");
602 params.stop_sequences = Some(vec!["STOP".into()]);
603 let err = handler
604 .handle_create_message(params)
605 .await
606 .expect_err("stopSequences must be rejected");
607 let msg = format!("{err}");
608 assert!(msg.contains("stopSequences"), "got: {msg}");
609 }
610
611 #[tokio::test]
612 async fn handle_create_message_rejects_tool_choice() {
613 let executor = Arc::new(MockLlm {
614 response_text: "ignored".into(),
615 });
616 let handler = DefaultSamplingHandler::new(executor, "m");
617 let mut params = make_params("hi");
618 params.tool_choice = Some(mcp::ToolChoice {
619 mode: Some(mcp::ToolChoiceMode::Required),
620 });
621 let err = handler
622 .handle_create_message(params)
623 .await
624 .expect_err("toolChoice must be rejected");
625 assert!(format!("{err}").contains("toolChoice"));
626 }
627
628 #[tokio::test]
629 async fn handle_create_message_rejects_include_context() {
630 let executor = Arc::new(MockLlm {
631 response_text: "ignored".into(),
632 });
633 let handler = DefaultSamplingHandler::new(executor, "m");
634 let mut params = make_params("hi");
635 params.include_context = Some("thisServer".into());
636 let err = handler
637 .handle_create_message(params)
638 .await
639 .expect_err("must reject");
640 let msg = format!("{err}");
641 assert!(msg.contains("includeContext"), "got: {msg}");
642 assert!(!msg.contains("modelPreferences"), "got: {msg}");
643 }
644
645 #[tokio::test]
646 async fn default_sampling_handler_ignores_model_preferences() {
647 let executor = Arc::new(MockLlm {
648 response_text: "ok".into(),
649 });
650 let handler = DefaultSamplingHandler::new(executor, "configured-model");
651 let mut params = make_params("hi");
652 params.model_preferences = Some(mcp::ModelPreferences {
653 hints: None,
654 cost_priority: None,
655 speed_priority: None,
656 intelligence_priority: None,
657 });
658
659 let result = handler
660 .handle_create_message(params)
661 .await
662 .expect("modelPreferences are advisory and should not fail basic sampling");
663
664 assert_eq!(result.model, "configured-model");
665 }
666
667 #[tokio::test]
668 async fn default_sampling_handler_routes_to_executor() {
669 let executor = Arc::new(MockLlm {
670 response_text: "I can help!".into(),
671 });
672 let handler = DefaultSamplingHandler::new(executor, "test-model");
673
674 let params = make_params("help me");
675 let result = handler.handle_create_message(params).await.unwrap();
676
677 assert_eq!(result.model, "test-model");
678 assert!(matches!(result.role, mcp::Role::Assistant));
679 match &result.content[0] {
680 SamplingContent::Text { text, .. } => assert_eq!(text, "I can help!"),
681 _ => panic!("expected text content"),
682 }
683 assert_eq!(result.stop_reason.as_deref(), Some("endTurn"));
684 }
685
686 #[tokio::test]
687 async fn default_sampling_handler_empty_messages_returns_error() {
688 let executor = Arc::new(MockLlm {
689 response_text: "".into(),
690 });
691 let handler = DefaultSamplingHandler::new(executor, "test-model");
692
693 let params = CreateMessageParams {
694 messages: vec![],
695 model_preferences: None,
696 system_prompt: None,
697 include_context: None,
698 temperature: None,
699 max_tokens: 1024,
700 stop_sequences: None,
701 metadata: None,
702 tools: None,
703 tool_choice: None,
704 task: None,
705 meta: None,
706 };
707 let err = handler.handle_create_message(params).await;
708 assert!(err.is_err());
709 }
710
711 #[tokio::test]
712 async fn fixed_factory_returns_same_handler_regardless_of_agent() {
713 let executor = Arc::new(MockLlm {
717 response_text: "shared".into(),
718 });
719 let handler: Arc<dyn SamplingHandler> =
720 Arc::new(DefaultSamplingHandler::new(executor, "shared-model"));
721 let factory = FixedSamplingHandlerFactory::new(Arc::clone(&handler));
722
723 let spec_a = AgentSpec {
724 id: "a".into(),
725 model_id: "claude-opus".into(),
726 system_prompt: "".into(),
727 ..Default::default()
728 };
729 let spec_b = AgentSpec {
730 id: "b".into(),
731 model_id: "gpt-5".into(),
732 system_prompt: "".into(),
733 ..Default::default()
734 };
735
736 let resolved_a = factory.for_agent(&spec_a).await.expect("Some handler");
737 let resolved_b = factory.for_agent(&spec_b).await.expect("Some handler");
738 assert!(Arc::ptr_eq(&resolved_a, &handler));
740 assert!(Arc::ptr_eq(&resolved_b, &handler));
741 }
742
743 #[tokio::test]
744 async fn default_sampling_handler_passes_overrides() {
745 let executor = Arc::new(MockLlm {
747 response_text: "ok".into(),
748 });
749 let handler = DefaultSamplingHandler::new(executor, "model-v1");
750
751 let mut params = make_params("test");
752 params.temperature = Some(0.7);
753 params.max_tokens = 512;
754 params.system_prompt = Some("System".into());
755
756 let result = handler.handle_create_message(params).await.unwrap();
757 assert_eq!(result.model, "model-v1");
758 }
759}