agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::{create_filesystem_tools, create_todos_tool};
15use async_trait::async_trait;
16use serde::Deserialize;
17
18/// Request sent to the underlying language model. Middlewares can augment
19/// the system prompt or mutate the pending message list before the model call.
20#[derive(Debug, Clone)]
21pub struct ModelRequest {
22    pub system_prompt: String,
23    pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28        Self {
29            system_prompt: system_prompt.into(),
30            messages,
31        }
32    }
33
34    pub fn append_prompt(&mut self, fragment: &str) {
35        if !fragment.is_empty() {
36            self.system_prompt.push_str("\n\n");
37            self.system_prompt.push_str(fragment);
38        }
39    }
40}
41
42/// Read/write state handle exposed to middleware implementations.
43pub struct MiddlewareContext<'a> {
44    pub request: &'a mut ModelRequest,
45    pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49    pub fn with_request(
50        request: &'a mut ModelRequest,
51        state: Arc<RwLock<AgentStateSnapshot>>,
52    ) -> Self {
53        Self { request, state }
54    }
55}
56
57/// Middleware hook that can register additional tools and mutate the model request
58/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
59/// interface async-first for future network calls.
60#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62    /// Unique identifier for logging and diagnostics.
63    fn id(&self) -> &'static str;
64
65    /// Tools to expose when this middleware is active.
66    fn tools(&self) -> Vec<ToolBox> {
67        Vec::new()
68    }
69
70    /// Apply middleware-specific mutations to the pending model request.
71    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72}
73
74pub struct SummarizationMiddleware {
75    pub messages_to_keep: usize,
76    pub summary_note: String,
77}
78
79impl SummarizationMiddleware {
80    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
81        Self {
82            messages_to_keep,
83            summary_note: summary_note.into(),
84        }
85    }
86}
87
88#[async_trait]
89impl AgentMiddleware for SummarizationMiddleware {
90    fn id(&self) -> &'static str {
91        "summarization"
92    }
93
94    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
95        if ctx.request.messages.len() > self.messages_to_keep {
96            let dropped = ctx.request.messages.len() - self.messages_to_keep;
97            let mut truncated = ctx
98                .request
99                .messages
100                .split_off(ctx.request.messages.len() - self.messages_to_keep);
101            truncated.insert(
102                0,
103                AgentMessage {
104                    role: MessageRole::System,
105                    content: MessageContent::Text(format!(
106                        "{} ({} earlier messages summarized)",
107                        self.summary_note, dropped
108                    )),
109                    metadata: None,
110                },
111            );
112            ctx.request.messages = truncated;
113        }
114        Ok(())
115    }
116}
117
118pub struct PlanningMiddleware {
119    _state: Arc<RwLock<AgentStateSnapshot>>,
120}
121
122impl PlanningMiddleware {
123    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
124        Self { _state: state }
125    }
126}
127
128#[async_trait]
129impl AgentMiddleware for PlanningMiddleware {
130    fn id(&self) -> &'static str {
131        "planning"
132    }
133
134    fn tools(&self) -> Vec<ToolBox> {
135        vec![create_todos_tool()]
136    }
137
138    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
139        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
140        Ok(())
141    }
142}
143
144pub struct FilesystemMiddleware {
145    _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl FilesystemMiddleware {
149    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150        Self { _state: state }
151    }
152}
153
154#[async_trait]
155impl AgentMiddleware for FilesystemMiddleware {
156    fn id(&self) -> &'static str {
157        "filesystem"
158    }
159
160    fn tools(&self) -> Vec<ToolBox> {
161        create_filesystem_tools()
162    }
163
164    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
165        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
166        Ok(())
167    }
168}
169
170#[derive(Clone)]
171pub struct SubAgentRegistration {
172    pub descriptor: SubAgentDescriptor,
173    pub agent: Arc<dyn AgentHandle>,
174}
175
176struct SubAgentRegistry {
177    agents: HashMap<String, Arc<dyn AgentHandle>>,
178}
179
180impl SubAgentRegistry {
181    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
182        let mut agents = HashMap::new();
183        for reg in registrations {
184            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
185        }
186        Self { agents }
187    }
188
189    fn available_names(&self) -> Vec<String> {
190        self.agents.keys().cloned().collect()
191    }
192
193    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
194        self.agents.get(name).cloned()
195    }
196}
197
198pub struct SubAgentMiddleware {
199    task_tool: ToolBox,
200    descriptors: Vec<SubAgentDescriptor>,
201    _registry: Arc<SubAgentRegistry>,
202}
203
204impl SubAgentMiddleware {
205    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
206        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
207        let registry = Arc::new(SubAgentRegistry::new(registrations));
208        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone()));
209        Self {
210            task_tool,
211            descriptors,
212            _registry: registry,
213        }
214    }
215
216    fn prompt_fragment(&self) -> String {
217        let descriptions: Vec<String> = if self.descriptors.is_empty() {
218            vec![String::from("- general-purpose: Default reasoning agent")]
219        } else {
220            self.descriptors
221                .iter()
222                .map(|agent| format!("- {}: {}", agent.name, agent.description))
223                .collect()
224        };
225
226        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
227    }
228}
229
230#[async_trait]
231impl AgentMiddleware for SubAgentMiddleware {
232    fn id(&self) -> &'static str {
233        "subagent"
234    }
235
236    fn tools(&self) -> Vec<ToolBox> {
237        vec![self.task_tool.clone()]
238    }
239
240    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
241        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
242        ctx.request.append_prompt(&self.prompt_fragment());
243        Ok(())
244    }
245}
246
247#[derive(Clone, Debug)]
248pub struct HitlPolicy {
249    pub allow_auto: bool,
250    pub note: Option<String>,
251}
252
253pub struct HumanInLoopMiddleware {
254    policies: HashMap<String, HitlPolicy>,
255}
256
257impl HumanInLoopMiddleware {
258    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
259        Self { policies }
260    }
261
262    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
263        self.policies
264            .get(tool_name)
265            .filter(|policy| !policy.allow_auto)
266    }
267
268    fn prompt_fragment(&self) -> Option<String> {
269        let pending: Vec<String> = self
270            .policies
271            .iter()
272            .filter(|(_, policy)| !policy.allow_auto)
273            .map(|(tool, policy)| match &policy.note {
274                Some(note) => format!("- {tool}: {note}"),
275                None => format!("- {tool}: Requires approval"),
276            })
277            .collect();
278        if pending.is_empty() {
279            None
280        } else {
281            Some(format!(
282                "The following tools require human approval before execution:\n{}",
283                pending.join("\n")
284            ))
285        }
286    }
287}
288
289#[async_trait]
290impl AgentMiddleware for HumanInLoopMiddleware {
291    fn id(&self) -> &'static str {
292        "human-in-loop"
293    }
294
295    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
296        if let Some(fragment) = self.prompt_fragment() {
297            ctx.request.append_prompt(&fragment);
298        }
299        ctx.request.messages.push(AgentMessage {
300            role: MessageRole::System,
301            content: MessageContent::Text(
302                "Tools marked for human approval will emit interrupts requiring external resolution."
303                    .into(),
304            ),
305            metadata: None,
306        });
307        Ok(())
308    }
309}
310
311pub struct BaseSystemPromptMiddleware;
312
313#[async_trait]
314impl AgentMiddleware for BaseSystemPromptMiddleware {
315    fn id(&self) -> &'static str {
316        "base-system-prompt"
317    }
318
319    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
320        ctx.request.append_prompt(BASE_AGENT_PROMPT);
321        Ok(())
322    }
323}
324
325/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
326/// to reduce latency on subsequent requests with the same base prompt.
327pub struct AnthropicPromptCachingMiddleware {
328    pub ttl: String,
329    pub unsupported_model_behavior: String,
330}
331
332impl AnthropicPromptCachingMiddleware {
333    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
334        Self {
335            ttl: ttl.into(),
336            unsupported_model_behavior: unsupported_model_behavior.into(),
337        }
338    }
339
340    pub fn with_defaults() -> Self {
341        Self::new("5m", "ignore")
342    }
343
344    /// Parse TTL string like "5m" to detect if caching is requested.
345    /// For now, any non-empty TTL enables ephemeral caching.
346    fn should_enable_caching(&self) -> bool {
347        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
348    }
349}
350
351#[async_trait]
352impl AgentMiddleware for AnthropicPromptCachingMiddleware {
353    fn id(&self) -> &'static str {
354        "anthropic-prompt-caching"
355    }
356
357    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
358        if !self.should_enable_caching() {
359            return Ok(());
360        }
361
362        // Mark system prompt for caching by converting it to a system message with cache control
363        if !ctx.request.system_prompt.is_empty() {
364            let system_message = AgentMessage {
365                role: MessageRole::System,
366                content: MessageContent::Text(ctx.request.system_prompt.clone()),
367                metadata: Some(MessageMetadata {
368                    tool_call_id: None,
369                    cache_control: Some(CacheControl {
370                        cache_type: "ephemeral".to_string(),
371                    }),
372                }),
373            };
374
375            // Insert system message at the beginning of the messages
376            ctx.request.messages.insert(0, system_message);
377
378            // Clear the system_prompt since it's now in messages
379            ctx.request.system_prompt.clear();
380
381            tracing::debug!(
382                ttl = %self.ttl,
383                behavior = %self.unsupported_model_behavior,
384                "Applied Anthropic prompt caching to system message"
385            );
386        }
387
388        Ok(())
389    }
390}
391
392pub struct TaskRouterTool {
393    registry: Arc<SubAgentRegistry>,
394}
395
396impl TaskRouterTool {
397    fn new(registry: Arc<SubAgentRegistry>) -> Self {
398        Self { registry }
399    }
400
401    fn available_subagents(&self) -> Vec<String> {
402        self.registry.available_names()
403    }
404}
405
406#[derive(Debug, Clone, Deserialize)]
407struct TaskInvocationArgs {
408    description: String,
409    subagent_type: String,
410}
411
412#[async_trait]
413impl Tool for TaskRouterTool {
414    fn schema(&self) -> agents_core::tools::ToolSchema {
415        use agents_core::tools::{ToolParameterSchema, ToolSchema};
416        use std::collections::HashMap;
417
418        let mut properties = HashMap::new();
419        properties.insert(
420            "description".to_string(),
421            ToolParameterSchema::string("Description of the task for the subagent"),
422        );
423        properties.insert(
424            "subagent_type".to_string(),
425            ToolParameterSchema::string("Type of subagent to use"),
426        );
427
428        ToolSchema::new(
429            "task",
430            "Delegate a task to a specialized subagent",
431            ToolParameterSchema::object(
432                "Task parameters",
433                properties,
434                vec!["description".to_string(), "subagent_type".to_string()],
435            ),
436        )
437    }
438
439    async fn execute(
440        &self,
441        args: serde_json::Value,
442        ctx: ToolContext,
443    ) -> anyhow::Result<ToolResult> {
444        let args: TaskInvocationArgs = serde_json::from_value(args)?;
445        let available = self.available_subagents();
446
447        if let Some(agent) = self.registry.get(&args.subagent_type) {
448            let user_message = AgentMessage {
449                role: MessageRole::User,
450                content: MessageContent::Text(args.description.clone()),
451                metadata: None,
452            };
453            let response = agent
454                .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
455                .await?;
456
457            return Ok(ToolResult::Message(AgentMessage {
458                role: MessageRole::Tool,
459                content: response.content,
460                metadata: ctx.tool_call_id.map(|id| MessageMetadata {
461                    tool_call_id: Some(id),
462                    cache_control: None,
463                }),
464            }));
465        }
466
467        Ok(ToolResult::text(
468            &ctx,
469            format!(
470                "Unknown subagent '{subagent}'. Available: {available:?}",
471                subagent = args.subagent_type,
472                available = available
473            ),
474        ))
475    }
476}
477
478#[derive(Debug, Clone)]
479pub struct SubAgentDescriptor {
480    pub name: String,
481    pub description: String,
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use agents_core::agent::{AgentDescriptor, AgentHandle};
488    use agents_core::messaging::{MessageContent, MessageRole};
489    use serde_json::json;
490
491    struct AppendPromptMiddleware;
492
493    #[async_trait]
494    impl AgentMiddleware for AppendPromptMiddleware {
495        fn id(&self) -> &'static str {
496            "append-prompt"
497        }
498
499        async fn modify_model_request(
500            &self,
501            ctx: &mut MiddlewareContext<'_>,
502        ) -> anyhow::Result<()> {
503            ctx.request.system_prompt.push_str("\nExtra directives.");
504            Ok(())
505        }
506    }
507
508    #[tokio::test]
509    async fn middleware_mutates_prompt() {
510        let mut request = ModelRequest::new(
511            "System",
512            vec![AgentMessage {
513                role: MessageRole::User,
514                content: MessageContent::Text("Hi".into()),
515                metadata: None,
516            }],
517        );
518        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
519        let mut ctx = MiddlewareContext::with_request(&mut request, state);
520        let middleware = AppendPromptMiddleware;
521        middleware.modify_model_request(&mut ctx).await.unwrap();
522        assert!(ctx.request.system_prompt.contains("Extra directives"));
523    }
524
525    #[tokio::test]
526    async fn planning_middleware_registers_write_todos() {
527        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
528        let middleware = PlanningMiddleware::new(state);
529        let tool_names: Vec<_> = middleware
530            .tools()
531            .iter()
532            .map(|t| t.schema().name.clone())
533            .collect();
534        assert!(tool_names.contains(&"write_todos".to_string()));
535
536        let mut request = ModelRequest::new("System", vec![]);
537        let mut ctx = MiddlewareContext::with_request(
538            &mut request,
539            Arc::new(RwLock::new(AgentStateSnapshot::default())),
540        );
541        middleware.modify_model_request(&mut ctx).await.unwrap();
542        assert!(ctx.request.system_prompt.contains("write_todos"));
543    }
544
545    #[tokio::test]
546    async fn filesystem_middleware_registers_tools() {
547        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
548        let middleware = FilesystemMiddleware::new(state);
549        let tool_names: Vec<_> = middleware
550            .tools()
551            .iter()
552            .map(|t| t.schema().name.clone())
553            .collect();
554        for expected in ["ls", "read_file", "write_file", "edit_file"] {
555            assert!(tool_names.contains(&expected.to_string()));
556        }
557    }
558
559    #[tokio::test]
560    async fn summarization_middleware_trims_messages() {
561        let middleware = SummarizationMiddleware::new(2, "Summary note");
562        let mut request = ModelRequest::new(
563            "System",
564            vec![
565                AgentMessage {
566                    role: MessageRole::User,
567                    content: MessageContent::Text("one".into()),
568                    metadata: None,
569                },
570                AgentMessage {
571                    role: MessageRole::Agent,
572                    content: MessageContent::Text("two".into()),
573                    metadata: None,
574                },
575                AgentMessage {
576                    role: MessageRole::User,
577                    content: MessageContent::Text("three".into()),
578                    metadata: None,
579                },
580            ],
581        );
582        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
583        let mut ctx = MiddlewareContext::with_request(&mut request, state);
584        middleware.modify_model_request(&mut ctx).await.unwrap();
585        assert_eq!(ctx.request.messages.len(), 3);
586        match &ctx.request.messages[0].content {
587            MessageContent::Text(text) => assert!(text.contains("Summary note")),
588            other => panic!("expected text, got {other:?}"),
589        }
590    }
591
592    struct StubAgent;
593
594    #[async_trait]
595    impl AgentHandle for StubAgent {
596        async fn describe(&self) -> AgentDescriptor {
597            AgentDescriptor {
598                name: "stub".into(),
599                version: "0.0.1".into(),
600                description: None,
601            }
602        }
603
604        async fn handle_message(
605            &self,
606            _input: AgentMessage,
607            _state: Arc<AgentStateSnapshot>,
608        ) -> anyhow::Result<AgentMessage> {
609            Ok(AgentMessage {
610                role: MessageRole::Agent,
611                content: MessageContent::Text("stub-response".into()),
612                metadata: None,
613            })
614        }
615    }
616
617    #[tokio::test]
618    async fn task_router_reports_unknown_subagent() {
619        let registry = Arc::new(SubAgentRegistry::new(vec![]));
620        let task_tool = TaskRouterTool::new(registry.clone());
621        let state = Arc::new(AgentStateSnapshot::default());
622        let ctx = ToolContext::new(state);
623
624        let response = task_tool
625            .execute(
626                json!({
627                    "description": "Do something",
628                    "subagent_type": "unknown"
629                }),
630                ctx,
631            )
632            .await
633            .unwrap();
634
635        match response {
636            ToolResult::Message(msg) => match msg.content {
637                MessageContent::Text(text) => assert!(text.contains("Unknown subagent")),
638                other => panic!("expected text, got {other:?}"),
639            },
640            _ => panic!("expected message"),
641        }
642    }
643
644    #[tokio::test]
645    async fn subagent_middleware_appends_prompt() {
646        let subagents = vec![SubAgentRegistration {
647            descriptor: SubAgentDescriptor {
648                name: "research-agent".into(),
649                description: "Deep research specialist".into(),
650            },
651            agent: Arc::new(StubAgent),
652        }];
653        let middleware = SubAgentMiddleware::new(subagents);
654
655        let mut request = ModelRequest::new("System", vec![]);
656        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
657        let mut ctx = MiddlewareContext::with_request(&mut request, state);
658        middleware.modify_model_request(&mut ctx).await.unwrap();
659
660        assert!(ctx.request.system_prompt.contains("research-agent"));
661        let tool_names: Vec<_> = middleware
662            .tools()
663            .iter()
664            .map(|t| t.schema().name.clone())
665            .collect();
666        assert!(tool_names.contains(&"task".to_string()));
667    }
668
669    #[tokio::test]
670    async fn task_router_invokes_registered_subagent() {
671        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
672            descriptor: SubAgentDescriptor {
673                name: "stub-agent".into(),
674                description: "Stub".into(),
675            },
676            agent: Arc::new(StubAgent),
677        }]));
678        let task_tool = TaskRouterTool::new(registry.clone());
679        let state = Arc::new(AgentStateSnapshot::default());
680        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
681        let response = task_tool
682            .execute(
683                json!({
684                    "description": "do work",
685                    "subagent_type": "stub-agent"
686                }),
687                ctx,
688            )
689            .await
690            .unwrap();
691
692        match response {
693            ToolResult::Message(msg) => {
694                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
695                match msg.content {
696                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
697                    other => panic!("expected text, got {other:?}"),
698                }
699            }
700            _ => panic!("expected message"),
701        }
702    }
703
704    #[tokio::test]
705    async fn human_in_loop_appends_prompt() {
706        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
707            "danger-tool".into(),
708            HitlPolicy {
709                allow_auto: false,
710                note: Some("Requires security review".into()),
711            },
712        )]));
713        let mut request = ModelRequest::new("System", vec![]);
714        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
715        let mut ctx = MiddlewareContext::with_request(&mut request, state);
716        middleware.modify_model_request(&mut ctx).await.unwrap();
717        assert!(ctx
718            .request
719            .system_prompt
720            .contains("danger-tool: Requires security review"));
721    }
722
723    #[tokio::test]
724    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
725        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
726        let mut request = ModelRequest::new(
727            "This is the system prompt",
728            vec![AgentMessage {
729                role: MessageRole::User,
730                content: MessageContent::Text("Hello".into()),
731                metadata: None,
732            }],
733        );
734        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
735        let mut ctx = MiddlewareContext::with_request(&mut request, state);
736
737        // Apply the middleware
738        middleware.modify_model_request(&mut ctx).await.unwrap();
739
740        // System prompt should be cleared
741        assert!(ctx.request.system_prompt.is_empty());
742
743        // Should have added a system message with cache control at the beginning
744        assert_eq!(ctx.request.messages.len(), 2);
745
746        let system_message = &ctx.request.messages[0];
747        assert!(matches!(system_message.role, MessageRole::System));
748        assert_eq!(
749            system_message.content.as_text().unwrap(),
750            "This is the system prompt"
751        );
752
753        // Check cache control metadata
754        let metadata = system_message.metadata.as_ref().unwrap();
755        let cache_control = metadata.cache_control.as_ref().unwrap();
756        assert_eq!(cache_control.cache_type, "ephemeral");
757
758        // Original user message should still be there
759        let user_message = &ctx.request.messages[1];
760        assert!(matches!(user_message.role, MessageRole::User));
761        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
762    }
763
764    #[tokio::test]
765    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
766        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
767        let mut request = ModelRequest::new("This is the system prompt", vec![]);
768        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
769        let mut ctx = MiddlewareContext::with_request(&mut request, state);
770
771        // Apply the middleware
772        middleware.modify_model_request(&mut ctx).await.unwrap();
773
774        // System prompt should be unchanged
775        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
776        assert_eq!(ctx.request.messages.len(), 0);
777    }
778
779    #[tokio::test]
780    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
781        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
782        let mut request = ModelRequest::new(
783            "",
784            vec![AgentMessage {
785                role: MessageRole::User,
786                content: MessageContent::Text("Hello".into()),
787                metadata: None,
788            }],
789        );
790        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
791        let mut ctx = MiddlewareContext::with_request(&mut request, state);
792
793        // Apply the middleware
794        middleware.modify_model_request(&mut ctx).await.unwrap();
795
796        // Should be unchanged
797        assert!(ctx.request.system_prompt.is_empty());
798        assert_eq!(ctx.request.messages.len(), 1);
799        assert!(matches!(ctx.request.messages[0].role, MessageRole::User));
800    }
801}