agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18/// Request sent to the underlying language model. Middlewares can augment
19/// the system prompt or mutate the pending message list before the model call.
20#[derive(Debug, Clone)]
21pub struct ModelRequest {
22    pub system_prompt: String,
23    pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28        Self {
29            system_prompt: system_prompt.into(),
30            messages,
31        }
32    }
33
34    pub fn append_prompt(&mut self, fragment: &str) {
35        if !fragment.is_empty() {
36            self.system_prompt.push_str("\n\n");
37            self.system_prompt.push_str(fragment);
38        }
39    }
40}
41
42/// Read/write state handle exposed to middleware implementations.
43pub struct MiddlewareContext<'a> {
44    pub request: &'a mut ModelRequest,
45    pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49    pub fn with_request(
50        request: &'a mut ModelRequest,
51        state: Arc<RwLock<AgentStateSnapshot>>,
52    ) -> Self {
53        Self { request, state }
54    }
55}
56
57/// Middleware hook that can register additional tools and mutate the model request
58/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
59/// interface async-first for future network calls.
60#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62    /// Unique identifier for logging and diagnostics.
63    fn id(&self) -> &'static str;
64
65    /// Tools to expose when this middleware is active.
66    fn tools(&self) -> Vec<ToolBox> {
67        Vec::new()
68    }
69
70    /// Apply middleware-specific mutations to the pending model request.
71    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72}
73
74pub struct SummarizationMiddleware {
75    pub messages_to_keep: usize,
76    pub summary_note: String,
77}
78
79impl SummarizationMiddleware {
80    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
81        Self {
82            messages_to_keep,
83            summary_note: summary_note.into(),
84        }
85    }
86}
87
88#[async_trait]
89impl AgentMiddleware for SummarizationMiddleware {
90    fn id(&self) -> &'static str {
91        "summarization"
92    }
93
94    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
95        if ctx.request.messages.len() > self.messages_to_keep {
96            let dropped = ctx.request.messages.len() - self.messages_to_keep;
97            let mut truncated = ctx
98                .request
99                .messages
100                .split_off(ctx.request.messages.len() - self.messages_to_keep);
101            truncated.insert(
102                0,
103                AgentMessage {
104                    role: MessageRole::System,
105                    content: MessageContent::Text(format!(
106                        "{} ({} earlier messages summarized)",
107                        self.summary_note, dropped
108                    )),
109                    metadata: None,
110                },
111            );
112            ctx.request.messages = truncated;
113        }
114        Ok(())
115    }
116}
117
118pub struct PlanningMiddleware {
119    _state: Arc<RwLock<AgentStateSnapshot>>,
120}
121
122impl PlanningMiddleware {
123    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
124        Self { _state: state }
125    }
126}
127
128#[async_trait]
129impl AgentMiddleware for PlanningMiddleware {
130    fn id(&self) -> &'static str {
131        "planning"
132    }
133
134    fn tools(&self) -> Vec<ToolBox> {
135        use agents_toolkit::create_todos_tools;
136        create_todos_tools()
137    }
138
139    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
140        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
141        Ok(())
142    }
143}
144
145pub struct FilesystemMiddleware {
146    _state: Arc<RwLock<AgentStateSnapshot>>,
147}
148
149impl FilesystemMiddleware {
150    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
151        Self { _state: state }
152    }
153}
154
155#[async_trait]
156impl AgentMiddleware for FilesystemMiddleware {
157    fn id(&self) -> &'static str {
158        "filesystem"
159    }
160
161    fn tools(&self) -> Vec<ToolBox> {
162        create_filesystem_tools()
163    }
164
165    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
167        Ok(())
168    }
169}
170
171#[derive(Clone)]
172pub struct SubAgentRegistration {
173    pub descriptor: SubAgentDescriptor,
174    pub agent: Arc<dyn AgentHandle>,
175}
176
177struct SubAgentRegistry {
178    agents: HashMap<String, Arc<dyn AgentHandle>>,
179}
180
181impl SubAgentRegistry {
182    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
183        let mut agents = HashMap::new();
184        for reg in registrations {
185            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
186        }
187        Self { agents }
188    }
189
190    fn available_names(&self) -> Vec<String> {
191        self.agents.keys().cloned().collect()
192    }
193
194    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
195        self.agents.get(name).cloned()
196    }
197}
198
199pub struct SubAgentMiddleware {
200    task_tool: ToolBox,
201    descriptors: Vec<SubAgentDescriptor>,
202    _registry: Arc<SubAgentRegistry>,
203}
204
205impl SubAgentMiddleware {
206    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
207        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
208        let registry = Arc::new(SubAgentRegistry::new(registrations));
209        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone()));
210        Self {
211            task_tool,
212            descriptors,
213            _registry: registry,
214        }
215    }
216
217    fn prompt_fragment(&self) -> String {
218        let descriptions: Vec<String> = if self.descriptors.is_empty() {
219            vec![String::from("- general-purpose: Default reasoning agent")]
220        } else {
221            self.descriptors
222                .iter()
223                .map(|agent| format!("- {}: {}", agent.name, agent.description))
224                .collect()
225        };
226
227        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
228    }
229}
230
231#[async_trait]
232impl AgentMiddleware for SubAgentMiddleware {
233    fn id(&self) -> &'static str {
234        "subagent"
235    }
236
237    fn tools(&self) -> Vec<ToolBox> {
238        vec![self.task_tool.clone()]
239    }
240
241    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
242        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
243        ctx.request.append_prompt(&self.prompt_fragment());
244        Ok(())
245    }
246}
247
248#[derive(Clone, Debug)]
249pub struct HitlPolicy {
250    pub allow_auto: bool,
251    pub note: Option<String>,
252}
253
254pub struct HumanInLoopMiddleware {
255    policies: HashMap<String, HitlPolicy>,
256}
257
258impl HumanInLoopMiddleware {
259    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
260        Self { policies }
261    }
262
263    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
264        self.policies
265            .get(tool_name)
266            .filter(|policy| !policy.allow_auto)
267    }
268
269    fn prompt_fragment(&self) -> Option<String> {
270        let pending: Vec<String> = self
271            .policies
272            .iter()
273            .filter(|(_, policy)| !policy.allow_auto)
274            .map(|(tool, policy)| match &policy.note {
275                Some(note) => format!("- {tool}: {note}"),
276                None => format!("- {tool}: Requires approval"),
277            })
278            .collect();
279        if pending.is_empty() {
280            None
281        } else {
282            Some(format!(
283                "The following tools require human approval before execution:\n{}",
284                pending.join("\n")
285            ))
286        }
287    }
288}
289
290#[async_trait]
291impl AgentMiddleware for HumanInLoopMiddleware {
292    fn id(&self) -> &'static str {
293        "human-in-loop"
294    }
295
296    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
297        if let Some(fragment) = self.prompt_fragment() {
298            ctx.request.append_prompt(&fragment);
299        }
300        ctx.request.messages.push(AgentMessage {
301            role: MessageRole::System,
302            content: MessageContent::Text(
303                "Tools marked for human approval will emit interrupts requiring external resolution."
304                    .into(),
305            ),
306            metadata: None,
307        });
308        Ok(())
309    }
310}
311
312pub struct BaseSystemPromptMiddleware;
313
314#[async_trait]
315impl AgentMiddleware for BaseSystemPromptMiddleware {
316    fn id(&self) -> &'static str {
317        "base-system-prompt"
318    }
319
320    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
321        ctx.request.append_prompt(BASE_AGENT_PROMPT);
322        Ok(())
323    }
324}
325
326/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
327/// and examples to force the LLM to actually call tools instead of just talking about them.
328///
329/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
330/// It provides:
331/// - Explicit tool usage rules with imperative language
332/// - JSON examples of tool calling
333/// - Workflow guidance for multi-step tasks
334/// - Few-shot examples for common patterns
335pub struct DeepAgentPromptMiddleware {
336    custom_instructions: String,
337}
338
339impl DeepAgentPromptMiddleware {
340    pub fn new(custom_instructions: impl Into<String>) -> Self {
341        Self {
342            custom_instructions: custom_instructions.into(),
343        }
344    }
345}
346
347#[async_trait]
348impl AgentMiddleware for DeepAgentPromptMiddleware {
349    fn id(&self) -> &'static str {
350        "deep-agent-prompt"
351    }
352
353    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
354        use crate::prompts::get_deep_agent_system_prompt;
355        let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
356        ctx.request.append_prompt(&deep_prompt);
357        Ok(())
358    }
359}
360
361/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
362/// to reduce latency on subsequent requests with the same base prompt.
363pub struct AnthropicPromptCachingMiddleware {
364    pub ttl: String,
365    pub unsupported_model_behavior: String,
366}
367
368impl AnthropicPromptCachingMiddleware {
369    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
370        Self {
371            ttl: ttl.into(),
372            unsupported_model_behavior: unsupported_model_behavior.into(),
373        }
374    }
375
376    pub fn with_defaults() -> Self {
377        Self::new("5m", "ignore")
378    }
379
380    /// Parse TTL string like "5m" to detect if caching is requested.
381    /// For now, any non-empty TTL enables ephemeral caching.
382    fn should_enable_caching(&self) -> bool {
383        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
384    }
385}
386
387#[async_trait]
388impl AgentMiddleware for AnthropicPromptCachingMiddleware {
389    fn id(&self) -> &'static str {
390        "anthropic-prompt-caching"
391    }
392
393    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
394        if !self.should_enable_caching() {
395            return Ok(());
396        }
397
398        // Mark system prompt for caching by converting it to a system message with cache control
399        if !ctx.request.system_prompt.is_empty() {
400            let system_message = AgentMessage {
401                role: MessageRole::System,
402                content: MessageContent::Text(ctx.request.system_prompt.clone()),
403                metadata: Some(MessageMetadata {
404                    tool_call_id: None,
405                    cache_control: Some(CacheControl {
406                        cache_type: "ephemeral".to_string(),
407                    }),
408                }),
409            };
410
411            // Insert system message at the beginning of the messages
412            ctx.request.messages.insert(0, system_message);
413
414            // Clear the system_prompt since it's now in messages
415            ctx.request.system_prompt.clear();
416
417            tracing::debug!(
418                ttl = %self.ttl,
419                behavior = %self.unsupported_model_behavior,
420                "Applied Anthropic prompt caching to system message"
421            );
422        }
423
424        Ok(())
425    }
426}
427
428pub struct TaskRouterTool {
429    registry: Arc<SubAgentRegistry>,
430}
431
432impl TaskRouterTool {
433    fn new(registry: Arc<SubAgentRegistry>) -> Self {
434        Self { registry }
435    }
436
437    fn available_subagents(&self) -> Vec<String> {
438        self.registry.available_names()
439    }
440}
441
442#[derive(Debug, Clone, Deserialize)]
443struct TaskInvocationArgs {
444    #[serde(alias = "description")]
445    instruction: String,
446    #[serde(alias = "subagent_type")]
447    agent: String,
448}
449
450#[async_trait]
451impl Tool for TaskRouterTool {
452    fn schema(&self) -> agents_core::tools::ToolSchema {
453        use agents_core::tools::{ToolParameterSchema, ToolSchema};
454        use std::collections::HashMap;
455
456        let mut properties = HashMap::new();
457        properties.insert(
458            "agent".to_string(),
459            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
460        );
461        properties.insert(
462            "instruction".to_string(),
463            ToolParameterSchema::string("Clear instruction for the sub-agent"),
464        );
465
466        ToolSchema::new(
467            "task",
468            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
469            ToolParameterSchema::object(
470                "Task delegation parameters",
471                properties,
472                vec!["agent".to_string(), "instruction".to_string()],
473            ),
474        )
475    }
476
477    async fn execute(
478        &self,
479        args: serde_json::Value,
480        ctx: ToolContext,
481    ) -> anyhow::Result<ToolResult> {
482        let args: TaskInvocationArgs = serde_json::from_value(args)?;
483        let available = self.available_subagents();
484
485        if let Some(agent) = self.registry.get(&args.agent) {
486            // Log delegation start
487            tracing::warn!(
488                "🎯 DELEGATING to sub-agent: {} with instruction: {}",
489                args.agent,
490                args.instruction
491            );
492
493            let start_time = std::time::Instant::now();
494            let user_message = AgentMessage {
495                role: MessageRole::User,
496                content: MessageContent::Text(args.instruction.clone()),
497                metadata: None,
498            };
499
500            let response = agent
501                .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
502                .await?;
503
504            // Log delegation completion
505            let duration = start_time.elapsed();
506            let response_preview = match &response.content {
507                MessageContent::Text(t) => {
508                    if t.len() > 100 {
509                        format!("{}... ({} chars)", &t[..100], t.len())
510                    } else {
511                        t.clone()
512                    }
513                }
514                MessageContent::Json(v) => {
515                    format!("JSON: {} bytes", v.to_string().len())
516                }
517            };
518
519            tracing::warn!(
520                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
521                args.agent,
522                duration,
523                response_preview
524            );
525
526            // Return sub-agent response as text content, not as a separate tool message
527            // This will be incorporated into the LLM's next response naturally
528            let result_text = match response.content {
529                MessageContent::Text(text) => text,
530                MessageContent::Json(json) => json.to_string(),
531            };
532
533            return Ok(ToolResult::text(&ctx, result_text));
534        }
535
536        tracing::error!(
537            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
538            args.agent,
539            available
540        );
541
542        Ok(ToolResult::text(
543            &ctx,
544            format!(
545                "Sub-agent '{}' not found. Available sub-agents: {}",
546                args.agent,
547                available.join(", ")
548            ),
549        ))
550    }
551}
552
553#[derive(Debug, Clone)]
554pub struct SubAgentDescriptor {
555    pub name: String,
556    pub description: String,
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use agents_core::agent::{AgentDescriptor, AgentHandle};
563    use agents_core::messaging::{MessageContent, MessageRole};
564    use serde_json::json;
565
566    struct AppendPromptMiddleware;
567
568    #[async_trait]
569    impl AgentMiddleware for AppendPromptMiddleware {
570        fn id(&self) -> &'static str {
571            "append-prompt"
572        }
573
574        async fn modify_model_request(
575            &self,
576            ctx: &mut MiddlewareContext<'_>,
577        ) -> anyhow::Result<()> {
578            ctx.request.system_prompt.push_str("\nExtra directives.");
579            Ok(())
580        }
581    }
582
583    #[tokio::test]
584    async fn middleware_mutates_prompt() {
585        let mut request = ModelRequest::new(
586            "System",
587            vec![AgentMessage {
588                role: MessageRole::User,
589                content: MessageContent::Text("Hi".into()),
590                metadata: None,
591            }],
592        );
593        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
594        let mut ctx = MiddlewareContext::with_request(&mut request, state);
595        let middleware = AppendPromptMiddleware;
596        middleware.modify_model_request(&mut ctx).await.unwrap();
597        assert!(ctx.request.system_prompt.contains("Extra directives"));
598    }
599
600    #[tokio::test]
601    async fn planning_middleware_registers_write_todos() {
602        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
603        let middleware = PlanningMiddleware::new(state);
604        let tool_names: Vec<_> = middleware
605            .tools()
606            .iter()
607            .map(|t| t.schema().name.clone())
608            .collect();
609        assert!(tool_names.contains(&"write_todos".to_string()));
610
611        let mut request = ModelRequest::new("System", vec![]);
612        let mut ctx = MiddlewareContext::with_request(
613            &mut request,
614            Arc::new(RwLock::new(AgentStateSnapshot::default())),
615        );
616        middleware.modify_model_request(&mut ctx).await.unwrap();
617        assert!(ctx.request.system_prompt.contains("write_todos"));
618    }
619
620    #[tokio::test]
621    async fn filesystem_middleware_registers_tools() {
622        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
623        let middleware = FilesystemMiddleware::new(state);
624        let tool_names: Vec<_> = middleware
625            .tools()
626            .iter()
627            .map(|t| t.schema().name.clone())
628            .collect();
629        for expected in ["ls", "read_file", "write_file", "edit_file"] {
630            assert!(tool_names.contains(&expected.to_string()));
631        }
632    }
633
634    #[tokio::test]
635    async fn summarization_middleware_trims_messages() {
636        let middleware = SummarizationMiddleware::new(2, "Summary note");
637        let mut request = ModelRequest::new(
638            "System",
639            vec![
640                AgentMessage {
641                    role: MessageRole::User,
642                    content: MessageContent::Text("one".into()),
643                    metadata: None,
644                },
645                AgentMessage {
646                    role: MessageRole::Agent,
647                    content: MessageContent::Text("two".into()),
648                    metadata: None,
649                },
650                AgentMessage {
651                    role: MessageRole::User,
652                    content: MessageContent::Text("three".into()),
653                    metadata: None,
654                },
655            ],
656        );
657        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
658        let mut ctx = MiddlewareContext::with_request(&mut request, state);
659        middleware.modify_model_request(&mut ctx).await.unwrap();
660        assert_eq!(ctx.request.messages.len(), 3);
661        match &ctx.request.messages[0].content {
662            MessageContent::Text(text) => assert!(text.contains("Summary note")),
663            other => panic!("expected text, got {other:?}"),
664        }
665    }
666
667    struct StubAgent;
668
669    #[async_trait]
670    impl AgentHandle for StubAgent {
671        async fn describe(&self) -> AgentDescriptor {
672            AgentDescriptor {
673                name: "stub".into(),
674                version: "0.0.1".into(),
675                description: None,
676            }
677        }
678
679        async fn handle_message(
680            &self,
681            _input: AgentMessage,
682            _state: Arc<AgentStateSnapshot>,
683        ) -> anyhow::Result<AgentMessage> {
684            Ok(AgentMessage {
685                role: MessageRole::Agent,
686                content: MessageContent::Text("stub-response".into()),
687                metadata: None,
688            })
689        }
690    }
691
692    #[tokio::test]
693    async fn task_router_reports_unknown_subagent() {
694        let registry = Arc::new(SubAgentRegistry::new(vec![]));
695        let task_tool = TaskRouterTool::new(registry.clone());
696        let state = Arc::new(AgentStateSnapshot::default());
697        let ctx = ToolContext::new(state);
698
699        let response = task_tool
700            .execute(
701                json!({
702                    "instruction": "Do something",
703                    "agent": "unknown"
704                }),
705                ctx,
706            )
707            .await
708            .unwrap();
709
710        match response {
711            ToolResult::Message(msg) => match msg.content {
712                MessageContent::Text(text) => {
713                    assert!(text.contains("Sub-agent 'unknown' not found"))
714                }
715                other => panic!("expected text, got {other:?}"),
716            },
717            _ => panic!("expected message"),
718        }
719    }
720
721    #[tokio::test]
722    async fn subagent_middleware_appends_prompt() {
723        let subagents = vec![SubAgentRegistration {
724            descriptor: SubAgentDescriptor {
725                name: "research-agent".into(),
726                description: "Deep research specialist".into(),
727            },
728            agent: Arc::new(StubAgent),
729        }];
730        let middleware = SubAgentMiddleware::new(subagents);
731
732        let mut request = ModelRequest::new("System", vec![]);
733        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
734        let mut ctx = MiddlewareContext::with_request(&mut request, state);
735        middleware.modify_model_request(&mut ctx).await.unwrap();
736
737        assert!(ctx.request.system_prompt.contains("research-agent"));
738        let tool_names: Vec<_> = middleware
739            .tools()
740            .iter()
741            .map(|t| t.schema().name.clone())
742            .collect();
743        assert!(tool_names.contains(&"task".to_string()));
744    }
745
746    #[tokio::test]
747    async fn task_router_invokes_registered_subagent() {
748        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
749            descriptor: SubAgentDescriptor {
750                name: "stub-agent".into(),
751                description: "Stub".into(),
752            },
753            agent: Arc::new(StubAgent),
754        }]));
755        let task_tool = TaskRouterTool::new(registry.clone());
756        let state = Arc::new(AgentStateSnapshot::default());
757        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
758        let response = task_tool
759            .execute(
760                json!({
761                    "description": "do work",
762                    "subagent_type": "stub-agent"
763                }),
764                ctx,
765            )
766            .await
767            .unwrap();
768
769        match response {
770            ToolResult::Message(msg) => {
771                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
772                match msg.content {
773                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
774                    other => panic!("expected text, got {other:?}"),
775                }
776            }
777            _ => panic!("expected message"),
778        }
779    }
780
781    #[tokio::test]
782    async fn human_in_loop_appends_prompt() {
783        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
784            "danger-tool".into(),
785            HitlPolicy {
786                allow_auto: false,
787                note: Some("Requires security review".into()),
788            },
789        )]));
790        let mut request = ModelRequest::new("System", vec![]);
791        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
792        let mut ctx = MiddlewareContext::with_request(&mut request, state);
793        middleware.modify_model_request(&mut ctx).await.unwrap();
794        assert!(ctx
795            .request
796            .system_prompt
797            .contains("danger-tool: Requires security review"));
798    }
799
800    #[tokio::test]
801    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
802        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
803        let mut request = ModelRequest::new(
804            "This is the system prompt",
805            vec![AgentMessage {
806                role: MessageRole::User,
807                content: MessageContent::Text("Hello".into()),
808                metadata: None,
809            }],
810        );
811        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
812        let mut ctx = MiddlewareContext::with_request(&mut request, state);
813
814        // Apply the middleware
815        middleware.modify_model_request(&mut ctx).await.unwrap();
816
817        // System prompt should be cleared
818        assert!(ctx.request.system_prompt.is_empty());
819
820        // Should have added a system message with cache control at the beginning
821        assert_eq!(ctx.request.messages.len(), 2);
822
823        let system_message = &ctx.request.messages[0];
824        assert!(matches!(system_message.role, MessageRole::System));
825        assert_eq!(
826            system_message.content.as_text().unwrap(),
827            "This is the system prompt"
828        );
829
830        // Check cache control metadata
831        let metadata = system_message.metadata.as_ref().unwrap();
832        let cache_control = metadata.cache_control.as_ref().unwrap();
833        assert_eq!(cache_control.cache_type, "ephemeral");
834
835        // Original user message should still be there
836        let user_message = &ctx.request.messages[1];
837        assert!(matches!(user_message.role, MessageRole::User));
838        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
839    }
840
841    #[tokio::test]
842    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
843        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
844        let mut request = ModelRequest::new("This is the system prompt", vec![]);
845        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
846        let mut ctx = MiddlewareContext::with_request(&mut request, state);
847
848        // Apply the middleware
849        middleware.modify_model_request(&mut ctx).await.unwrap();
850
851        // System prompt should be unchanged
852        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
853        assert_eq!(ctx.request.messages.len(), 0);
854    }
855
856    #[tokio::test]
857    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
858        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
859        let mut request = ModelRequest::new(
860            "",
861            vec![AgentMessage {
862                role: MessageRole::User,
863                content: MessageContent::Text("Hello".into()),
864                metadata: None,
865            }],
866        );
867        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
868        let mut ctx = MiddlewareContext::with_request(&mut request, state);
869
870        // Apply the middleware
871        middleware.modify_model_request(&mut ctx).await.unwrap();
872
873        // Should be unchanged
874        assert!(ctx.request.system_prompt.is_empty());
875        assert_eq!(ctx.request.messages.len(), 1);
876        assert!(matches!(ctx.request.messages[0].role, MessageRole::User));
877    }
878}