1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6 AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9 BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10 WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20#[derive(Debug, Clone)]
23pub struct ModelRequest {
24 pub system_prompt: String,
25 pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29 pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30 Self {
31 system_prompt: system_prompt.into(),
32 messages,
33 }
34 }
35
36 pub fn append_prompt(&mut self, fragment: &str) {
37 if !fragment.is_empty() {
38 self.system_prompt.push_str("\n\n");
39 self.system_prompt.push_str(fragment);
40 }
41 }
42}
43
44pub struct MiddlewareContext<'a> {
46 pub request: &'a mut ModelRequest,
47 pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51 pub fn with_request(
52 request: &'a mut ModelRequest,
53 state: Arc<RwLock<AgentStateSnapshot>>,
54 ) -> Self {
55 Self { request, state }
56 }
57}
58
59#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64 fn id(&self) -> &'static str;
66
67 fn tools(&self) -> Vec<ToolBox> {
69 Vec::new()
70 }
71
72 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75 async fn before_tool_execution(
91 &self,
92 _tool_name: &str,
93 _tool_args: &serde_json::Value,
94 _call_id: &str,
95 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96 Ok(None)
97 }
98}
99
100pub struct SummarizationMiddleware {
101 pub messages_to_keep: usize,
102 pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106 pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107 Self {
108 messages_to_keep,
109 summary_note: summary_note.into(),
110 }
111 }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116 fn id(&self) -> &'static str {
117 "summarization"
118 }
119
120 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121 if ctx.request.messages.len() > self.messages_to_keep {
122 let dropped = ctx.request.messages.len() - self.messages_to_keep;
123 let mut truncated = ctx
124 .request
125 .messages
126 .split_off(ctx.request.messages.len() - self.messages_to_keep);
127 truncated.insert(
128 0,
129 AgentMessage {
130 role: MessageRole::System,
131 content: MessageContent::Text(format!(
132 "{} ({} earlier messages summarized)",
133 self.summary_note, dropped
134 )),
135 metadata: None,
136 },
137 );
138 ctx.request.messages = truncated;
139 }
140 Ok(())
141 }
142}
143
144pub struct PlanningMiddleware {
145 _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150 Self { _state: state }
151 }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156 fn id(&self) -> &'static str {
157 "planning"
158 }
159
160 fn tools(&self) -> Vec<ToolBox> {
161 use agents_toolkit::create_todos_tools;
162 create_todos_tools()
163 }
164
165 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166 ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
167 Ok(())
168 }
169}
170
171pub struct FilesystemMiddleware {
172 _state: Arc<RwLock<AgentStateSnapshot>>,
173}
174
175impl FilesystemMiddleware {
176 pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
177 Self { _state: state }
178 }
179}
180
181#[async_trait]
182impl AgentMiddleware for FilesystemMiddleware {
183 fn id(&self) -> &'static str {
184 "filesystem"
185 }
186
187 fn tools(&self) -> Vec<ToolBox> {
188 create_filesystem_tools()
189 }
190
191 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
192 ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
193 Ok(())
194 }
195}
196
197#[derive(Clone)]
198pub struct SubAgentRegistration {
199 pub descriptor: SubAgentDescriptor,
200 pub agent: Arc<dyn AgentHandle>,
201}
202
203struct SubAgentRegistry {
204 agents: HashMap<String, Arc<dyn AgentHandle>>,
205}
206
207impl SubAgentRegistry {
208 fn new(registrations: Vec<SubAgentRegistration>) -> Self {
209 let mut agents = HashMap::new();
210 for reg in registrations {
211 agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
212 }
213 Self { agents }
214 }
215
216 fn available_names(&self) -> Vec<String> {
217 self.agents.keys().cloned().collect()
218 }
219
220 fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
221 self.agents.get(name).cloned()
222 }
223}
224
225pub struct SubAgentMiddleware {
226 task_tool: ToolBox,
227 descriptors: Vec<SubAgentDescriptor>,
228 _registry: Arc<SubAgentRegistry>,
229}
230
231impl SubAgentMiddleware {
232 pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
233 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
234 let registry = Arc::new(SubAgentRegistry::new(registrations));
235 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
236 Self {
237 task_tool,
238 descriptors,
239 _registry: registry,
240 }
241 }
242
243 pub fn new_with_events(
244 registrations: Vec<SubAgentRegistration>,
245 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
246 ) -> Self {
247 let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
248 let registry = Arc::new(SubAgentRegistry::new(registrations));
249 let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
250 Self {
251 task_tool,
252 descriptors,
253 _registry: registry,
254 }
255 }
256
257 fn prompt_fragment(&self) -> String {
258 let descriptions: Vec<String> = if self.descriptors.is_empty() {
259 vec![String::from("- general-purpose: Default reasoning agent")]
260 } else {
261 self.descriptors
262 .iter()
263 .map(|agent| format!("- {}: {}", agent.name, agent.description))
264 .collect()
265 };
266
267 TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
268 }
269}
270
271#[async_trait]
272impl AgentMiddleware for SubAgentMiddleware {
273 fn id(&self) -> &'static str {
274 "subagent"
275 }
276
277 fn tools(&self) -> Vec<ToolBox> {
278 vec![self.task_tool.clone()]
279 }
280
281 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
282 ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
283 ctx.request.append_prompt(&self.prompt_fragment());
284 Ok(())
285 }
286}
287
288#[derive(Clone, Debug)]
289pub struct HitlPolicy {
290 pub allow_auto: bool,
291 pub note: Option<String>,
292}
293
294pub struct HumanInLoopMiddleware {
295 policies: HashMap<String, HitlPolicy>,
296}
297
298impl HumanInLoopMiddleware {
299 pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
300 Self { policies }
301 }
302
303 pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
304 self.policies
305 .get(tool_name)
306 .filter(|policy| !policy.allow_auto)
307 }
308
309 fn prompt_fragment(&self) -> Option<String> {
310 let pending: Vec<String> = self
311 .policies
312 .iter()
313 .filter(|(_, policy)| !policy.allow_auto)
314 .map(|(tool, policy)| match &policy.note {
315 Some(note) => format!("- {tool}: {note}"),
316 None => format!("- {tool}: Requires approval"),
317 })
318 .collect();
319 if pending.is_empty() {
320 None
321 } else {
322 Some(format!(
323 "The following tools require human approval before execution:\n{}",
324 pending.join("\n")
325 ))
326 }
327 }
328}
329
330#[async_trait]
331impl AgentMiddleware for HumanInLoopMiddleware {
332 fn id(&self) -> &'static str {
333 "human-in-loop"
334 }
335
336 async fn before_tool_execution(
337 &self,
338 tool_name: &str,
339 tool_args: &serde_json::Value,
340 call_id: &str,
341 ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
342 if let Some(policy) = self.requires_approval(tool_name) {
343 tracing::warn!(
344 tool_name = %tool_name,
345 call_id = %call_id,
346 policy_note = ?policy.note,
347 "🔒 HITL: Tool execution requires human approval"
348 );
349
350 let interrupt = agents_core::hitl::HitlInterrupt::new(
351 tool_name,
352 tool_args.clone(),
353 call_id,
354 policy.note.clone(),
355 );
356
357 return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
358 interrupt,
359 )));
360 }
361
362 Ok(None)
363 }
364
365 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
366 if let Some(fragment) = self.prompt_fragment() {
367 ctx.request.append_prompt(&fragment);
368 }
369 ctx.request.messages.push(AgentMessage {
370 role: MessageRole::System,
371 content: MessageContent::Text(
372 "Tools marked for human approval will emit interrupts requiring external resolution."
373 .into(),
374 ),
375 metadata: None,
376 });
377 Ok(())
378 }
379}
380
381pub struct BaseSystemPromptMiddleware;
382
383#[async_trait]
384impl AgentMiddleware for BaseSystemPromptMiddleware {
385 fn id(&self) -> &'static str {
386 "base-system-prompt"
387 }
388
389 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
390 ctx.request.append_prompt(BASE_AGENT_PROMPT);
391 Ok(())
392 }
393}
394
395pub struct DeepAgentPromptMiddleware {
409 custom_instructions: String,
410 override_system_prompt: Option<String>,
412}
413
414impl DeepAgentPromptMiddleware {
415 pub fn new(custom_instructions: impl Into<String>) -> Self {
416 Self {
417 custom_instructions: custom_instructions.into(),
418 override_system_prompt: None,
419 }
420 }
421
422 pub fn with_override(system_prompt: impl Into<String>) -> Self {
427 Self {
428 custom_instructions: String::new(),
429 override_system_prompt: Some(system_prompt.into()),
430 }
431 }
432}
433
434#[async_trait]
435impl AgentMiddleware for DeepAgentPromptMiddleware {
436 fn id(&self) -> &'static str {
437 "deep-agent-prompt"
438 }
439
440 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
441 let prompt = if let Some(ref override_prompt) = self.override_system_prompt {
442 override_prompt.clone()
444 } else {
445 use crate::prompts::get_deep_agent_system_prompt;
447 get_deep_agent_system_prompt(&self.custom_instructions)
448 };
449 ctx.request.append_prompt(&prompt);
450 Ok(())
451 }
452}
453
454pub struct AnthropicPromptCachingMiddleware {
457 pub ttl: String,
458 pub unsupported_model_behavior: String,
459}
460
461impl AnthropicPromptCachingMiddleware {
462 pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
463 Self {
464 ttl: ttl.into(),
465 unsupported_model_behavior: unsupported_model_behavior.into(),
466 }
467 }
468
469 pub fn with_defaults() -> Self {
470 Self::new("5m", "ignore")
471 }
472
473 fn should_enable_caching(&self) -> bool {
476 !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
477 }
478}
479
480#[async_trait]
481impl AgentMiddleware for AnthropicPromptCachingMiddleware {
482 fn id(&self) -> &'static str {
483 "anthropic-prompt-caching"
484 }
485
486 async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
487 if !self.should_enable_caching() {
488 return Ok(());
489 }
490
491 if !ctx.request.system_prompt.is_empty() {
493 let system_message = AgentMessage {
494 role: MessageRole::System,
495 content: MessageContent::Text(ctx.request.system_prompt.clone()),
496 metadata: Some(MessageMetadata {
497 tool_call_id: None,
498 cache_control: Some(CacheControl {
499 cache_type: "ephemeral".to_string(),
500 }),
501 }),
502 };
503
504 ctx.request.messages.insert(0, system_message);
506
507 ctx.request.system_prompt.clear();
509
510 tracing::debug!(
511 ttl = %self.ttl,
512 behavior = %self.unsupported_model_behavior,
513 "Applied Anthropic prompt caching to system message"
514 );
515 }
516
517 Ok(())
518 }
519}
520
521pub struct TaskRouterTool {
522 registry: Arc<SubAgentRegistry>,
523 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
524 delegation_depth: Arc<RwLock<u32>>,
525}
526
527impl TaskRouterTool {
528 fn new(
529 registry: Arc<SubAgentRegistry>,
530 event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
531 ) -> Self {
532 Self {
533 registry,
534 event_dispatcher,
535 delegation_depth: Arc::new(RwLock::new(0)),
536 }
537 }
538
539 fn available_subagents(&self) -> Vec<String> {
540 self.registry.available_names()
541 }
542
543 fn emit_event(&self, event: agents_core::events::AgentEvent) {
544 if let Some(dispatcher) = &self.event_dispatcher {
545 let dispatcher_clone = dispatcher.clone();
546 tokio::spawn(async move {
547 dispatcher_clone.dispatch(event).await;
548 });
549 }
550 }
551
552 fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
553 agents_core::events::EventMetadata::new(
554 "default".to_string(),
555 uuid::Uuid::new_v4().to_string(),
556 None,
557 )
558 }
559
560 fn get_delegation_depth(&self) -> u32 {
561 *self.delegation_depth.read().unwrap_or_else(|_| {
562 tracing::warn!("Failed to read delegation depth, defaulting to 0");
563 panic!("RwLock poisoned")
564 })
565 }
566
567 fn increment_delegation_depth(&self) {
568 if let Ok(mut depth) = self.delegation_depth.write() {
569 *depth += 1;
570 }
571 }
572
573 fn decrement_delegation_depth(&self) {
574 if let Ok(mut depth) = self.delegation_depth.write() {
575 if *depth > 0 {
576 *depth -= 1;
577 }
578 }
579 }
580}
581
582#[derive(Debug, Clone, Deserialize)]
583struct TaskInvocationArgs {
584 #[serde(alias = "description")]
585 instruction: String,
586 #[serde(alias = "subagent_type")]
587 agent: String,
588}
589
590#[async_trait]
591impl Tool for TaskRouterTool {
592 fn schema(&self) -> agents_core::tools::ToolSchema {
593 use agents_core::tools::{ToolParameterSchema, ToolSchema};
594 use std::collections::HashMap;
595
596 let mut properties = HashMap::new();
597 properties.insert(
598 "agent".to_string(),
599 ToolParameterSchema::string("Name of the sub-agent to delegate to"),
600 );
601 properties.insert(
602 "instruction".to_string(),
603 ToolParameterSchema::string("Clear instruction for the sub-agent"),
604 );
605
606 ToolSchema::new(
607 "task",
608 "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
609 ToolParameterSchema::object(
610 "Task delegation parameters",
611 properties,
612 vec!["agent".to_string(), "instruction".to_string()],
613 ),
614 )
615 }
616
617 async fn execute(
618 &self,
619 args: serde_json::Value,
620 ctx: ToolContext,
621 ) -> anyhow::Result<ToolResult> {
622 let args: TaskInvocationArgs = serde_json::from_value(args)?;
623 let available = self.available_subagents();
624
625 if let Some(agent) = self.registry.get(&args.agent) {
626 self.increment_delegation_depth();
628 let current_depth = self.get_delegation_depth();
629
630 let instruction_summary = if args.instruction.len() > 100 {
632 format!("{}...", &args.instruction[..100])
633 } else {
634 args.instruction.clone()
635 };
636
637 self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
639 agents_core::events::SubAgentStartedEvent {
640 metadata: self.create_event_metadata(),
641 agent_name: args.agent.clone(),
642 instruction_summary: instruction_summary.clone(),
643 delegation_depth: current_depth,
644 },
645 ));
646
647 tracing::warn!(
649 "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
650 args.agent,
651 current_depth,
652 args.instruction
653 );
654
655 let start_time = std::time::Instant::now();
656 let user_message = AgentMessage {
657 role: MessageRole::User,
658 content: MessageContent::Text(args.instruction.clone()),
659 metadata: None,
660 };
661
662 let response = agent
663 .handle_message(user_message, ctx.state.clone())
664 .await?;
665
666 let duration = start_time.elapsed();
668 let duration_ms = duration.as_millis() as u64;
669
670 let response_preview = match &response.content {
672 MessageContent::Text(t) => {
673 if t.len() > 100 {
674 format!("{}...", &t[..100])
675 } else {
676 t.clone()
677 }
678 }
679 MessageContent::Json(v) => {
680 let json_str = v.to_string();
681 if json_str.len() > 100 {
682 format!("{}...", &json_str[..100])
683 } else {
684 json_str
685 }
686 }
687 };
688
689 self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
691 agents_core::events::SubAgentCompletedEvent {
692 metadata: self.create_event_metadata(),
693 agent_name: args.agent.clone(),
694 duration_ms,
695 result_summary: response_preview.clone(),
696 },
697 ));
698
699 tracing::warn!(
701 "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
702 args.agent,
703 duration,
704 response_preview
705 );
706
707 self.decrement_delegation_depth();
709
710 let result_text = match response.content {
713 MessageContent::Text(text) => text,
714 MessageContent::Json(json) => json.to_string(),
715 };
716
717 return Ok(ToolResult::text(&ctx, result_text));
718 }
719
720 tracing::error!(
721 "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
722 args.agent,
723 available
724 );
725
726 Ok(ToolResult::text(
727 &ctx,
728 format!(
729 "Sub-agent '{}' not found. Available sub-agents: {}",
730 args.agent,
731 available.join(", ")
732 ),
733 ))
734 }
735}
736
737#[derive(Debug, Clone)]
738pub struct SubAgentDescriptor {
739 pub name: String,
740 pub description: String,
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use agents_core::agent::{AgentDescriptor, AgentHandle};
747 use agents_core::messaging::{MessageContent, MessageRole};
748 use serde_json::json;
749
750 struct AppendPromptMiddleware;
751
752 #[async_trait]
753 impl AgentMiddleware for AppendPromptMiddleware {
754 fn id(&self) -> &'static str {
755 "append-prompt"
756 }
757
758 async fn modify_model_request(
759 &self,
760 ctx: &mut MiddlewareContext<'_>,
761 ) -> anyhow::Result<()> {
762 ctx.request.system_prompt.push_str("\nExtra directives.");
763 Ok(())
764 }
765 }
766
767 #[tokio::test]
768 async fn middleware_mutates_prompt() {
769 let mut request = ModelRequest::new(
770 "System",
771 vec![AgentMessage {
772 role: MessageRole::User,
773 content: MessageContent::Text("Hi".into()),
774 metadata: None,
775 }],
776 );
777 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
778 let mut ctx = MiddlewareContext::with_request(&mut request, state);
779 let middleware = AppendPromptMiddleware;
780 middleware.modify_model_request(&mut ctx).await.unwrap();
781 assert!(ctx.request.system_prompt.contains("Extra directives"));
782 }
783
784 #[tokio::test]
785 async fn planning_middleware_registers_write_todos() {
786 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
787 let middleware = PlanningMiddleware::new(state);
788 let tool_names: Vec<_> = middleware
789 .tools()
790 .iter()
791 .map(|t| t.schema().name.clone())
792 .collect();
793 assert!(tool_names.contains(&"write_todos".to_string()));
794
795 let mut request = ModelRequest::new("System", vec![]);
796 let mut ctx = MiddlewareContext::with_request(
797 &mut request,
798 Arc::new(RwLock::new(AgentStateSnapshot::default())),
799 );
800 middleware.modify_model_request(&mut ctx).await.unwrap();
801 assert!(ctx.request.system_prompt.contains("write_todos"));
802 }
803
804 #[tokio::test]
805 async fn filesystem_middleware_registers_tools() {
806 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
807 let middleware = FilesystemMiddleware::new(state);
808 let tool_names: Vec<_> = middleware
809 .tools()
810 .iter()
811 .map(|t| t.schema().name.clone())
812 .collect();
813 for expected in ["ls", "read_file", "write_file", "edit_file"] {
814 assert!(tool_names.contains(&expected.to_string()));
815 }
816 }
817
818 #[tokio::test]
819 async fn summarization_middleware_trims_messages() {
820 let middleware = SummarizationMiddleware::new(2, "Summary note");
821 let mut request = ModelRequest::new(
822 "System",
823 vec![
824 AgentMessage {
825 role: MessageRole::User,
826 content: MessageContent::Text("one".into()),
827 metadata: None,
828 },
829 AgentMessage {
830 role: MessageRole::Agent,
831 content: MessageContent::Text("two".into()),
832 metadata: None,
833 },
834 AgentMessage {
835 role: MessageRole::User,
836 content: MessageContent::Text("three".into()),
837 metadata: None,
838 },
839 ],
840 );
841 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
842 let mut ctx = MiddlewareContext::with_request(&mut request, state);
843 middleware.modify_model_request(&mut ctx).await.unwrap();
844 assert_eq!(ctx.request.messages.len(), 3);
845 match &ctx.request.messages[0].content {
846 MessageContent::Text(text) => assert!(text.contains("Summary note")),
847 other => panic!("expected text, got {other:?}"),
848 }
849 }
850
851 struct StubAgent;
852
853 #[async_trait]
854 impl AgentHandle for StubAgent {
855 async fn describe(&self) -> AgentDescriptor {
856 AgentDescriptor {
857 name: "stub".into(),
858 version: "0.0.1".into(),
859 description: None,
860 }
861 }
862
863 async fn handle_message(
864 &self,
865 _input: AgentMessage,
866 _state: Arc<AgentStateSnapshot>,
867 ) -> anyhow::Result<AgentMessage> {
868 Ok(AgentMessage {
869 role: MessageRole::Agent,
870 content: MessageContent::Text("stub-response".into()),
871 metadata: None,
872 })
873 }
874 }
875
876 #[tokio::test]
877 async fn task_router_reports_unknown_subagent() {
878 let registry = Arc::new(SubAgentRegistry::new(vec![]));
879 let task_tool = TaskRouterTool::new(registry.clone(), None);
880 let state = Arc::new(AgentStateSnapshot::default());
881 let ctx = ToolContext::new(state);
882
883 let response = task_tool
884 .execute(
885 json!({
886 "instruction": "Do something",
887 "agent": "unknown"
888 }),
889 ctx,
890 )
891 .await
892 .unwrap();
893
894 match response {
895 ToolResult::Message(msg) => match msg.content {
896 MessageContent::Text(text) => {
897 assert!(text.contains("Sub-agent 'unknown' not found"))
898 }
899 other => panic!("expected text, got {other:?}"),
900 },
901 _ => panic!("expected message"),
902 }
903 }
904
905 #[tokio::test]
906 async fn subagent_middleware_appends_prompt() {
907 let subagents = vec![SubAgentRegistration {
908 descriptor: SubAgentDescriptor {
909 name: "research-agent".into(),
910 description: "Deep research specialist".into(),
911 },
912 agent: Arc::new(StubAgent),
913 }];
914 let middleware = SubAgentMiddleware::new(subagents);
915
916 let mut request = ModelRequest::new("System", vec![]);
917 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
918 let mut ctx = MiddlewareContext::with_request(&mut request, state);
919 middleware.modify_model_request(&mut ctx).await.unwrap();
920
921 assert!(ctx.request.system_prompt.contains("research-agent"));
922 let tool_names: Vec<_> = middleware
923 .tools()
924 .iter()
925 .map(|t| t.schema().name.clone())
926 .collect();
927 assert!(tool_names.contains(&"task".to_string()));
928 }
929
930 #[tokio::test]
931 async fn task_router_invokes_registered_subagent() {
932 let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
933 descriptor: SubAgentDescriptor {
934 name: "stub-agent".into(),
935 description: "Stub".into(),
936 },
937 agent: Arc::new(StubAgent),
938 }]));
939 let task_tool = TaskRouterTool::new(registry.clone(), None);
940 let state = Arc::new(AgentStateSnapshot::default());
941 let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
942 let response = task_tool
943 .execute(
944 json!({
945 "description": "do work",
946 "subagent_type": "stub-agent"
947 }),
948 ctx,
949 )
950 .await
951 .unwrap();
952
953 match response {
954 ToolResult::Message(msg) => {
955 assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
956 match msg.content {
957 MessageContent::Text(text) => assert_eq!(text, "stub-response"),
958 other => panic!("expected text, got {other:?}"),
959 }
960 }
961 _ => panic!("expected message"),
962 }
963 }
964
965 #[tokio::test]
966 async fn human_in_loop_appends_prompt() {
967 let middleware = HumanInLoopMiddleware::new(HashMap::from([(
968 "danger-tool".into(),
969 HitlPolicy {
970 allow_auto: false,
971 note: Some("Requires security review".into()),
972 },
973 )]));
974 let mut request = ModelRequest::new("System", vec![]);
975 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
976 let mut ctx = MiddlewareContext::with_request(&mut request, state);
977 middleware.modify_model_request(&mut ctx).await.unwrap();
978 assert!(ctx
979 .request
980 .system_prompt
981 .contains("danger-tool: Requires security review"));
982 }
983
984 #[tokio::test]
985 async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
986 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
987 let mut request = ModelRequest::new(
988 "This is the system prompt",
989 vec![AgentMessage {
990 role: MessageRole::User,
991 content: MessageContent::Text("Hello".into()),
992 metadata: None,
993 }],
994 );
995 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
996 let mut ctx = MiddlewareContext::with_request(&mut request, state);
997
998 middleware.modify_model_request(&mut ctx).await.unwrap();
1000
1001 assert!(ctx.request.system_prompt.is_empty());
1003
1004 assert_eq!(ctx.request.messages.len(), 2);
1006
1007 let system_message = &ctx.request.messages[0];
1008 assert!(matches!(system_message.role, MessageRole::System));
1009 assert_eq!(
1010 system_message.content.as_text().unwrap(),
1011 "This is the system prompt"
1012 );
1013
1014 let metadata = system_message.metadata.as_ref().unwrap();
1016 let cache_control = metadata.cache_control.as_ref().unwrap();
1017 assert_eq!(cache_control.cache_type, "ephemeral");
1018
1019 let user_message = &ctx.request.messages[1];
1021 assert!(matches!(user_message.role, MessageRole::User));
1022 assert_eq!(user_message.content.as_text().unwrap(), "Hello");
1023 }
1024
1025 #[tokio::test]
1026 async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1027 let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1028 let mut request = ModelRequest::new("This is the system prompt", vec![]);
1029 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1030 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1031
1032 middleware.modify_model_request(&mut ctx).await.unwrap();
1034
1035 assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1037 assert_eq!(ctx.request.messages.len(), 0);
1038 }
1039
1040 #[tokio::test]
1041 async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1042 let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1043 let mut request = ModelRequest::new(
1044 "",
1045 vec![AgentMessage {
1046 role: MessageRole::User,
1047 content: MessageContent::Text("Hello".into()),
1048 metadata: None,
1049 }],
1050 );
1051 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1052 let mut ctx = MiddlewareContext::with_request(&mut request, state);
1053
1054 middleware.modify_model_request(&mut ctx).await.unwrap();
1056
1057 assert!(ctx.request.system_prompt.is_empty());
1059 assert_eq!(ctx.request.messages.len(), 1);
1061 }
1062
1063 #[tokio::test]
1066 async fn hitl_creates_interrupt_for_disallowed_tool() {
1067 let mut policies = HashMap::new();
1068 policies.insert(
1069 "dangerous_tool".to_string(),
1070 HitlPolicy {
1071 allow_auto: false,
1072 note: Some("Requires security review".to_string()),
1073 },
1074 );
1075
1076 let middleware = HumanInLoopMiddleware::new(policies);
1077 let tool_args = json!({"action": "delete_all"});
1078
1079 let result = middleware
1080 .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1081 .await
1082 .unwrap();
1083
1084 assert!(result.is_some());
1085 let interrupt = result.unwrap();
1086
1087 match interrupt {
1088 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1089 assert_eq!(hitl.tool_name, "dangerous_tool");
1090 assert_eq!(hitl.tool_args, tool_args);
1091 assert_eq!(hitl.call_id, "call_123");
1092 assert_eq!(
1093 hitl.policy_note,
1094 Some("Requires security review".to_string())
1095 );
1096 }
1097 }
1098 }
1099
1100 #[tokio::test]
1101 async fn hitl_no_interrupt_for_allowed_tool() {
1102 let mut policies = HashMap::new();
1103 policies.insert(
1104 "safe_tool".to_string(),
1105 HitlPolicy {
1106 allow_auto: true,
1107 note: None,
1108 },
1109 );
1110
1111 let middleware = HumanInLoopMiddleware::new(policies);
1112 let tool_args = json!({"action": "read"});
1113
1114 let result = middleware
1115 .before_tool_execution("safe_tool", &tool_args, "call_456")
1116 .await
1117 .unwrap();
1118
1119 assert!(result.is_none());
1120 }
1121
1122 #[tokio::test]
1123 async fn hitl_no_interrupt_for_unlisted_tool() {
1124 let policies = HashMap::new();
1125 let middleware = HumanInLoopMiddleware::new(policies);
1126 let tool_args = json!({"action": "anything"});
1127
1128 let result = middleware
1129 .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1130 .await
1131 .unwrap();
1132
1133 assert!(result.is_none());
1134 }
1135
1136 #[tokio::test]
1137 async fn hitl_interrupt_includes_correct_details() {
1138 let mut policies = HashMap::new();
1139 policies.insert(
1140 "critical_tool".to_string(),
1141 HitlPolicy {
1142 allow_auto: false,
1143 note: Some("Critical operation - requires approval".to_string()),
1144 },
1145 );
1146
1147 let middleware = HumanInLoopMiddleware::new(policies);
1148 let tool_args = json!({
1149 "database": "production",
1150 "operation": "drop_table"
1151 });
1152
1153 let result = middleware
1154 .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1155 .await
1156 .unwrap();
1157
1158 assert!(result.is_some());
1159 let interrupt = result.unwrap();
1160
1161 match interrupt {
1162 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1163 assert_eq!(hitl.tool_name, "critical_tool");
1164 assert_eq!(hitl.tool_args["database"], "production");
1165 assert_eq!(hitl.tool_args["operation"], "drop_table");
1166 assert_eq!(hitl.call_id, "call_critical_1");
1167 assert!(hitl.policy_note.is_some());
1168 assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1169 }
1172 }
1173 }
1174
1175 #[tokio::test]
1176 async fn hitl_interrupt_without_policy_note() {
1177 let mut policies = HashMap::new();
1178 policies.insert(
1179 "tool_no_note".to_string(),
1180 HitlPolicy {
1181 allow_auto: false,
1182 note: None,
1183 },
1184 );
1185
1186 let middleware = HumanInLoopMiddleware::new(policies);
1187 let tool_args = json!({"param": "value"});
1188
1189 let result = middleware
1190 .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1191 .await
1192 .unwrap();
1193
1194 assert!(result.is_some());
1195 let interrupt = result.unwrap();
1196
1197 match interrupt {
1198 agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1199 assert_eq!(hitl.tool_name, "tool_no_note");
1200 assert_eq!(hitl.policy_note, None);
1201 }
1202 }
1203 }
1204}