agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20/// Request sent to the underlying language model. Middlewares can augment
21/// the system prompt or mutate the pending message list before the model call.
22#[derive(Debug, Clone)]
23pub struct ModelRequest {
24    pub system_prompt: String,
25    pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30        Self {
31            system_prompt: system_prompt.into(),
32            messages,
33        }
34    }
35
36    pub fn append_prompt(&mut self, fragment: &str) {
37        if !fragment.is_empty() {
38            self.system_prompt.push_str("\n\n");
39            self.system_prompt.push_str(fragment);
40        }
41    }
42}
43
44/// Read/write state handle exposed to middleware implementations.
45pub struct MiddlewareContext<'a> {
46    pub request: &'a mut ModelRequest,
47    pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51    pub fn with_request(
52        request: &'a mut ModelRequest,
53        state: Arc<RwLock<AgentStateSnapshot>>,
54    ) -> Self {
55        Self { request, state }
56    }
57}
58
59/// Middleware hook that can register additional tools and mutate the model request
60/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
61/// interface async-first for future network calls.
62#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64    /// Unique identifier for logging and diagnostics.
65    fn id(&self) -> &'static str;
66
67    /// Tools to expose when this middleware is active.
68    fn tools(&self) -> Vec<ToolBox> {
69        Vec::new()
70    }
71
72    /// Apply middleware-specific mutations to the pending model request.
73    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75    /// Hook called before tool execution - can return an interrupt to pause execution.
76    ///
77    /// This hook is invoked for each tool call before it executes, allowing middleware
78    /// to intercept and pause execution for human review. If an interrupt is returned,
79    /// the agent will save its state and wait for human approval before continuing.
80    ///
81    /// # Arguments
82    /// * `tool_name` - Name of the tool about to be executed
83    /// * `tool_args` - Arguments that will be passed to the tool
84    /// * `call_id` - Unique identifier for this tool call
85    ///
86    /// # Returns
87    /// * `Ok(Some(interrupt))` - Pause execution and wait for human response
88    /// * `Ok(None)` - Continue with tool execution normally
89    /// * `Err(e)` - Error occurred during interrupt check
90    async fn before_tool_execution(
91        &self,
92        _tool_name: &str,
93        _tool_args: &serde_json::Value,
94        _call_id: &str,
95    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96        Ok(None)
97    }
98}
99
100pub struct SummarizationMiddleware {
101    pub messages_to_keep: usize,
102    pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107        Self {
108            messages_to_keep,
109            summary_note: summary_note.into(),
110        }
111    }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116    fn id(&self) -> &'static str {
117        "summarization"
118    }
119
120    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121        if ctx.request.messages.len() > self.messages_to_keep {
122            let dropped = ctx.request.messages.len() - self.messages_to_keep;
123            let mut truncated = ctx
124                .request
125                .messages
126                .split_off(ctx.request.messages.len() - self.messages_to_keep);
127            truncated.insert(
128                0,
129                AgentMessage {
130                    role: MessageRole::System,
131                    content: MessageContent::Text(format!(
132                        "{} ({} earlier messages summarized)",
133                        self.summary_note, dropped
134                    )),
135                    metadata: None,
136                },
137            );
138            ctx.request.messages = truncated;
139        }
140        Ok(())
141    }
142}
143
144pub struct PlanningMiddleware {
145    _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150        Self { _state: state }
151    }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156    fn id(&self) -> &'static str {
157        "planning"
158    }
159
160    fn tools(&self) -> Vec<ToolBox> {
161        use agents_toolkit::create_todos_tools;
162        create_todos_tools()
163    }
164
165    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
167        Ok(())
168    }
169}
170
171pub struct FilesystemMiddleware {
172    _state: Arc<RwLock<AgentStateSnapshot>>,
173}
174
175impl FilesystemMiddleware {
176    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
177        Self { _state: state }
178    }
179}
180
181#[async_trait]
182impl AgentMiddleware for FilesystemMiddleware {
183    fn id(&self) -> &'static str {
184        "filesystem"
185    }
186
187    fn tools(&self) -> Vec<ToolBox> {
188        create_filesystem_tools()
189    }
190
191    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
192        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
193        Ok(())
194    }
195}
196
197#[derive(Clone)]
198pub struct SubAgentRegistration {
199    pub descriptor: SubAgentDescriptor,
200    pub agent: Arc<dyn AgentHandle>,
201}
202
203struct SubAgentRegistry {
204    agents: HashMap<String, Arc<dyn AgentHandle>>,
205}
206
207impl SubAgentRegistry {
208    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
209        let mut agents = HashMap::new();
210        for reg in registrations {
211            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
212        }
213        Self { agents }
214    }
215
216    fn available_names(&self) -> Vec<String> {
217        self.agents.keys().cloned().collect()
218    }
219
220    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
221        self.agents.get(name).cloned()
222    }
223}
224
225pub struct SubAgentMiddleware {
226    task_tool: ToolBox,
227    descriptors: Vec<SubAgentDescriptor>,
228    _registry: Arc<SubAgentRegistry>,
229}
230
231impl SubAgentMiddleware {
232    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
233        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
234        let registry = Arc::new(SubAgentRegistry::new(registrations));
235        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
236        Self {
237            task_tool,
238            descriptors,
239            _registry: registry,
240        }
241    }
242
243    pub fn new_with_events(
244        registrations: Vec<SubAgentRegistration>,
245        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
246    ) -> Self {
247        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
248        let registry = Arc::new(SubAgentRegistry::new(registrations));
249        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
250        Self {
251            task_tool,
252            descriptors,
253            _registry: registry,
254        }
255    }
256
257    fn prompt_fragment(&self) -> String {
258        let descriptions: Vec<String> = if self.descriptors.is_empty() {
259            vec![String::from("- general-purpose: Default reasoning agent")]
260        } else {
261            self.descriptors
262                .iter()
263                .map(|agent| format!("- {}: {}", agent.name, agent.description))
264                .collect()
265        };
266
267        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
268    }
269}
270
271#[async_trait]
272impl AgentMiddleware for SubAgentMiddleware {
273    fn id(&self) -> &'static str {
274        "subagent"
275    }
276
277    fn tools(&self) -> Vec<ToolBox> {
278        vec![self.task_tool.clone()]
279    }
280
281    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
282        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
283        ctx.request.append_prompt(&self.prompt_fragment());
284        Ok(())
285    }
286}
287
288#[derive(Clone, Debug)]
289pub struct HitlPolicy {
290    pub allow_auto: bool,
291    pub note: Option<String>,
292}
293
294pub struct HumanInLoopMiddleware {
295    policies: HashMap<String, HitlPolicy>,
296}
297
298impl HumanInLoopMiddleware {
299    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
300        Self { policies }
301    }
302
303    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
304        self.policies
305            .get(tool_name)
306            .filter(|policy| !policy.allow_auto)
307    }
308
309    fn prompt_fragment(&self) -> Option<String> {
310        let pending: Vec<String> = self
311            .policies
312            .iter()
313            .filter(|(_, policy)| !policy.allow_auto)
314            .map(|(tool, policy)| match &policy.note {
315                Some(note) => format!("- {tool}: {note}"),
316                None => format!("- {tool}: Requires approval"),
317            })
318            .collect();
319        if pending.is_empty() {
320            None
321        } else {
322            Some(format!(
323                "The following tools require human approval before execution:\n{}",
324                pending.join("\n")
325            ))
326        }
327    }
328}
329
330#[async_trait]
331impl AgentMiddleware for HumanInLoopMiddleware {
332    fn id(&self) -> &'static str {
333        "human-in-loop"
334    }
335
336    async fn before_tool_execution(
337        &self,
338        tool_name: &str,
339        tool_args: &serde_json::Value,
340        call_id: &str,
341    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
342        if let Some(policy) = self.requires_approval(tool_name) {
343            tracing::warn!(
344                tool_name = %tool_name,
345                call_id = %call_id,
346                policy_note = ?policy.note,
347                "🔒 HITL: Tool execution requires human approval"
348            );
349
350            let interrupt = agents_core::hitl::HitlInterrupt::new(
351                tool_name,
352                tool_args.clone(),
353                call_id,
354                policy.note.clone(),
355            );
356
357            return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
358                interrupt,
359            )));
360        }
361
362        Ok(None)
363    }
364
365    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
366        if let Some(fragment) = self.prompt_fragment() {
367            ctx.request.append_prompt(&fragment);
368        }
369        ctx.request.messages.push(AgentMessage {
370            role: MessageRole::System,
371            content: MessageContent::Text(
372                "Tools marked for human approval will emit interrupts requiring external resolution."
373                    .into(),
374            ),
375            metadata: None,
376        });
377        Ok(())
378    }
379}
380
381pub struct BaseSystemPromptMiddleware;
382
383#[async_trait]
384impl AgentMiddleware for BaseSystemPromptMiddleware {
385    fn id(&self) -> &'static str {
386        "base-system-prompt"
387    }
388
389    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
390        ctx.request.append_prompt(BASE_AGENT_PROMPT);
391        Ok(())
392    }
393}
394
395/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
396/// and examples to force the LLM to actually call tools instead of just talking about them.
397///
398/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
399/// It provides:
400/// - Explicit tool usage rules with imperative language
401/// - JSON or TOON examples of tool calling (configurable for token efficiency)
402/// - Workflow guidance for multi-step tasks
403/// - Few-shot examples for common patterns
404///
405/// The middleware supports three modes:
406/// 1. **JSON mode** (default): Uses JSON examples for tool calls
407/// 2. **TOON mode**: Uses TOON format for 30-60% token reduction
408/// 3. **Override mode**: Uses a completely custom system prompt, bypassing the default
409pub struct DeepAgentPromptMiddleware {
410    custom_instructions: String,
411    /// Format for tool call examples in the system prompt
412    prompt_format: crate::prompts::PromptFormat,
413    /// If set, this completely replaces the default Deep Agent system prompt
414    override_system_prompt: Option<String>,
415}
416
417impl DeepAgentPromptMiddleware {
418    pub fn new(custom_instructions: impl Into<String>) -> Self {
419        Self {
420            custom_instructions: custom_instructions.into(),
421            prompt_format: crate::prompts::PromptFormat::Json,
422            override_system_prompt: None,
423        }
424    }
425
426    /// Create a middleware with a specific prompt format (JSON or TOON).
427    ///
428    /// Use TOON format for 30-60% token reduction in system prompts.
429    /// See: <https://github.com/toon-format/toon>
430    pub fn with_format(
431        custom_instructions: impl Into<String>,
432        format: crate::prompts::PromptFormat,
433    ) -> Self {
434        Self {
435            custom_instructions: custom_instructions.into(),
436            prompt_format: format,
437            override_system_prompt: None,
438        }
439    }
440
441    /// Create a middleware with a completely custom system prompt that bypasses
442    /// the default Deep Agent prompt.
443    ///
444    /// Use this when you need full control over the agent's system prompt.
445    pub fn with_override(system_prompt: impl Into<String>) -> Self {
446        Self {
447            custom_instructions: String::new(),
448            prompt_format: crate::prompts::PromptFormat::Json,
449            override_system_prompt: Some(system_prompt.into()),
450        }
451    }
452}
453
454#[async_trait]
455impl AgentMiddleware for DeepAgentPromptMiddleware {
456    fn id(&self) -> &'static str {
457        "deep-agent-prompt"
458    }
459
460    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
461        let prompt = if let Some(ref override_prompt) = self.override_system_prompt {
462            // Use the custom system prompt directly, bypassing the Deep Agent prompt
463            override_prompt.clone()
464        } else {
465            // Use the formatted Deep Agent prompt based on prompt_format
466            use crate::prompts::get_deep_agent_system_prompt_formatted;
467            get_deep_agent_system_prompt_formatted(&self.custom_instructions, self.prompt_format)
468        };
469        ctx.request.append_prompt(&prompt);
470        Ok(())
471    }
472}
473
474/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
475/// to reduce latency on subsequent requests with the same base prompt.
476pub struct AnthropicPromptCachingMiddleware {
477    pub ttl: String,
478    pub unsupported_model_behavior: String,
479}
480
481impl AnthropicPromptCachingMiddleware {
482    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
483        Self {
484            ttl: ttl.into(),
485            unsupported_model_behavior: unsupported_model_behavior.into(),
486        }
487    }
488
489    pub fn with_defaults() -> Self {
490        Self::new("5m", "ignore")
491    }
492
493    /// Parse TTL string like "5m" to detect if caching is requested.
494    /// For now, any non-empty TTL enables ephemeral caching.
495    fn should_enable_caching(&self) -> bool {
496        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
497    }
498}
499
500#[async_trait]
501impl AgentMiddleware for AnthropicPromptCachingMiddleware {
502    fn id(&self) -> &'static str {
503        "anthropic-prompt-caching"
504    }
505
506    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
507        if !self.should_enable_caching() {
508            return Ok(());
509        }
510
511        // Mark system prompt for caching by converting it to a system message with cache control
512        if !ctx.request.system_prompt.is_empty() {
513            let system_message = AgentMessage {
514                role: MessageRole::System,
515                content: MessageContent::Text(ctx.request.system_prompt.clone()),
516                metadata: Some(MessageMetadata {
517                    tool_call_id: None,
518                    cache_control: Some(CacheControl {
519                        cache_type: "ephemeral".to_string(),
520                    }),
521                }),
522            };
523
524            // Insert system message at the beginning of the messages
525            ctx.request.messages.insert(0, system_message);
526
527            // Clear the system_prompt since it's now in messages
528            ctx.request.system_prompt.clear();
529
530            tracing::debug!(
531                ttl = %self.ttl,
532                behavior = %self.unsupported_model_behavior,
533                "Applied Anthropic prompt caching to system message"
534            );
535        }
536
537        Ok(())
538    }
539}
540
541pub struct TaskRouterTool {
542    registry: Arc<SubAgentRegistry>,
543    event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
544    delegation_depth: Arc<RwLock<u32>>,
545}
546
547impl TaskRouterTool {
548    fn new(
549        registry: Arc<SubAgentRegistry>,
550        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
551    ) -> Self {
552        Self {
553            registry,
554            event_dispatcher,
555            delegation_depth: Arc::new(RwLock::new(0)),
556        }
557    }
558
559    fn available_subagents(&self) -> Vec<String> {
560        self.registry.available_names()
561    }
562
563    fn emit_event(&self, event: agents_core::events::AgentEvent) {
564        if let Some(dispatcher) = &self.event_dispatcher {
565            let dispatcher_clone = dispatcher.clone();
566            tokio::spawn(async move {
567                dispatcher_clone.dispatch(event).await;
568            });
569        }
570    }
571
572    fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
573        agents_core::events::EventMetadata::new(
574            "default".to_string(),
575            uuid::Uuid::new_v4().to_string(),
576            None,
577        )
578    }
579
580    fn get_delegation_depth(&self) -> u32 {
581        *self.delegation_depth.read().unwrap_or_else(|_| {
582            tracing::warn!("Failed to read delegation depth, defaulting to 0");
583            panic!("RwLock poisoned")
584        })
585    }
586
587    fn increment_delegation_depth(&self) {
588        if let Ok(mut depth) = self.delegation_depth.write() {
589            *depth += 1;
590        }
591    }
592
593    fn decrement_delegation_depth(&self) {
594        if let Ok(mut depth) = self.delegation_depth.write() {
595            if *depth > 0 {
596                *depth -= 1;
597            }
598        }
599    }
600}
601
602#[derive(Debug, Clone, Deserialize)]
603struct TaskInvocationArgs {
604    #[serde(alias = "description")]
605    instruction: String,
606    #[serde(alias = "subagent_type")]
607    agent: String,
608}
609
610#[async_trait]
611impl Tool for TaskRouterTool {
612    fn schema(&self) -> agents_core::tools::ToolSchema {
613        use agents_core::tools::{ToolParameterSchema, ToolSchema};
614        use std::collections::HashMap;
615
616        let mut properties = HashMap::new();
617        properties.insert(
618            "agent".to_string(),
619            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
620        );
621        properties.insert(
622            "instruction".to_string(),
623            ToolParameterSchema::string("Clear instruction for the sub-agent"),
624        );
625
626        ToolSchema::new(
627            "task",
628            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
629            ToolParameterSchema::object(
630                "Task delegation parameters",
631                properties,
632                vec!["agent".to_string(), "instruction".to_string()],
633            ),
634        )
635    }
636
637    async fn execute(
638        &self,
639        args: serde_json::Value,
640        ctx: ToolContext,
641    ) -> anyhow::Result<ToolResult> {
642        let args: TaskInvocationArgs = serde_json::from_value(args)?;
643        let available = self.available_subagents();
644
645        if let Some(agent) = self.registry.get(&args.agent) {
646            // Increment delegation depth
647            self.increment_delegation_depth();
648            let current_depth = self.get_delegation_depth();
649
650            // Truncate instruction for event
651            let instruction_summary = if args.instruction.len() > 100 {
652                format!("{}...", &args.instruction[..100])
653            } else {
654                args.instruction.clone()
655            };
656
657            // Emit: SubAgentStarted event
658            self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
659                agents_core::events::SubAgentStartedEvent {
660                    metadata: self.create_event_metadata(),
661                    agent_name: args.agent.clone(),
662                    instruction_summary: instruction_summary.clone(),
663                    delegation_depth: current_depth,
664                },
665            ));
666
667            // Log delegation start
668            tracing::warn!(
669                "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
670                args.agent,
671                current_depth,
672                args.instruction
673            );
674
675            let start_time = std::time::Instant::now();
676            let user_message = AgentMessage {
677                role: MessageRole::User,
678                content: MessageContent::Text(args.instruction.clone()),
679                metadata: None,
680            };
681
682            let response = agent
683                .handle_message(user_message, ctx.state.clone())
684                .await?;
685
686            // Calculate duration
687            let duration = start_time.elapsed();
688            let duration_ms = duration.as_millis() as u64;
689
690            // Create response preview
691            let response_preview = match &response.content {
692                MessageContent::Text(t) => {
693                    if t.len() > 100 {
694                        format!("{}...", &t[..100])
695                    } else {
696                        t.clone()
697                    }
698                }
699                MessageContent::Json(v) => {
700                    let json_str = v.to_string();
701                    if json_str.len() > 100 {
702                        format!("{}...", &json_str[..100])
703                    } else {
704                        json_str
705                    }
706                }
707            };
708
709            // Emit: SubAgentCompleted event
710            self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
711                agents_core::events::SubAgentCompletedEvent {
712                    metadata: self.create_event_metadata(),
713                    agent_name: args.agent.clone(),
714                    duration_ms,
715                    result_summary: response_preview.clone(),
716                },
717            ));
718
719            // Log delegation completion
720            tracing::warn!(
721                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
722                args.agent,
723                duration,
724                response_preview
725            );
726
727            // Decrement delegation depth
728            self.decrement_delegation_depth();
729
730            // Return sub-agent response as text content, not as a separate tool message
731            // This will be incorporated into the LLM's next response naturally
732            let result_text = match response.content {
733                MessageContent::Text(text) => text,
734                MessageContent::Json(json) => json.to_string(),
735            };
736
737            return Ok(ToolResult::text(&ctx, result_text));
738        }
739
740        tracing::error!(
741            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
742            args.agent,
743            available
744        );
745
746        Ok(ToolResult::text(
747            &ctx,
748            format!(
749                "Sub-agent '{}' not found. Available sub-agents: {}",
750                args.agent,
751                available.join(", ")
752            ),
753        ))
754    }
755}
756
757#[derive(Debug, Clone)]
758pub struct SubAgentDescriptor {
759    pub name: String,
760    pub description: String,
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use agents_core::agent::{AgentDescriptor, AgentHandle};
767    use agents_core::messaging::{MessageContent, MessageRole};
768    use serde_json::json;
769
770    struct AppendPromptMiddleware;
771
772    #[async_trait]
773    impl AgentMiddleware for AppendPromptMiddleware {
774        fn id(&self) -> &'static str {
775            "append-prompt"
776        }
777
778        async fn modify_model_request(
779            &self,
780            ctx: &mut MiddlewareContext<'_>,
781        ) -> anyhow::Result<()> {
782            ctx.request.system_prompt.push_str("\nExtra directives.");
783            Ok(())
784        }
785    }
786
787    #[tokio::test]
788    async fn middleware_mutates_prompt() {
789        let mut request = ModelRequest::new(
790            "System",
791            vec![AgentMessage {
792                role: MessageRole::User,
793                content: MessageContent::Text("Hi".into()),
794                metadata: None,
795            }],
796        );
797        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
798        let mut ctx = MiddlewareContext::with_request(&mut request, state);
799        let middleware = AppendPromptMiddleware;
800        middleware.modify_model_request(&mut ctx).await.unwrap();
801        assert!(ctx.request.system_prompt.contains("Extra directives"));
802    }
803
804    #[tokio::test]
805    async fn planning_middleware_registers_write_todos() {
806        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
807        let middleware = PlanningMiddleware::new(state);
808        let tool_names: Vec<_> = middleware
809            .tools()
810            .iter()
811            .map(|t| t.schema().name.clone())
812            .collect();
813        assert!(tool_names.contains(&"write_todos".to_string()));
814
815        let mut request = ModelRequest::new("System", vec![]);
816        let mut ctx = MiddlewareContext::with_request(
817            &mut request,
818            Arc::new(RwLock::new(AgentStateSnapshot::default())),
819        );
820        middleware.modify_model_request(&mut ctx).await.unwrap();
821        assert!(ctx.request.system_prompt.contains("write_todos"));
822    }
823
824    #[tokio::test]
825    async fn filesystem_middleware_registers_tools() {
826        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
827        let middleware = FilesystemMiddleware::new(state);
828        let tool_names: Vec<_> = middleware
829            .tools()
830            .iter()
831            .map(|t| t.schema().name.clone())
832            .collect();
833        for expected in ["ls", "read_file", "write_file", "edit_file"] {
834            assert!(tool_names.contains(&expected.to_string()));
835        }
836    }
837
838    #[tokio::test]
839    async fn summarization_middleware_trims_messages() {
840        let middleware = SummarizationMiddleware::new(2, "Summary note");
841        let mut request = ModelRequest::new(
842            "System",
843            vec![
844                AgentMessage {
845                    role: MessageRole::User,
846                    content: MessageContent::Text("one".into()),
847                    metadata: None,
848                },
849                AgentMessage {
850                    role: MessageRole::Agent,
851                    content: MessageContent::Text("two".into()),
852                    metadata: None,
853                },
854                AgentMessage {
855                    role: MessageRole::User,
856                    content: MessageContent::Text("three".into()),
857                    metadata: None,
858                },
859            ],
860        );
861        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
862        let mut ctx = MiddlewareContext::with_request(&mut request, state);
863        middleware.modify_model_request(&mut ctx).await.unwrap();
864        assert_eq!(ctx.request.messages.len(), 3);
865        match &ctx.request.messages[0].content {
866            MessageContent::Text(text) => assert!(text.contains("Summary note")),
867            other => panic!("expected text, got {other:?}"),
868        }
869    }
870
871    struct StubAgent;
872
873    #[async_trait]
874    impl AgentHandle for StubAgent {
875        async fn describe(&self) -> AgentDescriptor {
876            AgentDescriptor {
877                name: "stub".into(),
878                version: "0.0.1".into(),
879                description: None,
880            }
881        }
882
883        async fn handle_message(
884            &self,
885            _input: AgentMessage,
886            _state: Arc<AgentStateSnapshot>,
887        ) -> anyhow::Result<AgentMessage> {
888            Ok(AgentMessage {
889                role: MessageRole::Agent,
890                content: MessageContent::Text("stub-response".into()),
891                metadata: None,
892            })
893        }
894    }
895
896    #[tokio::test]
897    async fn task_router_reports_unknown_subagent() {
898        let registry = Arc::new(SubAgentRegistry::new(vec![]));
899        let task_tool = TaskRouterTool::new(registry.clone(), None);
900        let state = Arc::new(AgentStateSnapshot::default());
901        let ctx = ToolContext::new(state);
902
903        let response = task_tool
904            .execute(
905                json!({
906                    "instruction": "Do something",
907                    "agent": "unknown"
908                }),
909                ctx,
910            )
911            .await
912            .unwrap();
913
914        match response {
915            ToolResult::Message(msg) => match msg.content {
916                MessageContent::Text(text) => {
917                    assert!(text.contains("Sub-agent 'unknown' not found"))
918                }
919                other => panic!("expected text, got {other:?}"),
920            },
921            _ => panic!("expected message"),
922        }
923    }
924
925    #[tokio::test]
926    async fn subagent_middleware_appends_prompt() {
927        let subagents = vec![SubAgentRegistration {
928            descriptor: SubAgentDescriptor {
929                name: "research-agent".into(),
930                description: "Deep research specialist".into(),
931            },
932            agent: Arc::new(StubAgent),
933        }];
934        let middleware = SubAgentMiddleware::new(subagents);
935
936        let mut request = ModelRequest::new("System", vec![]);
937        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
938        let mut ctx = MiddlewareContext::with_request(&mut request, state);
939        middleware.modify_model_request(&mut ctx).await.unwrap();
940
941        assert!(ctx.request.system_prompt.contains("research-agent"));
942        let tool_names: Vec<_> = middleware
943            .tools()
944            .iter()
945            .map(|t| t.schema().name.clone())
946            .collect();
947        assert!(tool_names.contains(&"task".to_string()));
948    }
949
950    #[tokio::test]
951    async fn task_router_invokes_registered_subagent() {
952        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
953            descriptor: SubAgentDescriptor {
954                name: "stub-agent".into(),
955                description: "Stub".into(),
956            },
957            agent: Arc::new(StubAgent),
958        }]));
959        let task_tool = TaskRouterTool::new(registry.clone(), None);
960        let state = Arc::new(AgentStateSnapshot::default());
961        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
962        let response = task_tool
963            .execute(
964                json!({
965                    "description": "do work",
966                    "subagent_type": "stub-agent"
967                }),
968                ctx,
969            )
970            .await
971            .unwrap();
972
973        match response {
974            ToolResult::Message(msg) => {
975                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
976                match msg.content {
977                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
978                    other => panic!("expected text, got {other:?}"),
979                }
980            }
981            _ => panic!("expected message"),
982        }
983    }
984
985    #[tokio::test]
986    async fn human_in_loop_appends_prompt() {
987        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
988            "danger-tool".into(),
989            HitlPolicy {
990                allow_auto: false,
991                note: Some("Requires security review".into()),
992            },
993        )]));
994        let mut request = ModelRequest::new("System", vec![]);
995        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
996        let mut ctx = MiddlewareContext::with_request(&mut request, state);
997        middleware.modify_model_request(&mut ctx).await.unwrap();
998        assert!(ctx
999            .request
1000            .system_prompt
1001            .contains("danger-tool: Requires security review"));
1002    }
1003
1004    #[tokio::test]
1005    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
1006        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1007        let mut request = ModelRequest::new(
1008            "This is the system prompt",
1009            vec![AgentMessage {
1010                role: MessageRole::User,
1011                content: MessageContent::Text("Hello".into()),
1012                metadata: None,
1013            }],
1014        );
1015        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1016        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1017
1018        // Apply the middleware
1019        middleware.modify_model_request(&mut ctx).await.unwrap();
1020
1021        // System prompt should be cleared
1022        assert!(ctx.request.system_prompt.is_empty());
1023
1024        // Should have added a system message with cache control at the beginning
1025        assert_eq!(ctx.request.messages.len(), 2);
1026
1027        let system_message = &ctx.request.messages[0];
1028        assert!(matches!(system_message.role, MessageRole::System));
1029        assert_eq!(
1030            system_message.content.as_text().unwrap(),
1031            "This is the system prompt"
1032        );
1033
1034        // Check cache control metadata
1035        let metadata = system_message.metadata.as_ref().unwrap();
1036        let cache_control = metadata.cache_control.as_ref().unwrap();
1037        assert_eq!(cache_control.cache_type, "ephemeral");
1038
1039        // Original user message should still be there
1040        let user_message = &ctx.request.messages[1];
1041        assert!(matches!(user_message.role, MessageRole::User));
1042        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
1043    }
1044
1045    #[tokio::test]
1046    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1047        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1048        let mut request = ModelRequest::new("This is the system prompt", vec![]);
1049        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1050        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1051
1052        // Apply the middleware
1053        middleware.modify_model_request(&mut ctx).await.unwrap();
1054
1055        // System prompt should be unchanged
1056        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1057        assert_eq!(ctx.request.messages.len(), 0);
1058    }
1059
1060    #[tokio::test]
1061    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1062        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1063        let mut request = ModelRequest::new(
1064            "",
1065            vec![AgentMessage {
1066                role: MessageRole::User,
1067                content: MessageContent::Text("Hello".into()),
1068                metadata: None,
1069            }],
1070        );
1071        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1072        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1073
1074        // Apply the middleware
1075        middleware.modify_model_request(&mut ctx).await.unwrap();
1076
1077        // System prompt should remain empty
1078        assert!(ctx.request.system_prompt.is_empty());
1079        // No system message should be added
1080        assert_eq!(ctx.request.messages.len(), 1);
1081    }
1082
1083    // ========== HITL Interrupt Creation Tests ==========
1084
1085    #[tokio::test]
1086    async fn hitl_creates_interrupt_for_disallowed_tool() {
1087        let mut policies = HashMap::new();
1088        policies.insert(
1089            "dangerous_tool".to_string(),
1090            HitlPolicy {
1091                allow_auto: false,
1092                note: Some("Requires security review".to_string()),
1093            },
1094        );
1095
1096        let middleware = HumanInLoopMiddleware::new(policies);
1097        let tool_args = json!({"action": "delete_all"});
1098
1099        let result = middleware
1100            .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1101            .await
1102            .unwrap();
1103
1104        assert!(result.is_some());
1105        let interrupt = result.unwrap();
1106
1107        match interrupt {
1108            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1109                assert_eq!(hitl.tool_name, "dangerous_tool");
1110                assert_eq!(hitl.tool_args, tool_args);
1111                assert_eq!(hitl.call_id, "call_123");
1112                assert_eq!(
1113                    hitl.policy_note,
1114                    Some("Requires security review".to_string())
1115                );
1116            }
1117        }
1118    }
1119
1120    #[tokio::test]
1121    async fn hitl_no_interrupt_for_allowed_tool() {
1122        let mut policies = HashMap::new();
1123        policies.insert(
1124            "safe_tool".to_string(),
1125            HitlPolicy {
1126                allow_auto: true,
1127                note: None,
1128            },
1129        );
1130
1131        let middleware = HumanInLoopMiddleware::new(policies);
1132        let tool_args = json!({"action": "read"});
1133
1134        let result = middleware
1135            .before_tool_execution("safe_tool", &tool_args, "call_456")
1136            .await
1137            .unwrap();
1138
1139        assert!(result.is_none());
1140    }
1141
1142    #[tokio::test]
1143    async fn hitl_no_interrupt_for_unlisted_tool() {
1144        let policies = HashMap::new();
1145        let middleware = HumanInLoopMiddleware::new(policies);
1146        let tool_args = json!({"action": "anything"});
1147
1148        let result = middleware
1149            .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1150            .await
1151            .unwrap();
1152
1153        assert!(result.is_none());
1154    }
1155
1156    #[tokio::test]
1157    async fn hitl_interrupt_includes_correct_details() {
1158        let mut policies = HashMap::new();
1159        policies.insert(
1160            "critical_tool".to_string(),
1161            HitlPolicy {
1162                allow_auto: false,
1163                note: Some("Critical operation - requires approval".to_string()),
1164            },
1165        );
1166
1167        let middleware = HumanInLoopMiddleware::new(policies);
1168        let tool_args = json!({
1169            "database": "production",
1170            "operation": "drop_table"
1171        });
1172
1173        let result = middleware
1174            .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1175            .await
1176            .unwrap();
1177
1178        assert!(result.is_some());
1179        let interrupt = result.unwrap();
1180
1181        match interrupt {
1182            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1183                assert_eq!(hitl.tool_name, "critical_tool");
1184                assert_eq!(hitl.tool_args["database"], "production");
1185                assert_eq!(hitl.tool_args["operation"], "drop_table");
1186                assert_eq!(hitl.call_id, "call_critical_1");
1187                assert!(hitl.policy_note.is_some());
1188                assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1189                // Verify timestamp exists (created_at field is populated)
1190                // The actual timestamp value is tested in agents-core/hitl.rs
1191            }
1192        }
1193    }
1194
1195    #[tokio::test]
1196    async fn hitl_interrupt_without_policy_note() {
1197        let mut policies = HashMap::new();
1198        policies.insert(
1199            "tool_no_note".to_string(),
1200            HitlPolicy {
1201                allow_auto: false,
1202                note: None,
1203            },
1204        );
1205
1206        let middleware = HumanInLoopMiddleware::new(policies);
1207        let tool_args = json!({"param": "value"});
1208
1209        let result = middleware
1210            .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1211            .await
1212            .unwrap();
1213
1214        assert!(result.is_some());
1215        let interrupt = result.unwrap();
1216
1217        match interrupt {
1218            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1219                assert_eq!(hitl.tool_name, "tool_no_note");
1220                assert_eq!(hitl.policy_note, None);
1221            }
1222        }
1223    }
1224}