Skip to main content

awaken_ext_mcp/
sampling.rs

1//! Sampling handler for routing MCP `sampling/createMessage` requests to an LLM.
2//!
3//! Provides the [`SamplingHandler`] trait and a [`DefaultSamplingHandler`]
4//! that bridges MCP sampling requests to an awaken [`LlmExecutor`].
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use awaken_runtime_contract::AgentSpec;
10use awaken_runtime_contract::contract::content::ContentBlock;
11use awaken_runtime_contract::contract::executor::{InferenceRequest, LlmExecutor};
12use awaken_runtime_contract::contract::message::Message;
13use mcp::transport::McpTransportError;
14use mcp::{CreateMessageParams, CreateMessageResult, SamplingContent};
15
16/// Handler for MCP `sampling/createMessage` requests from the server.
17///
18/// When an MCP server sends a `sampling/createMessage` request during tool
19/// execution, this handler is invoked to route it to an LLM for inference.
20#[async_trait]
21pub trait SamplingHandler: Send + Sync {
22    async fn handle_create_message(
23        &self,
24        params: CreateMessageParams,
25    ) -> Result<CreateMessageResult, McpTransportError>;
26}
27
28/// Factory that constructs a per-call [`SamplingHandler`] given the agent
29/// initiating an MCP tool call. Lets awaken route server-initiated
30/// `sampling/createMessage` requests to the **calling agent's** LLM
31/// executor — different agents using different models will see their own
32/// LLM respond to MCP sampling, instead of all sharing one fixed handler
33/// at the registry level (the previous design leak documented in
34/// `awaken-ext-mcp` audit).
35///
36/// `for_agent` returns:
37/// - `Some(handler)` — bind this agent's call to that handler;
38///   `sampling/createMessage` during the call routes there.
39/// - `None` — the factory **explicitly refuses** to bind this agent
40///   (e.g. `agent.model_id` doesn't resolve, agent opted out, tenant
41///   has no sampling quota). The manager maps this to
42///   `McpCallSampling::Denied`; the transport then rejects the
43///   server-initiated `sampling/createMessage` with JSON-RPC
44///   method-not-supported. **It does NOT fall back to the registry-
45///   level fixed handler** — falling through would re-introduce the
46///   cross-agent leak this factory exists to prevent.
47///
48/// The "no factory configured at all" case is a separate state
49/// (`McpCallSampling::Inherit`) and is the only path that falls back
50/// to the transport-level fixed handler. See [`McpCallSampling`] in
51/// `awaken_ext_mcp::transport` for the full three-state semantics.
52///
53/// [`McpCallSampling`]: crate::transport::McpCallSampling
54#[async_trait]
55pub trait SamplingHandlerFactory: Send + Sync {
56    async fn for_agent(&self, agent_spec: &AgentSpec) -> Option<Arc<dyn SamplingHandler>>;
57}
58
59/// Trivial factory that ignores the agent and always returns the same
60/// handler. Preserves the pre-R1 behaviour of "one fixed handler for all
61/// agents". New runtime wiring should provide a registry-driven factory
62/// that resolves the agent's `model_id` → provider → `LlmExecutor` and
63/// wraps it in [`DefaultSamplingHandler`] — that's the per-agent fix the
64/// MCP audit identified.
65pub struct FixedSamplingHandlerFactory {
66    handler: Arc<dyn SamplingHandler>,
67}
68
69impl FixedSamplingHandlerFactory {
70    pub fn new(handler: Arc<dyn SamplingHandler>) -> Self {
71        Self { handler }
72    }
73}
74
75#[async_trait]
76impl SamplingHandlerFactory for FixedSamplingHandlerFactory {
77    async fn for_agent(&self, _agent_spec: &AgentSpec) -> Option<Arc<dyn SamplingHandler>> {
78        Some(self.handler.clone())
79    }
80}
81
82/// Default [`SamplingHandler`] that converts MCP sampling requests to awaken
83/// [`InferenceRequest`]s, calls the configured [`LlmExecutor`], and converts
84/// the response back to MCP format.
85pub struct DefaultSamplingHandler {
86    executor: Arc<dyn LlmExecutor>,
87    upstream_model: String,
88}
89
90impl DefaultSamplingHandler {
91    /// Create a new handler backed by the given LLM executor.
92    ///
93    /// `upstream_model` is the model name sent to the configured executor.
94    pub fn new(executor: Arc<dyn LlmExecutor>, upstream_model: impl Into<String>) -> Self {
95        Self {
96            executor,
97            upstream_model: upstream_model.into(),
98        }
99    }
100
101    /// Convert MCP sampling messages to awaken [`Message`] types.
102    ///
103    /// MCP `SamplingContent` is a union of `Text`, `Image`, `Audio`, …
104    /// awaken's [`Message`] today only carries text. Rather than
105    /// silently dropping non-text blocks (which would let the server's
106    /// "describe this image" prompt arrive at the LLM with the image
107    /// stripped — a correctness bug masked as an empty turn), we
108    /// surface the limitation as a typed error.
109    ///
110    /// Returns `Err(unsupported_content_kind)` on the first non-text
111    /// block encountered. Callers should map this to an MCP JSON-RPC
112    /// error so the server learns its sampling request can't be
113    /// serviced and can decide how to proceed (retry with text only,
114    /// fall back to a different client, etc).
115    ///
116    /// Multiple text blocks within a single message are joined with a
117    /// blank line ("\n\n") so `"hello"` + `"world"` becomes
118    /// `"hello\n\nworld"`, not `"helloworld"`. The spec doesn't
119    /// prescribe a join; blank-line is the convention for prose
120    /// paragraphs and avoids accidentally fusing tokens.
121    fn convert_messages(params: &CreateMessageParams) -> Result<Vec<Message>, McpTransportError> {
122        let mut out = Vec::with_capacity(params.messages.len());
123        for msg in &params.messages {
124            let mut text_parts: Vec<&str> = Vec::with_capacity(msg.content.len());
125            for block in &msg.content {
126                match block {
127                    SamplingContent::Text { text: t, .. } => text_parts.push(t.as_str()),
128                    other => {
129                        return Err(McpTransportError::TransportError(format!(
130                            "sampling request contains unsupported content kind: {} \
131                             (awaken's sampling handler only supports text — server should \
132                             retry with a text-only message)",
133                            sampling_content_kind(other)
134                        )));
135                    }
136                }
137            }
138            let joined = text_parts.join("\n\n");
139            out.push(match msg.role {
140                mcp::Role::User => Message::user(joined),
141                mcp::Role::Assistant => Message::assistant(joined),
142            });
143        }
144        Ok(out)
145    }
146
147    /// Build the system prompt content blocks from the params.
148    fn system_blocks(params: &CreateMessageParams) -> Vec<ContentBlock> {
149        match &params.system_prompt {
150            Some(prompt) if !prompt.is_empty() => vec![ContentBlock::text(prompt.clone())],
151            _ => vec![],
152        }
153    }
154
155    /// Convert an awaken `StreamResult` to MCP `CreateMessageResult`.
156    fn convert_result(
157        result: &awaken_runtime_contract::contract::inference::StreamResult,
158        model: &str,
159    ) -> CreateMessageResult {
160        let text = result.text();
161        let content = vec![SamplingContent::Text {
162            text,
163            annotations: None,
164            meta: None,
165        }];
166
167        let stop_reason = result.stop_reason.map(|sr| match sr {
168            awaken_runtime_contract::contract::inference::StopReason::EndTurn => {
169                "endTurn".to_string()
170            }
171            awaken_runtime_contract::contract::inference::StopReason::MaxTokens => {
172                "maxTokens".to_string()
173            }
174            awaken_runtime_contract::contract::inference::StopReason::ToolUse => {
175                "toolUse".to_string()
176            }
177            awaken_runtime_contract::contract::inference::StopReason::StopSequence => {
178                "stopSequence".to_string()
179            }
180        });
181
182        CreateMessageResult {
183            role: mcp::Role::Assistant,
184            content,
185            model: model.to_string(),
186            stop_reason,
187            meta: None,
188        }
189    }
190}
191
192/// Name the variant of [`SamplingContent`] for use in error messages.
193/// Kept as a private free function so it can be unit-tested independently
194/// of [`DefaultSamplingHandler`].
195fn sampling_content_kind(content: &SamplingContent) -> &'static str {
196    match content {
197        SamplingContent::Text { .. } => "text",
198        SamplingContent::Image { .. } => "image",
199        SamplingContent::Audio { .. } => "audio",
200        SamplingContent::ToolUse { .. } => "tool_use",
201        SamplingContent::ToolResult { .. } => "tool_result",
202    }
203}
204
205/// Reject sampling requests whose presence would silently change LLM
206/// behaviour. The MCP spec lets the server specify stop sequences,
207/// context inclusion, tool choice, etc.; awaken's handler currently maps
208/// only a small subset (system prompt, temperature, max_tokens), so
209/// honouring these would produce a different reply than the server
210/// asked for — a class of bug that's invisible until model output goes
211/// subtly wrong. Returning an error puts the burden back on the server
212/// to either retry without the unsupported field or fall over to a
213/// different client.
214///
215/// `modelPreferences` is advisory in MCP sampling. The default handler
216/// uses the model already configured for the agent and ignores those
217/// hints instead of rejecting otherwise interoperable servers.
218///
219/// Returns `Err` with a human-readable description of the offending
220/// field. `Ok(())` means every behavioural field is either absent or
221/// in awaken's supported subset.
222fn reject_unsupported_sampling_fields(
223    params: &CreateMessageParams,
224) -> Result<(), McpTransportError> {
225    let mut unsupported: Vec<&'static str> = Vec::new();
226    if params
227        .stop_sequences
228        .as_ref()
229        .is_some_and(|s| !s.is_empty())
230    {
231        unsupported.push("stopSequences");
232    }
233    if params.include_context.is_some() {
234        unsupported.push("includeContext");
235    }
236    if params.tools.as_ref().is_some_and(|t| !t.is_empty()) {
237        unsupported.push("tools");
238    }
239    if params.tool_choice.is_some() {
240        unsupported.push("toolChoice");
241    }
242    if !unsupported.is_empty() {
243        return Err(McpTransportError::TransportError(format!(
244            "sampling request sets unsupported field(s): {} \
245             (awaken's DefaultSamplingHandler maps systemPrompt, \
246             temperature, maxTokens only; honouring others silently \
247             would change the LLM's reply away from what the server \
248             requested)",
249            unsupported.join(", ")
250        )));
251    }
252    Ok(())
253}
254
255#[async_trait]
256impl SamplingHandler for DefaultSamplingHandler {
257    async fn handle_create_message(
258        &self,
259        params: CreateMessageParams,
260    ) -> Result<CreateMessageResult, McpTransportError> {
261        // Reject BEFORE message conversion so the server sees the
262        // field-level objection even when content happens to be valid.
263        reject_unsupported_sampling_fields(&params)?;
264
265        let messages = Self::convert_messages(&params)?;
266        if messages.is_empty() {
267            return Err(McpTransportError::TransportError(
268                "sampling request contained no messages".to_string(),
269            ));
270        }
271
272        let system = Self::system_blocks(&params);
273
274        let overrides = {
275            let mut ovr =
276                awaken_runtime_contract::contract::inference::InferenceOverride::default();
277            if let Some(temp) = params.temperature {
278                ovr.temperature = Some(temp);
279            }
280            ovr.max_tokens = Some(params.max_tokens);
281            if ovr.temperature.is_none() && ovr.max_tokens.is_none() {
282                None
283            } else {
284                Some(ovr)
285            }
286        };
287
288        let request = InferenceRequest {
289            upstream_model: self.upstream_model.clone(),
290            routing_key: None,
291            messages,
292            tools: vec![],
293            system,
294            overrides,
295            enable_prompt_cache: false,
296        };
297
298        let result =
299            self.executor.execute(request).await.map_err(|e| {
300                McpTransportError::TransportError(format!("LLM execution failed: {e}"))
301            })?;
302
303        Ok(Self::convert_result(&result, &self.upstream_model))
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use awaken_runtime_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
311    use awaken_runtime_contract::contract::message::Role;
312    use mcp::SamplingMessage;
313
314    struct MockLlm {
315        response_text: String,
316    }
317
318    #[async_trait]
319    impl LlmExecutor for MockLlm {
320        async fn execute(
321            &self,
322            _request: InferenceRequest,
323        ) -> Result<
324            StreamResult,
325            awaken_runtime_contract::contract::executor::InferenceExecutionError,
326        > {
327            Ok(StreamResult {
328                content: vec![ContentBlock::text(self.response_text.clone())],
329                tool_calls: vec![],
330                usage: Some(TokenUsage {
331                    prompt_tokens: Some(10),
332                    completion_tokens: Some(5),
333                    total_tokens: Some(15),
334                    ..Default::default()
335                }),
336                stop_reason: Some(StopReason::EndTurn),
337                has_incomplete_tool_calls: false,
338            })
339        }
340
341        fn name(&self) -> &str {
342            "mock"
343        }
344    }
345
346    fn make_params(text: &str) -> CreateMessageParams {
347        CreateMessageParams {
348            messages: vec![SamplingMessage {
349                role: mcp::Role::User,
350                content: vec![SamplingContent::Text {
351                    text: text.to_string(),
352                    annotations: None,
353                    meta: None,
354                }],
355                meta: None,
356            }],
357            model_preferences: None,
358            system_prompt: None,
359            include_context: None,
360            temperature: None,
361            max_tokens: 1024,
362            stop_sequences: None,
363            metadata: None,
364            tools: None,
365            tool_choice: None,
366            task: None,
367            meta: None,
368        }
369    }
370
371    #[test]
372    fn convert_messages_maps_roles() {
373        let params = CreateMessageParams {
374            messages: vec![
375                SamplingMessage {
376                    role: mcp::Role::User,
377                    content: vec![SamplingContent::Text {
378                        text: "hello".into(),
379                        annotations: None,
380                        meta: None,
381                    }],
382                    meta: None,
383                },
384                SamplingMessage {
385                    role: mcp::Role::Assistant,
386                    content: vec![SamplingContent::Text {
387                        text: "hi there".into(),
388                        annotations: None,
389                        meta: None,
390                    }],
391                    meta: None,
392                },
393            ],
394            model_preferences: None,
395            system_prompt: None,
396            include_context: None,
397            temperature: None,
398            max_tokens: 1024,
399            stop_sequences: None,
400            metadata: None,
401            tools: None,
402            tool_choice: None,
403            task: None,
404            meta: None,
405        };
406        let msgs =
407            DefaultSamplingHandler::convert_messages(&params).expect("text-only converts cleanly");
408        assert_eq!(msgs.len(), 2);
409        assert_eq!(msgs[0].role, Role::User);
410        assert_eq!(msgs[0].text(), "hello");
411        assert_eq!(msgs[1].role, Role::Assistant);
412        assert_eq!(msgs[1].text(), "hi there");
413    }
414
415    #[test]
416    fn convert_messages_rejects_image_content() {
417        // Reviewer flagged the previous "silently filter non-text"
418        // behaviour: a server's "describe this image" sampling request
419        // would arrive at the LLM with the image stripped and only the
420        // prose text — producing nonsense answers and a baffling debug
421        // session. New behaviour: surface a typed error so the server
422        // sees method-supported-but-content-not-handled, not silent
423        // success with bogus output.
424        let params = CreateMessageParams {
425            messages: vec![SamplingMessage {
426                role: mcp::Role::User,
427                content: vec![
428                    SamplingContent::Text {
429                        text: "describe this:".into(),
430                        annotations: None,
431                        meta: None,
432                    },
433                    SamplingContent::Image {
434                        data: "base64-blob".into(),
435                        mime_type: "image/png".into(),
436                        annotations: None,
437                        meta: None,
438                    },
439                ],
440                meta: None,
441            }],
442            model_preferences: None,
443            system_prompt: None,
444            include_context: None,
445            temperature: None,
446            max_tokens: 1024,
447            stop_sequences: None,
448            metadata: None,
449            tools: None,
450            tool_choice: None,
451            task: None,
452            meta: None,
453        };
454        let err =
455            DefaultSamplingHandler::convert_messages(&params).expect_err("image must be rejected");
456        let msg = format!("{err}");
457        assert!(
458            msg.contains("image"),
459            "error should identify the offending content kind, got: {msg}"
460        );
461    }
462
463    #[test]
464    fn convert_messages_rejects_audio_content() {
465        let params = CreateMessageParams {
466            messages: vec![SamplingMessage {
467                role: mcp::Role::User,
468                content: vec![SamplingContent::Audio {
469                    data: "base64-blob".into(),
470                    mime_type: "audio/wav".into(),
471                    annotations: None,
472                    meta: None,
473                }],
474                meta: None,
475            }],
476            model_preferences: None,
477            system_prompt: None,
478            include_context: None,
479            temperature: None,
480            max_tokens: 1024,
481            stop_sequences: None,
482            metadata: None,
483            tools: None,
484            tool_choice: None,
485            task: None,
486            meta: None,
487        };
488        let err =
489            DefaultSamplingHandler::convert_messages(&params).expect_err("audio must be rejected");
490        assert!(format!("{err}").contains("audio"));
491    }
492
493    #[test]
494    fn sampling_content_kind_names_each_variant() {
495        // Lock in the strings used in error messages so a future
496        // refactor of the helper doesn't silently drop a variant.
497        assert_eq!(
498            sampling_content_kind(&SamplingContent::Text {
499                text: "x".into(),
500                annotations: None,
501                meta: None,
502            }),
503            "text"
504        );
505        assert_eq!(
506            sampling_content_kind(&SamplingContent::Image {
507                data: "x".into(),
508                mime_type: "image/png".into(),
509                annotations: None,
510                meta: None,
511            }),
512            "image"
513        );
514        assert_eq!(
515            sampling_content_kind(&SamplingContent::Audio {
516                data: "x".into(),
517                mime_type: "audio/wav".into(),
518                annotations: None,
519                meta: None,
520            }),
521            "audio"
522        );
523    }
524
525    #[test]
526    fn system_blocks_from_params() {
527        let mut params = make_params("test");
528        assert!(DefaultSamplingHandler::system_blocks(&params).is_empty());
529
530        params.system_prompt = Some("Be helpful".into());
531        let blocks = DefaultSamplingHandler::system_blocks(&params);
532        assert_eq!(blocks.len(), 1);
533        match &blocks[0] {
534            ContentBlock::Text { text } => assert_eq!(text, "Be helpful"),
535            _ => panic!("expected text block"),
536        }
537    }
538
539    #[test]
540    fn convert_result_maps_stop_reasons() {
541        let result = StreamResult {
542            content: vec![ContentBlock::text("response")],
543            tool_calls: vec![],
544            usage: None,
545            stop_reason: Some(StopReason::EndTurn),
546            has_incomplete_tool_calls: false,
547        };
548        let mcp_result = DefaultSamplingHandler::convert_result(&result, "test-model");
549        assert_eq!(mcp_result.model, "test-model");
550        assert_eq!(mcp_result.stop_reason.as_deref(), Some("endTurn"));
551        assert!(matches!(mcp_result.role, mcp::Role::Assistant));
552        assert_eq!(mcp_result.content.len(), 1);
553    }
554
555    #[test]
556    fn convert_messages_joins_multi_text_with_blank_line() {
557        // Prior version used push_str with no separator, so
558        // ["hello", "world"] became "helloworld". Blank-line join
559        // preserves the boundary so consumers can still see the
560        // paragraph structure the server sent.
561        let params = CreateMessageParams {
562            messages: vec![SamplingMessage {
563                role: mcp::Role::User,
564                content: vec![
565                    SamplingContent::Text {
566                        text: "hello".into(),
567                        annotations: None,
568                        meta: None,
569                    },
570                    SamplingContent::Text {
571                        text: "world".into(),
572                        annotations: None,
573                        meta: None,
574                    },
575                ],
576                meta: None,
577            }],
578            model_preferences: None,
579            system_prompt: None,
580            include_context: None,
581            temperature: None,
582            max_tokens: 1024,
583            stop_sequences: None,
584            metadata: None,
585            tools: None,
586            tool_choice: None,
587            task: None,
588            meta: None,
589        };
590        let msgs = DefaultSamplingHandler::convert_messages(&params).unwrap();
591        assert_eq!(msgs.len(), 1);
592        assert_eq!(msgs[0].text(), "hello\n\nworld");
593    }
594
595    #[tokio::test]
596    async fn handle_create_message_rejects_stop_sequences() {
597        let executor = Arc::new(MockLlm {
598            response_text: "ignored".into(),
599        });
600        let handler = DefaultSamplingHandler::new(executor, "m");
601        let mut params = make_params("hi");
602        params.stop_sequences = Some(vec!["STOP".into()]);
603        let err = handler
604            .handle_create_message(params)
605            .await
606            .expect_err("stopSequences must be rejected");
607        let msg = format!("{err}");
608        assert!(msg.contains("stopSequences"), "got: {msg}");
609    }
610
611    #[tokio::test]
612    async fn handle_create_message_rejects_tool_choice() {
613        let executor = Arc::new(MockLlm {
614            response_text: "ignored".into(),
615        });
616        let handler = DefaultSamplingHandler::new(executor, "m");
617        let mut params = make_params("hi");
618        params.tool_choice = Some(mcp::ToolChoice {
619            mode: Some(mcp::ToolChoiceMode::Required),
620        });
621        let err = handler
622            .handle_create_message(params)
623            .await
624            .expect_err("toolChoice must be rejected");
625        assert!(format!("{err}").contains("toolChoice"));
626    }
627
628    #[tokio::test]
629    async fn handle_create_message_rejects_include_context() {
630        let executor = Arc::new(MockLlm {
631            response_text: "ignored".into(),
632        });
633        let handler = DefaultSamplingHandler::new(executor, "m");
634        let mut params = make_params("hi");
635        params.include_context = Some("thisServer".into());
636        let err = handler
637            .handle_create_message(params)
638            .await
639            .expect_err("must reject");
640        let msg = format!("{err}");
641        assert!(msg.contains("includeContext"), "got: {msg}");
642        assert!(!msg.contains("modelPreferences"), "got: {msg}");
643    }
644
645    #[tokio::test]
646    async fn default_sampling_handler_ignores_model_preferences() {
647        let executor = Arc::new(MockLlm {
648            response_text: "ok".into(),
649        });
650        let handler = DefaultSamplingHandler::new(executor, "configured-model");
651        let mut params = make_params("hi");
652        params.model_preferences = Some(mcp::ModelPreferences {
653            hints: None,
654            cost_priority: None,
655            speed_priority: None,
656            intelligence_priority: None,
657        });
658
659        let result = handler
660            .handle_create_message(params)
661            .await
662            .expect("modelPreferences are advisory and should not fail basic sampling");
663
664        assert_eq!(result.model, "configured-model");
665    }
666
667    #[tokio::test]
668    async fn default_sampling_handler_routes_to_executor() {
669        let executor = Arc::new(MockLlm {
670            response_text: "I can help!".into(),
671        });
672        let handler = DefaultSamplingHandler::new(executor, "test-model");
673
674        let params = make_params("help me");
675        let result = handler.handle_create_message(params).await.unwrap();
676
677        assert_eq!(result.model, "test-model");
678        assert!(matches!(result.role, mcp::Role::Assistant));
679        match &result.content[0] {
680            SamplingContent::Text { text, .. } => assert_eq!(text, "I can help!"),
681            _ => panic!("expected text content"),
682        }
683        assert_eq!(result.stop_reason.as_deref(), Some("endTurn"));
684    }
685
686    #[tokio::test]
687    async fn default_sampling_handler_empty_messages_returns_error() {
688        let executor = Arc::new(MockLlm {
689            response_text: "".into(),
690        });
691        let handler = DefaultSamplingHandler::new(executor, "test-model");
692
693        let params = CreateMessageParams {
694            messages: vec![],
695            model_preferences: None,
696            system_prompt: None,
697            include_context: None,
698            temperature: None,
699            max_tokens: 1024,
700            stop_sequences: None,
701            metadata: None,
702            tools: None,
703            tool_choice: None,
704            task: None,
705            meta: None,
706        };
707        let err = handler.handle_create_message(params).await;
708        assert!(err.is_err());
709    }
710
711    #[tokio::test]
712    async fn fixed_factory_returns_same_handler_regardless_of_agent() {
713        // The fixed factory preserves R0 behaviour: every agent gets the
714        // same handler. Per-agent routing only kicks in when callers wire
715        // a registry-driven factory.
716        let executor = Arc::new(MockLlm {
717            response_text: "shared".into(),
718        });
719        let handler: Arc<dyn SamplingHandler> =
720            Arc::new(DefaultSamplingHandler::new(executor, "shared-model"));
721        let factory = FixedSamplingHandlerFactory::new(Arc::clone(&handler));
722
723        let spec_a = AgentSpec {
724            id: "a".into(),
725            model_id: "claude-opus".into(),
726            system_prompt: "".into(),
727            ..Default::default()
728        };
729        let spec_b = AgentSpec {
730            id: "b".into(),
731            model_id: "gpt-5".into(),
732            system_prompt: "".into(),
733            ..Default::default()
734        };
735
736        let resolved_a = factory.for_agent(&spec_a).await.expect("Some handler");
737        let resolved_b = factory.for_agent(&spec_b).await.expect("Some handler");
738        // Same Arc identity regardless of agent.
739        assert!(Arc::ptr_eq(&resolved_a, &handler));
740        assert!(Arc::ptr_eq(&resolved_b, &handler));
741    }
742
743    #[tokio::test]
744    async fn default_sampling_handler_passes_overrides() {
745        // Use a mock that captures and returns — we verify the handler doesn't error
746        let executor = Arc::new(MockLlm {
747            response_text: "ok".into(),
748        });
749        let handler = DefaultSamplingHandler::new(executor, "model-v1");
750
751        let mut params = make_params("test");
752        params.temperature = Some(0.7);
753        params.max_tokens = 512;
754        params.system_prompt = Some("System".into());
755
756        let result = handler.handle_create_message(params).await.unwrap();
757        assert_eq!(result.model, "model-v1");
758    }
759}