agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18pub mod token_tracking;
19
20/// Request sent to the underlying language model. Middlewares can augment
21/// the system prompt or mutate the pending message list before the model call.
22#[derive(Debug, Clone)]
23pub struct ModelRequest {
24    pub system_prompt: String,
25    pub messages: Vec<AgentMessage>,
26}
27
28impl ModelRequest {
29    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
30        Self {
31            system_prompt: system_prompt.into(),
32            messages,
33        }
34    }
35
36    pub fn append_prompt(&mut self, fragment: &str) {
37        if !fragment.is_empty() {
38            self.system_prompt.push_str("\n\n");
39            self.system_prompt.push_str(fragment);
40        }
41    }
42}
43
44/// Read/write state handle exposed to middleware implementations.
45pub struct MiddlewareContext<'a> {
46    pub request: &'a mut ModelRequest,
47    pub state: Arc<RwLock<AgentStateSnapshot>>,
48}
49
50impl<'a> MiddlewareContext<'a> {
51    pub fn with_request(
52        request: &'a mut ModelRequest,
53        state: Arc<RwLock<AgentStateSnapshot>>,
54    ) -> Self {
55        Self { request, state }
56    }
57}
58
59/// Middleware hook that can register additional tools and mutate the model request
60/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
61/// interface async-first for future network calls.
62#[async_trait]
63pub trait AgentMiddleware: Send + Sync {
64    /// Unique identifier for logging and diagnostics.
65    fn id(&self) -> &'static str;
66
67    /// Tools to expose when this middleware is active.
68    fn tools(&self) -> Vec<ToolBox> {
69        Vec::new()
70    }
71
72    /// Apply middleware-specific mutations to the pending model request.
73    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
74
75    /// Hook called before tool execution - can return an interrupt to pause execution.
76    ///
77    /// This hook is invoked for each tool call before it executes, allowing middleware
78    /// to intercept and pause execution for human review. If an interrupt is returned,
79    /// the agent will save its state and wait for human approval before continuing.
80    ///
81    /// # Arguments
82    /// * `tool_name` - Name of the tool about to be executed
83    /// * `tool_args` - Arguments that will be passed to the tool
84    /// * `call_id` - Unique identifier for this tool call
85    ///
86    /// # Returns
87    /// * `Ok(Some(interrupt))` - Pause execution and wait for human response
88    /// * `Ok(None)` - Continue with tool execution normally
89    /// * `Err(e)` - Error occurred during interrupt check
90    async fn before_tool_execution(
91        &self,
92        _tool_name: &str,
93        _tool_args: &serde_json::Value,
94        _call_id: &str,
95    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
96        Ok(None)
97    }
98}
99
100pub struct SummarizationMiddleware {
101    pub messages_to_keep: usize,
102    pub summary_note: String,
103}
104
105impl SummarizationMiddleware {
106    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
107        Self {
108            messages_to_keep,
109            summary_note: summary_note.into(),
110        }
111    }
112}
113
114#[async_trait]
115impl AgentMiddleware for SummarizationMiddleware {
116    fn id(&self) -> &'static str {
117        "summarization"
118    }
119
120    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
121        if ctx.request.messages.len() > self.messages_to_keep {
122            let dropped = ctx.request.messages.len() - self.messages_to_keep;
123            let mut truncated = ctx
124                .request
125                .messages
126                .split_off(ctx.request.messages.len() - self.messages_to_keep);
127            truncated.insert(
128                0,
129                AgentMessage {
130                    role: MessageRole::System,
131                    content: MessageContent::Text(format!(
132                        "{} ({} earlier messages summarized)",
133                        self.summary_note, dropped
134                    )),
135                    metadata: None,
136                },
137            );
138            ctx.request.messages = truncated;
139        }
140        Ok(())
141    }
142}
143
144pub struct PlanningMiddleware {
145    _state: Arc<RwLock<AgentStateSnapshot>>,
146}
147
148impl PlanningMiddleware {
149    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
150        Self { _state: state }
151    }
152}
153
154#[async_trait]
155impl AgentMiddleware for PlanningMiddleware {
156    fn id(&self) -> &'static str {
157        "planning"
158    }
159
160    fn tools(&self) -> Vec<ToolBox> {
161        use agents_toolkit::create_todos_tools;
162        create_todos_tools()
163    }
164
165    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
166        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
167        Ok(())
168    }
169}
170
171pub struct FilesystemMiddleware {
172    _state: Arc<RwLock<AgentStateSnapshot>>,
173}
174
175impl FilesystemMiddleware {
176    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
177        Self { _state: state }
178    }
179}
180
181#[async_trait]
182impl AgentMiddleware for FilesystemMiddleware {
183    fn id(&self) -> &'static str {
184        "filesystem"
185    }
186
187    fn tools(&self) -> Vec<ToolBox> {
188        create_filesystem_tools()
189    }
190
191    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
192        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
193        Ok(())
194    }
195}
196
197#[derive(Clone)]
198pub struct SubAgentRegistration {
199    pub descriptor: SubAgentDescriptor,
200    pub agent: Arc<dyn AgentHandle>,
201}
202
203struct SubAgentRegistry {
204    agents: HashMap<String, Arc<dyn AgentHandle>>,
205}
206
207impl SubAgentRegistry {
208    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
209        let mut agents = HashMap::new();
210        for reg in registrations {
211            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
212        }
213        Self { agents }
214    }
215
216    fn available_names(&self) -> Vec<String> {
217        self.agents.keys().cloned().collect()
218    }
219
220    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
221        self.agents.get(name).cloned()
222    }
223}
224
225pub struct SubAgentMiddleware {
226    task_tool: ToolBox,
227    descriptors: Vec<SubAgentDescriptor>,
228    _registry: Arc<SubAgentRegistry>,
229}
230
231impl SubAgentMiddleware {
232    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
233        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
234        let registry = Arc::new(SubAgentRegistry::new(registrations));
235        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
236        Self {
237            task_tool,
238            descriptors,
239            _registry: registry,
240        }
241    }
242
243    pub fn new_with_events(
244        registrations: Vec<SubAgentRegistration>,
245        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
246    ) -> Self {
247        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
248        let registry = Arc::new(SubAgentRegistry::new(registrations));
249        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
250        Self {
251            task_tool,
252            descriptors,
253            _registry: registry,
254        }
255    }
256
257    fn prompt_fragment(&self) -> String {
258        let descriptions: Vec<String> = if self.descriptors.is_empty() {
259            vec![String::from("- general-purpose: Default reasoning agent")]
260        } else {
261            self.descriptors
262                .iter()
263                .map(|agent| format!("- {}: {}", agent.name, agent.description))
264                .collect()
265        };
266
267        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
268    }
269}
270
271#[async_trait]
272impl AgentMiddleware for SubAgentMiddleware {
273    fn id(&self) -> &'static str {
274        "subagent"
275    }
276
277    fn tools(&self) -> Vec<ToolBox> {
278        vec![self.task_tool.clone()]
279    }
280
281    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
282        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
283        ctx.request.append_prompt(&self.prompt_fragment());
284        Ok(())
285    }
286}
287
288#[derive(Clone, Debug)]
289pub struct HitlPolicy {
290    pub allow_auto: bool,
291    pub note: Option<String>,
292}
293
294pub struct HumanInLoopMiddleware {
295    policies: HashMap<String, HitlPolicy>,
296}
297
298impl HumanInLoopMiddleware {
299    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
300        Self { policies }
301    }
302
303    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
304        self.policies
305            .get(tool_name)
306            .filter(|policy| !policy.allow_auto)
307    }
308
309    fn prompt_fragment(&self) -> Option<String> {
310        let pending: Vec<String> = self
311            .policies
312            .iter()
313            .filter(|(_, policy)| !policy.allow_auto)
314            .map(|(tool, policy)| match &policy.note {
315                Some(note) => format!("- {tool}: {note}"),
316                None => format!("- {tool}: Requires approval"),
317            })
318            .collect();
319        if pending.is_empty() {
320            None
321        } else {
322            Some(format!(
323                "The following tools require human approval before execution:\n{}",
324                pending.join("\n")
325            ))
326        }
327    }
328}
329
330#[async_trait]
331impl AgentMiddleware for HumanInLoopMiddleware {
332    fn id(&self) -> &'static str {
333        "human-in-loop"
334    }
335
336    async fn before_tool_execution(
337        &self,
338        tool_name: &str,
339        tool_args: &serde_json::Value,
340        call_id: &str,
341    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
342        if let Some(policy) = self.requires_approval(tool_name) {
343            tracing::warn!(
344                tool_name = %tool_name,
345                call_id = %call_id,
346                policy_note = ?policy.note,
347                "🔒 HITL: Tool execution requires human approval"
348            );
349
350            let interrupt = agents_core::hitl::HitlInterrupt::new(
351                tool_name,
352                tool_args.clone(),
353                call_id,
354                policy.note.clone(),
355            );
356
357            return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
358                interrupt,
359            )));
360        }
361
362        Ok(None)
363    }
364
365    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
366        if let Some(fragment) = self.prompt_fragment() {
367            ctx.request.append_prompt(&fragment);
368        }
369        ctx.request.messages.push(AgentMessage {
370            role: MessageRole::System,
371            content: MessageContent::Text(
372                "Tools marked for human approval will emit interrupts requiring external resolution."
373                    .into(),
374            ),
375            metadata: None,
376        });
377        Ok(())
378    }
379}
380
381pub struct BaseSystemPromptMiddleware;
382
383#[async_trait]
384impl AgentMiddleware for BaseSystemPromptMiddleware {
385    fn id(&self) -> &'static str {
386        "base-system-prompt"
387    }
388
389    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
390        ctx.request.append_prompt(BASE_AGENT_PROMPT);
391        Ok(())
392    }
393}
394
395/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
396/// and examples to force the LLM to actually call tools instead of just talking about them.
397///
398/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
399/// It provides:
400/// - Explicit tool usage rules with imperative language
401/// - JSON examples of tool calling
402/// - Workflow guidance for multi-step tasks
403/// - Few-shot examples for common patterns
404///
405/// The middleware supports two modes:
406/// 1. **Default mode**: Combines custom instructions with the Deep Agent system prompt
407/// 2. **Override mode**: Uses a completely custom system prompt, bypassing the default
408pub struct DeepAgentPromptMiddleware {
409    custom_instructions: String,
410    /// If set, this completely replaces the default Deep Agent system prompt
411    override_system_prompt: Option<String>,
412}
413
414impl DeepAgentPromptMiddleware {
415    pub fn new(custom_instructions: impl Into<String>) -> Self {
416        Self {
417            custom_instructions: custom_instructions.into(),
418            override_system_prompt: None,
419        }
420    }
421
422    /// Create a middleware with a completely custom system prompt that bypasses
423    /// the default Deep Agent prompt.
424    ///
425    /// Use this when you need full control over the agent's system prompt.
426    pub fn with_override(system_prompt: impl Into<String>) -> Self {
427        Self {
428            custom_instructions: String::new(),
429            override_system_prompt: Some(system_prompt.into()),
430        }
431    }
432}
433
434#[async_trait]
435impl AgentMiddleware for DeepAgentPromptMiddleware {
436    fn id(&self) -> &'static str {
437        "deep-agent-prompt"
438    }
439
440    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
441        let prompt = if let Some(ref override_prompt) = self.override_system_prompt {
442            // Use the custom system prompt directly, bypassing the Deep Agent prompt
443            override_prompt.clone()
444        } else {
445            // Use the default Deep Agent prompt with custom instructions prepended
446            use crate::prompts::get_deep_agent_system_prompt;
447            get_deep_agent_system_prompt(&self.custom_instructions)
448        };
449        ctx.request.append_prompt(&prompt);
450        Ok(())
451    }
452}
453
454/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
455/// to reduce latency on subsequent requests with the same base prompt.
456pub struct AnthropicPromptCachingMiddleware {
457    pub ttl: String,
458    pub unsupported_model_behavior: String,
459}
460
461impl AnthropicPromptCachingMiddleware {
462    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
463        Self {
464            ttl: ttl.into(),
465            unsupported_model_behavior: unsupported_model_behavior.into(),
466        }
467    }
468
469    pub fn with_defaults() -> Self {
470        Self::new("5m", "ignore")
471    }
472
473    /// Parse TTL string like "5m" to detect if caching is requested.
474    /// For now, any non-empty TTL enables ephemeral caching.
475    fn should_enable_caching(&self) -> bool {
476        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
477    }
478}
479
480#[async_trait]
481impl AgentMiddleware for AnthropicPromptCachingMiddleware {
482    fn id(&self) -> &'static str {
483        "anthropic-prompt-caching"
484    }
485
486    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
487        if !self.should_enable_caching() {
488            return Ok(());
489        }
490
491        // Mark system prompt for caching by converting it to a system message with cache control
492        if !ctx.request.system_prompt.is_empty() {
493            let system_message = AgentMessage {
494                role: MessageRole::System,
495                content: MessageContent::Text(ctx.request.system_prompt.clone()),
496                metadata: Some(MessageMetadata {
497                    tool_call_id: None,
498                    cache_control: Some(CacheControl {
499                        cache_type: "ephemeral".to_string(),
500                    }),
501                }),
502            };
503
504            // Insert system message at the beginning of the messages
505            ctx.request.messages.insert(0, system_message);
506
507            // Clear the system_prompt since it's now in messages
508            ctx.request.system_prompt.clear();
509
510            tracing::debug!(
511                ttl = %self.ttl,
512                behavior = %self.unsupported_model_behavior,
513                "Applied Anthropic prompt caching to system message"
514            );
515        }
516
517        Ok(())
518    }
519}
520
521pub struct TaskRouterTool {
522    registry: Arc<SubAgentRegistry>,
523    event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
524    delegation_depth: Arc<RwLock<u32>>,
525}
526
527impl TaskRouterTool {
528    fn new(
529        registry: Arc<SubAgentRegistry>,
530        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
531    ) -> Self {
532        Self {
533            registry,
534            event_dispatcher,
535            delegation_depth: Arc::new(RwLock::new(0)),
536        }
537    }
538
539    fn available_subagents(&self) -> Vec<String> {
540        self.registry.available_names()
541    }
542
543    fn emit_event(&self, event: agents_core::events::AgentEvent) {
544        if let Some(dispatcher) = &self.event_dispatcher {
545            let dispatcher_clone = dispatcher.clone();
546            tokio::spawn(async move {
547                dispatcher_clone.dispatch(event).await;
548            });
549        }
550    }
551
552    fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
553        agents_core::events::EventMetadata::new(
554            "default".to_string(),
555            uuid::Uuid::new_v4().to_string(),
556            None,
557        )
558    }
559
560    fn get_delegation_depth(&self) -> u32 {
561        *self.delegation_depth.read().unwrap_or_else(|_| {
562            tracing::warn!("Failed to read delegation depth, defaulting to 0");
563            panic!("RwLock poisoned")
564        })
565    }
566
567    fn increment_delegation_depth(&self) {
568        if let Ok(mut depth) = self.delegation_depth.write() {
569            *depth += 1;
570        }
571    }
572
573    fn decrement_delegation_depth(&self) {
574        if let Ok(mut depth) = self.delegation_depth.write() {
575            if *depth > 0 {
576                *depth -= 1;
577            }
578        }
579    }
580}
581
582#[derive(Debug, Clone, Deserialize)]
583struct TaskInvocationArgs {
584    #[serde(alias = "description")]
585    instruction: String,
586    #[serde(alias = "subagent_type")]
587    agent: String,
588}
589
590#[async_trait]
591impl Tool for TaskRouterTool {
592    fn schema(&self) -> agents_core::tools::ToolSchema {
593        use agents_core::tools::{ToolParameterSchema, ToolSchema};
594        use std::collections::HashMap;
595
596        let mut properties = HashMap::new();
597        properties.insert(
598            "agent".to_string(),
599            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
600        );
601        properties.insert(
602            "instruction".to_string(),
603            ToolParameterSchema::string("Clear instruction for the sub-agent"),
604        );
605
606        ToolSchema::new(
607            "task",
608            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
609            ToolParameterSchema::object(
610                "Task delegation parameters",
611                properties,
612                vec!["agent".to_string(), "instruction".to_string()],
613            ),
614        )
615    }
616
617    async fn execute(
618        &self,
619        args: serde_json::Value,
620        ctx: ToolContext,
621    ) -> anyhow::Result<ToolResult> {
622        let args: TaskInvocationArgs = serde_json::from_value(args)?;
623        let available = self.available_subagents();
624
625        if let Some(agent) = self.registry.get(&args.agent) {
626            // Increment delegation depth
627            self.increment_delegation_depth();
628            let current_depth = self.get_delegation_depth();
629
630            // Truncate instruction for event
631            let instruction_summary = if args.instruction.len() > 100 {
632                format!("{}...", &args.instruction[..100])
633            } else {
634                args.instruction.clone()
635            };
636
637            // Emit: SubAgentStarted event
638            self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
639                agents_core::events::SubAgentStartedEvent {
640                    metadata: self.create_event_metadata(),
641                    agent_name: args.agent.clone(),
642                    instruction_summary: instruction_summary.clone(),
643                    delegation_depth: current_depth,
644                },
645            ));
646
647            // Log delegation start
648            tracing::warn!(
649                "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
650                args.agent,
651                current_depth,
652                args.instruction
653            );
654
655            let start_time = std::time::Instant::now();
656            let user_message = AgentMessage {
657                role: MessageRole::User,
658                content: MessageContent::Text(args.instruction.clone()),
659                metadata: None,
660            };
661
662            let response = agent
663                .handle_message(user_message, ctx.state.clone())
664                .await?;
665
666            // Calculate duration
667            let duration = start_time.elapsed();
668            let duration_ms = duration.as_millis() as u64;
669
670            // Create response preview
671            let response_preview = match &response.content {
672                MessageContent::Text(t) => {
673                    if t.len() > 100 {
674                        format!("{}...", &t[..100])
675                    } else {
676                        t.clone()
677                    }
678                }
679                MessageContent::Json(v) => {
680                    let json_str = v.to_string();
681                    if json_str.len() > 100 {
682                        format!("{}...", &json_str[..100])
683                    } else {
684                        json_str
685                    }
686                }
687            };
688
689            // Emit: SubAgentCompleted event
690            self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
691                agents_core::events::SubAgentCompletedEvent {
692                    metadata: self.create_event_metadata(),
693                    agent_name: args.agent.clone(),
694                    duration_ms,
695                    result_summary: response_preview.clone(),
696                },
697            ));
698
699            // Log delegation completion
700            tracing::warn!(
701                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
702                args.agent,
703                duration,
704                response_preview
705            );
706
707            // Decrement delegation depth
708            self.decrement_delegation_depth();
709
710            // Return sub-agent response as text content, not as a separate tool message
711            // This will be incorporated into the LLM's next response naturally
712            let result_text = match response.content {
713                MessageContent::Text(text) => text,
714                MessageContent::Json(json) => json.to_string(),
715            };
716
717            return Ok(ToolResult::text(&ctx, result_text));
718        }
719
720        tracing::error!(
721            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
722            args.agent,
723            available
724        );
725
726        Ok(ToolResult::text(
727            &ctx,
728            format!(
729                "Sub-agent '{}' not found. Available sub-agents: {}",
730                args.agent,
731                available.join(", ")
732            ),
733        ))
734    }
735}
736
737#[derive(Debug, Clone)]
738pub struct SubAgentDescriptor {
739    pub name: String,
740    pub description: String,
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746    use agents_core::agent::{AgentDescriptor, AgentHandle};
747    use agents_core::messaging::{MessageContent, MessageRole};
748    use serde_json::json;
749
750    struct AppendPromptMiddleware;
751
752    #[async_trait]
753    impl AgentMiddleware for AppendPromptMiddleware {
754        fn id(&self) -> &'static str {
755            "append-prompt"
756        }
757
758        async fn modify_model_request(
759            &self,
760            ctx: &mut MiddlewareContext<'_>,
761        ) -> anyhow::Result<()> {
762            ctx.request.system_prompt.push_str("\nExtra directives.");
763            Ok(())
764        }
765    }
766
767    #[tokio::test]
768    async fn middleware_mutates_prompt() {
769        let mut request = ModelRequest::new(
770            "System",
771            vec![AgentMessage {
772                role: MessageRole::User,
773                content: MessageContent::Text("Hi".into()),
774                metadata: None,
775            }],
776        );
777        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
778        let mut ctx = MiddlewareContext::with_request(&mut request, state);
779        let middleware = AppendPromptMiddleware;
780        middleware.modify_model_request(&mut ctx).await.unwrap();
781        assert!(ctx.request.system_prompt.contains("Extra directives"));
782    }
783
784    #[tokio::test]
785    async fn planning_middleware_registers_write_todos() {
786        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
787        let middleware = PlanningMiddleware::new(state);
788        let tool_names: Vec<_> = middleware
789            .tools()
790            .iter()
791            .map(|t| t.schema().name.clone())
792            .collect();
793        assert!(tool_names.contains(&"write_todos".to_string()));
794
795        let mut request = ModelRequest::new("System", vec![]);
796        let mut ctx = MiddlewareContext::with_request(
797            &mut request,
798            Arc::new(RwLock::new(AgentStateSnapshot::default())),
799        );
800        middleware.modify_model_request(&mut ctx).await.unwrap();
801        assert!(ctx.request.system_prompt.contains("write_todos"));
802    }
803
804    #[tokio::test]
805    async fn filesystem_middleware_registers_tools() {
806        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
807        let middleware = FilesystemMiddleware::new(state);
808        let tool_names: Vec<_> = middleware
809            .tools()
810            .iter()
811            .map(|t| t.schema().name.clone())
812            .collect();
813        for expected in ["ls", "read_file", "write_file", "edit_file"] {
814            assert!(tool_names.contains(&expected.to_string()));
815        }
816    }
817
818    #[tokio::test]
819    async fn summarization_middleware_trims_messages() {
820        let middleware = SummarizationMiddleware::new(2, "Summary note");
821        let mut request = ModelRequest::new(
822            "System",
823            vec![
824                AgentMessage {
825                    role: MessageRole::User,
826                    content: MessageContent::Text("one".into()),
827                    metadata: None,
828                },
829                AgentMessage {
830                    role: MessageRole::Agent,
831                    content: MessageContent::Text("two".into()),
832                    metadata: None,
833                },
834                AgentMessage {
835                    role: MessageRole::User,
836                    content: MessageContent::Text("three".into()),
837                    metadata: None,
838                },
839            ],
840        );
841        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
842        let mut ctx = MiddlewareContext::with_request(&mut request, state);
843        middleware.modify_model_request(&mut ctx).await.unwrap();
844        assert_eq!(ctx.request.messages.len(), 3);
845        match &ctx.request.messages[0].content {
846            MessageContent::Text(text) => assert!(text.contains("Summary note")),
847            other => panic!("expected text, got {other:?}"),
848        }
849    }
850
851    struct StubAgent;
852
853    #[async_trait]
854    impl AgentHandle for StubAgent {
855        async fn describe(&self) -> AgentDescriptor {
856            AgentDescriptor {
857                name: "stub".into(),
858                version: "0.0.1".into(),
859                description: None,
860            }
861        }
862
863        async fn handle_message(
864            &self,
865            _input: AgentMessage,
866            _state: Arc<AgentStateSnapshot>,
867        ) -> anyhow::Result<AgentMessage> {
868            Ok(AgentMessage {
869                role: MessageRole::Agent,
870                content: MessageContent::Text("stub-response".into()),
871                metadata: None,
872            })
873        }
874    }
875
876    #[tokio::test]
877    async fn task_router_reports_unknown_subagent() {
878        let registry = Arc::new(SubAgentRegistry::new(vec![]));
879        let task_tool = TaskRouterTool::new(registry.clone(), None);
880        let state = Arc::new(AgentStateSnapshot::default());
881        let ctx = ToolContext::new(state);
882
883        let response = task_tool
884            .execute(
885                json!({
886                    "instruction": "Do something",
887                    "agent": "unknown"
888                }),
889                ctx,
890            )
891            .await
892            .unwrap();
893
894        match response {
895            ToolResult::Message(msg) => match msg.content {
896                MessageContent::Text(text) => {
897                    assert!(text.contains("Sub-agent 'unknown' not found"))
898                }
899                other => panic!("expected text, got {other:?}"),
900            },
901            _ => panic!("expected message"),
902        }
903    }
904
905    #[tokio::test]
906    async fn subagent_middleware_appends_prompt() {
907        let subagents = vec![SubAgentRegistration {
908            descriptor: SubAgentDescriptor {
909                name: "research-agent".into(),
910                description: "Deep research specialist".into(),
911            },
912            agent: Arc::new(StubAgent),
913        }];
914        let middleware = SubAgentMiddleware::new(subagents);
915
916        let mut request = ModelRequest::new("System", vec![]);
917        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
918        let mut ctx = MiddlewareContext::with_request(&mut request, state);
919        middleware.modify_model_request(&mut ctx).await.unwrap();
920
921        assert!(ctx.request.system_prompt.contains("research-agent"));
922        let tool_names: Vec<_> = middleware
923            .tools()
924            .iter()
925            .map(|t| t.schema().name.clone())
926            .collect();
927        assert!(tool_names.contains(&"task".to_string()));
928    }
929
930    #[tokio::test]
931    async fn task_router_invokes_registered_subagent() {
932        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
933            descriptor: SubAgentDescriptor {
934                name: "stub-agent".into(),
935                description: "Stub".into(),
936            },
937            agent: Arc::new(StubAgent),
938        }]));
939        let task_tool = TaskRouterTool::new(registry.clone(), None);
940        let state = Arc::new(AgentStateSnapshot::default());
941        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
942        let response = task_tool
943            .execute(
944                json!({
945                    "description": "do work",
946                    "subagent_type": "stub-agent"
947                }),
948                ctx,
949            )
950            .await
951            .unwrap();
952
953        match response {
954            ToolResult::Message(msg) => {
955                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
956                match msg.content {
957                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
958                    other => panic!("expected text, got {other:?}"),
959                }
960            }
961            _ => panic!("expected message"),
962        }
963    }
964
965    #[tokio::test]
966    async fn human_in_loop_appends_prompt() {
967        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
968            "danger-tool".into(),
969            HitlPolicy {
970                allow_auto: false,
971                note: Some("Requires security review".into()),
972            },
973        )]));
974        let mut request = ModelRequest::new("System", vec![]);
975        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
976        let mut ctx = MiddlewareContext::with_request(&mut request, state);
977        middleware.modify_model_request(&mut ctx).await.unwrap();
978        assert!(ctx
979            .request
980            .system_prompt
981            .contains("danger-tool: Requires security review"));
982    }
983
984    #[tokio::test]
985    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
986        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
987        let mut request = ModelRequest::new(
988            "This is the system prompt",
989            vec![AgentMessage {
990                role: MessageRole::User,
991                content: MessageContent::Text("Hello".into()),
992                metadata: None,
993            }],
994        );
995        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
996        let mut ctx = MiddlewareContext::with_request(&mut request, state);
997
998        // Apply the middleware
999        middleware.modify_model_request(&mut ctx).await.unwrap();
1000
1001        // System prompt should be cleared
1002        assert!(ctx.request.system_prompt.is_empty());
1003
1004        // Should have added a system message with cache control at the beginning
1005        assert_eq!(ctx.request.messages.len(), 2);
1006
1007        let system_message = &ctx.request.messages[0];
1008        assert!(matches!(system_message.role, MessageRole::System));
1009        assert_eq!(
1010            system_message.content.as_text().unwrap(),
1011            "This is the system prompt"
1012        );
1013
1014        // Check cache control metadata
1015        let metadata = system_message.metadata.as_ref().unwrap();
1016        let cache_control = metadata.cache_control.as_ref().unwrap();
1017        assert_eq!(cache_control.cache_type, "ephemeral");
1018
1019        // Original user message should still be there
1020        let user_message = &ctx.request.messages[1];
1021        assert!(matches!(user_message.role, MessageRole::User));
1022        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
1023    }
1024
1025    #[tokio::test]
1026    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1027        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1028        let mut request = ModelRequest::new("This is the system prompt", vec![]);
1029        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1030        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1031
1032        // Apply the middleware
1033        middleware.modify_model_request(&mut ctx).await.unwrap();
1034
1035        // System prompt should be unchanged
1036        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1037        assert_eq!(ctx.request.messages.len(), 0);
1038    }
1039
1040    #[tokio::test]
1041    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1042        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1043        let mut request = ModelRequest::new(
1044            "",
1045            vec![AgentMessage {
1046                role: MessageRole::User,
1047                content: MessageContent::Text("Hello".into()),
1048                metadata: None,
1049            }],
1050        );
1051        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1052        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1053
1054        // Apply the middleware
1055        middleware.modify_model_request(&mut ctx).await.unwrap();
1056
1057        // System prompt should remain empty
1058        assert!(ctx.request.system_prompt.is_empty());
1059        // No system message should be added
1060        assert_eq!(ctx.request.messages.len(), 1);
1061    }
1062
1063    // ========== HITL Interrupt Creation Tests ==========
1064
1065    #[tokio::test]
1066    async fn hitl_creates_interrupt_for_disallowed_tool() {
1067        let mut policies = HashMap::new();
1068        policies.insert(
1069            "dangerous_tool".to_string(),
1070            HitlPolicy {
1071                allow_auto: false,
1072                note: Some("Requires security review".to_string()),
1073            },
1074        );
1075
1076        let middleware = HumanInLoopMiddleware::new(policies);
1077        let tool_args = json!({"action": "delete_all"});
1078
1079        let result = middleware
1080            .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1081            .await
1082            .unwrap();
1083
1084        assert!(result.is_some());
1085        let interrupt = result.unwrap();
1086
1087        match interrupt {
1088            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1089                assert_eq!(hitl.tool_name, "dangerous_tool");
1090                assert_eq!(hitl.tool_args, tool_args);
1091                assert_eq!(hitl.call_id, "call_123");
1092                assert_eq!(
1093                    hitl.policy_note,
1094                    Some("Requires security review".to_string())
1095                );
1096            }
1097        }
1098    }
1099
1100    #[tokio::test]
1101    async fn hitl_no_interrupt_for_allowed_tool() {
1102        let mut policies = HashMap::new();
1103        policies.insert(
1104            "safe_tool".to_string(),
1105            HitlPolicy {
1106                allow_auto: true,
1107                note: None,
1108            },
1109        );
1110
1111        let middleware = HumanInLoopMiddleware::new(policies);
1112        let tool_args = json!({"action": "read"});
1113
1114        let result = middleware
1115            .before_tool_execution("safe_tool", &tool_args, "call_456")
1116            .await
1117            .unwrap();
1118
1119        assert!(result.is_none());
1120    }
1121
1122    #[tokio::test]
1123    async fn hitl_no_interrupt_for_unlisted_tool() {
1124        let policies = HashMap::new();
1125        let middleware = HumanInLoopMiddleware::new(policies);
1126        let tool_args = json!({"action": "anything"});
1127
1128        let result = middleware
1129            .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1130            .await
1131            .unwrap();
1132
1133        assert!(result.is_none());
1134    }
1135
1136    #[tokio::test]
1137    async fn hitl_interrupt_includes_correct_details() {
1138        let mut policies = HashMap::new();
1139        policies.insert(
1140            "critical_tool".to_string(),
1141            HitlPolicy {
1142                allow_auto: false,
1143                note: Some("Critical operation - requires approval".to_string()),
1144            },
1145        );
1146
1147        let middleware = HumanInLoopMiddleware::new(policies);
1148        let tool_args = json!({
1149            "database": "production",
1150            "operation": "drop_table"
1151        });
1152
1153        let result = middleware
1154            .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1155            .await
1156            .unwrap();
1157
1158        assert!(result.is_some());
1159        let interrupt = result.unwrap();
1160
1161        match interrupt {
1162            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1163                assert_eq!(hitl.tool_name, "critical_tool");
1164                assert_eq!(hitl.tool_args["database"], "production");
1165                assert_eq!(hitl.tool_args["operation"], "drop_table");
1166                assert_eq!(hitl.call_id, "call_critical_1");
1167                assert!(hitl.policy_note.is_some());
1168                assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1169                // Verify timestamp exists (created_at field is populated)
1170                // The actual timestamp value is tested in agents-core/hitl.rs
1171            }
1172        }
1173    }
1174
1175    #[tokio::test]
1176    async fn hitl_interrupt_without_policy_note() {
1177        let mut policies = HashMap::new();
1178        policies.insert(
1179            "tool_no_note".to_string(),
1180            HitlPolicy {
1181                allow_auto: false,
1182                note: None,
1183            },
1184        );
1185
1186        let middleware = HumanInLoopMiddleware::new(policies);
1187        let tool_args = json!({"param": "value"});
1188
1189        let result = middleware
1190            .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1191            .await
1192            .unwrap();
1193
1194        assert!(result.is_some());
1195        let interrupt = result.unwrap();
1196
1197        match interrupt {
1198            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1199                assert_eq!(hitl.tool_name, "tool_no_note");
1200                assert_eq!(hitl.policy_note, None);
1201            }
1202        }
1203    }
1204}