1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18#[derive(Debug, Clone)]
21pub struct ModelRequest {
22 pub system_prompt: String,
23 pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28 Self {
29 system_prompt: system_prompt.into(),
30 messages,
31 }
32 }
33
34 pub fn append_prompt(&mut self, fragment: &str) {
35 if !fragment.is_empty() {
36 self.system_prompt.push_str("\n\n");
37 self.system_prompt.push_str(fragment);
38 }
39 }
40}
41
42pub struct MiddlewareContext<'a> {
44 pub request: &'a mut ModelRequest,
45 pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49 pub fn with_request(
50 request: &'a mut ModelRequest,
51 state: Arc<RwLock<AgentStateSnapshot>>,
52 ) -> Self {
53 Self { request, state }
54 }
55}
56
57#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62 fn id(&self) -> &'static str;
64
65 fn tools(&self) -> Vec<ToolBox> {
67 Vec::new()
68 }
69
70 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72}
73
74pub struct SummarizationMiddleware {
75 pub messages_to_keep: usize,
76 pub summary_note: String,
77}
78
79impl SummarizationMiddleware {
80 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
81 Self {
82 messages_to_keep,
83 summary_note: summary_note.into(),
84 }
85 }
86}
87
88#[async_trait]
89impl AgentMiddleware for SummarizationMiddleware {
90 fn id(&self) -> &'static str {
91 "summarization"
92 }
93
94 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
95 if ctx.request.messages.len() > self.messages_to_keep {
96 let dropped = ctx.request.messages.len() - self.messages_to_keep;
97 let mut truncated = ctx
98 .request
99 .messages
100 .split_off(ctx.request.messages.len() - self.messages_to_keep);
101 truncated.insert(
102 0,
103 AgentMessage {
104 role: MessageRole::System,
105 content: MessageContent::Text(format!(
106 "{} ({} earlier messages summarized)",
107 self.summary_note, dropped
108 )),
109 metadata: None,
110 },
111 );
112 ctx.request.messages = truncated;
113 }
114 Ok(())
115 }
116}
117
118pub struct PlanningMiddleware {
119 _state: Arc<RwLock<AgentStateSnapshot>>,
120}
121
122impl PlanningMiddleware {
123 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
124 Self { _state: state }
125 }
126}
127
128#[async_trait]
129impl AgentMiddleware for PlanningMiddleware {
130 fn id(&self) -> &'static str {
131 "planning"
132 }
133
134 fn tools(&self) -> Vec<ToolBox> {
135 use agents_toolkit::create_todos_tools;
136 create_todos_tools()
137 }
138
139 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
140 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
141 Ok(())
142 }
143}
144
145pub struct FilesystemMiddleware {
146 _state: Arc<RwLock<AgentStateSnapshot>>,
147}
148
149impl FilesystemMiddleware {
150 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
151 Self { _state: state }
152 }
153}
154
155#[async_trait]
156impl AgentMiddleware for FilesystemMiddleware {
157 fn id(&self) -> &'static str {
158 "filesystem"
159 }
160
161 fn tools(&self) -> Vec<ToolBox> {
162 create_filesystem_tools()
163 }
164
165 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
167 Ok(())
168 }
169}
170
171#[derive(Clone)]
172pub struct SubAgentRegistration {
173 pub descriptor: SubAgentDescriptor,
174 pub agent: Arc<dyn AgentHandle>,
175}
176
177struct SubAgentRegistry {
178 agents: HashMap<String, Arc<dyn AgentHandle>>,
179}
180
181impl SubAgentRegistry {
182 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
183 let mut agents = HashMap::new();
184 for reg in registrations {
185 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
186 }
187 Self { agents }
188 }
189
190 fn available_names(&self) -> Vec<String> {
191 self.agents.keys().cloned().collect()
192 }
193
194 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
195 self.agents.get(name).cloned()
196 }
197}
198
199pub struct SubAgentMiddleware {
200 task_tool: ToolBox,
201 descriptors: Vec<SubAgentDescriptor>,
202 _registry: Arc<SubAgentRegistry>,
203}
204
205impl SubAgentMiddleware {
206 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
207 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
208 let registry = Arc::new(SubAgentRegistry::new(registrations));
209 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone()));
210 Self {
211 task_tool,
212 descriptors,
213 _registry: registry,
214 }
215 }
216
217 fn prompt_fragment(&self) -> String {
218 let descriptions: Vec<String> = if self.descriptors.is_empty() {
219 vec![String::from("- general-purpose: Default reasoning agent")]
220 } else {
221 self.descriptors
222 .iter()
223 .map(|agent| format!("- {}: {}", agent.name, agent.description))
224 .collect()
225 };
226
227 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
228 }
229}
230
231#[async_trait]
232impl AgentMiddleware for SubAgentMiddleware {
233 fn id(&self) -> &'static str {
234 "subagent"
235 }
236
237 fn tools(&self) -> Vec<ToolBox> {
238 vec![self.task_tool.clone()]
239 }
240
241 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
242 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
243 ctx.request.append_prompt(&self.prompt_fragment());
244 Ok(())
245 }
246}
247
248#[derive(Clone, Debug)]
249pub struct HitlPolicy {
250 pub allow_auto: bool,
251 pub note: Option<String>,
252}
253
254pub struct HumanInLoopMiddleware {
255 policies: HashMap<String, HitlPolicy>,
256}
257
258impl HumanInLoopMiddleware {
259 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
260 Self { policies }
261 }
262
263 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
264 self.policies
265 .get(tool_name)
266 .filter(|policy| !policy.allow_auto)
267 }
268
269 fn prompt_fragment(&self) -> Option<String> {
270 let pending: Vec<String> = self
271 .policies
272 .iter()
273 .filter(|(_, policy)| !policy.allow_auto)
274 .map(|(tool, policy)| match &policy.note {
275 Some(note) => format!("- {tool}: {note}"),
276 None => format!("- {tool}: Requires approval"),
277 })
278 .collect();
279 if pending.is_empty() {
280 None
281 } else {
282 Some(format!(
283 "The following tools require human approval before execution:\n{}",
284 pending.join("\n")
285 ))
286 }
287 }
288}
289
290#[async_trait]
291impl AgentMiddleware for HumanInLoopMiddleware {
292 fn id(&self) -> &'static str {
293 "human-in-loop"
294 }
295
296 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
297 if let Some(fragment) = self.prompt_fragment() {
298 ctx.request.append_prompt(&fragment);
299 }
300 ctx.request.messages.push(AgentMessage {
301 role: MessageRole::System,
302 content: MessageContent::Text(
303 "Tools marked for human approval will emit interrupts requiring external resolution."
304 .into(),
305 ),
306 metadata: None,
307 });
308 Ok(())
309 }
310}
311
312pub struct BaseSystemPromptMiddleware;
313
314#[async_trait]
315impl AgentMiddleware for BaseSystemPromptMiddleware {
316 fn id(&self) -> &'static str {
317 "base-system-prompt"
318 }
319
320 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
321 ctx.request.append_prompt(BASE_AGENT_PROMPT);
322 Ok(())
323 }
324}
325
326pub struct DeepAgentPromptMiddleware {
336 custom_instructions: String,
337}
338
339impl DeepAgentPromptMiddleware {
340 pub fn new(custom_instructions: impl Into<String>) -> Self {
341 Self {
342 custom_instructions: custom_instructions.into(),
343 }
344 }
345}
346
347#[async_trait]
348impl AgentMiddleware for DeepAgentPromptMiddleware {
349 fn id(&self) -> &'static str {
350 "deep-agent-prompt"
351 }
352
353 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
354 use crate::prompts::get_deep_agent_system_prompt;
355 let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
356 ctx.request.append_prompt(&deep_prompt);
357 Ok(())
358 }
359}
360
361pub struct AnthropicPromptCachingMiddleware {
364 pub ttl: String,
365 pub unsupported_model_behavior: String,
366}
367
368impl AnthropicPromptCachingMiddleware {
369 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
370 Self {
371 ttl: ttl.into(),
372 unsupported_model_behavior: unsupported_model_behavior.into(),
373 }
374 }
375
376 pub fn with_defaults() -> Self {
377 Self::new("5m", "ignore")
378 }
379
380 fn should_enable_caching(&self) -> bool {
383 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
384 }
385}
386
387#[async_trait]
388impl AgentMiddleware for AnthropicPromptCachingMiddleware {
389 fn id(&self) -> &'static str {
390 "anthropic-prompt-caching"
391 }
392
393 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
394 if !self.should_enable_caching() {
395 return Ok(());
396 }
397
398 if !ctx.request.system_prompt.is_empty() {
400 let system_message = AgentMessage {
401 role: MessageRole::System,
402 content: MessageContent::Text(ctx.request.system_prompt.clone()),
403 metadata: Some(MessageMetadata {
404 tool_call_id: None,
405 cache_control: Some(CacheControl {
406 cache_type: "ephemeral".to_string(),
407 }),
408 }),
409 };
410
411 ctx.request.messages.insert(0, system_message);
413
414 ctx.request.system_prompt.clear();
416
417 tracing::debug!(
418 ttl = %self.ttl,
419 behavior = %self.unsupported_model_behavior,
420 "Applied Anthropic prompt caching to system message"
421 );
422 }
423
424 Ok(())
425 }
426}
427
428pub struct TaskRouterTool {
429 registry: Arc<SubAgentRegistry>,
430}
431
432impl TaskRouterTool {
433 fn new(registry: Arc<SubAgentRegistry>) -> Self {
434 Self { registry }
435 }
436
437 fn available_subagents(&self) -> Vec<String> {
438 self.registry.available_names()
439 }
440}
441
442#[derive(Debug, Clone, Deserialize)]
443struct TaskInvocationArgs {
444 #[serde(alias = "description")]
445 instruction: String,
446 #[serde(alias = "subagent_type")]
447 agent: String,
448}
449
450#[async_trait]
451impl Tool for TaskRouterTool {
452 fn schema(&self) -> agents_core::tools::ToolSchema {
453 use agents_core::tools::{ToolParameterSchema, ToolSchema};
454 use std::collections::HashMap;
455
456 let mut properties = HashMap::new();
457 properties.insert(
458 "agent".to_string(),
459 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
460 );
461 properties.insert(
462 "instruction".to_string(),
463 ToolParameterSchema::string("Clear instruction for the sub-agent"),
464 );
465
466 ToolSchema::new(
467 "task",
468 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
469 ToolParameterSchema::object(
470 "Task delegation parameters",
471 properties,
472 vec!["agent".to_string(), "instruction".to_string()],
473 ),
474 )
475 }
476
477 async fn execute(
478 &self,
479 args: serde_json::Value,
480 ctx: ToolContext,
481 ) -> anyhow::Result<ToolResult> {
482 let args: TaskInvocationArgs = serde_json::from_value(args)?;
483 let available = self.available_subagents();
484
485 if let Some(agent) = self.registry.get(&args.agent) {
486 tracing::warn!(
488 "🎯 DELEGATING to sub-agent: {} with instruction: {}",
489 args.agent,
490 args.instruction
491 );
492
493 let start_time = std::time::Instant::now();
494 let user_message = AgentMessage {
495 role: MessageRole::User,
496 content: MessageContent::Text(args.instruction.clone()),
497 metadata: None,
498 };
499
500 let response = agent
501 .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
502 .await?;
503
504 let duration = start_time.elapsed();
506 let response_preview = match &response.content {
507 MessageContent::Text(t) => {
508 if t.len() > 100 {
509 format!("{}... ({} chars)", &t[..100], t.len())
510 } else {
511 t.clone()
512 }
513 }
514 MessageContent::Json(v) => {
515 format!("JSON: {} bytes", v.to_string().len())
516 }
517 };
518
519 tracing::warn!(
520 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
521 args.agent,
522 duration,
523 response_preview
524 );
525
526 let result_text = match response.content {
529 MessageContent::Text(text) => text,
530 MessageContent::Json(json) => json.to_string(),
531 };
532
533 return Ok(ToolResult::text(&ctx, result_text));
534 }
535
536 tracing::error!(
537 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
538 args.agent,
539 available
540 );
541
542 Ok(ToolResult::text(
543 &ctx,
544 format!(
545 "Sub-agent '{}' not found. Available sub-agents: {}",
546 args.agent,
547 available.join(", ")
548 ),
549 ))
550 }
551}
552
553#[derive(Debug, Clone)]
554pub struct SubAgentDescriptor {
555 pub name: String,
556 pub description: String,
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use agents_core::agent::{AgentDescriptor, AgentHandle};
563 use agents_core::messaging::{MessageContent, MessageRole};
564 use serde_json::json;
565
566 struct AppendPromptMiddleware;
567
568 #[async_trait]
569 impl AgentMiddleware for AppendPromptMiddleware {
570 fn id(&self) -> &'static str {
571 "append-prompt"
572 }
573
574 async fn modify_model_request(
575 &self,
576 ctx: &mut MiddlewareContext<'_>,
577 ) -> anyhow::Result<()> {
578 ctx.request.system_prompt.push_str("\nExtra directives.");
579 Ok(())
580 }
581 }
582
583 #[tokio::test]
584 async fn middleware_mutates_prompt() {
585 let mut request = ModelRequest::new(
586 "System",
587 vec![AgentMessage {
588 role: MessageRole::User,
589 content: MessageContent::Text("Hi".into()),
590 metadata: None,
591 }],
592 );
593 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
594 let mut ctx = MiddlewareContext::with_request(&mut request, state);
595 let middleware = AppendPromptMiddleware;
596 middleware.modify_model_request(&mut ctx).await.unwrap();
597 assert!(ctx.request.system_prompt.contains("Extra directives"));
598 }
599
600 #[tokio::test]
601 async fn planning_middleware_registers_write_todos() {
602 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
603 let middleware = PlanningMiddleware::new(state);
604 let tool_names: Vec<_> = middleware
605 .tools()
606 .iter()
607 .map(|t| t.schema().name.clone())
608 .collect();
609 assert!(tool_names.contains(&"write_todos".to_string()));
610
611 let mut request = ModelRequest::new("System", vec![]);
612 let mut ctx = MiddlewareContext::with_request(
613 &mut request,
614 Arc::new(RwLock::new(AgentStateSnapshot::default())),
615 );
616 middleware.modify_model_request(&mut ctx).await.unwrap();
617 assert!(ctx.request.system_prompt.contains("write_todos"));
618 }
619
620 #[tokio::test]
621 async fn filesystem_middleware_registers_tools() {
622 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
623 let middleware = FilesystemMiddleware::new(state);
624 let tool_names: Vec<_> = middleware
625 .tools()
626 .iter()
627 .map(|t| t.schema().name.clone())
628 .collect();
629 for expected in ["ls", "read_file", "write_file", "edit_file"] {
630 assert!(tool_names.contains(&expected.to_string()));
631 }
632 }
633
634 #[tokio::test]
635 async fn summarization_middleware_trims_messages() {
636 let middleware = SummarizationMiddleware::new(2, "Summary note");
637 let mut request = ModelRequest::new(
638 "System",
639 vec![
640 AgentMessage {
641 role: MessageRole::User,
642 content: MessageContent::Text("one".into()),
643 metadata: None,
644 },
645 AgentMessage {
646 role: MessageRole::Agent,
647 content: MessageContent::Text("two".into()),
648 metadata: None,
649 },
650 AgentMessage {
651 role: MessageRole::User,
652 content: MessageContent::Text("three".into()),
653 metadata: None,
654 },
655 ],
656 );
657 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
658 let mut ctx = MiddlewareContext::with_request(&mut request, state);
659 middleware.modify_model_request(&mut ctx).await.unwrap();
660 assert_eq!(ctx.request.messages.len(), 3);
661 match &ctx.request.messages[0].content {
662 MessageContent::Text(text) => assert!(text.contains("Summary note")),
663 other => panic!("expected text, got {other:?}"),
664 }
665 }
666
667 struct StubAgent;
668
669 #[async_trait]
670 impl AgentHandle for StubAgent {
671 async fn describe(&self) -> AgentDescriptor {
672 AgentDescriptor {
673 name: "stub".into(),
674 version: "0.0.1".into(),
675 description: None,
676 }
677 }
678
679 async fn handle_message(
680 &self,
681 _input: AgentMessage,
682 _state: Arc<AgentStateSnapshot>,
683 ) -> anyhow::Result<AgentMessage> {
684 Ok(AgentMessage {
685 role: MessageRole::Agent,
686 content: MessageContent::Text("stub-response".into()),
687 metadata: None,
688 })
689 }
690 }
691
692 #[tokio::test]
693 async fn task_router_reports_unknown_subagent() {
694 let registry = Arc::new(SubAgentRegistry::new(vec![]));
695 let task_tool = TaskRouterTool::new(registry.clone());
696 let state = Arc::new(AgentStateSnapshot::default());
697 let ctx = ToolContext::new(state);
698
699 let response = task_tool
700 .execute(
701 json!({
702 "instruction": "Do something",
703 "agent": "unknown"
704 }),
705 ctx,
706 )
707 .await
708 .unwrap();
709
710 match response {
711 ToolResult::Message(msg) => match msg.content {
712 MessageContent::Text(text) => {
713 assert!(text.contains("Sub-agent 'unknown' not found"))
714 }
715 other => panic!("expected text, got {other:?}"),
716 },
717 _ => panic!("expected message"),
718 }
719 }
720
721 #[tokio::test]
722 async fn subagent_middleware_appends_prompt() {
723 let subagents = vec![SubAgentRegistration {
724 descriptor: SubAgentDescriptor {
725 name: "research-agent".into(),
726 description: "Deep research specialist".into(),
727 },
728 agent: Arc::new(StubAgent),
729 }];
730 let middleware = SubAgentMiddleware::new(subagents);
731
732 let mut request = ModelRequest::new("System", vec![]);
733 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
734 let mut ctx = MiddlewareContext::with_request(&mut request, state);
735 middleware.modify_model_request(&mut ctx).await.unwrap();
736
737 assert!(ctx.request.system_prompt.contains("research-agent"));
738 let tool_names: Vec<_> = middleware
739 .tools()
740 .iter()
741 .map(|t| t.schema().name.clone())
742 .collect();
743 assert!(tool_names.contains(&"task".to_string()));
744 }
745
746 #[tokio::test]
747 async fn task_router_invokes_registered_subagent() {
748 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
749 descriptor: SubAgentDescriptor {
750 name: "stub-agent".into(),
751 description: "Stub".into(),
752 },
753 agent: Arc::new(StubAgent),
754 }]));
755 let task_tool = TaskRouterTool::new(registry.clone());
756 let state = Arc::new(AgentStateSnapshot::default());
757 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
758 let response = task_tool
759 .execute(
760 json!({
761 "description": "do work",
762 "subagent_type": "stub-agent"
763 }),
764 ctx,
765 )
766 .await
767 .unwrap();
768
769 match response {
770 ToolResult::Message(msg) => {
771 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
772 match msg.content {
773 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
774 other => panic!("expected text, got {other:?}"),
775 }
776 }
777 _ => panic!("expected message"),
778 }
779 }
780
781 #[tokio::test]
782 async fn human_in_loop_appends_prompt() {
783 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
784 "danger-tool".into(),
785 HitlPolicy {
786 allow_auto: false,
787 note: Some("Requires security review".into()),
788 },
789 )]));
790 let mut request = ModelRequest::new("System", vec![]);
791 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
792 let mut ctx = MiddlewareContext::with_request(&mut request, state);
793 middleware.modify_model_request(&mut ctx).await.unwrap();
794 assert!(ctx
795 .request
796 .system_prompt
797 .contains("danger-tool: Requires security review"));
798 }
799
800 #[tokio::test]
801 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
802 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
803 let mut request = ModelRequest::new(
804 "This is the system prompt",
805 vec![AgentMessage {
806 role: MessageRole::User,
807 content: MessageContent::Text("Hello".into()),
808 metadata: None,
809 }],
810 );
811 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
812 let mut ctx = MiddlewareContext::with_request(&mut request, state);
813
814 middleware.modify_model_request(&mut ctx).await.unwrap();
816
817 assert!(ctx.request.system_prompt.is_empty());
819
820 assert_eq!(ctx.request.messages.len(), 2);
822
823 let system_message = &ctx.request.messages[0];
824 assert!(matches!(system_message.role, MessageRole::System));
825 assert_eq!(
826 system_message.content.as_text().unwrap(),
827 "This is the system prompt"
828 );
829
830 let metadata = system_message.metadata.as_ref().unwrap();
832 let cache_control = metadata.cache_control.as_ref().unwrap();
833 assert_eq!(cache_control.cache_type, "ephemeral");
834
835 let user_message = &ctx.request.messages[1];
837 assert!(matches!(user_message.role, MessageRole::User));
838 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
839 }
840
841 #[tokio::test]
842 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
843 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
844 let mut request = ModelRequest::new("This is the system prompt", vec![]);
845 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
846 let mut ctx = MiddlewareContext::with_request(&mut request, state);
847
848 middleware.modify_model_request(&mut ctx).await.unwrap();
850
851 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
853 assert_eq!(ctx.request.messages.len(), 0);
854 }
855
856 #[tokio::test]
857 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
858 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
859 let mut request = ModelRequest::new(
860 "",
861 vec![AgentMessage {
862 role: MessageRole::User,
863 content: MessageContent::Text("Hello".into()),
864 metadata: None,
865 }],
866 );
867 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
868 let mut ctx = MiddlewareContext::with_request(&mut request, state);
869
870 middleware.modify_model_request(&mut ctx).await.unwrap();
872
873 assert!(ctx.request.system_prompt.is_empty());
875 assert_eq!(ctx.request.messages.len(), 1);
876 assert!(matches!(ctx.request.messages[0].role, MessageRole::User));
877 }
878}