1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20#[derive(Debug, Clone)]
23pub struct ModelRequest {
24 pub system_prompt: String,
25 pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30 Self {
31 system_prompt: system_prompt.into(),
32 messages,
33 }
34 }
35
36 pub fn append_prompt(&mut self, fragment: &str) {
37 if !fragment.is_empty() {
38 self.system_prompt.push_str("\n\n");
39 self.system_prompt.push_str(fragment);
40 }
41 }
42}
43
44pub struct MiddlewareContext<'a> {
46 pub request: &'a mut ModelRequest,
47 pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51 pub fn with_request(
52 request: &'a mut ModelRequest,
53 state: Arc<RwLock<AgentStateSnapshot>>,
54 ) -> Self {
55 Self { request, state }
56 }
57}
58
59#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64 fn id(&self) -> &'static str;
66
67 fn tools(&self) -> Vec<ToolBox> {
69 Vec::new()
70 }
71
72 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75 async fn before_tool_execution(
91 &self,
92 _tool_name: &str,
93 _tool_args: &serde_json::Value,
94 _call_id: &str,
95 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96 Ok(None)
97 }
98}
99
100pub struct SummarizationMiddleware {
101 pub messages_to_keep: usize,
102 pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107 Self {
108 messages_to_keep,
109 summary_note: summary_note.into(),
110 }
111 }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116 fn id(&self) -> &'static str {
117 "summarization"
118 }
119
120 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121 if ctx.request.messages.len() > self.messages_to_keep {
122 let dropped = ctx.request.messages.len() - self.messages_to_keep;
123 let mut truncated = ctx
124 .request
125 .messages
126 .split_off(ctx.request.messages.len() - self.messages_to_keep);
127 truncated.insert(
128 0,
129 AgentMessage {
130 role: MessageRole::System,
131 content: MessageContent::Text(format!(
132 "{} ({} earlier messages summarized)",
133 self.summary_note, dropped
134 )),
135 metadata: None,
136 },
137 );
138 ctx.request.messages = truncated;
139 }
140 Ok(())
141 }
142}
143
144pub struct PlanningMiddleware {
145 _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150 Self { _state: state }
151 }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156 fn id(&self) -> &'static str {
157 "planning"
158 }
159
160 fn tools(&self) -> Vec<ToolBox> {
161 use agents_toolkit::create_todos_tools;
162 create_todos_tools()
163 }
164
165 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
167 Ok(())
168 }
169}
170
171pub struct FilesystemMiddleware {
172 _state: Arc<RwLock<AgentStateSnapshot>>,
173}
174
175impl FilesystemMiddleware {
176 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
177 Self { _state: state }
178 }
179}
180
181#[async_trait]
182impl AgentMiddleware for FilesystemMiddleware {
183 fn id(&self) -> &'static str {
184 "filesystem"
185 }
186
187 fn tools(&self) -> Vec<ToolBox> {
188 create_filesystem_tools()
189 }
190
191 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
192 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
193 Ok(())
194 }
195}
196
197#[derive(Clone)]
198pub struct SubAgentRegistration {
199 pub descriptor: SubAgentDescriptor,
200 pub agent: Arc<dyn AgentHandle>,
201}
202
203struct SubAgentRegistry {
204 agents: HashMap<String, Arc<dyn AgentHandle>>,
205}
206
207impl SubAgentRegistry {
208 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
209 let mut agents = HashMap::new();
210 for reg in registrations {
211 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
212 }
213 Self { agents }
214 }
215
216 fn available_names(&self) -> Vec<String> {
217 self.agents.keys().cloned().collect()
218 }
219
220 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
221 self.agents.get(name).cloned()
222 }
223}
224
225pub struct SubAgentMiddleware {
226 task_tool: ToolBox,
227 descriptors: Vec<SubAgentDescriptor>,
228 _registry: Arc<SubAgentRegistry>,
229}
230
231impl SubAgentMiddleware {
232 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
233 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
234 let registry = Arc::new(SubAgentRegistry::new(registrations));
235 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
236 Self {
237 task_tool,
238 descriptors,
239 _registry: registry,
240 }
241 }
242
243 pub fn new_with_events(
244 registrations: Vec<SubAgentRegistration>,
245 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
246 ) -> Self {
247 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
248 let registry = Arc::new(SubAgentRegistry::new(registrations));
249 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
250 Self {
251 task_tool,
252 descriptors,
253 _registry: registry,
254 }
255 }
256
257 fn prompt_fragment(&self) -> String {
258 let descriptions: Vec<String> = if self.descriptors.is_empty() {
259 vec![String::from("- general-purpose: Default reasoning agent")]
260 } else {
261 self.descriptors
262 .iter()
263 .map(|agent| format!("- {}: {}", agent.name, agent.description))
264 .collect()
265 };
266
267 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
268 }
269}
270
271#[async_trait]
272impl AgentMiddleware for SubAgentMiddleware {
273 fn id(&self) -> &'static str {
274 "subagent"
275 }
276
277 fn tools(&self) -> Vec<ToolBox> {
278 vec![self.task_tool.clone()]
279 }
280
281 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
282 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
283 ctx.request.append_prompt(&self.prompt_fragment());
284 Ok(())
285 }
286}
287
288#[derive(Clone, Debug)]
289pub struct HitlPolicy {
290 pub allow_auto: bool,
291 pub note: Option<String>,
292}
293
294pub struct HumanInLoopMiddleware {
295 policies: HashMap<String, HitlPolicy>,
296}
297
298impl HumanInLoopMiddleware {
299 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
300 Self { policies }
301 }
302
303 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
304 self.policies
305 .get(tool_name)
306 .filter(|policy| !policy.allow_auto)
307 }
308
309 fn prompt_fragment(&self) -> Option<String> {
310 let pending: Vec<String> = self
311 .policies
312 .iter()
313 .filter(|(_, policy)| !policy.allow_auto)
314 .map(|(tool, policy)| match &policy.note {
315 Some(note) => format!("- {tool}: {note}"),
316 None => format!("- {tool}: Requires approval"),
317 })
318 .collect();
319 if pending.is_empty() {
320 None
321 } else {
322 Some(format!(
323 "The following tools require human approval before execution:\n{}",
324 pending.join("\n")
325 ))
326 }
327 }
328}
329
330#[async_trait]
331impl AgentMiddleware for HumanInLoopMiddleware {
332 fn id(&self) -> &'static str {
333 "human-in-loop"
334 }
335
336 async fn before_tool_execution(
337 &self,
338 tool_name: &str,
339 tool_args: &serde_json::Value,
340 call_id: &str,
341 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
342 if let Some(policy) = self.requires_approval(tool_name) {
343 tracing::warn!(
344 tool_name = %tool_name,
345 call_id = %call_id,
346 policy_note = ?policy.note,
347 "🔒 HITL: Tool execution requires human approval"
348 );
349
350 let interrupt = agents_core::hitl::HitlInterrupt::new(
351 tool_name,
352 tool_args.clone(),
353 call_id,
354 policy.note.clone(),
355 );
356
357 return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
358 interrupt,
359 )));
360 }
361
362 Ok(None)
363 }
364
365 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
366 if let Some(fragment) = self.prompt_fragment() {
367 ctx.request.append_prompt(&fragment);
368 }
369 ctx.request.messages.push(AgentMessage {
370 role: MessageRole::System,
371 content: MessageContent::Text(
372 "Tools marked for human approval will emit interrupts requiring external resolution."
373 .into(),
374 ),
375 metadata: None,
376 });
377 Ok(())
378 }
379}
380
381pub struct BaseSystemPromptMiddleware;
382
383#[async_trait]
384impl AgentMiddleware for BaseSystemPromptMiddleware {
385 fn id(&self) -> &'static str {
386 "base-system-prompt"
387 }
388
389 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
390 ctx.request.append_prompt(BASE_AGENT_PROMPT);
391 Ok(())
392 }
393}
394
395pub struct DeepAgentPromptMiddleware {
410 custom_instructions: String,
411 prompt_format: crate::prompts::PromptFormat,
413 override_system_prompt: Option<String>,
415}
416
417impl DeepAgentPromptMiddleware {
418 pub fn new(custom_instructions: impl Into<String>) -> Self {
419 Self {
420 custom_instructions: custom_instructions.into(),
421 prompt_format: crate::prompts::PromptFormat::Json,
422 override_system_prompt: None,
423 }
424 }
425
426 pub fn with_format(
431 custom_instructions: impl Into<String>,
432 format: crate::prompts::PromptFormat,
433 ) -> Self {
434 Self {
435 custom_instructions: custom_instructions.into(),
436 prompt_format: format,
437 override_system_prompt: None,
438 }
439 }
440
441 pub fn with_override(system_prompt: impl Into<String>) -> Self {
446 Self {
447 custom_instructions: String::new(),
448 prompt_format: crate::prompts::PromptFormat::Json,
449 override_system_prompt: Some(system_prompt.into()),
450 }
451 }
452}
453
454#[async_trait]
455impl AgentMiddleware for DeepAgentPromptMiddleware {
456 fn id(&self) -> &'static str {
457 "deep-agent-prompt"
458 }
459
460 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
461 let prompt = if let Some(ref override_prompt) = self.override_system_prompt {
462 override_prompt.clone()
464 } else {
465 use crate::prompts::get_deep_agent_system_prompt_formatted;
467 get_deep_agent_system_prompt_formatted(&self.custom_instructions, self.prompt_format)
468 };
469 ctx.request.append_prompt(&prompt);
470 Ok(())
471 }
472}
473
474pub struct AnthropicPromptCachingMiddleware {
477 pub ttl: String,
478 pub unsupported_model_behavior: String,
479}
480
481impl AnthropicPromptCachingMiddleware {
482 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
483 Self {
484 ttl: ttl.into(),
485 unsupported_model_behavior: unsupported_model_behavior.into(),
486 }
487 }
488
489 pub fn with_defaults() -> Self {
490 Self::new("5m", "ignore")
491 }
492
493 fn should_enable_caching(&self) -> bool {
496 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
497 }
498}
499
500#[async_trait]
501impl AgentMiddleware for AnthropicPromptCachingMiddleware {
502 fn id(&self) -> &'static str {
503 "anthropic-prompt-caching"
504 }
505
506 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
507 if !self.should_enable_caching() {
508 return Ok(());
509 }
510
511 if !ctx.request.system_prompt.is_empty() {
513 let system_message = AgentMessage {
514 role: MessageRole::System,
515 content: MessageContent::Text(ctx.request.system_prompt.clone()),
516 metadata: Some(MessageMetadata {
517 tool_call_id: None,
518 cache_control: Some(CacheControl {
519 cache_type: "ephemeral".to_string(),
520 }),
521 }),
522 };
523
524 ctx.request.messages.insert(0, system_message);
526
527 ctx.request.system_prompt.clear();
529
530 tracing::debug!(
531 ttl = %self.ttl,
532 behavior = %self.unsupported_model_behavior,
533 "Applied Anthropic prompt caching to system message"
534 );
535 }
536
537 Ok(())
538 }
539}
540
541pub struct TaskRouterTool {
542 registry: Arc<SubAgentRegistry>,
543 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
544 delegation_depth: Arc<RwLock<u32>>,
545}
546
547impl TaskRouterTool {
548 fn new(
549 registry: Arc<SubAgentRegistry>,
550 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
551 ) -> Self {
552 Self {
553 registry,
554 event_dispatcher,
555 delegation_depth: Arc::new(RwLock::new(0)),
556 }
557 }
558
559 fn available_subagents(&self) -> Vec<String> {
560 self.registry.available_names()
561 }
562
563 fn emit_event(&self, event: agents_core::events::AgentEvent) {
564 if let Some(dispatcher) = &self.event_dispatcher {
565 let dispatcher_clone = dispatcher.clone();
566 tokio::spawn(async move {
567 dispatcher_clone.dispatch(event).await;
568 });
569 }
570 }
571
572 fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
573 agents_core::events::EventMetadata::new(
574 "default".to_string(),
575 uuid::Uuid::new_v4().to_string(),
576 None,
577 )
578 }
579
580 fn get_delegation_depth(&self) -> u32 {
581 *self.delegation_depth.read().unwrap_or_else(|_| {
582 tracing::warn!("Failed to read delegation depth, defaulting to 0");
583 panic!("RwLock poisoned")
584 })
585 }
586
587 fn increment_delegation_depth(&self) {
588 if let Ok(mut depth) = self.delegation_depth.write() {
589 *depth += 1;
590 }
591 }
592
593 fn decrement_delegation_depth(&self) {
594 if let Ok(mut depth) = self.delegation_depth.write() {
595 if *depth > 0 {
596 *depth -= 1;
597 }
598 }
599 }
600}
601
602#[derive(Debug, Clone, Deserialize)]
603struct TaskInvocationArgs {
604 #[serde(alias = "description")]
605 instruction: String,
606 #[serde(alias = "subagent_type")]
607 agent: String,
608}
609
610#[async_trait]
611impl Tool for TaskRouterTool {
612 fn schema(&self) -> agents_core::tools::ToolSchema {
613 use agents_core::tools::{ToolParameterSchema, ToolSchema};
614 use std::collections::HashMap;
615
616 let mut properties = HashMap::new();
617 properties.insert(
618 "agent".to_string(),
619 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
620 );
621 properties.insert(
622 "instruction".to_string(),
623 ToolParameterSchema::string("Clear instruction for the sub-agent"),
624 );
625
626 ToolSchema::new(
627 "task",
628 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
629 ToolParameterSchema::object(
630 "Task delegation parameters",
631 properties,
632 vec!["agent".to_string(), "instruction".to_string()],
633 ),
634 )
635 }
636
637 async fn execute(
638 &self,
639 args: serde_json::Value,
640 ctx: ToolContext,
641 ) -> anyhow::Result<ToolResult> {
642 let args: TaskInvocationArgs = serde_json::from_value(args)?;
643 let available = self.available_subagents();
644
645 if let Some(agent) = self.registry.get(&args.agent) {
646 self.increment_delegation_depth();
648 let current_depth = self.get_delegation_depth();
649
650 let instruction_summary = if args.instruction.len() > 100 {
652 format!("{}...", &args.instruction[..100])
653 } else {
654 args.instruction.clone()
655 };
656
657 self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
659 agents_core::events::SubAgentStartedEvent {
660 metadata: self.create_event_metadata(),
661 agent_name: args.agent.clone(),
662 instruction_summary: instruction_summary.clone(),
663 delegation_depth: current_depth,
664 },
665 ));
666
667 tracing::warn!(
669 "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
670 args.agent,
671 current_depth,
672 args.instruction
673 );
674
675 let start_time = std::time::Instant::now();
676 let user_message = AgentMessage {
677 role: MessageRole::User,
678 content: MessageContent::Text(args.instruction.clone()),
679 metadata: None,
680 };
681
682 let response = agent
683 .handle_message(user_message, ctx.state.clone())
684 .await?;
685
686 let duration = start_time.elapsed();
688 let duration_ms = duration.as_millis() as u64;
689
690 let response_preview = match &response.content {
692 MessageContent::Text(t) => {
693 if t.len() > 100 {
694 format!("{}...", &t[..100])
695 } else {
696 t.clone()
697 }
698 }
699 MessageContent::Json(v) => {
700 let json_str = v.to_string();
701 if json_str.len() > 100 {
702 format!("{}...", &json_str[..100])
703 } else {
704 json_str
705 }
706 }
707 };
708
709 self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
711 agents_core::events::SubAgentCompletedEvent {
712 metadata: self.create_event_metadata(),
713 agent_name: args.agent.clone(),
714 duration_ms,
715 result_summary: response_preview.clone(),
716 },
717 ));
718
719 tracing::warn!(
721 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
722 args.agent,
723 duration,
724 response_preview
725 );
726
727 self.decrement_delegation_depth();
729
730 let result_text = match response.content {
733 MessageContent::Text(text) => text,
734 MessageContent::Json(json) => json.to_string(),
735 };
736
737 return Ok(ToolResult::text(&ctx, result_text));
738 }
739
740 tracing::error!(
741 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
742 args.agent,
743 available
744 );
745
746 Ok(ToolResult::text(
747 &ctx,
748 format!(
749 "Sub-agent '{}' not found. Available sub-agents: {}",
750 args.agent,
751 available.join(", ")
752 ),
753 ))
754 }
755}
756
757#[derive(Debug, Clone)]
758pub struct SubAgentDescriptor {
759 pub name: String,
760 pub description: String,
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766 use agents_core::agent::{AgentDescriptor, AgentHandle};
767 use agents_core::messaging::{MessageContent, MessageRole};
768 use serde_json::json;
769
770 struct AppendPromptMiddleware;
771
772 #[async_trait]
773 impl AgentMiddleware for AppendPromptMiddleware {
774 fn id(&self) -> &'static str {
775 "append-prompt"
776 }
777
778 async fn modify_model_request(
779 &self,
780 ctx: &mut MiddlewareContext<'_>,
781 ) -> anyhow::Result<()> {
782 ctx.request.system_prompt.push_str("\nExtra directives.");
783 Ok(())
784 }
785 }
786
787 #[tokio::test]
788 async fn middleware_mutates_prompt() {
789 let mut request = ModelRequest::new(
790 "System",
791 vec![AgentMessage {
792 role: MessageRole::User,
793 content: MessageContent::Text("Hi".into()),
794 metadata: None,
795 }],
796 );
797 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
798 let mut ctx = MiddlewareContext::with_request(&mut request, state);
799 let middleware = AppendPromptMiddleware;
800 middleware.modify_model_request(&mut ctx).await.unwrap();
801 assert!(ctx.request.system_prompt.contains("Extra directives"));
802 }
803
804 #[tokio::test]
805 async fn planning_middleware_registers_write_todos() {
806 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
807 let middleware = PlanningMiddleware::new(state);
808 let tool_names: Vec<_> = middleware
809 .tools()
810 .iter()
811 .map(|t| t.schema().name.clone())
812 .collect();
813 assert!(tool_names.contains(&"write_todos".to_string()));
814
815 let mut request = ModelRequest::new("System", vec![]);
816 let mut ctx = MiddlewareContext::with_request(
817 &mut request,
818 Arc::new(RwLock::new(AgentStateSnapshot::default())),
819 );
820 middleware.modify_model_request(&mut ctx).await.unwrap();
821 assert!(ctx.request.system_prompt.contains("write_todos"));
822 }
823
824 #[tokio::test]
825 async fn filesystem_middleware_registers_tools() {
826 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
827 let middleware = FilesystemMiddleware::new(state);
828 let tool_names: Vec<_> = middleware
829 .tools()
830 .iter()
831 .map(|t| t.schema().name.clone())
832 .collect();
833 for expected in ["ls", "read_file", "write_file", "edit_file"] {
834 assert!(tool_names.contains(&expected.to_string()));
835 }
836 }
837
838 #[tokio::test]
839 async fn summarization_middleware_trims_messages() {
840 let middleware = SummarizationMiddleware::new(2, "Summary note");
841 let mut request = ModelRequest::new(
842 "System",
843 vec![
844 AgentMessage {
845 role: MessageRole::User,
846 content: MessageContent::Text("one".into()),
847 metadata: None,
848 },
849 AgentMessage {
850 role: MessageRole::Agent,
851 content: MessageContent::Text("two".into()),
852 metadata: None,
853 },
854 AgentMessage {
855 role: MessageRole::User,
856 content: MessageContent::Text("three".into()),
857 metadata: None,
858 },
859 ],
860 );
861 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
862 let mut ctx = MiddlewareContext::with_request(&mut request, state);
863 middleware.modify_model_request(&mut ctx).await.unwrap();
864 assert_eq!(ctx.request.messages.len(), 3);
865 match &ctx.request.messages[0].content {
866 MessageContent::Text(text) => assert!(text.contains("Summary note")),
867 other => panic!("expected text, got {other:?}"),
868 }
869 }
870
871 struct StubAgent;
872
873 #[async_trait]
874 impl AgentHandle for StubAgent {
875 async fn describe(&self) -> AgentDescriptor {
876 AgentDescriptor {
877 name: "stub".into(),
878 version: "0.0.1".into(),
879 description: None,
880 }
881 }
882
883 async fn handle_message(
884 &self,
885 _input: AgentMessage,
886 _state: Arc<AgentStateSnapshot>,
887 ) -> anyhow::Result<AgentMessage> {
888 Ok(AgentMessage {
889 role: MessageRole::Agent,
890 content: MessageContent::Text("stub-response".into()),
891 metadata: None,
892 })
893 }
894 }
895
896 #[tokio::test]
897 async fn task_router_reports_unknown_subagent() {
898 let registry = Arc::new(SubAgentRegistry::new(vec![]));
899 let task_tool = TaskRouterTool::new(registry.clone(), None);
900 let state = Arc::new(AgentStateSnapshot::default());
901 let ctx = ToolContext::new(state);
902
903 let response = task_tool
904 .execute(
905 json!({
906 "instruction": "Do something",
907 "agent": "unknown"
908 }),
909 ctx,
910 )
911 .await
912 .unwrap();
913
914 match response {
915 ToolResult::Message(msg) => match msg.content {
916 MessageContent::Text(text) => {
917 assert!(text.contains("Sub-agent 'unknown' not found"))
918 }
919 other => panic!("expected text, got {other:?}"),
920 },
921 _ => panic!("expected message"),
922 }
923 }
924
925 #[tokio::test]
926 async fn subagent_middleware_appends_prompt() {
927 let subagents = vec![SubAgentRegistration {
928 descriptor: SubAgentDescriptor {
929 name: "research-agent".into(),
930 description: "Deep research specialist".into(),
931 },
932 agent: Arc::new(StubAgent),
933 }];
934 let middleware = SubAgentMiddleware::new(subagents);
935
936 let mut request = ModelRequest::new("System", vec![]);
937 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
938 let mut ctx = MiddlewareContext::with_request(&mut request, state);
939 middleware.modify_model_request(&mut ctx).await.unwrap();
940
941 assert!(ctx.request.system_prompt.contains("research-agent"));
942 let tool_names: Vec<_> = middleware
943 .tools()
944 .iter()
945 .map(|t| t.schema().name.clone())
946 .collect();
947 assert!(tool_names.contains(&"task".to_string()));
948 }
949
950 #[tokio::test]
951 async fn task_router_invokes_registered_subagent() {
952 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
953 descriptor: SubAgentDescriptor {
954 name: "stub-agent".into(),
955 description: "Stub".into(),
956 },
957 agent: Arc::new(StubAgent),
958 }]));
959 let task_tool = TaskRouterTool::new(registry.clone(), None);
960 let state = Arc::new(AgentStateSnapshot::default());
961 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
962 let response = task_tool
963 .execute(
964 json!({
965 "description": "do work",
966 "subagent_type": "stub-agent"
967 }),
968 ctx,
969 )
970 .await
971 .unwrap();
972
973 match response {
974 ToolResult::Message(msg) => {
975 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
976 match msg.content {
977 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
978 other => panic!("expected text, got {other:?}"),
979 }
980 }
981 _ => panic!("expected message"),
982 }
983 }
984
985 #[tokio::test]
986 async fn human_in_loop_appends_prompt() {
987 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
988 "danger-tool".into(),
989 HitlPolicy {
990 allow_auto: false,
991 note: Some("Requires security review".into()),
992 },
993 )]));
994 let mut request = ModelRequest::new("System", vec![]);
995 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
996 let mut ctx = MiddlewareContext::with_request(&mut request, state);
997 middleware.modify_model_request(&mut ctx).await.unwrap();
998 assert!(ctx
999 .request
1000 .system_prompt
1001 .contains("danger-tool: Requires security review"));
1002 }
1003
1004 #[tokio::test]
1005 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
1006 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1007 let mut request = ModelRequest::new(
1008 "This is the system prompt",
1009 vec![AgentMessage {
1010 role: MessageRole::User,
1011 content: MessageContent::Text("Hello".into()),
1012 metadata: None,
1013 }],
1014 );
1015 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1016 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1017
1018 middleware.modify_model_request(&mut ctx).await.unwrap();
1020
1021 assert!(ctx.request.system_prompt.is_empty());
1023
1024 assert_eq!(ctx.request.messages.len(), 2);
1026
1027 let system_message = &ctx.request.messages[0];
1028 assert!(matches!(system_message.role, MessageRole::System));
1029 assert_eq!(
1030 system_message.content.as_text().unwrap(),
1031 "This is the system prompt"
1032 );
1033
1034 let metadata = system_message.metadata.as_ref().unwrap();
1036 let cache_control = metadata.cache_control.as_ref().unwrap();
1037 assert_eq!(cache_control.cache_type, "ephemeral");
1038
1039 let user_message = &ctx.request.messages[1];
1041 assert!(matches!(user_message.role, MessageRole::User));
1042 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
1043 }
1044
1045 #[tokio::test]
1046 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1047 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1048 let mut request = ModelRequest::new("This is the system prompt", vec![]);
1049 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1050 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1051
1052 middleware.modify_model_request(&mut ctx).await.unwrap();
1054
1055 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1057 assert_eq!(ctx.request.messages.len(), 0);
1058 }
1059
1060 #[tokio::test]
1061 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1062 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1063 let mut request = ModelRequest::new(
1064 "",
1065 vec![AgentMessage {
1066 role: MessageRole::User,
1067 content: MessageContent::Text("Hello".into()),
1068 metadata: None,
1069 }],
1070 );
1071 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1072 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1073
1074 middleware.modify_model_request(&mut ctx).await.unwrap();
1076
1077 assert!(ctx.request.system_prompt.is_empty());
1079 assert_eq!(ctx.request.messages.len(), 1);
1081 }
1082
1083 #[tokio::test]
1086 async fn hitl_creates_interrupt_for_disallowed_tool() {
1087 let mut policies = HashMap::new();
1088 policies.insert(
1089 "dangerous_tool".to_string(),
1090 HitlPolicy {
1091 allow_auto: false,
1092 note: Some("Requires security review".to_string()),
1093 },
1094 );
1095
1096 let middleware = HumanInLoopMiddleware::new(policies);
1097 let tool_args = json!({"action": "delete_all"});
1098
1099 let result = middleware
1100 .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1101 .await
1102 .unwrap();
1103
1104 assert!(result.is_some());
1105 let interrupt = result.unwrap();
1106
1107 match interrupt {
1108 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1109 assert_eq!(hitl.tool_name, "dangerous_tool");
1110 assert_eq!(hitl.tool_args, tool_args);
1111 assert_eq!(hitl.call_id, "call_123");
1112 assert_eq!(
1113 hitl.policy_note,
1114 Some("Requires security review".to_string())
1115 );
1116 }
1117 }
1118 }
1119
1120 #[tokio::test]
1121 async fn hitl_no_interrupt_for_allowed_tool() {
1122 let mut policies = HashMap::new();
1123 policies.insert(
1124 "safe_tool".to_string(),
1125 HitlPolicy {
1126 allow_auto: true,
1127 note: None,
1128 },
1129 );
1130
1131 let middleware = HumanInLoopMiddleware::new(policies);
1132 let tool_args = json!({"action": "read"});
1133
1134 let result = middleware
1135 .before_tool_execution("safe_tool", &tool_args, "call_456")
1136 .await
1137 .unwrap();
1138
1139 assert!(result.is_none());
1140 }
1141
1142 #[tokio::test]
1143 async fn hitl_no_interrupt_for_unlisted_tool() {
1144 let policies = HashMap::new();
1145 let middleware = HumanInLoopMiddleware::new(policies);
1146 let tool_args = json!({"action": "anything"});
1147
1148 let result = middleware
1149 .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1150 .await
1151 .unwrap();
1152
1153 assert!(result.is_none());
1154 }
1155
1156 #[tokio::test]
1157 async fn hitl_interrupt_includes_correct_details() {
1158 let mut policies = HashMap::new();
1159 policies.insert(
1160 "critical_tool".to_string(),
1161 HitlPolicy {
1162 allow_auto: false,
1163 note: Some("Critical operation - requires approval".to_string()),
1164 },
1165 );
1166
1167 let middleware = HumanInLoopMiddleware::new(policies);
1168 let tool_args = json!({
1169 "database": "production",
1170 "operation": "drop_table"
1171 });
1172
1173 let result = middleware
1174 .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1175 .await
1176 .unwrap();
1177
1178 assert!(result.is_some());
1179 let interrupt = result.unwrap();
1180
1181 match interrupt {
1182 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1183 assert_eq!(hitl.tool_name, "critical_tool");
1184 assert_eq!(hitl.tool_args["database"], "production");
1185 assert_eq!(hitl.tool_args["operation"], "drop_table");
1186 assert_eq!(hitl.call_id, "call_critical_1");
1187 assert!(hitl.policy_note.is_some());
1188 assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1189 }
1192 }
1193 }
1194
1195 #[tokio::test]
1196 async fn hitl_interrupt_without_policy_note() {
1197 let mut policies = HashMap::new();
1198 policies.insert(
1199 "tool_no_note".to_string(),
1200 HitlPolicy {
1201 allow_auto: false,
1202 note: None,
1203 },
1204 );
1205
1206 let middleware = HumanInLoopMiddleware::new(policies);
1207 let tool_args = json!({"param": "value"});
1208
1209 let result = middleware
1210 .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1211 .await
1212 .unwrap();
1213
1214 assert!(result.is_some());
1215 let interrupt = result.unwrap();
1216
1217 match interrupt {
1218 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1219 assert_eq!(hitl.tool_name, "tool_no_note");
1220 assert_eq!(hitl.policy_note, None);
1221 }
1222 }
1223 }
1224}