1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20#[derive(Debug, Clone)]
23pub struct ModelRequest {
24 pub system_prompt: String,
25 pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30 Self {
31 system_prompt: system_prompt.into(),
32 messages,
33 }
34 }
35
36 pub fn append_prompt(&mut self, fragment: &str) {
37 if !fragment.is_empty() {
38 self.system_prompt.push_str("\n\n");
39 self.system_prompt.push_str(fragment);
40 }
41 }
42}
43
44pub struct MiddlewareContext<'a> {
46 pub request: &'a mut ModelRequest,
47 pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51 pub fn with_request(
52 request: &'a mut ModelRequest,
53 state: Arc<RwLock<AgentStateSnapshot>>,
54 ) -> Self {
55 Self { request, state }
56 }
57}
58
59#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64 fn id(&self) -> &'static str;
66
67 fn tools(&self) -> Vec<ToolBox> {
69 Vec::new()
70 }
71
72 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75 async fn before_tool_execution(
91 &self,
92 _tool_name: &str,
93 _tool_args: &serde_json::Value,
94 _call_id: &str,
95 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96 Ok(None)
97 }
98}
99
100pub struct SummarizationMiddleware {
101 pub messages_to_keep: usize,
102 pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107 Self {
108 messages_to_keep,
109 summary_note: summary_note.into(),
110 }
111 }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116 fn id(&self) -> &'static str {
117 "summarization"
118 }
119
120 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121 if ctx.request.messages.len() > self.messages_to_keep {
122 let dropped = ctx.request.messages.len() - self.messages_to_keep;
123 let mut truncated = ctx
124 .request
125 .messages
126 .split_off(ctx.request.messages.len() - self.messages_to_keep);
127 truncated.insert(
128 0,
129 AgentMessage {
130 role: MessageRole::System,
131 content: MessageContent::Text(format!(
132 "{} ({} earlier messages summarized)",
133 self.summary_note, dropped
134 )),
135 metadata: None,
136 },
137 );
138 ctx.request.messages = truncated;
139 }
140 Ok(())
141 }
142}
143
144pub struct PlanningMiddleware {
145 _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150 Self { _state: state }
151 }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156 fn id(&self) -> &'static str {
157 "planning"
158 }
159
160 fn tools(&self) -> Vec<ToolBox> {
161 use agents_toolkit::create_todos_tool;
165 vec![create_todos_tool()]
166 }
167
168 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
169 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
170 Ok(())
171 }
172}
173
174pub struct FilesystemMiddleware {
175 _state: Arc<RwLock<AgentStateSnapshot>>,
176}
177
178impl FilesystemMiddleware {
179 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
180 Self { _state: state }
181 }
182}
183
184#[async_trait]
185impl AgentMiddleware for FilesystemMiddleware {
186 fn id(&self) -> &'static str {
187 "filesystem"
188 }
189
190 fn tools(&self) -> Vec<ToolBox> {
191 create_filesystem_tools()
192 }
193
194 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
195 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
196 Ok(())
197 }
198}
199
200#[derive(Clone)]
201pub struct SubAgentRegistration {
202 pub descriptor: SubAgentDescriptor,
203 pub agent: Arc<dyn AgentHandle>,
204}
205
206struct SubAgentRegistry {
207 agents: HashMap<String, Arc<dyn AgentHandle>>,
208}
209
210impl SubAgentRegistry {
211 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
212 let mut agents = HashMap::new();
213 for reg in registrations {
214 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
215 }
216 Self { agents }
217 }
218
219 fn available_names(&self) -> Vec<String> {
220 self.agents.keys().cloned().collect()
221 }
222
223 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
224 self.agents.get(name).cloned()
225 }
226}
227
228pub struct SubAgentMiddleware {
229 task_tool: ToolBox,
230 descriptors: Vec<SubAgentDescriptor>,
231 _registry: Arc<SubAgentRegistry>,
232}
233
234impl SubAgentMiddleware {
235 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
236 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
237 let registry = Arc::new(SubAgentRegistry::new(registrations));
238 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
239 Self {
240 task_tool,
241 descriptors,
242 _registry: registry,
243 }
244 }
245
246 pub fn new_with_events(
247 registrations: Vec<SubAgentRegistration>,
248 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
249 ) -> Self {
250 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
251 let registry = Arc::new(SubAgentRegistry::new(registrations));
252 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
253 Self {
254 task_tool,
255 descriptors,
256 _registry: registry,
257 }
258 }
259
260 fn prompt_fragment(&self) -> String {
261 let descriptions: Vec<String> = if self.descriptors.is_empty() {
262 vec![String::from("- general-purpose: Default reasoning agent")]
263 } else {
264 self.descriptors
265 .iter()
266 .map(|agent| format!("- {}: {}", agent.name, agent.description))
267 .collect()
268 };
269
270 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
271 }
272}
273
274#[async_trait]
275impl AgentMiddleware for SubAgentMiddleware {
276 fn id(&self) -> &'static str {
277 "subagent"
278 }
279
280 fn tools(&self) -> Vec<ToolBox> {
281 vec![self.task_tool.clone()]
282 }
283
284 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
285 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
286 ctx.request.append_prompt(&self.prompt_fragment());
287 Ok(())
288 }
289}
290
291#[derive(Clone, Debug)]
292pub struct HitlPolicy {
293 pub allow_auto: bool,
294 pub note: Option<String>,
295}
296
297pub struct HumanInLoopMiddleware {
298 policies: HashMap<String, HitlPolicy>,
299}
300
301impl HumanInLoopMiddleware {
302 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
303 Self { policies }
304 }
305
306 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
307 self.policies
308 .get(tool_name)
309 .filter(|policy| !policy.allow_auto)
310 }
311
312 fn prompt_fragment(&self) -> Option<String> {
313 let pending: Vec<String> = self
314 .policies
315 .iter()
316 .filter(|(_, policy)| !policy.allow_auto)
317 .map(|(tool, policy)| match &policy.note {
318 Some(note) => format!("- {tool}: {note}"),
319 None => format!("- {tool}: Requires approval"),
320 })
321 .collect();
322 if pending.is_empty() {
323 None
324 } else {
325 Some(format!(
326 "The following tools require human approval before execution:\n{}",
327 pending.join("\n")
328 ))
329 }
330 }
331}
332
333#[async_trait]
334impl AgentMiddleware for HumanInLoopMiddleware {
335 fn id(&self) -> &'static str {
336 "human-in-loop"
337 }
338
339 async fn before_tool_execution(
340 &self,
341 tool_name: &str,
342 tool_args: &serde_json::Value,
343 call_id: &str,
344 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
345 if let Some(policy) = self.requires_approval(tool_name) {
346 tracing::warn!(
347 tool_name = %tool_name,
348 call_id = %call_id,
349 policy_note = ?policy.note,
350 "🔒 HITL: Tool execution requires human approval"
351 );
352
353 let interrupt = agents_core::hitl::HitlInterrupt::new(
354 tool_name,
355 tool_args.clone(),
356 call_id,
357 policy.note.clone(),
358 );
359
360 return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
361 interrupt,
362 )));
363 }
364
365 Ok(None)
366 }
367
368 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
369 if let Some(fragment) = self.prompt_fragment() {
370 ctx.request.append_prompt(&fragment);
371 }
372 ctx.request.messages.push(AgentMessage {
373 role: MessageRole::System,
374 content: MessageContent::Text(
375 "Tools marked for human approval will emit interrupts requiring external resolution."
376 .into(),
377 ),
378 metadata: None,
379 });
380 Ok(())
381 }
382}
383
384pub struct BaseSystemPromptMiddleware;
385
386#[async_trait]
387impl AgentMiddleware for BaseSystemPromptMiddleware {
388 fn id(&self) -> &'static str {
389 "base-system-prompt"
390 }
391
392 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
393 ctx.request.append_prompt(BASE_AGENT_PROMPT);
394 Ok(())
395 }
396}
397
398pub struct DeepAgentPromptMiddleware {
413 custom_instructions: String,
414 prompt_format: crate::prompts::PromptFormat,
416 override_system_prompt: Option<String>,
418}
419
420impl DeepAgentPromptMiddleware {
421 pub fn new(custom_instructions: impl Into<String>) -> Self {
422 Self {
423 custom_instructions: custom_instructions.into(),
424 prompt_format: crate::prompts::PromptFormat::Json,
425 override_system_prompt: None,
426 }
427 }
428
429 pub fn with_format(
434 custom_instructions: impl Into<String>,
435 format: crate::prompts::PromptFormat,
436 ) -> Self {
437 Self {
438 custom_instructions: custom_instructions.into(),
439 prompt_format: format,
440 override_system_prompt: None,
441 }
442 }
443
444 pub fn with_override(system_prompt: impl Into<String>) -> Self {
449 Self {
450 custom_instructions: String::new(),
451 prompt_format: crate::prompts::PromptFormat::Json,
452 override_system_prompt: Some(system_prompt.into()),
453 }
454 }
455}
456
457#[async_trait]
458impl AgentMiddleware for DeepAgentPromptMiddleware {
459 fn id(&self) -> &'static str {
460 "deep-agent-prompt"
461 }
462
463 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
464 let prompt = if let Some(ref override_prompt) = self.override_system_prompt {
465 override_prompt.clone()
467 } else {
468 use crate::prompts::get_deep_agent_system_prompt_formatted;
470 get_deep_agent_system_prompt_formatted(&self.custom_instructions, self.prompt_format)
471 };
472 ctx.request.append_prompt(&prompt);
473 Ok(())
474 }
475}
476
477pub struct AnthropicPromptCachingMiddleware {
480 pub ttl: String,
481 pub unsupported_model_behavior: String,
482}
483
484impl AnthropicPromptCachingMiddleware {
485 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
486 Self {
487 ttl: ttl.into(),
488 unsupported_model_behavior: unsupported_model_behavior.into(),
489 }
490 }
491
492 pub fn with_defaults() -> Self {
493 Self::new("5m", "ignore")
494 }
495
496 fn should_enable_caching(&self) -> bool {
499 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
500 }
501}
502
503#[async_trait]
504impl AgentMiddleware for AnthropicPromptCachingMiddleware {
505 fn id(&self) -> &'static str {
506 "anthropic-prompt-caching"
507 }
508
509 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
510 if !self.should_enable_caching() {
511 return Ok(());
512 }
513
514 if !ctx.request.system_prompt.is_empty() {
516 let system_message = AgentMessage {
517 role: MessageRole::System,
518 content: MessageContent::Text(ctx.request.system_prompt.clone()),
519 metadata: Some(MessageMetadata {
520 tool_call_id: None,
521 cache_control: Some(CacheControl {
522 cache_type: "ephemeral".to_string(),
523 }),
524 }),
525 };
526
527 ctx.request.messages.insert(0, system_message);
529
530 ctx.request.system_prompt.clear();
532
533 tracing::debug!(
534 ttl = %self.ttl,
535 behavior = %self.unsupported_model_behavior,
536 "Applied Anthropic prompt caching to system message"
537 );
538 }
539
540 Ok(())
541 }
542}
543
544pub struct TaskRouterTool {
545 registry: Arc<SubAgentRegistry>,
546 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
547 delegation_depth: Arc<RwLock<u32>>,
548}
549
550impl TaskRouterTool {
551 fn new(
552 registry: Arc<SubAgentRegistry>,
553 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
554 ) -> Self {
555 Self {
556 registry,
557 event_dispatcher,
558 delegation_depth: Arc::new(RwLock::new(0)),
559 }
560 }
561
562 fn available_subagents(&self) -> Vec<String> {
563 self.registry.available_names()
564 }
565
566 fn emit_event(&self, event: agents_core::events::AgentEvent) {
567 if let Some(dispatcher) = &self.event_dispatcher {
568 let dispatcher_clone = dispatcher.clone();
569 tokio::spawn(async move {
570 dispatcher_clone.dispatch(event).await;
571 });
572 }
573 }
574
575 fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
576 agents_core::events::EventMetadata::new(
577 "default".to_string(),
578 uuid::Uuid::new_v4().to_string(),
579 None,
580 )
581 }
582
583 fn get_delegation_depth(&self) -> u32 {
584 *self.delegation_depth.read().unwrap_or_else(|_| {
585 tracing::warn!("Failed to read delegation depth, defaulting to 0");
586 panic!("RwLock poisoned")
587 })
588 }
589
590 fn increment_delegation_depth(&self) {
591 if let Ok(mut depth) = self.delegation_depth.write() {
592 *depth += 1;
593 }
594 }
595
596 fn decrement_delegation_depth(&self) {
597 if let Ok(mut depth) = self.delegation_depth.write() {
598 if *depth > 0 {
599 *depth -= 1;
600 }
601 }
602 }
603}
604
605#[derive(Debug, Clone, Deserialize)]
606struct TaskInvocationArgs {
607 #[serde(alias = "description")]
608 instruction: String,
609 #[serde(alias = "subagent_type")]
610 agent: String,
611}
612
613#[async_trait]
614impl Tool for TaskRouterTool {
615 fn schema(&self) -> agents_core::tools::ToolSchema {
616 use agents_core::tools::{ToolParameterSchema, ToolSchema};
617 use std::collections::HashMap;
618
619 let mut properties = HashMap::new();
620 properties.insert(
621 "agent".to_string(),
622 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
623 );
624 properties.insert(
625 "instruction".to_string(),
626 ToolParameterSchema::string("Clear instruction for the sub-agent"),
627 );
628
629 ToolSchema::new(
630 "task",
631 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
632 ToolParameterSchema::object(
633 "Task delegation parameters",
634 properties,
635 vec!["agent".to_string(), "instruction".to_string()],
636 ),
637 )
638 }
639
640 async fn execute(
641 &self,
642 args: serde_json::Value,
643 ctx: ToolContext,
644 ) -> anyhow::Result<ToolResult> {
645 let args: TaskInvocationArgs = serde_json::from_value(args)?;
646 let available = self.available_subagents();
647
648 if let Some(agent) = self.registry.get(&args.agent) {
649 self.increment_delegation_depth();
651 let current_depth = self.get_delegation_depth();
652
653 let instruction_summary = if args.instruction.chars().count() > 100 {
655 format!("{:.100}...", &args.instruction)
656 } else {
657 args.instruction.clone()
658 };
659
660 self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
662 agents_core::events::SubAgentStartedEvent {
663 metadata: self.create_event_metadata(),
664 agent_name: args.agent.clone(),
665 instruction_summary: instruction_summary.clone(),
666 delegation_depth: current_depth,
667 },
668 ));
669
670 tracing::warn!(
672 "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
673 args.agent,
674 current_depth,
675 args.instruction
676 );
677
678 let start_time = std::time::Instant::now();
679 let user_message = AgentMessage {
680 role: MessageRole::User,
681 content: MessageContent::Text(args.instruction.clone()),
682 metadata: None,
683 };
684
685 let response = agent
686 .handle_message(user_message, ctx.state.clone())
687 .await?;
688
689 let duration = start_time.elapsed();
691 let duration_ms = duration.as_millis() as u64;
692
693 let response_preview = match &response.content {
695 MessageContent::Text(t) => {
696 if t.chars().count() > 100 {
697 format!("{:.100}...", t)
698 } else {
699 t.clone()
700 }
701 }
702 MessageContent::Json(v) => {
703 let json_str = v.to_string();
704 if json_str.chars().count() > 100 {
705 format!("{:.100}...", json_str)
706 } else {
707 json_str
708 }
709 }
710 };
711
712 self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
714 agents_core::events::SubAgentCompletedEvent {
715 metadata: self.create_event_metadata(),
716 agent_name: args.agent.clone(),
717 duration_ms,
718 result_summary: response_preview.clone(),
719 },
720 ));
721
722 tracing::warn!(
724 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
725 args.agent,
726 duration,
727 response_preview
728 );
729
730 self.decrement_delegation_depth();
732
733 let result_text = match response.content {
736 MessageContent::Text(text) => text,
737 MessageContent::Json(json) => json.to_string(),
738 };
739
740 return Ok(ToolResult::text(&ctx, result_text));
741 }
742
743 tracing::error!(
744 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
745 args.agent,
746 available
747 );
748
749 Ok(ToolResult::text(
750 &ctx,
751 format!(
752 "Sub-agent '{}' not found. Available sub-agents: {}",
753 args.agent,
754 available.join(", ")
755 ),
756 ))
757 }
758}
759
760#[derive(Debug, Clone)]
761pub struct SubAgentDescriptor {
762 pub name: String,
763 pub description: String,
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use agents_core::agent::{AgentDescriptor, AgentHandle};
770 use agents_core::messaging::{MessageContent, MessageRole};
771 use serde_json::json;
772
773 struct AppendPromptMiddleware;
774
775 #[async_trait]
776 impl AgentMiddleware for AppendPromptMiddleware {
777 fn id(&self) -> &'static str {
778 "append-prompt"
779 }
780
781 async fn modify_model_request(
782 &self,
783 ctx: &mut MiddlewareContext<'_>,
784 ) -> anyhow::Result<()> {
785 ctx.request.system_prompt.push_str("\nExtra directives.");
786 Ok(())
787 }
788 }
789
790 #[tokio::test]
791 async fn middleware_mutates_prompt() {
792 let mut request = ModelRequest::new(
793 "System",
794 vec![AgentMessage {
795 role: MessageRole::User,
796 content: MessageContent::Text("Hi".into()),
797 metadata: None,
798 }],
799 );
800 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
801 let mut ctx = MiddlewareContext::with_request(&mut request, state);
802 let middleware = AppendPromptMiddleware;
803 middleware.modify_model_request(&mut ctx).await.unwrap();
804 assert!(ctx.request.system_prompt.contains("Extra directives"));
805 }
806
807 #[tokio::test]
808 async fn planning_middleware_registers_write_todos() {
809 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
810 let middleware = PlanningMiddleware::new(state);
811 let tool_names: Vec<_> = middleware
812 .tools()
813 .iter()
814 .map(|t| t.schema().name.clone())
815 .collect();
816 assert!(tool_names.contains(&"write_todos".to_string()));
817
818 let mut request = ModelRequest::new("System", vec![]);
819 let mut ctx = MiddlewareContext::with_request(
820 &mut request,
821 Arc::new(RwLock::new(AgentStateSnapshot::default())),
822 );
823 middleware.modify_model_request(&mut ctx).await.unwrap();
824 assert!(ctx.request.system_prompt.contains("write_todos"));
825 }
826
827 #[tokio::test]
828 async fn filesystem_middleware_registers_tools() {
829 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
830 let middleware = FilesystemMiddleware::new(state);
831 let tool_names: Vec<_> = middleware
832 .tools()
833 .iter()
834 .map(|t| t.schema().name.clone())
835 .collect();
836 for expected in ["ls", "read_file", "write_file", "edit_file"] {
837 assert!(tool_names.contains(&expected.to_string()));
838 }
839 }
840
841 #[tokio::test]
842 async fn summarization_middleware_trims_messages() {
843 let middleware = SummarizationMiddleware::new(2, "Summary note");
844 let mut request = ModelRequest::new(
845 "System",
846 vec![
847 AgentMessage {
848 role: MessageRole::User,
849 content: MessageContent::Text("one".into()),
850 metadata: None,
851 },
852 AgentMessage {
853 role: MessageRole::Agent,
854 content: MessageContent::Text("two".into()),
855 metadata: None,
856 },
857 AgentMessage {
858 role: MessageRole::User,
859 content: MessageContent::Text("three".into()),
860 metadata: None,
861 },
862 ],
863 );
864 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
865 let mut ctx = MiddlewareContext::with_request(&mut request, state);
866 middleware.modify_model_request(&mut ctx).await.unwrap();
867 assert_eq!(ctx.request.messages.len(), 3);
868 match &ctx.request.messages[0].content {
869 MessageContent::Text(text) => assert!(text.contains("Summary note")),
870 other => panic!("expected text, got {other:?}"),
871 }
872 }
873
874 struct StubAgent;
875
876 #[async_trait]
877 impl AgentHandle for StubAgent {
878 async fn describe(&self) -> AgentDescriptor {
879 AgentDescriptor {
880 name: "stub".into(),
881 version: "0.0.1".into(),
882 description: None,
883 }
884 }
885
886 async fn handle_message(
887 &self,
888 _input: AgentMessage,
889 _state: Arc<AgentStateSnapshot>,
890 ) -> anyhow::Result<AgentMessage> {
891 Ok(AgentMessage {
892 role: MessageRole::Agent,
893 content: MessageContent::Text("stub-response".into()),
894 metadata: None,
895 })
896 }
897 }
898
899 #[tokio::test]
900 async fn task_router_reports_unknown_subagent() {
901 let registry = Arc::new(SubAgentRegistry::new(vec![]));
902 let task_tool = TaskRouterTool::new(registry.clone(), None);
903 let state = Arc::new(AgentStateSnapshot::default());
904 let ctx = ToolContext::new(state);
905
906 let response = task_tool
907 .execute(
908 json!({
909 "instruction": "Do something",
910 "agent": "unknown"
911 }),
912 ctx,
913 )
914 .await
915 .unwrap();
916
917 match response {
918 ToolResult::Message(msg) => match msg.content {
919 MessageContent::Text(text) => {
920 assert!(text.contains("Sub-agent 'unknown' not found"))
921 }
922 other => panic!("expected text, got {other:?}"),
923 },
924 _ => panic!("expected message"),
925 }
926 }
927
928 #[tokio::test]
929 async fn subagent_middleware_appends_prompt() {
930 let subagents = vec![SubAgentRegistration {
931 descriptor: SubAgentDescriptor {
932 name: "research-agent".into(),
933 description: "Deep research specialist".into(),
934 },
935 agent: Arc::new(StubAgent),
936 }];
937 let middleware = SubAgentMiddleware::new(subagents);
938
939 let mut request = ModelRequest::new("System", vec![]);
940 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
941 let mut ctx = MiddlewareContext::with_request(&mut request, state);
942 middleware.modify_model_request(&mut ctx).await.unwrap();
943
944 assert!(ctx.request.system_prompt.contains("research-agent"));
945 let tool_names: Vec<_> = middleware
946 .tools()
947 .iter()
948 .map(|t| t.schema().name.clone())
949 .collect();
950 assert!(tool_names.contains(&"task".to_string()));
951 }
952
953 #[tokio::test]
954 async fn task_router_invokes_registered_subagent() {
955 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
956 descriptor: SubAgentDescriptor {
957 name: "stub-agent".into(),
958 description: "Stub".into(),
959 },
960 agent: Arc::new(StubAgent),
961 }]));
962 let task_tool = TaskRouterTool::new(registry.clone(), None);
963 let state = Arc::new(AgentStateSnapshot::default());
964 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
965 let response = task_tool
966 .execute(
967 json!({
968 "description": "do work",
969 "subagent_type": "stub-agent"
970 }),
971 ctx,
972 )
973 .await
974 .unwrap();
975
976 match response {
977 ToolResult::Message(msg) => {
978 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
979 match msg.content {
980 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
981 other => panic!("expected text, got {other:?}"),
982 }
983 }
984 _ => panic!("expected message"),
985 }
986 }
987
988 #[tokio::test]
989 async fn human_in_loop_appends_prompt() {
990 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
991 "danger-tool".into(),
992 HitlPolicy {
993 allow_auto: false,
994 note: Some("Requires security review".into()),
995 },
996 )]));
997 let mut request = ModelRequest::new("System", vec![]);
998 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
999 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1000 middleware.modify_model_request(&mut ctx).await.unwrap();
1001 assert!(ctx
1002 .request
1003 .system_prompt
1004 .contains("danger-tool: Requires security review"));
1005 }
1006
1007 #[tokio::test]
1008 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
1009 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1010 let mut request = ModelRequest::new(
1011 "This is the system prompt",
1012 vec![AgentMessage {
1013 role: MessageRole::User,
1014 content: MessageContent::Text("Hello".into()),
1015 metadata: None,
1016 }],
1017 );
1018 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1019 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1020
1021 middleware.modify_model_request(&mut ctx).await.unwrap();
1023
1024 assert!(ctx.request.system_prompt.is_empty());
1026
1027 assert_eq!(ctx.request.messages.len(), 2);
1029
1030 let system_message = &ctx.request.messages[0];
1031 assert!(matches!(system_message.role, MessageRole::System));
1032 assert_eq!(
1033 system_message.content.as_text().unwrap(),
1034 "This is the system prompt"
1035 );
1036
1037 let metadata = system_message.metadata.as_ref().unwrap();
1039 let cache_control = metadata.cache_control.as_ref().unwrap();
1040 assert_eq!(cache_control.cache_type, "ephemeral");
1041
1042 let user_message = &ctx.request.messages[1];
1044 assert!(matches!(user_message.role, MessageRole::User));
1045 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
1046 }
1047
1048 #[tokio::test]
1049 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1050 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1051 let mut request = ModelRequest::new("This is the system prompt", vec![]);
1052 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1053 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1054
1055 middleware.modify_model_request(&mut ctx).await.unwrap();
1057
1058 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1060 assert_eq!(ctx.request.messages.len(), 0);
1061 }
1062
1063 #[tokio::test]
1064 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1065 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1066 let mut request = ModelRequest::new(
1067 "",
1068 vec![AgentMessage {
1069 role: MessageRole::User,
1070 content: MessageContent::Text("Hello".into()),
1071 metadata: None,
1072 }],
1073 );
1074 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1075 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1076
1077 middleware.modify_model_request(&mut ctx).await.unwrap();
1079
1080 assert!(ctx.request.system_prompt.is_empty());
1082 assert_eq!(ctx.request.messages.len(), 1);
1084 }
1085
1086 #[tokio::test]
1089 async fn hitl_creates_interrupt_for_disallowed_tool() {
1090 let mut policies = HashMap::new();
1091 policies.insert(
1092 "dangerous_tool".to_string(),
1093 HitlPolicy {
1094 allow_auto: false,
1095 note: Some("Requires security review".to_string()),
1096 },
1097 );
1098
1099 let middleware = HumanInLoopMiddleware::new(policies);
1100 let tool_args = json!({"action": "delete_all"});
1101
1102 let result = middleware
1103 .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1104 .await
1105 .unwrap();
1106
1107 assert!(result.is_some());
1108 let interrupt = result.unwrap();
1109
1110 match interrupt {
1111 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1112 assert_eq!(hitl.tool_name, "dangerous_tool");
1113 assert_eq!(hitl.tool_args, tool_args);
1114 assert_eq!(hitl.call_id, "call_123");
1115 assert_eq!(
1116 hitl.policy_note,
1117 Some("Requires security review".to_string())
1118 );
1119 }
1120 }
1121 }
1122
1123 #[tokio::test]
1124 async fn hitl_no_interrupt_for_allowed_tool() {
1125 let mut policies = HashMap::new();
1126 policies.insert(
1127 "safe_tool".to_string(),
1128 HitlPolicy {
1129 allow_auto: true,
1130 note: None,
1131 },
1132 );
1133
1134 let middleware = HumanInLoopMiddleware::new(policies);
1135 let tool_args = json!({"action": "read"});
1136
1137 let result = middleware
1138 .before_tool_execution("safe_tool", &tool_args, "call_456")
1139 .await
1140 .unwrap();
1141
1142 assert!(result.is_none());
1143 }
1144
1145 #[tokio::test]
1146 async fn hitl_no_interrupt_for_unlisted_tool() {
1147 let policies = HashMap::new();
1148 let middleware = HumanInLoopMiddleware::new(policies);
1149 let tool_args = json!({"action": "anything"});
1150
1151 let result = middleware
1152 .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1153 .await
1154 .unwrap();
1155
1156 assert!(result.is_none());
1157 }
1158
1159 #[tokio::test]
1160 async fn hitl_interrupt_includes_correct_details() {
1161 let mut policies = HashMap::new();
1162 policies.insert(
1163 "critical_tool".to_string(),
1164 HitlPolicy {
1165 allow_auto: false,
1166 note: Some("Critical operation - requires approval".to_string()),
1167 },
1168 );
1169
1170 let middleware = HumanInLoopMiddleware::new(policies);
1171 let tool_args = json!({
1172 "database": "production",
1173 "operation": "drop_table"
1174 });
1175
1176 let result = middleware
1177 .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1178 .await
1179 .unwrap();
1180
1181 assert!(result.is_some());
1182 let interrupt = result.unwrap();
1183
1184 match interrupt {
1185 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1186 assert_eq!(hitl.tool_name, "critical_tool");
1187 assert_eq!(hitl.tool_args["database"], "production");
1188 assert_eq!(hitl.tool_args["operation"], "drop_table");
1189 assert_eq!(hitl.call_id, "call_critical_1");
1190 assert!(hitl.policy_note.is_some());
1191 assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1192 }
1195 }
1196 }
1197
1198 #[tokio::test]
1199 async fn hitl_interrupt_without_policy_note() {
1200 let mut policies = HashMap::new();
1201 policies.insert(
1202 "tool_no_note".to_string(),
1203 HitlPolicy {
1204 allow_auto: false,
1205 note: None,
1206 },
1207 );
1208
1209 let middleware = HumanInLoopMiddleware::new(policies);
1210 let tool_args = json!({"param": "value"});
1211
1212 let result = middleware
1213 .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1214 .await
1215 .unwrap();
1216
1217 assert!(result.is_some());
1218 let interrupt = result.unwrap();
1219
1220 match interrupt {
1221 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1222 assert_eq!(hitl.tool_name, "tool_no_note");
1223 assert_eq!(hitl.policy_note, None);
1224 }
1225 }
1226 }
1227}