agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20/// Request sent to the underlying language model. Middlewares can augment
21/// the system prompt or mutate the pending message list before the model call.
22#[derive(Debug, Clone)]
23pub struct ModelRequest {
24    pub system_prompt: String,
25    pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30        Self {
31            system_prompt: system_prompt.into(),
32            messages,
33        }
34    }
35
36    pub fn append_prompt(&mut self, fragment: &str) {
37        if !fragment.is_empty() {
38            self.system_prompt.push_str("\n\n");
39            self.system_prompt.push_str(fragment);
40        }
41    }
42}
43
44/// Read/write state handle exposed to middleware implementations.
45pub struct MiddlewareContext<'a> {
46    pub request: &'a mut ModelRequest,
47    pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51    pub fn with_request(
52        request: &'a mut ModelRequest,
53        state: Arc<RwLock<AgentStateSnapshot>>,
54    ) -> Self {
55        Self { request, state }
56    }
57}
58
59/// Middleware hook that can register additional tools and mutate the model request
60/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
61/// interface async-first for future network calls.
62#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64    /// Unique identifier for logging and diagnostics.
65    fn id(&self) -> &'static str;
66
67    /// Tools to expose when this middleware is active.
68    fn tools(&self) -> Vec<ToolBox> {
69        Vec::new()
70    }
71
72    /// Apply middleware-specific mutations to the pending model request.
73    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75    /// Hook called before tool execution - can return an interrupt to pause execution.
76    ///
77    /// This hook is invoked for each tool call before it executes, allowing middleware
78    /// to intercept and pause execution for human review. If an interrupt is returned,
79    /// the agent will save its state and wait for human approval before continuing.
80    ///
81    /// # Arguments
82    /// * `tool_name` - Name of the tool about to be executed
83    /// * `tool_args` - Arguments that will be passed to the tool
84    /// * `call_id` - Unique identifier for this tool call
85    ///
86    /// # Returns
87    /// * `Ok(Some(interrupt))` - Pause execution and wait for human response
88    /// * `Ok(None)` - Continue with tool execution normally
89    /// * `Err(e)` - Error occurred during interrupt check
90    async fn before_tool_execution(
91        &self,
92        _tool_name: &str,
93        _tool_args: &serde_json::Value,
94        _call_id: &str,
95    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96        Ok(None)
97    }
98}
99
100pub struct SummarizationMiddleware {
101    pub messages_to_keep: usize,
102    pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107        Self {
108            messages_to_keep,
109            summary_note: summary_note.into(),
110        }
111    }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116    fn id(&self) -> &'static str {
117        "summarization"
118    }
119
120    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121        if ctx.request.messages.len() > self.messages_to_keep {
122            let dropped = ctx.request.messages.len() - self.messages_to_keep;
123            let mut truncated = ctx
124                .request
125                .messages
126                .split_off(ctx.request.messages.len() - self.messages_to_keep);
127            truncated.insert(
128                0,
129                AgentMessage {
130                    role: MessageRole::System,
131                    content: MessageContent::Text(format!(
132                        "{} ({} earlier messages summarized)",
133                        self.summary_note, dropped
134                    )),
135                    metadata: None,
136                },
137            );
138            ctx.request.messages = truncated;
139        }
140        Ok(())
141    }
142}
143
144pub struct PlanningMiddleware {
145    _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150        Self { _state: state }
151    }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156    fn id(&self) -> &'static str {
157        "planning"
158    }
159
160    fn tools(&self) -> Vec<ToolBox> {
161        // Match LangChain deepagents: expose the planning tool `write_todos` as the built-in.
162        // (We keep `read_todos` available in the toolkit for opt-in use, but it is not a
163        // default built-in tool.)
164        use agents_toolkit::create_todos_tool;
165        vec![create_todos_tool()]
166    }
167
168    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
169        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
170        Ok(())
171    }
172}
173
174pub struct FilesystemMiddleware {
175    _state: Arc<RwLock<AgentStateSnapshot>>,
176}
177
178impl FilesystemMiddleware {
179    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
180        Self { _state: state }
181    }
182}
183
184#[async_trait]
185impl AgentMiddleware for FilesystemMiddleware {
186    fn id(&self) -> &'static str {
187        "filesystem"
188    }
189
190    fn tools(&self) -> Vec<ToolBox> {
191        create_filesystem_tools()
192    }
193
194    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
195        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
196        Ok(())
197    }
198}
199
200#[derive(Clone)]
201pub struct SubAgentRegistration {
202    pub descriptor: SubAgentDescriptor,
203    pub agent: Arc<dyn AgentHandle>,
204}
205
206struct SubAgentRegistry {
207    agents: HashMap<String, Arc<dyn AgentHandle>>,
208}
209
210impl SubAgentRegistry {
211    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
212        let mut agents = HashMap::new();
213        for reg in registrations {
214            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
215        }
216        Self { agents }
217    }
218
219    fn available_names(&self) -> Vec<String> {
220        self.agents.keys().cloned().collect()
221    }
222
223    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
224        self.agents.get(name).cloned()
225    }
226}
227
228pub struct SubAgentMiddleware {
229    task_tool: ToolBox,
230    descriptors: Vec<SubAgentDescriptor>,
231    _registry: Arc<SubAgentRegistry>,
232}
233
234impl SubAgentMiddleware {
235    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
236        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
237        let registry = Arc::new(SubAgentRegistry::new(registrations));
238        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
239        Self {
240            task_tool,
241            descriptors,
242            _registry: registry,
243        }
244    }
245
246    pub fn new_with_events(
247        registrations: Vec<SubAgentRegistration>,
248        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
249    ) -> Self {
250        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
251        let registry = Arc::new(SubAgentRegistry::new(registrations));
252        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
253        Self {
254            task_tool,
255            descriptors,
256            _registry: registry,
257        }
258    }
259
260    fn prompt_fragment(&self) -> String {
261        let descriptions: Vec<String> = if self.descriptors.is_empty() {
262            vec![String::from("- general-purpose: Default reasoning agent")]
263        } else {
264            self.descriptors
265                .iter()
266                .map(|agent| format!("- {}: {}", agent.name, agent.description))
267                .collect()
268        };
269
270        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
271    }
272}
273
274#[async_trait]
275impl AgentMiddleware for SubAgentMiddleware {
276    fn id(&self) -> &'static str {
277        "subagent"
278    }
279
280    fn tools(&self) -> Vec<ToolBox> {
281        vec![self.task_tool.clone()]
282    }
283
284    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
285        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
286        ctx.request.append_prompt(&self.prompt_fragment());
287        Ok(())
288    }
289}
290
291#[derive(Clone, Debug)]
292pub struct HitlPolicy {
293    pub allow_auto: bool,
294    pub note: Option<String>,
295}
296
297pub struct HumanInLoopMiddleware {
298    policies: HashMap<String, HitlPolicy>,
299}
300
301impl HumanInLoopMiddleware {
302    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
303        Self { policies }
304    }
305
306    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
307        self.policies
308            .get(tool_name)
309            .filter(|policy| !policy.allow_auto)
310    }
311
312    fn prompt_fragment(&self) -> Option<String> {
313        let pending: Vec<String> = self
314            .policies
315            .iter()
316            .filter(|(_, policy)| !policy.allow_auto)
317            .map(|(tool, policy)| match &policy.note {
318                Some(note) => format!("- {tool}: {note}"),
319                None => format!("- {tool}: Requires approval"),
320            })
321            .collect();
322        if pending.is_empty() {
323            None
324        } else {
325            Some(format!(
326                "The following tools require human approval before execution:\n{}",
327                pending.join("\n")
328            ))
329        }
330    }
331}
332
333#[async_trait]
334impl AgentMiddleware for HumanInLoopMiddleware {
335    fn id(&self) -> &'static str {
336        "human-in-loop"
337    }
338
339    async fn before_tool_execution(
340        &self,
341        tool_name: &str,
342        tool_args: &serde_json::Value,
343        call_id: &str,
344    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
345        if let Some(policy) = self.requires_approval(tool_name) {
346            tracing::warn!(
347                tool_name = %tool_name,
348                call_id = %call_id,
349                policy_note = ?policy.note,
350                "🔒 HITL: Tool execution requires human approval"
351            );
352
353            let interrupt = agents_core::hitl::HitlInterrupt::new(
354                tool_name,
355                tool_args.clone(),
356                call_id,
357                policy.note.clone(),
358            );
359
360            return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
361                interrupt,
362            )));
363        }
364
365        Ok(None)
366    }
367
368    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
369        if let Some(fragment) = self.prompt_fragment() {
370            ctx.request.append_prompt(&fragment);
371        }
372        ctx.request.messages.push(AgentMessage {
373            role: MessageRole::System,
374            content: MessageContent::Text(
375                "Tools marked for human approval will emit interrupts requiring external resolution."
376                    .into(),
377            ),
378            metadata: None,
379        });
380        Ok(())
381    }
382}
383
384pub struct BaseSystemPromptMiddleware;
385
386#[async_trait]
387impl AgentMiddleware for BaseSystemPromptMiddleware {
388    fn id(&self) -> &'static str {
389        "base-system-prompt"
390    }
391
392    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
393        ctx.request.append_prompt(BASE_AGENT_PROMPT);
394        Ok(())
395    }
396}
397
398/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
399/// and examples to force the LLM to actually call tools instead of just talking about them.
400///
401/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
402/// It provides:
403/// - Explicit tool usage rules with imperative language
404/// - JSON or TOON examples of tool calling (configurable for token efficiency)
405/// - Workflow guidance for multi-step tasks
406/// - Few-shot examples for common patterns
407///
408/// The middleware supports three modes:
409/// 1. **JSON mode** (default): Uses JSON examples for tool calls
410/// 2. **TOON mode**: Uses TOON format for 30-60% token reduction
411/// 3. **Override mode**: Uses a completely custom system prompt, bypassing the default
412pub struct DeepAgentPromptMiddleware {
413    custom_instructions: String,
414    /// Format for tool call examples in the system prompt
415    prompt_format: crate::prompts::PromptFormat,
416    /// If set, this completely replaces the default Deep Agent system prompt
417    override_system_prompt: Option<String>,
418}
419
420impl DeepAgentPromptMiddleware {
421    pub fn new(custom_instructions: impl Into<String>) -> Self {
422        Self {
423            custom_instructions: custom_instructions.into(),
424            prompt_format: crate::prompts::PromptFormat::Json,
425            override_system_prompt: None,
426        }
427    }
428
429    /// Create a middleware with a specific prompt format (JSON or TOON).
430    ///
431    /// Use TOON format for 30-60% token reduction in system prompts.
432    /// See: <https://github.com/toon-format/toon>
433    pub fn with_format(
434        custom_instructions: impl Into<String>,
435        format: crate::prompts::PromptFormat,
436    ) -> Self {
437        Self {
438            custom_instructions: custom_instructions.into(),
439            prompt_format: format,
440            override_system_prompt: None,
441        }
442    }
443
444    /// Create a middleware with a completely custom system prompt that bypasses
445    /// the default Deep Agent prompt.
446    ///
447    /// Use this when you need full control over the agent's system prompt.
448    pub fn with_override(system_prompt: impl Into<String>) -> Self {
449        Self {
450            custom_instructions: String::new(),
451            prompt_format: crate::prompts::PromptFormat::Json,
452            override_system_prompt: Some(system_prompt.into()),
453        }
454    }
455}
456
457#[async_trait]
458impl AgentMiddleware for DeepAgentPromptMiddleware {
459    fn id(&self) -> &'static str {
460        "deep-agent-prompt"
461    }
462
463    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
464        let prompt = if let Some(ref override_prompt) = self.override_system_prompt {
465            // Use the custom system prompt directly, bypassing the Deep Agent prompt
466            override_prompt.clone()
467        } else {
468            // Use the formatted Deep Agent prompt based on prompt_format
469            use crate::prompts::get_deep_agent_system_prompt_formatted;
470            get_deep_agent_system_prompt_formatted(&self.custom_instructions, self.prompt_format)
471        };
472        ctx.request.append_prompt(&prompt);
473        Ok(())
474    }
475}
476
477/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
478/// to reduce latency on subsequent requests with the same base prompt.
479pub struct AnthropicPromptCachingMiddleware {
480    pub ttl: String,
481    pub unsupported_model_behavior: String,
482}
483
484impl AnthropicPromptCachingMiddleware {
485    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
486        Self {
487            ttl: ttl.into(),
488            unsupported_model_behavior: unsupported_model_behavior.into(),
489        }
490    }
491
492    pub fn with_defaults() -> Self {
493        Self::new("5m", "ignore")
494    }
495
496    /// Parse TTL string like "5m" to detect if caching is requested.
497    /// For now, any non-empty TTL enables ephemeral caching.
498    fn should_enable_caching(&self) -> bool {
499        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
500    }
501}
502
503#[async_trait]
504impl AgentMiddleware for AnthropicPromptCachingMiddleware {
505    fn id(&self) -> &'static str {
506        "anthropic-prompt-caching"
507    }
508
509    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
510        if !self.should_enable_caching() {
511            return Ok(());
512        }
513
514        // Mark system prompt for caching by converting it to a system message with cache control
515        if !ctx.request.system_prompt.is_empty() {
516            let system_message = AgentMessage {
517                role: MessageRole::System,
518                content: MessageContent::Text(ctx.request.system_prompt.clone()),
519                metadata: Some(MessageMetadata {
520                    tool_call_id: None,
521                    cache_control: Some(CacheControl {
522                        cache_type: "ephemeral".to_string(),
523                    }),
524                }),
525            };
526
527            // Insert system message at the beginning of the messages
528            ctx.request.messages.insert(0, system_message);
529
530            // Clear the system_prompt since it's now in messages
531            ctx.request.system_prompt.clear();
532
533            tracing::debug!(
534                ttl = %self.ttl,
535                behavior = %self.unsupported_model_behavior,
536                "Applied Anthropic prompt caching to system message"
537            );
538        }
539
540        Ok(())
541    }
542}
543
544pub struct TaskRouterTool {
545    registry: Arc<SubAgentRegistry>,
546    event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
547    delegation_depth: Arc<RwLock<u32>>,
548}
549
550impl TaskRouterTool {
551    fn new(
552        registry: Arc<SubAgentRegistry>,
553        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
554    ) -> Self {
555        Self {
556            registry,
557            event_dispatcher,
558            delegation_depth: Arc::new(RwLock::new(0)),
559        }
560    }
561
562    fn available_subagents(&self) -> Vec<String> {
563        self.registry.available_names()
564    }
565
566    fn emit_event(&self, event: agents_core::events::AgentEvent) {
567        if let Some(dispatcher) = &self.event_dispatcher {
568            let dispatcher_clone = dispatcher.clone();
569            tokio::spawn(async move {
570                dispatcher_clone.dispatch(event).await;
571            });
572        }
573    }
574
575    fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
576        agents_core::events::EventMetadata::new(
577            "default".to_string(),
578            uuid::Uuid::new_v4().to_string(),
579            None,
580        )
581    }
582
583    fn get_delegation_depth(&self) -> u32 {
584        *self.delegation_depth.read().unwrap_or_else(|_| {
585            tracing::warn!("Failed to read delegation depth, defaulting to 0");
586            panic!("RwLock poisoned")
587        })
588    }
589
590    fn increment_delegation_depth(&self) {
591        if let Ok(mut depth) = self.delegation_depth.write() {
592            *depth += 1;
593        }
594    }
595
596    fn decrement_delegation_depth(&self) {
597        if let Ok(mut depth) = self.delegation_depth.write() {
598            if *depth > 0 {
599                *depth -= 1;
600            }
601        }
602    }
603}
604
605#[derive(Debug, Clone, Deserialize)]
606struct TaskInvocationArgs {
607    #[serde(alias = "description")]
608    instruction: String,
609    #[serde(alias = "subagent_type")]
610    agent: String,
611}
612
613#[async_trait]
614impl Tool for TaskRouterTool {
615    fn schema(&self) -> agents_core::tools::ToolSchema {
616        use agents_core::tools::{ToolParameterSchema, ToolSchema};
617        use std::collections::HashMap;
618
619        let mut properties = HashMap::new();
620        properties.insert(
621            "agent".to_string(),
622            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
623        );
624        properties.insert(
625            "instruction".to_string(),
626            ToolParameterSchema::string("Clear instruction for the sub-agent"),
627        );
628
629        ToolSchema::new(
630            "task",
631            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
632            ToolParameterSchema::object(
633                "Task delegation parameters",
634                properties,
635                vec!["agent".to_string(), "instruction".to_string()],
636            ),
637        )
638    }
639
640    async fn execute(
641        &self,
642        args: serde_json::Value,
643        ctx: ToolContext,
644    ) -> anyhow::Result<ToolResult> {
645        let args: TaskInvocationArgs = serde_json::from_value(args)?;
646        let available = self.available_subagents();
647
648        if let Some(agent) = self.registry.get(&args.agent) {
649            // Increment delegation depth
650            self.increment_delegation_depth();
651            let current_depth = self.get_delegation_depth();
652
653            // Truncate instruction for event
654            let instruction_summary = if args.instruction.chars().count() > 100 {
655                format!("{:.100}...", &args.instruction)
656            } else {
657                args.instruction.clone()
658            };
659
660            // Emit: SubAgentStarted event
661            self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
662                agents_core::events::SubAgentStartedEvent {
663                    metadata: self.create_event_metadata(),
664                    agent_name: args.agent.clone(),
665                    instruction_summary: instruction_summary.clone(),
666                    delegation_depth: current_depth,
667                },
668            ));
669
670            // Log delegation start
671            tracing::warn!(
672                "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
673                args.agent,
674                current_depth,
675                args.instruction
676            );
677
678            let start_time = std::time::Instant::now();
679            let user_message = AgentMessage {
680                role: MessageRole::User,
681                content: MessageContent::Text(args.instruction.clone()),
682                metadata: None,
683            };
684
685            let response = agent
686                .handle_message(user_message, ctx.state.clone())
687                .await?;
688
689            // Calculate duration
690            let duration = start_time.elapsed();
691            let duration_ms = duration.as_millis() as u64;
692
693            // Create response preview
694            let response_preview = match &response.content {
695                MessageContent::Text(t) => {
696                    if t.chars().count() > 100 {
697                        format!("{:.100}...", t)
698                    } else {
699                        t.clone()
700                    }
701                }
702                MessageContent::Json(v) => {
703                    let json_str = v.to_string();
704                    if json_str.chars().count() > 100 {
705                        format!("{:.100}...", json_str)
706                    } else {
707                        json_str
708                    }
709                }
710            };
711
712            // Emit: SubAgentCompleted event
713            self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
714                agents_core::events::SubAgentCompletedEvent {
715                    metadata: self.create_event_metadata(),
716                    agent_name: args.agent.clone(),
717                    duration_ms,
718                    result_summary: response_preview.clone(),
719                },
720            ));
721
722            // Log delegation completion
723            tracing::warn!(
724                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
725                args.agent,
726                duration,
727                response_preview
728            );
729
730            // Decrement delegation depth
731            self.decrement_delegation_depth();
732
733            // Return sub-agent response as text content, not as a separate tool message
734            // This will be incorporated into the LLM's next response naturally
735            let result_text = match response.content {
736                MessageContent::Text(text) => text,
737                MessageContent::Json(json) => json.to_string(),
738            };
739
740            return Ok(ToolResult::text(&ctx, result_text));
741        }
742
743        tracing::error!(
744            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
745            args.agent,
746            available
747        );
748
749        Ok(ToolResult::text(
750            &ctx,
751            format!(
752                "Sub-agent '{}' not found. Available sub-agents: {}",
753                args.agent,
754                available.join(", ")
755            ),
756        ))
757    }
758}
759
760#[derive(Debug, Clone)]
761pub struct SubAgentDescriptor {
762    pub name: String,
763    pub description: String,
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use agents_core::agent::{AgentDescriptor, AgentHandle};
770    use agents_core::messaging::{MessageContent, MessageRole};
771    use serde_json::json;
772
773    struct AppendPromptMiddleware;
774
775    #[async_trait]
776    impl AgentMiddleware for AppendPromptMiddleware {
777        fn id(&self) -> &'static str {
778            "append-prompt"
779        }
780
781        async fn modify_model_request(
782            &self,
783            ctx: &mut MiddlewareContext<'_>,
784        ) -> anyhow::Result<()> {
785            ctx.request.system_prompt.push_str("\nExtra directives.");
786            Ok(())
787        }
788    }
789
790    #[tokio::test]
791    async fn middleware_mutates_prompt() {
792        let mut request = ModelRequest::new(
793            "System",
794            vec![AgentMessage {
795                role: MessageRole::User,
796                content: MessageContent::Text("Hi".into()),
797                metadata: None,
798            }],
799        );
800        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
801        let mut ctx = MiddlewareContext::with_request(&mut request, state);
802        let middleware = AppendPromptMiddleware;
803        middleware.modify_model_request(&mut ctx).await.unwrap();
804        assert!(ctx.request.system_prompt.contains("Extra directives"));
805    }
806
807    #[tokio::test]
808    async fn planning_middleware_registers_write_todos() {
809        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
810        let middleware = PlanningMiddleware::new(state);
811        let tool_names: Vec<_> = middleware
812            .tools()
813            .iter()
814            .map(|t| t.schema().name.clone())
815            .collect();
816        assert!(tool_names.contains(&"write_todos".to_string()));
817
818        let mut request = ModelRequest::new("System", vec![]);
819        let mut ctx = MiddlewareContext::with_request(
820            &mut request,
821            Arc::new(RwLock::new(AgentStateSnapshot::default())),
822        );
823        middleware.modify_model_request(&mut ctx).await.unwrap();
824        assert!(ctx.request.system_prompt.contains("write_todos"));
825    }
826
827    #[tokio::test]
828    async fn filesystem_middleware_registers_tools() {
829        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
830        let middleware = FilesystemMiddleware::new(state);
831        let tool_names: Vec<_> = middleware
832            .tools()
833            .iter()
834            .map(|t| t.schema().name.clone())
835            .collect();
836        for expected in ["ls", "read_file", "write_file", "edit_file"] {
837            assert!(tool_names.contains(&expected.to_string()));
838        }
839    }
840
841    #[tokio::test]
842    async fn summarization_middleware_trims_messages() {
843        let middleware = SummarizationMiddleware::new(2, "Summary note");
844        let mut request = ModelRequest::new(
845            "System",
846            vec![
847                AgentMessage {
848                    role: MessageRole::User,
849                    content: MessageContent::Text("one".into()),
850                    metadata: None,
851                },
852                AgentMessage {
853                    role: MessageRole::Agent,
854                    content: MessageContent::Text("two".into()),
855                    metadata: None,
856                },
857                AgentMessage {
858                    role: MessageRole::User,
859                    content: MessageContent::Text("three".into()),
860                    metadata: None,
861                },
862            ],
863        );
864        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
865        let mut ctx = MiddlewareContext::with_request(&mut request, state);
866        middleware.modify_model_request(&mut ctx).await.unwrap();
867        assert_eq!(ctx.request.messages.len(), 3);
868        match &ctx.request.messages[0].content {
869            MessageContent::Text(text) => assert!(text.contains("Summary note")),
870            other => panic!("expected text, got {other:?}"),
871        }
872    }
873
874    struct StubAgent;
875
876    #[async_trait]
877    impl AgentHandle for StubAgent {
878        async fn describe(&self) -> AgentDescriptor {
879            AgentDescriptor {
880                name: "stub".into(),
881                version: "0.0.1".into(),
882                description: None,
883            }
884        }
885
886        async fn handle_message(
887            &self,
888            _input: AgentMessage,
889            _state: Arc<AgentStateSnapshot>,
890        ) -> anyhow::Result<AgentMessage> {
891            Ok(AgentMessage {
892                role: MessageRole::Agent,
893                content: MessageContent::Text("stub-response".into()),
894                metadata: None,
895            })
896        }
897    }
898
899    #[tokio::test]
900    async fn task_router_reports_unknown_subagent() {
901        let registry = Arc::new(SubAgentRegistry::new(vec![]));
902        let task_tool = TaskRouterTool::new(registry.clone(), None);
903        let state = Arc::new(AgentStateSnapshot::default());
904        let ctx = ToolContext::new(state);
905
906        let response = task_tool
907            .execute(
908                json!({
909                    "instruction": "Do something",
910                    "agent": "unknown"
911                }),
912                ctx,
913            )
914            .await
915            .unwrap();
916
917        match response {
918            ToolResult::Message(msg) => match msg.content {
919                MessageContent::Text(text) => {
920                    assert!(text.contains("Sub-agent 'unknown' not found"))
921                }
922                other => panic!("expected text, got {other:?}"),
923            },
924            _ => panic!("expected message"),
925        }
926    }
927
928    #[tokio::test]
929    async fn subagent_middleware_appends_prompt() {
930        let subagents = vec![SubAgentRegistration {
931            descriptor: SubAgentDescriptor {
932                name: "research-agent".into(),
933                description: "Deep research specialist".into(),
934            },
935            agent: Arc::new(StubAgent),
936        }];
937        let middleware = SubAgentMiddleware::new(subagents);
938
939        let mut request = ModelRequest::new("System", vec![]);
940        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
941        let mut ctx = MiddlewareContext::with_request(&mut request, state);
942        middleware.modify_model_request(&mut ctx).await.unwrap();
943
944        assert!(ctx.request.system_prompt.contains("research-agent"));
945        let tool_names: Vec<_> = middleware
946            .tools()
947            .iter()
948            .map(|t| t.schema().name.clone())
949            .collect();
950        assert!(tool_names.contains(&"task".to_string()));
951    }
952
953    #[tokio::test]
954    async fn task_router_invokes_registered_subagent() {
955        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
956            descriptor: SubAgentDescriptor {
957                name: "stub-agent".into(),
958                description: "Stub".into(),
959            },
960            agent: Arc::new(StubAgent),
961        }]));
962        let task_tool = TaskRouterTool::new(registry.clone(), None);
963        let state = Arc::new(AgentStateSnapshot::default());
964        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
965        let response = task_tool
966            .execute(
967                json!({
968                    "description": "do work",
969                    "subagent_type": "stub-agent"
970                }),
971                ctx,
972            )
973            .await
974            .unwrap();
975
976        match response {
977            ToolResult::Message(msg) => {
978                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
979                match msg.content {
980                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
981                    other => panic!("expected text, got {other:?}"),
982                }
983            }
984            _ => panic!("expected message"),
985        }
986    }
987
988    #[tokio::test]
989    async fn human_in_loop_appends_prompt() {
990        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
991            "danger-tool".into(),
992            HitlPolicy {
993                allow_auto: false,
994                note: Some("Requires security review".into()),
995            },
996        )]));
997        let mut request = ModelRequest::new("System", vec![]);
998        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
999        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1000        middleware.modify_model_request(&mut ctx).await.unwrap();
1001        assert!(ctx
1002            .request
1003            .system_prompt
1004            .contains("danger-tool: Requires security review"));
1005    }
1006
1007    #[tokio::test]
1008    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
1009        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1010        let mut request = ModelRequest::new(
1011            "This is the system prompt",
1012            vec![AgentMessage {
1013                role: MessageRole::User,
1014                content: MessageContent::Text("Hello".into()),
1015                metadata: None,
1016            }],
1017        );
1018        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1019        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1020
1021        // Apply the middleware
1022        middleware.modify_model_request(&mut ctx).await.unwrap();
1023
1024        // System prompt should be cleared
1025        assert!(ctx.request.system_prompt.is_empty());
1026
1027        // Should have added a system message with cache control at the beginning
1028        assert_eq!(ctx.request.messages.len(), 2);
1029
1030        let system_message = &ctx.request.messages[0];
1031        assert!(matches!(system_message.role, MessageRole::System));
1032        assert_eq!(
1033            system_message.content.as_text().unwrap(),
1034            "This is the system prompt"
1035        );
1036
1037        // Check cache control metadata
1038        let metadata = system_message.metadata.as_ref().unwrap();
1039        let cache_control = metadata.cache_control.as_ref().unwrap();
1040        assert_eq!(cache_control.cache_type, "ephemeral");
1041
1042        // Original user message should still be there
1043        let user_message = &ctx.request.messages[1];
1044        assert!(matches!(user_message.role, MessageRole::User));
1045        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
1046    }
1047
1048    #[tokio::test]
1049    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1050        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1051        let mut request = ModelRequest::new("This is the system prompt", vec![]);
1052        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1053        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1054
1055        // Apply the middleware
1056        middleware.modify_model_request(&mut ctx).await.unwrap();
1057
1058        // System prompt should be unchanged
1059        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1060        assert_eq!(ctx.request.messages.len(), 0);
1061    }
1062
1063    #[tokio::test]
1064    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1065        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1066        let mut request = ModelRequest::new(
1067            "",
1068            vec![AgentMessage {
1069                role: MessageRole::User,
1070                content: MessageContent::Text("Hello".into()),
1071                metadata: None,
1072            }],
1073        );
1074        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1075        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1076
1077        // Apply the middleware
1078        middleware.modify_model_request(&mut ctx).await.unwrap();
1079
1080        // System prompt should remain empty
1081        assert!(ctx.request.system_prompt.is_empty());
1082        // No system message should be added
1083        assert_eq!(ctx.request.messages.len(), 1);
1084    }
1085
1086    // ========== HITL Interrupt Creation Tests ==========
1087
1088    #[tokio::test]
1089    async fn hitl_creates_interrupt_for_disallowed_tool() {
1090        let mut policies = HashMap::new();
1091        policies.insert(
1092            "dangerous_tool".to_string(),
1093            HitlPolicy {
1094                allow_auto: false,
1095                note: Some("Requires security review".to_string()),
1096            },
1097        );
1098
1099        let middleware = HumanInLoopMiddleware::new(policies);
1100        let tool_args = json!({"action": "delete_all"});
1101
1102        let result = middleware
1103            .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1104            .await
1105            .unwrap();
1106
1107        assert!(result.is_some());
1108        let interrupt = result.unwrap();
1109
1110        match interrupt {
1111            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1112                assert_eq!(hitl.tool_name, "dangerous_tool");
1113                assert_eq!(hitl.tool_args, tool_args);
1114                assert_eq!(hitl.call_id, "call_123");
1115                assert_eq!(
1116                    hitl.policy_note,
1117                    Some("Requires security review".to_string())
1118                );
1119            }
1120        }
1121    }
1122
1123    #[tokio::test]
1124    async fn hitl_no_interrupt_for_allowed_tool() {
1125        let mut policies = HashMap::new();
1126        policies.insert(
1127            "safe_tool".to_string(),
1128            HitlPolicy {
1129                allow_auto: true,
1130                note: None,
1131            },
1132        );
1133
1134        let middleware = HumanInLoopMiddleware::new(policies);
1135        let tool_args = json!({"action": "read"});
1136
1137        let result = middleware
1138            .before_tool_execution("safe_tool", &tool_args, "call_456")
1139            .await
1140            .unwrap();
1141
1142        assert!(result.is_none());
1143    }
1144
1145    #[tokio::test]
1146    async fn hitl_no_interrupt_for_unlisted_tool() {
1147        let policies = HashMap::new();
1148        let middleware = HumanInLoopMiddleware::new(policies);
1149        let tool_args = json!({"action": "anything"});
1150
1151        let result = middleware
1152            .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1153            .await
1154            .unwrap();
1155
1156        assert!(result.is_none());
1157    }
1158
1159    #[tokio::test]
1160    async fn hitl_interrupt_includes_correct_details() {
1161        let mut policies = HashMap::new();
1162        policies.insert(
1163            "critical_tool".to_string(),
1164            HitlPolicy {
1165                allow_auto: false,
1166                note: Some("Critical operation - requires approval".to_string()),
1167            },
1168        );
1169
1170        let middleware = HumanInLoopMiddleware::new(policies);
1171        let tool_args = json!({
1172            "database": "production",
1173            "operation": "drop_table"
1174        });
1175
1176        let result = middleware
1177            .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1178            .await
1179            .unwrap();
1180
1181        assert!(result.is_some());
1182        let interrupt = result.unwrap();
1183
1184        match interrupt {
1185            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1186                assert_eq!(hitl.tool_name, "critical_tool");
1187                assert_eq!(hitl.tool_args["database"], "production");
1188                assert_eq!(hitl.tool_args["operation"], "drop_table");
1189                assert_eq!(hitl.call_id, "call_critical_1");
1190                assert!(hitl.policy_note.is_some());
1191                assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1192                // Verify timestamp exists (created_at field is populated)
1193                // The actual timestamp value is tested in agents-core/hitl.rs
1194            }
1195        }
1196    }
1197
1198    #[tokio::test]
1199    async fn hitl_interrupt_without_policy_note() {
1200        let mut policies = HashMap::new();
1201        policies.insert(
1202            "tool_no_note".to_string(),
1203            HitlPolicy {
1204                allow_auto: false,
1205                note: None,
1206            },
1207        );
1208
1209        let middleware = HumanInLoopMiddleware::new(policies);
1210        let tool_args = json!({"param": "value"});
1211
1212        let result = middleware
1213            .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1214            .await
1215            .unwrap();
1216
1217        assert!(result.is_some());
1218        let interrupt = result.unwrap();
1219
1220        match interrupt {
1221            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1222                assert_eq!(hitl.tool_name, "tool_no_note");
1223                assert_eq!(hitl.policy_note, None);
1224            }
1225        }
1226    }
1227}