1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18#[derive(Debug, Clone)]
21pub struct ModelRequest {
22 pub system_prompt: String,
23 pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28 Self {
29 system_prompt: system_prompt.into(),
30 messages,
31 }
32 }
33
34 pub fn append_prompt(&mut self, fragment: &str) {
35 if !fragment.is_empty() {
36 self.system_prompt.push_str("\n\n");
37 self.system_prompt.push_str(fragment);
38 }
39 }
40}
41
42pub struct MiddlewareContext<'a> {
44 pub request: &'a mut ModelRequest,
45 pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49 pub fn with_request(
50 request: &'a mut ModelRequest,
51 state: Arc<RwLock<AgentStateSnapshot>>,
52 ) -> Self {
53 Self { request, state }
54 }
55}
56
57#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62 fn id(&self) -> &'static str;
64
65 fn tools(&self) -> Vec<ToolBox> {
67 Vec::new()
68 }
69
70 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72
73 async fn before_tool_execution(
89 &self,
90 _tool_name: &str,
91 _tool_args: &serde_json::Value,
92 _call_id: &str,
93 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
94 Ok(None)
95 }
96}
97
98pub struct SummarizationMiddleware {
99 pub messages_to_keep: usize,
100 pub summary_note: String,
101}
102
103impl SummarizationMiddleware {
104 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
105 Self {
106 messages_to_keep,
107 summary_note: summary_note.into(),
108 }
109 }
110}
111
112#[async_trait]
113impl AgentMiddleware for SummarizationMiddleware {
114 fn id(&self) -> &'static str {
115 "summarization"
116 }
117
118 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
119 if ctx.request.messages.len() > self.messages_to_keep {
120 let dropped = ctx.request.messages.len() - self.messages_to_keep;
121 let mut truncated = ctx
122 .request
123 .messages
124 .split_off(ctx.request.messages.len() - self.messages_to_keep);
125 truncated.insert(
126 0,
127 AgentMessage {
128 role: MessageRole::System,
129 content: MessageContent::Text(format!(
130 "{} ({} earlier messages summarized)",
131 self.summary_note, dropped
132 )),
133 metadata: None,
134 },
135 );
136 ctx.request.messages = truncated;
137 }
138 Ok(())
139 }
140}
141
142pub struct PlanningMiddleware {
143 _state: Arc<RwLock<AgentStateSnapshot>>,
144}
145
146impl PlanningMiddleware {
147 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
148 Self { _state: state }
149 }
150}
151
152#[async_trait]
153impl AgentMiddleware for PlanningMiddleware {
154 fn id(&self) -> &'static str {
155 "planning"
156 }
157
158 fn tools(&self) -> Vec<ToolBox> {
159 use agents_toolkit::create_todos_tools;
160 create_todos_tools()
161 }
162
163 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
164 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
165 Ok(())
166 }
167}
168
169pub struct FilesystemMiddleware {
170 _state: Arc<RwLock<AgentStateSnapshot>>,
171}
172
173impl FilesystemMiddleware {
174 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
175 Self { _state: state }
176 }
177}
178
179#[async_trait]
180impl AgentMiddleware for FilesystemMiddleware {
181 fn id(&self) -> &'static str {
182 "filesystem"
183 }
184
185 fn tools(&self) -> Vec<ToolBox> {
186 create_filesystem_tools()
187 }
188
189 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
190 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
191 Ok(())
192 }
193}
194
195#[derive(Clone)]
196pub struct SubAgentRegistration {
197 pub descriptor: SubAgentDescriptor,
198 pub agent: Arc<dyn AgentHandle>,
199}
200
201struct SubAgentRegistry {
202 agents: HashMap<String, Arc<dyn AgentHandle>>,
203}
204
205impl SubAgentRegistry {
206 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
207 let mut agents = HashMap::new();
208 for reg in registrations {
209 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
210 }
211 Self { agents }
212 }
213
214 fn available_names(&self) -> Vec<String> {
215 self.agents.keys().cloned().collect()
216 }
217
218 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
219 self.agents.get(name).cloned()
220 }
221}
222
223pub struct SubAgentMiddleware {
224 task_tool: ToolBox,
225 descriptors: Vec<SubAgentDescriptor>,
226 _registry: Arc<SubAgentRegistry>,
227}
228
229impl SubAgentMiddleware {
230 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
231 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
232 let registry = Arc::new(SubAgentRegistry::new(registrations));
233 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
234 Self {
235 task_tool,
236 descriptors,
237 _registry: registry,
238 }
239 }
240
241 pub fn new_with_events(
242 registrations: Vec<SubAgentRegistration>,
243 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
244 ) -> Self {
245 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
246 let registry = Arc::new(SubAgentRegistry::new(registrations));
247 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
248 Self {
249 task_tool,
250 descriptors,
251 _registry: registry,
252 }
253 }
254
255 fn prompt_fragment(&self) -> String {
256 let descriptions: Vec<String> = if self.descriptors.is_empty() {
257 vec![String::from("- general-purpose: Default reasoning agent")]
258 } else {
259 self.descriptors
260 .iter()
261 .map(|agent| format!("- {}: {}", agent.name, agent.description))
262 .collect()
263 };
264
265 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
266 }
267}
268
269#[async_trait]
270impl AgentMiddleware for SubAgentMiddleware {
271 fn id(&self) -> &'static str {
272 "subagent"
273 }
274
275 fn tools(&self) -> Vec<ToolBox> {
276 vec![self.task_tool.clone()]
277 }
278
279 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
280 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
281 ctx.request.append_prompt(&self.prompt_fragment());
282 Ok(())
283 }
284}
285
286#[derive(Clone, Debug)]
287pub struct HitlPolicy {
288 pub allow_auto: bool,
289 pub note: Option<String>,
290}
291
292pub struct HumanInLoopMiddleware {
293 policies: HashMap<String, HitlPolicy>,
294}
295
296impl HumanInLoopMiddleware {
297 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
298 Self { policies }
299 }
300
301 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
302 self.policies
303 .get(tool_name)
304 .filter(|policy| !policy.allow_auto)
305 }
306
307 fn prompt_fragment(&self) -> Option<String> {
308 let pending: Vec<String> = self
309 .policies
310 .iter()
311 .filter(|(_, policy)| !policy.allow_auto)
312 .map(|(tool, policy)| match &policy.note {
313 Some(note) => format!("- {tool}: {note}"),
314 None => format!("- {tool}: Requires approval"),
315 })
316 .collect();
317 if pending.is_empty() {
318 None
319 } else {
320 Some(format!(
321 "The following tools require human approval before execution:\n{}",
322 pending.join("\n")
323 ))
324 }
325 }
326}
327
328#[async_trait]
329impl AgentMiddleware for HumanInLoopMiddleware {
330 fn id(&self) -> &'static str {
331 "human-in-loop"
332 }
333
334 async fn before_tool_execution(
335 &self,
336 tool_name: &str,
337 tool_args: &serde_json::Value,
338 call_id: &str,
339 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
340 if let Some(policy) = self.requires_approval(tool_name) {
341 tracing::warn!(
342 tool_name = %tool_name,
343 call_id = %call_id,
344 policy_note = ?policy.note,
345 "🔒 HITL: Tool execution requires human approval"
346 );
347
348 let interrupt = agents_core::hitl::HitlInterrupt::new(
349 tool_name,
350 tool_args.clone(),
351 call_id,
352 policy.note.clone(),
353 );
354
355 return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
356 interrupt,
357 )));
358 }
359
360 Ok(None)
361 }
362
363 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
364 if let Some(fragment) = self.prompt_fragment() {
365 ctx.request.append_prompt(&fragment);
366 }
367 ctx.request.messages.push(AgentMessage {
368 role: MessageRole::System,
369 content: MessageContent::Text(
370 "Tools marked for human approval will emit interrupts requiring external resolution."
371 .into(),
372 ),
373 metadata: None,
374 });
375 Ok(())
376 }
377}
378
379pub struct BaseSystemPromptMiddleware;
380
381#[async_trait]
382impl AgentMiddleware for BaseSystemPromptMiddleware {
383 fn id(&self) -> &'static str {
384 "base-system-prompt"
385 }
386
387 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
388 ctx.request.append_prompt(BASE_AGENT_PROMPT);
389 Ok(())
390 }
391}
392
393pub struct DeepAgentPromptMiddleware {
403 custom_instructions: String,
404}
405
406impl DeepAgentPromptMiddleware {
407 pub fn new(custom_instructions: impl Into<String>) -> Self {
408 Self {
409 custom_instructions: custom_instructions.into(),
410 }
411 }
412}
413
414#[async_trait]
415impl AgentMiddleware for DeepAgentPromptMiddleware {
416 fn id(&self) -> &'static str {
417 "deep-agent-prompt"
418 }
419
420 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
421 use crate::prompts::get_deep_agent_system_prompt;
422 let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
423 ctx.request.append_prompt(&deep_prompt);
424 Ok(())
425 }
426}
427
428pub struct AnthropicPromptCachingMiddleware {
431 pub ttl: String,
432 pub unsupported_model_behavior: String,
433}
434
435impl AnthropicPromptCachingMiddleware {
436 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
437 Self {
438 ttl: ttl.into(),
439 unsupported_model_behavior: unsupported_model_behavior.into(),
440 }
441 }
442
443 pub fn with_defaults() -> Self {
444 Self::new("5m", "ignore")
445 }
446
447 fn should_enable_caching(&self) -> bool {
450 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
451 }
452}
453
454#[async_trait]
455impl AgentMiddleware for AnthropicPromptCachingMiddleware {
456 fn id(&self) -> &'static str {
457 "anthropic-prompt-caching"
458 }
459
460 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
461 if !self.should_enable_caching() {
462 return Ok(());
463 }
464
465 if !ctx.request.system_prompt.is_empty() {
467 let system_message = AgentMessage {
468 role: MessageRole::System,
469 content: MessageContent::Text(ctx.request.system_prompt.clone()),
470 metadata: Some(MessageMetadata {
471 tool_call_id: None,
472 cache_control: Some(CacheControl {
473 cache_type: "ephemeral".to_string(),
474 }),
475 }),
476 };
477
478 ctx.request.messages.insert(0, system_message);
480
481 ctx.request.system_prompt.clear();
483
484 tracing::debug!(
485 ttl = %self.ttl,
486 behavior = %self.unsupported_model_behavior,
487 "Applied Anthropic prompt caching to system message"
488 );
489 }
490
491 Ok(())
492 }
493}
494
495pub struct TaskRouterTool {
496 registry: Arc<SubAgentRegistry>,
497 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
498 delegation_depth: Arc<RwLock<u32>>,
499}
500
501impl TaskRouterTool {
502 fn new(
503 registry: Arc<SubAgentRegistry>,
504 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
505 ) -> Self {
506 Self {
507 registry,
508 event_dispatcher,
509 delegation_depth: Arc::new(RwLock::new(0)),
510 }
511 }
512
513 fn available_subagents(&self) -> Vec<String> {
514 self.registry.available_names()
515 }
516
517 fn emit_event(&self, event: agents_core::events::AgentEvent) {
518 if let Some(dispatcher) = &self.event_dispatcher {
519 let dispatcher_clone = dispatcher.clone();
520 tokio::spawn(async move {
521 dispatcher_clone.dispatch(event).await;
522 });
523 }
524 }
525
526 fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
527 agents_core::events::EventMetadata::new(
528 "default".to_string(),
529 uuid::Uuid::new_v4().to_string(),
530 None,
531 )
532 }
533
534 fn get_delegation_depth(&self) -> u32 {
535 *self.delegation_depth.read().unwrap_or_else(|_| {
536 tracing::warn!("Failed to read delegation depth, defaulting to 0");
537 panic!("RwLock poisoned")
538 })
539 }
540
541 fn increment_delegation_depth(&self) {
542 if let Ok(mut depth) = self.delegation_depth.write() {
543 *depth += 1;
544 }
545 }
546
547 fn decrement_delegation_depth(&self) {
548 if let Ok(mut depth) = self.delegation_depth.write() {
549 if *depth > 0 {
550 *depth -= 1;
551 }
552 }
553 }
554}
555
556#[derive(Debug, Clone, Deserialize)]
557struct TaskInvocationArgs {
558 #[serde(alias = "description")]
559 instruction: String,
560 #[serde(alias = "subagent_type")]
561 agent: String,
562}
563
564#[async_trait]
565impl Tool for TaskRouterTool {
566 fn schema(&self) -> agents_core::tools::ToolSchema {
567 use agents_core::tools::{ToolParameterSchema, ToolSchema};
568 use std::collections::HashMap;
569
570 let mut properties = HashMap::new();
571 properties.insert(
572 "agent".to_string(),
573 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
574 );
575 properties.insert(
576 "instruction".to_string(),
577 ToolParameterSchema::string("Clear instruction for the sub-agent"),
578 );
579
580 ToolSchema::new(
581 "task",
582 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
583 ToolParameterSchema::object(
584 "Task delegation parameters",
585 properties,
586 vec!["agent".to_string(), "instruction".to_string()],
587 ),
588 )
589 }
590
591 async fn execute(
592 &self,
593 args: serde_json::Value,
594 ctx: ToolContext,
595 ) -> anyhow::Result<ToolResult> {
596 let args: TaskInvocationArgs = serde_json::from_value(args)?;
597 let available = self.available_subagents();
598
599 if let Some(agent) = self.registry.get(&args.agent) {
600 self.increment_delegation_depth();
602 let current_depth = self.get_delegation_depth();
603
604 let instruction_summary = if args.instruction.len() > 100 {
606 format!("{}...", &args.instruction[..100])
607 } else {
608 args.instruction.clone()
609 };
610
611 self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
613 agents_core::events::SubAgentStartedEvent {
614 metadata: self.create_event_metadata(),
615 agent_name: args.agent.clone(),
616 instruction_summary: instruction_summary.clone(),
617 delegation_depth: current_depth,
618 },
619 ));
620
621 tracing::warn!(
623 "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
624 args.agent,
625 current_depth,
626 args.instruction
627 );
628
629 let start_time = std::time::Instant::now();
630 let user_message = AgentMessage {
631 role: MessageRole::User,
632 content: MessageContent::Text(args.instruction.clone()),
633 metadata: None,
634 };
635
636 let response = agent
637 .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
638 .await?;
639
640 let duration = start_time.elapsed();
642 let duration_ms = duration.as_millis() as u64;
643
644 let response_preview = match &response.content {
646 MessageContent::Text(t) => {
647 if t.len() > 100 {
648 format!("{}...", &t[..100])
649 } else {
650 t.clone()
651 }
652 }
653 MessageContent::Json(v) => {
654 let json_str = v.to_string();
655 if json_str.len() > 100 {
656 format!("{}...", &json_str[..100])
657 } else {
658 json_str
659 }
660 }
661 };
662
663 self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
665 agents_core::events::SubAgentCompletedEvent {
666 metadata: self.create_event_metadata(),
667 agent_name: args.agent.clone(),
668 duration_ms,
669 result_summary: response_preview.clone(),
670 },
671 ));
672
673 tracing::warn!(
675 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
676 args.agent,
677 duration,
678 response_preview
679 );
680
681 self.decrement_delegation_depth();
683
684 let result_text = match response.content {
687 MessageContent::Text(text) => text,
688 MessageContent::Json(json) => json.to_string(),
689 };
690
691 return Ok(ToolResult::text(&ctx, result_text));
692 }
693
694 tracing::error!(
695 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
696 args.agent,
697 available
698 );
699
700 Ok(ToolResult::text(
701 &ctx,
702 format!(
703 "Sub-agent '{}' not found. Available sub-agents: {}",
704 args.agent,
705 available.join(", ")
706 ),
707 ))
708 }
709}
710
711#[derive(Debug, Clone)]
712pub struct SubAgentDescriptor {
713 pub name: String,
714 pub description: String,
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use agents_core::agent::{AgentDescriptor, AgentHandle};
721 use agents_core::messaging::{MessageContent, MessageRole};
722 use serde_json::json;
723
724 struct AppendPromptMiddleware;
725
726 #[async_trait]
727 impl AgentMiddleware for AppendPromptMiddleware {
728 fn id(&self) -> &'static str {
729 "append-prompt"
730 }
731
732 async fn modify_model_request(
733 &self,
734 ctx: &mut MiddlewareContext<'_>,
735 ) -> anyhow::Result<()> {
736 ctx.request.system_prompt.push_str("\nExtra directives.");
737 Ok(())
738 }
739 }
740
741 #[tokio::test]
742 async fn middleware_mutates_prompt() {
743 let mut request = ModelRequest::new(
744 "System",
745 vec![AgentMessage {
746 role: MessageRole::User,
747 content: MessageContent::Text("Hi".into()),
748 metadata: None,
749 }],
750 );
751 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
752 let mut ctx = MiddlewareContext::with_request(&mut request, state);
753 let middleware = AppendPromptMiddleware;
754 middleware.modify_model_request(&mut ctx).await.unwrap();
755 assert!(ctx.request.system_prompt.contains("Extra directives"));
756 }
757
758 #[tokio::test]
759 async fn planning_middleware_registers_write_todos() {
760 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
761 let middleware = PlanningMiddleware::new(state);
762 let tool_names: Vec<_> = middleware
763 .tools()
764 .iter()
765 .map(|t| t.schema().name.clone())
766 .collect();
767 assert!(tool_names.contains(&"write_todos".to_string()));
768
769 let mut request = ModelRequest::new("System", vec![]);
770 let mut ctx = MiddlewareContext::with_request(
771 &mut request,
772 Arc::new(RwLock::new(AgentStateSnapshot::default())),
773 );
774 middleware.modify_model_request(&mut ctx).await.unwrap();
775 assert!(ctx.request.system_prompt.contains("write_todos"));
776 }
777
778 #[tokio::test]
779 async fn filesystem_middleware_registers_tools() {
780 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
781 let middleware = FilesystemMiddleware::new(state);
782 let tool_names: Vec<_> = middleware
783 .tools()
784 .iter()
785 .map(|t| t.schema().name.clone())
786 .collect();
787 for expected in ["ls", "read_file", "write_file", "edit_file"] {
788 assert!(tool_names.contains(&expected.to_string()));
789 }
790 }
791
792 #[tokio::test]
793 async fn summarization_middleware_trims_messages() {
794 let middleware = SummarizationMiddleware::new(2, "Summary note");
795 let mut request = ModelRequest::new(
796 "System",
797 vec![
798 AgentMessage {
799 role: MessageRole::User,
800 content: MessageContent::Text("one".into()),
801 metadata: None,
802 },
803 AgentMessage {
804 role: MessageRole::Agent,
805 content: MessageContent::Text("two".into()),
806 metadata: None,
807 },
808 AgentMessage {
809 role: MessageRole::User,
810 content: MessageContent::Text("three".into()),
811 metadata: None,
812 },
813 ],
814 );
815 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
816 let mut ctx = MiddlewareContext::with_request(&mut request, state);
817 middleware.modify_model_request(&mut ctx).await.unwrap();
818 assert_eq!(ctx.request.messages.len(), 3);
819 match &ctx.request.messages[0].content {
820 MessageContent::Text(text) => assert!(text.contains("Summary note")),
821 other => panic!("expected text, got {other:?}"),
822 }
823 }
824
825 struct StubAgent;
826
827 #[async_trait]
828 impl AgentHandle for StubAgent {
829 async fn describe(&self) -> AgentDescriptor {
830 AgentDescriptor {
831 name: "stub".into(),
832 version: "0.0.1".into(),
833 description: None,
834 }
835 }
836
837 async fn handle_message(
838 &self,
839 _input: AgentMessage,
840 _state: Arc<AgentStateSnapshot>,
841 ) -> anyhow::Result<AgentMessage> {
842 Ok(AgentMessage {
843 role: MessageRole::Agent,
844 content: MessageContent::Text("stub-response".into()),
845 metadata: None,
846 })
847 }
848 }
849
850 #[tokio::test]
851 async fn task_router_reports_unknown_subagent() {
852 let registry = Arc::new(SubAgentRegistry::new(vec![]));
853 let task_tool = TaskRouterTool::new(registry.clone(), None);
854 let state = Arc::new(AgentStateSnapshot::default());
855 let ctx = ToolContext::new(state);
856
857 let response = task_tool
858 .execute(
859 json!({
860 "instruction": "Do something",
861 "agent": "unknown"
862 }),
863 ctx,
864 )
865 .await
866 .unwrap();
867
868 match response {
869 ToolResult::Message(msg) => match msg.content {
870 MessageContent::Text(text) => {
871 assert!(text.contains("Sub-agent 'unknown' not found"))
872 }
873 other => panic!("expected text, got {other:?}"),
874 },
875 _ => panic!("expected message"),
876 }
877 }
878
879 #[tokio::test]
880 async fn subagent_middleware_appends_prompt() {
881 let subagents = vec![SubAgentRegistration {
882 descriptor: SubAgentDescriptor {
883 name: "research-agent".into(),
884 description: "Deep research specialist".into(),
885 },
886 agent: Arc::new(StubAgent),
887 }];
888 let middleware = SubAgentMiddleware::new(subagents);
889
890 let mut request = ModelRequest::new("System", vec![]);
891 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
892 let mut ctx = MiddlewareContext::with_request(&mut request, state);
893 middleware.modify_model_request(&mut ctx).await.unwrap();
894
895 assert!(ctx.request.system_prompt.contains("research-agent"));
896 let tool_names: Vec<_> = middleware
897 .tools()
898 .iter()
899 .map(|t| t.schema().name.clone())
900 .collect();
901 assert!(tool_names.contains(&"task".to_string()));
902 }
903
904 #[tokio::test]
905 async fn task_router_invokes_registered_subagent() {
906 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
907 descriptor: SubAgentDescriptor {
908 name: "stub-agent".into(),
909 description: "Stub".into(),
910 },
911 agent: Arc::new(StubAgent),
912 }]));
913 let task_tool = TaskRouterTool::new(registry.clone(), None);
914 let state = Arc::new(AgentStateSnapshot::default());
915 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
916 let response = task_tool
917 .execute(
918 json!({
919 "description": "do work",
920 "subagent_type": "stub-agent"
921 }),
922 ctx,
923 )
924 .await
925 .unwrap();
926
927 match response {
928 ToolResult::Message(msg) => {
929 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
930 match msg.content {
931 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
932 other => panic!("expected text, got {other:?}"),
933 }
934 }
935 _ => panic!("expected message"),
936 }
937 }
938
939 #[tokio::test]
940 async fn human_in_loop_appends_prompt() {
941 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
942 "danger-tool".into(),
943 HitlPolicy {
944 allow_auto: false,
945 note: Some("Requires security review".into()),
946 },
947 )]));
948 let mut request = ModelRequest::new("System", vec![]);
949 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
950 let mut ctx = MiddlewareContext::with_request(&mut request, state);
951 middleware.modify_model_request(&mut ctx).await.unwrap();
952 assert!(ctx
953 .request
954 .system_prompt
955 .contains("danger-tool: Requires security review"));
956 }
957
958 #[tokio::test]
959 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
960 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
961 let mut request = ModelRequest::new(
962 "This is the system prompt",
963 vec![AgentMessage {
964 role: MessageRole::User,
965 content: MessageContent::Text("Hello".into()),
966 metadata: None,
967 }],
968 );
969 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
970 let mut ctx = MiddlewareContext::with_request(&mut request, state);
971
972 middleware.modify_model_request(&mut ctx).await.unwrap();
974
975 assert!(ctx.request.system_prompt.is_empty());
977
978 assert_eq!(ctx.request.messages.len(), 2);
980
981 let system_message = &ctx.request.messages[0];
982 assert!(matches!(system_message.role, MessageRole::System));
983 assert_eq!(
984 system_message.content.as_text().unwrap(),
985 "This is the system prompt"
986 );
987
988 let metadata = system_message.metadata.as_ref().unwrap();
990 let cache_control = metadata.cache_control.as_ref().unwrap();
991 assert_eq!(cache_control.cache_type, "ephemeral");
992
993 let user_message = &ctx.request.messages[1];
995 assert!(matches!(user_message.role, MessageRole::User));
996 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
997 }
998
999 #[tokio::test]
1000 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1001 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1002 let mut request = ModelRequest::new("This is the system prompt", vec![]);
1003 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1004 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1005
1006 middleware.modify_model_request(&mut ctx).await.unwrap();
1008
1009 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1011 assert_eq!(ctx.request.messages.len(), 0);
1012 }
1013
1014 #[tokio::test]
1015 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1016 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1017 let mut request = ModelRequest::new(
1018 "",
1019 vec![AgentMessage {
1020 role: MessageRole::User,
1021 content: MessageContent::Text("Hello".into()),
1022 metadata: None,
1023 }],
1024 );
1025 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1026 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1027
1028 middleware.modify_model_request(&mut ctx).await.unwrap();
1030
1031 assert!(ctx.request.system_prompt.is_empty());
1033 assert_eq!(ctx.request.messages.len(), 1);
1035 }
1036
1037 #[tokio::test]
1040 async fn hitl_creates_interrupt_for_disallowed_tool() {
1041 let mut policies = HashMap::new();
1042 policies.insert(
1043 "dangerous_tool".to_string(),
1044 HitlPolicy {
1045 allow_auto: false,
1046 note: Some("Requires security review".to_string()),
1047 },
1048 );
1049
1050 let middleware = HumanInLoopMiddleware::new(policies);
1051 let tool_args = json!({"action": "delete_all"});
1052
1053 let result = middleware
1054 .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1055 .await
1056 .unwrap();
1057
1058 assert!(result.is_some());
1059 let interrupt = result.unwrap();
1060
1061 match interrupt {
1062 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1063 assert_eq!(hitl.tool_name, "dangerous_tool");
1064 assert_eq!(hitl.tool_args, tool_args);
1065 assert_eq!(hitl.call_id, "call_123");
1066 assert_eq!(
1067 hitl.policy_note,
1068 Some("Requires security review".to_string())
1069 );
1070 }
1071 }
1072 }
1073
1074 #[tokio::test]
1075 async fn hitl_no_interrupt_for_allowed_tool() {
1076 let mut policies = HashMap::new();
1077 policies.insert(
1078 "safe_tool".to_string(),
1079 HitlPolicy {
1080 allow_auto: true,
1081 note: None,
1082 },
1083 );
1084
1085 let middleware = HumanInLoopMiddleware::new(policies);
1086 let tool_args = json!({"action": "read"});
1087
1088 let result = middleware
1089 .before_tool_execution("safe_tool", &tool_args, "call_456")
1090 .await
1091 .unwrap();
1092
1093 assert!(result.is_none());
1094 }
1095
1096 #[tokio::test]
1097 async fn hitl_no_interrupt_for_unlisted_tool() {
1098 let policies = HashMap::new();
1099 let middleware = HumanInLoopMiddleware::new(policies);
1100 let tool_args = json!({"action": "anything"});
1101
1102 let result = middleware
1103 .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1104 .await
1105 .unwrap();
1106
1107 assert!(result.is_none());
1108 }
1109
1110 #[tokio::test]
1111 async fn hitl_interrupt_includes_correct_details() {
1112 let mut policies = HashMap::new();
1113 policies.insert(
1114 "critical_tool".to_string(),
1115 HitlPolicy {
1116 allow_auto: false,
1117 note: Some("Critical operation - requires approval".to_string()),
1118 },
1119 );
1120
1121 let middleware = HumanInLoopMiddleware::new(policies);
1122 let tool_args = json!({
1123 "database": "production",
1124 "operation": "drop_table"
1125 });
1126
1127 let result = middleware
1128 .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1129 .await
1130 .unwrap();
1131
1132 assert!(result.is_some());
1133 let interrupt = result.unwrap();
1134
1135 match interrupt {
1136 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1137 assert_eq!(hitl.tool_name, "critical_tool");
1138 assert_eq!(hitl.tool_args["database"], "production");
1139 assert_eq!(hitl.tool_args["operation"], "drop_table");
1140 assert_eq!(hitl.call_id, "call_critical_1");
1141 assert!(hitl.policy_note.is_some());
1142 assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1143 }
1146 }
1147 }
1148
1149 #[tokio::test]
1150 async fn hitl_interrupt_without_policy_note() {
1151 let mut policies = HashMap::new();
1152 policies.insert(
1153 "tool_no_note".to_string(),
1154 HitlPolicy {
1155 allow_auto: false,
1156 note: None,
1157 },
1158 );
1159
1160 let middleware = HumanInLoopMiddleware::new(policies);
1161 let tool_args = json!({"param": "value"});
1162
1163 let result = middleware
1164 .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1165 .await
1166 .unwrap();
1167
1168 assert!(result.is_some());
1169 let interrupt = result.unwrap();
1170
1171 match interrupt {
1172 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1173 assert_eq!(hitl.tool_name, "tool_no_note");
1174 assert_eq!(hitl.policy_note, None);
1175 }
1176 }
1177 }
1178}