agents_runtime/
middleware.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::AgentHandle;
5use agents_core::messaging::{
6    AgentMessage, CacheControl, MessageContent, MessageMetadata, MessageRole,
7};
8use agents_core::prompts::{
9    BASE_AGENT_PROMPT, FILESYSTEM_SYSTEM_PROMPT, TASK_SYSTEM_PROMPT, TASK_TOOL_DESCRIPTION,
10    WRITE_TODOS_SYSTEM_PROMPT,
11};
12use agents_core::state::AgentStateSnapshot;
13use agents_core::tools::{Tool, ToolBox, ToolContext, ToolResult};
14use agents_toolkit::create_filesystem_tools;
15use async_trait::async_trait;
16use serde::Deserialize;
17
18/// Request sent to the underlying language model. Middlewares can augment
19/// the system prompt or mutate the pending message list before the model call.
20#[derive(Debug, Clone)]
21pub struct ModelRequest {
22    pub system_prompt: String,
23    pub messages: Vec<AgentMessage>,
24}
25
26impl ModelRequest {
27    pub fn new(system_prompt: impl Into<String>, messages: Vec<AgentMessage>) -> Self {
28        Self {
29            system_prompt: system_prompt.into(),
30            messages,
31        }
32    }
33
34    pub fn append_prompt(&mut self, fragment: &str) {
35        if !fragment.is_empty() {
36            self.system_prompt.push_str("\n\n");
37            self.system_prompt.push_str(fragment);
38        }
39    }
40}
41
42/// Read/write state handle exposed to middleware implementations.
43pub struct MiddlewareContext<'a> {
44    pub request: &'a mut ModelRequest,
45    pub state: Arc<RwLock<AgentStateSnapshot>>,
46}
47
48impl<'a> MiddlewareContext<'a> {
49    pub fn with_request(
50        request: &'a mut ModelRequest,
51        state: Arc<RwLock<AgentStateSnapshot>>,
52    ) -> Self {
53        Self { request, state }
54    }
55}
56
57/// Middleware hook that can register additional tools and mutate the model request
58/// prior to execution. Mirrors the Python AgentMiddleware contracts but keeps the
59/// interface async-first for future network calls.
60#[async_trait]
61pub trait AgentMiddleware: Send + Sync {
62    /// Unique identifier for logging and diagnostics.
63    fn id(&self) -> &'static str;
64
65    /// Tools to expose when this middleware is active.
66    fn tools(&self) -> Vec<ToolBox> {
67        Vec::new()
68    }
69
70    /// Apply middleware-specific mutations to the pending model request.
71    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()>;
72
73    /// Hook called before tool execution - can return an interrupt to pause execution.
74    ///
75    /// This hook is invoked for each tool call before it executes, allowing middleware
76    /// to intercept and pause execution for human review. If an interrupt is returned,
77    /// the agent will save its state and wait for human approval before continuing.
78    ///
79    /// # Arguments
80    /// * `tool_name` - Name of the tool about to be executed
81    /// * `tool_args` - Arguments that will be passed to the tool
82    /// * `call_id` - Unique identifier for this tool call
83    ///
84    /// # Returns
85    /// * `Ok(Some(interrupt))` - Pause execution and wait for human response
86    /// * `Ok(None)` - Continue with tool execution normally
87    /// * `Err(e)` - Error occurred during interrupt check
88    async fn before_tool_execution(
89        &self,
90        _tool_name: &str,
91        _tool_args: &serde_json::Value,
92        _call_id: &str,
93    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
94        Ok(None)
95    }
96}
97
98pub struct SummarizationMiddleware {
99    pub messages_to_keep: usize,
100    pub summary_note: String,
101}
102
103impl SummarizationMiddleware {
104    pub fn new(messages_to_keep: usize, summary_note: impl Into<String>) -> Self {
105        Self {
106            messages_to_keep,
107            summary_note: summary_note.into(),
108        }
109    }
110}
111
112#[async_trait]
113impl AgentMiddleware for SummarizationMiddleware {
114    fn id(&self) -> &'static str {
115        "summarization"
116    }
117
118    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
119        if ctx.request.messages.len() > self.messages_to_keep {
120            let dropped = ctx.request.messages.len() - self.messages_to_keep;
121            let mut truncated = ctx
122                .request
123                .messages
124                .split_off(ctx.request.messages.len() - self.messages_to_keep);
125            truncated.insert(
126                0,
127                AgentMessage {
128                    role: MessageRole::System,
129                    content: MessageContent::Text(format!(
130                        "{} ({} earlier messages summarized)",
131                        self.summary_note, dropped
132                    )),
133                    metadata: None,
134                },
135            );
136            ctx.request.messages = truncated;
137        }
138        Ok(())
139    }
140}
141
142pub struct PlanningMiddleware {
143    _state: Arc<RwLock<AgentStateSnapshot>>,
144}
145
146impl PlanningMiddleware {
147    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
148        Self { _state: state }
149    }
150}
151
152#[async_trait]
153impl AgentMiddleware for PlanningMiddleware {
154    fn id(&self) -> &'static str {
155        "planning"
156    }
157
158    fn tools(&self) -> Vec<ToolBox> {
159        use agents_toolkit::create_todos_tools;
160        create_todos_tools()
161    }
162
163    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
164        ctx.request.append_prompt(WRITE_TODOS_SYSTEM_PROMPT);
165        Ok(())
166    }
167}
168
169pub struct FilesystemMiddleware {
170    _state: Arc<RwLock<AgentStateSnapshot>>,
171}
172
173impl FilesystemMiddleware {
174    pub fn new(state: Arc<RwLock<AgentStateSnapshot>>) -> Self {
175        Self { _state: state }
176    }
177}
178
179#[async_trait]
180impl AgentMiddleware for FilesystemMiddleware {
181    fn id(&self) -> &'static str {
182        "filesystem"
183    }
184
185    fn tools(&self) -> Vec<ToolBox> {
186        create_filesystem_tools()
187    }
188
189    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
190        ctx.request.append_prompt(FILESYSTEM_SYSTEM_PROMPT);
191        Ok(())
192    }
193}
194
195#[derive(Clone)]
196pub struct SubAgentRegistration {
197    pub descriptor: SubAgentDescriptor,
198    pub agent: Arc<dyn AgentHandle>,
199}
200
201struct SubAgentRegistry {
202    agents: HashMap<String, Arc<dyn AgentHandle>>,
203}
204
205impl SubAgentRegistry {
206    fn new(registrations: Vec<SubAgentRegistration>) -> Self {
207        let mut agents = HashMap::new();
208        for reg in registrations {
209            agents.insert(reg.descriptor.name.clone(), reg.agent.clone());
210        }
211        Self { agents }
212    }
213
214    fn available_names(&self) -> Vec<String> {
215        self.agents.keys().cloned().collect()
216    }
217
218    fn get(&self, name: &str) -> Option<Arc<dyn AgentHandle>> {
219        self.agents.get(name).cloned()
220    }
221}
222
223pub struct SubAgentMiddleware {
224    task_tool: ToolBox,
225    descriptors: Vec<SubAgentDescriptor>,
226    _registry: Arc<SubAgentRegistry>,
227}
228
229impl SubAgentMiddleware {
230    pub fn new(registrations: Vec<SubAgentRegistration>) -> Self {
231        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
232        let registry = Arc::new(SubAgentRegistry::new(registrations));
233        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), None));
234        Self {
235            task_tool,
236            descriptors,
237            _registry: registry,
238        }
239    }
240
241    pub fn new_with_events(
242        registrations: Vec<SubAgentRegistration>,
243        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
244    ) -> Self {
245        let descriptors = registrations.iter().map(|r| r.descriptor.clone()).collect();
246        let registry = Arc::new(SubAgentRegistry::new(registrations));
247        let task_tool: ToolBox = Arc::new(TaskRouterTool::new(registry.clone(), event_dispatcher));
248        Self {
249            task_tool,
250            descriptors,
251            _registry: registry,
252        }
253    }
254
255    fn prompt_fragment(&self) -> String {
256        let descriptions: Vec<String> = if self.descriptors.is_empty() {
257            vec![String::from("- general-purpose: Default reasoning agent")]
258        } else {
259            self.descriptors
260                .iter()
261                .map(|agent| format!("- {}: {}", agent.name, agent.description))
262                .collect()
263        };
264
265        TASK_TOOL_DESCRIPTION.replace("{other_agents}", &descriptions.join("\n"))
266    }
267}
268
269#[async_trait]
270impl AgentMiddleware for SubAgentMiddleware {
271    fn id(&self) -> &'static str {
272        "subagent"
273    }
274
275    fn tools(&self) -> Vec<ToolBox> {
276        vec![self.task_tool.clone()]
277    }
278
279    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
280        ctx.request.append_prompt(TASK_SYSTEM_PROMPT);
281        ctx.request.append_prompt(&self.prompt_fragment());
282        Ok(())
283    }
284}
285
286#[derive(Clone, Debug)]
287pub struct HitlPolicy {
288    pub allow_auto: bool,
289    pub note: Option<String>,
290}
291
292pub struct HumanInLoopMiddleware {
293    policies: HashMap<String, HitlPolicy>,
294}
295
296impl HumanInLoopMiddleware {
297    pub fn new(policies: HashMap<String, HitlPolicy>) -> Self {
298        Self { policies }
299    }
300
301    pub fn requires_approval(&self, tool_name: &str) -> Option<&HitlPolicy> {
302        self.policies
303            .get(tool_name)
304            .filter(|policy| !policy.allow_auto)
305    }
306
307    fn prompt_fragment(&self) -> Option<String> {
308        let pending: Vec<String> = self
309            .policies
310            .iter()
311            .filter(|(_, policy)| !policy.allow_auto)
312            .map(|(tool, policy)| match &policy.note {
313                Some(note) => format!("- {tool}: {note}"),
314                None => format!("- {tool}: Requires approval"),
315            })
316            .collect();
317        if pending.is_empty() {
318            None
319        } else {
320            Some(format!(
321                "The following tools require human approval before execution:\n{}",
322                pending.join("\n")
323            ))
324        }
325    }
326}
327
328#[async_trait]
329impl AgentMiddleware for HumanInLoopMiddleware {
330    fn id(&self) -> &'static str {
331        "human-in-loop"
332    }
333
334    async fn before_tool_execution(
335        &self,
336        tool_name: &str,
337        tool_args: &serde_json::Value,
338        call_id: &str,
339    ) -> anyhow::Result<Option<agents_core::hitl::AgentInterrupt>> {
340        if let Some(policy) = self.requires_approval(tool_name) {
341            tracing::warn!(
342                tool_name = %tool_name,
343                call_id = %call_id,
344                policy_note = ?policy.note,
345                "🔒 HITL: Tool execution requires human approval"
346            );
347
348            let interrupt = agents_core::hitl::HitlInterrupt::new(
349                tool_name,
350                tool_args.clone(),
351                call_id,
352                policy.note.clone(),
353            );
354
355            return Ok(Some(agents_core::hitl::AgentInterrupt::HumanInLoop(
356                interrupt,
357            )));
358        }
359
360        Ok(None)
361    }
362
363    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
364        if let Some(fragment) = self.prompt_fragment() {
365            ctx.request.append_prompt(&fragment);
366        }
367        ctx.request.messages.push(AgentMessage {
368            role: MessageRole::System,
369            content: MessageContent::Text(
370                "Tools marked for human approval will emit interrupts requiring external resolution."
371                    .into(),
372            ),
373            metadata: None,
374        });
375        Ok(())
376    }
377}
378
379pub struct BaseSystemPromptMiddleware;
380
381#[async_trait]
382impl AgentMiddleware for BaseSystemPromptMiddleware {
383    fn id(&self) -> &'static str {
384        "base-system-prompt"
385    }
386
387    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
388        ctx.request.append_prompt(BASE_AGENT_PROMPT);
389        Ok(())
390    }
391}
392
393/// Deep Agent prompt middleware that injects comprehensive tool usage instructions
394/// and examples to force the LLM to actually call tools instead of just talking about them.
395///
396/// This middleware is inspired by Python's deepagents package and Claude Code's system prompt.
397/// It provides:
398/// - Explicit tool usage rules with imperative language
399/// - JSON examples of tool calling
400/// - Workflow guidance for multi-step tasks
401/// - Few-shot examples for common patterns
402pub struct DeepAgentPromptMiddleware {
403    custom_instructions: String,
404}
405
406impl DeepAgentPromptMiddleware {
407    pub fn new(custom_instructions: impl Into<String>) -> Self {
408        Self {
409            custom_instructions: custom_instructions.into(),
410        }
411    }
412}
413
414#[async_trait]
415impl AgentMiddleware for DeepAgentPromptMiddleware {
416    fn id(&self) -> &'static str {
417        "deep-agent-prompt"
418    }
419
420    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
421        use crate::prompts::get_deep_agent_system_prompt;
422        let deep_prompt = get_deep_agent_system_prompt(&self.custom_instructions);
423        ctx.request.append_prompt(&deep_prompt);
424        Ok(())
425    }
426}
427
428/// Anthropic-specific prompt caching middleware. Marks system prompts for caching
429/// to reduce latency on subsequent requests with the same base prompt.
430pub struct AnthropicPromptCachingMiddleware {
431    pub ttl: String,
432    pub unsupported_model_behavior: String,
433}
434
435impl AnthropicPromptCachingMiddleware {
436    pub fn new(ttl: impl Into<String>, unsupported_model_behavior: impl Into<String>) -> Self {
437        Self {
438            ttl: ttl.into(),
439            unsupported_model_behavior: unsupported_model_behavior.into(),
440        }
441    }
442
443    pub fn with_defaults() -> Self {
444        Self::new("5m", "ignore")
445    }
446
447    /// Parse TTL string like "5m" to detect if caching is requested.
448    /// For now, any non-empty TTL enables ephemeral caching.
449    fn should_enable_caching(&self) -> bool {
450        !self.ttl.is_empty() && self.ttl != "0" && self.ttl != "0s"
451    }
452}
453
454#[async_trait]
455impl AgentMiddleware for AnthropicPromptCachingMiddleware {
456    fn id(&self) -> &'static str {
457        "anthropic-prompt-caching"
458    }
459
460    async fn modify_model_request(&self, ctx: &mut MiddlewareContext<'_>) -> anyhow::Result<()> {
461        if !self.should_enable_caching() {
462            return Ok(());
463        }
464
465        // Mark system prompt for caching by converting it to a system message with cache control
466        if !ctx.request.system_prompt.is_empty() {
467            let system_message = AgentMessage {
468                role: MessageRole::System,
469                content: MessageContent::Text(ctx.request.system_prompt.clone()),
470                metadata: Some(MessageMetadata {
471                    tool_call_id: None,
472                    cache_control: Some(CacheControl {
473                        cache_type: "ephemeral".to_string(),
474                    }),
475                }),
476            };
477
478            // Insert system message at the beginning of the messages
479            ctx.request.messages.insert(0, system_message);
480
481            // Clear the system_prompt since it's now in messages
482            ctx.request.system_prompt.clear();
483
484            tracing::debug!(
485                ttl = %self.ttl,
486                behavior = %self.unsupported_model_behavior,
487                "Applied Anthropic prompt caching to system message"
488            );
489        }
490
491        Ok(())
492    }
493}
494
495pub struct TaskRouterTool {
496    registry: Arc<SubAgentRegistry>,
497    event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
498    delegation_depth: Arc<RwLock<u32>>,
499}
500
501impl TaskRouterTool {
502    fn new(
503        registry: Arc<SubAgentRegistry>,
504        event_dispatcher: Option<Arc<agents_core::events::EventDispatcher>>,
505    ) -> Self {
506        Self {
507            registry,
508            event_dispatcher,
509            delegation_depth: Arc::new(RwLock::new(0)),
510        }
511    }
512
513    fn available_subagents(&self) -> Vec<String> {
514        self.registry.available_names()
515    }
516
517    fn emit_event(&self, event: agents_core::events::AgentEvent) {
518        if let Some(dispatcher) = &self.event_dispatcher {
519            let dispatcher_clone = dispatcher.clone();
520            tokio::spawn(async move {
521                dispatcher_clone.dispatch(event).await;
522            });
523        }
524    }
525
526    fn create_event_metadata(&self) -> agents_core::events::EventMetadata {
527        agents_core::events::EventMetadata::new(
528            "default".to_string(),
529            uuid::Uuid::new_v4().to_string(),
530            None,
531        )
532    }
533
534    fn get_delegation_depth(&self) -> u32 {
535        *self.delegation_depth.read().unwrap_or_else(|_| {
536            tracing::warn!("Failed to read delegation depth, defaulting to 0");
537            panic!("RwLock poisoned")
538        })
539    }
540
541    fn increment_delegation_depth(&self) {
542        if let Ok(mut depth) = self.delegation_depth.write() {
543            *depth += 1;
544        }
545    }
546
547    fn decrement_delegation_depth(&self) {
548        if let Ok(mut depth) = self.delegation_depth.write() {
549            if *depth > 0 {
550                *depth -= 1;
551            }
552        }
553    }
554}
555
556#[derive(Debug, Clone, Deserialize)]
557struct TaskInvocationArgs {
558    #[serde(alias = "description")]
559    instruction: String,
560    #[serde(alias = "subagent_type")]
561    agent: String,
562}
563
564#[async_trait]
565impl Tool for TaskRouterTool {
566    fn schema(&self) -> agents_core::tools::ToolSchema {
567        use agents_core::tools::{ToolParameterSchema, ToolSchema};
568        use std::collections::HashMap;
569
570        let mut properties = HashMap::new();
571        properties.insert(
572            "agent".to_string(),
573            ToolParameterSchema::string("Name of the sub-agent to delegate to"),
574        );
575        properties.insert(
576            "instruction".to_string(),
577            ToolParameterSchema::string("Clear instruction for the sub-agent"),
578        );
579
580        ToolSchema::new(
581            "task",
582            "Delegate a task to a specialized sub-agent. Use this when you need specialized expertise or want to break down complex tasks.",
583            ToolParameterSchema::object(
584                "Task delegation parameters",
585                properties,
586                vec!["agent".to_string(), "instruction".to_string()],
587            ),
588        )
589    }
590
591    async fn execute(
592        &self,
593        args: serde_json::Value,
594        ctx: ToolContext,
595    ) -> anyhow::Result<ToolResult> {
596        let args: TaskInvocationArgs = serde_json::from_value(args)?;
597        let available = self.available_subagents();
598
599        if let Some(agent) = self.registry.get(&args.agent) {
600            // Increment delegation depth
601            self.increment_delegation_depth();
602            let current_depth = self.get_delegation_depth();
603
604            // Truncate instruction for event
605            let instruction_summary = if args.instruction.len() > 100 {
606                format!("{}...", &args.instruction[..100])
607            } else {
608                args.instruction.clone()
609            };
610
611            // Emit: SubAgentStarted event
612            self.emit_event(agents_core::events::AgentEvent::SubAgentStarted(
613                agents_core::events::SubAgentStartedEvent {
614                    metadata: self.create_event_metadata(),
615                    agent_name: args.agent.clone(),
616                    instruction_summary: instruction_summary.clone(),
617                    delegation_depth: current_depth,
618                },
619            ));
620
621            // Log delegation start
622            tracing::warn!(
623                "🎯 DELEGATING to sub-agent: {} (depth: {}) with instruction: {}",
624                args.agent,
625                current_depth,
626                args.instruction
627            );
628
629            let start_time = std::time::Instant::now();
630            let user_message = AgentMessage {
631                role: MessageRole::User,
632                content: MessageContent::Text(args.instruction.clone()),
633                metadata: None,
634            };
635
636            let response = agent
637                .handle_message(user_message, Arc::new(AgentStateSnapshot::default()))
638                .await?;
639
640            // Calculate duration
641            let duration = start_time.elapsed();
642            let duration_ms = duration.as_millis() as u64;
643
644            // Create response preview
645            let response_preview = match &response.content {
646                MessageContent::Text(t) => {
647                    if t.len() > 100 {
648                        format!("{}...", &t[..100])
649                    } else {
650                        t.clone()
651                    }
652                }
653                MessageContent::Json(v) => {
654                    let json_str = v.to_string();
655                    if json_str.len() > 100 {
656                        format!("{}...", &json_str[..100])
657                    } else {
658                        json_str
659                    }
660                }
661            };
662
663            // Emit: SubAgentCompleted event
664            self.emit_event(agents_core::events::AgentEvent::SubAgentCompleted(
665                agents_core::events::SubAgentCompletedEvent {
666                    metadata: self.create_event_metadata(),
667                    agent_name: args.agent.clone(),
668                    duration_ms,
669                    result_summary: response_preview.clone(),
670                },
671            ));
672
673            // Log delegation completion
674            tracing::warn!(
675                "✅ SUB-AGENT {} COMPLETED in {:?} - Response: {}",
676                args.agent,
677                duration,
678                response_preview
679            );
680
681            // Decrement delegation depth
682            self.decrement_delegation_depth();
683
684            // Return sub-agent response as text content, not as a separate tool message
685            // This will be incorporated into the LLM's next response naturally
686            let result_text = match response.content {
687                MessageContent::Text(text) => text,
688                MessageContent::Json(json) => json.to_string(),
689            };
690
691            return Ok(ToolResult::text(&ctx, result_text));
692        }
693
694        tracing::error!(
695            "❌ SUB-AGENT NOT FOUND: {} - Available: {:?}",
696            args.agent,
697            available
698        );
699
700        Ok(ToolResult::text(
701            &ctx,
702            format!(
703                "Sub-agent '{}' not found. Available sub-agents: {}",
704                args.agent,
705                available.join(", ")
706            ),
707        ))
708    }
709}
710
711#[derive(Debug, Clone)]
712pub struct SubAgentDescriptor {
713    pub name: String,
714    pub description: String,
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use agents_core::agent::{AgentDescriptor, AgentHandle};
721    use agents_core::messaging::{MessageContent, MessageRole};
722    use serde_json::json;
723
724    struct AppendPromptMiddleware;
725
726    #[async_trait]
727    impl AgentMiddleware for AppendPromptMiddleware {
728        fn id(&self) -> &'static str {
729            "append-prompt"
730        }
731
732        async fn modify_model_request(
733            &self,
734            ctx: &mut MiddlewareContext<'_>,
735        ) -> anyhow::Result<()> {
736            ctx.request.system_prompt.push_str("\nExtra directives.");
737            Ok(())
738        }
739    }
740
741    #[tokio::test]
742    async fn middleware_mutates_prompt() {
743        let mut request = ModelRequest::new(
744            "System",
745            vec![AgentMessage {
746                role: MessageRole::User,
747                content: MessageContent::Text("Hi".into()),
748                metadata: None,
749            }],
750        );
751        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
752        let mut ctx = MiddlewareContext::with_request(&mut request, state);
753        let middleware = AppendPromptMiddleware;
754        middleware.modify_model_request(&mut ctx).await.unwrap();
755        assert!(ctx.request.system_prompt.contains("Extra directives"));
756    }
757
758    #[tokio::test]
759    async fn planning_middleware_registers_write_todos() {
760        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
761        let middleware = PlanningMiddleware::new(state);
762        let tool_names: Vec<_> = middleware
763            .tools()
764            .iter()
765            .map(|t| t.schema().name.clone())
766            .collect();
767        assert!(tool_names.contains(&"write_todos".to_string()));
768
769        let mut request = ModelRequest::new("System", vec![]);
770        let mut ctx = MiddlewareContext::with_request(
771            &mut request,
772            Arc::new(RwLock::new(AgentStateSnapshot::default())),
773        );
774        middleware.modify_model_request(&mut ctx).await.unwrap();
775        assert!(ctx.request.system_prompt.contains("write_todos"));
776    }
777
778    #[tokio::test]
779    async fn filesystem_middleware_registers_tools() {
780        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
781        let middleware = FilesystemMiddleware::new(state);
782        let tool_names: Vec<_> = middleware
783            .tools()
784            .iter()
785            .map(|t| t.schema().name.clone())
786            .collect();
787        for expected in ["ls", "read_file", "write_file", "edit_file"] {
788            assert!(tool_names.contains(&expected.to_string()));
789        }
790    }
791
792    #[tokio::test]
793    async fn summarization_middleware_trims_messages() {
794        let middleware = SummarizationMiddleware::new(2, "Summary note");
795        let mut request = ModelRequest::new(
796            "System",
797            vec![
798                AgentMessage {
799                    role: MessageRole::User,
800                    content: MessageContent::Text("one".into()),
801                    metadata: None,
802                },
803                AgentMessage {
804                    role: MessageRole::Agent,
805                    content: MessageContent::Text("two".into()),
806                    metadata: None,
807                },
808                AgentMessage {
809                    role: MessageRole::User,
810                    content: MessageContent::Text("three".into()),
811                    metadata: None,
812                },
813            ],
814        );
815        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
816        let mut ctx = MiddlewareContext::with_request(&mut request, state);
817        middleware.modify_model_request(&mut ctx).await.unwrap();
818        assert_eq!(ctx.request.messages.len(), 3);
819        match &ctx.request.messages[0].content {
820            MessageContent::Text(text) => assert!(text.contains("Summary note")),
821            other => panic!("expected text, got {other:?}"),
822        }
823    }
824
825    struct StubAgent;
826
827    #[async_trait]
828    impl AgentHandle for StubAgent {
829        async fn describe(&self) -> AgentDescriptor {
830            AgentDescriptor {
831                name: "stub".into(),
832                version: "0.0.1".into(),
833                description: None,
834            }
835        }
836
837        async fn handle_message(
838            &self,
839            _input: AgentMessage,
840            _state: Arc<AgentStateSnapshot>,
841        ) -> anyhow::Result<AgentMessage> {
842            Ok(AgentMessage {
843                role: MessageRole::Agent,
844                content: MessageContent::Text("stub-response".into()),
845                metadata: None,
846            })
847        }
848    }
849
850    #[tokio::test]
851    async fn task_router_reports_unknown_subagent() {
852        let registry = Arc::new(SubAgentRegistry::new(vec![]));
853        let task_tool = TaskRouterTool::new(registry.clone(), None);
854        let state = Arc::new(AgentStateSnapshot::default());
855        let ctx = ToolContext::new(state);
856
857        let response = task_tool
858            .execute(
859                json!({
860                    "instruction": "Do something",
861                    "agent": "unknown"
862                }),
863                ctx,
864            )
865            .await
866            .unwrap();
867
868        match response {
869            ToolResult::Message(msg) => match msg.content {
870                MessageContent::Text(text) => {
871                    assert!(text.contains("Sub-agent 'unknown' not found"))
872                }
873                other => panic!("expected text, got {other:?}"),
874            },
875            _ => panic!("expected message"),
876        }
877    }
878
879    #[tokio::test]
880    async fn subagent_middleware_appends_prompt() {
881        let subagents = vec![SubAgentRegistration {
882            descriptor: SubAgentDescriptor {
883                name: "research-agent".into(),
884                description: "Deep research specialist".into(),
885            },
886            agent: Arc::new(StubAgent),
887        }];
888        let middleware = SubAgentMiddleware::new(subagents);
889
890        let mut request = ModelRequest::new("System", vec![]);
891        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
892        let mut ctx = MiddlewareContext::with_request(&mut request, state);
893        middleware.modify_model_request(&mut ctx).await.unwrap();
894
895        assert!(ctx.request.system_prompt.contains("research-agent"));
896        let tool_names: Vec<_> = middleware
897            .tools()
898            .iter()
899            .map(|t| t.schema().name.clone())
900            .collect();
901        assert!(tool_names.contains(&"task".to_string()));
902    }
903
904    #[tokio::test]
905    async fn task_router_invokes_registered_subagent() {
906        let registry = Arc::new(SubAgentRegistry::new(vec![SubAgentRegistration {
907            descriptor: SubAgentDescriptor {
908                name: "stub-agent".into(),
909                description: "Stub".into(),
910            },
911            agent: Arc::new(StubAgent),
912        }]));
913        let task_tool = TaskRouterTool::new(registry.clone(), None);
914        let state = Arc::new(AgentStateSnapshot::default());
915        let ctx = ToolContext::new(state).with_call_id(Some("call-42".into()));
916        let response = task_tool
917            .execute(
918                json!({
919                    "description": "do work",
920                    "subagent_type": "stub-agent"
921                }),
922                ctx,
923            )
924            .await
925            .unwrap();
926
927        match response {
928            ToolResult::Message(msg) => {
929                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-42");
930                match msg.content {
931                    MessageContent::Text(text) => assert_eq!(text, "stub-response"),
932                    other => panic!("expected text, got {other:?}"),
933                }
934            }
935            _ => panic!("expected message"),
936        }
937    }
938
939    #[tokio::test]
940    async fn human_in_loop_appends_prompt() {
941        let middleware = HumanInLoopMiddleware::new(HashMap::from([(
942            "danger-tool".into(),
943            HitlPolicy {
944                allow_auto: false,
945                note: Some("Requires security review".into()),
946            },
947        )]));
948        let mut request = ModelRequest::new("System", vec![]);
949        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
950        let mut ctx = MiddlewareContext::with_request(&mut request, state);
951        middleware.modify_model_request(&mut ctx).await.unwrap();
952        assert!(ctx
953            .request
954            .system_prompt
955            .contains("danger-tool: Requires security review"));
956    }
957
958    #[tokio::test]
959    async fn anthropic_prompt_caching_moves_system_prompt_to_messages() {
960        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
961        let mut request = ModelRequest::new(
962            "This is the system prompt",
963            vec![AgentMessage {
964                role: MessageRole::User,
965                content: MessageContent::Text("Hello".into()),
966                metadata: None,
967            }],
968        );
969        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
970        let mut ctx = MiddlewareContext::with_request(&mut request, state);
971
972        // Apply the middleware
973        middleware.modify_model_request(&mut ctx).await.unwrap();
974
975        // System prompt should be cleared
976        assert!(ctx.request.system_prompt.is_empty());
977
978        // Should have added a system message with cache control at the beginning
979        assert_eq!(ctx.request.messages.len(), 2);
980
981        let system_message = &ctx.request.messages[0];
982        assert!(matches!(system_message.role, MessageRole::System));
983        assert_eq!(
984            system_message.content.as_text().unwrap(),
985            "This is the system prompt"
986        );
987
988        // Check cache control metadata
989        let metadata = system_message.metadata.as_ref().unwrap();
990        let cache_control = metadata.cache_control.as_ref().unwrap();
991        assert_eq!(cache_control.cache_type, "ephemeral");
992
993        // Original user message should still be there
994        let user_message = &ctx.request.messages[1];
995        assert!(matches!(user_message.role, MessageRole::User));
996        assert_eq!(user_message.content.as_text().unwrap(), "Hello");
997    }
998
999    #[tokio::test]
1000    async fn anthropic_prompt_caching_disabled_with_zero_ttl() {
1001        let middleware = AnthropicPromptCachingMiddleware::new("0", "ignore");
1002        let mut request = ModelRequest::new("This is the system prompt", vec![]);
1003        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1004        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1005
1006        // Apply the middleware
1007        middleware.modify_model_request(&mut ctx).await.unwrap();
1008
1009        // System prompt should be unchanged
1010        assert_eq!(ctx.request.system_prompt, "This is the system prompt");
1011        assert_eq!(ctx.request.messages.len(), 0);
1012    }
1013
1014    #[tokio::test]
1015    async fn anthropic_prompt_caching_no_op_with_empty_system_prompt() {
1016        let middleware = AnthropicPromptCachingMiddleware::new("5m", "ignore");
1017        let mut request = ModelRequest::new(
1018            "",
1019            vec![AgentMessage {
1020                role: MessageRole::User,
1021                content: MessageContent::Text("Hello".into()),
1022                metadata: None,
1023            }],
1024        );
1025        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
1026        let mut ctx = MiddlewareContext::with_request(&mut request, state);
1027
1028        // Apply the middleware
1029        middleware.modify_model_request(&mut ctx).await.unwrap();
1030
1031        // System prompt should remain empty
1032        assert!(ctx.request.system_prompt.is_empty());
1033        // No system message should be added
1034        assert_eq!(ctx.request.messages.len(), 1);
1035    }
1036
1037    // ========== HITL Interrupt Creation Tests ==========
1038
1039    #[tokio::test]
1040    async fn hitl_creates_interrupt_for_disallowed_tool() {
1041        let mut policies = HashMap::new();
1042        policies.insert(
1043            "dangerous_tool".to_string(),
1044            HitlPolicy {
1045                allow_auto: false,
1046                note: Some("Requires security review".to_string()),
1047            },
1048        );
1049
1050        let middleware = HumanInLoopMiddleware::new(policies);
1051        let tool_args = json!({"action": "delete_all"});
1052
1053        let result = middleware
1054            .before_tool_execution("dangerous_tool", &tool_args, "call_123")
1055            .await
1056            .unwrap();
1057
1058        assert!(result.is_some());
1059        let interrupt = result.unwrap();
1060
1061        match interrupt {
1062            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1063                assert_eq!(hitl.tool_name, "dangerous_tool");
1064                assert_eq!(hitl.tool_args, tool_args);
1065                assert_eq!(hitl.call_id, "call_123");
1066                assert_eq!(
1067                    hitl.policy_note,
1068                    Some("Requires security review".to_string())
1069                );
1070            }
1071        }
1072    }
1073
1074    #[tokio::test]
1075    async fn hitl_no_interrupt_for_allowed_tool() {
1076        let mut policies = HashMap::new();
1077        policies.insert(
1078            "safe_tool".to_string(),
1079            HitlPolicy {
1080                allow_auto: true,
1081                note: None,
1082            },
1083        );
1084
1085        let middleware = HumanInLoopMiddleware::new(policies);
1086        let tool_args = json!({"action": "read"});
1087
1088        let result = middleware
1089            .before_tool_execution("safe_tool", &tool_args, "call_456")
1090            .await
1091            .unwrap();
1092
1093        assert!(result.is_none());
1094    }
1095
1096    #[tokio::test]
1097    async fn hitl_no_interrupt_for_unlisted_tool() {
1098        let policies = HashMap::new();
1099        let middleware = HumanInLoopMiddleware::new(policies);
1100        let tool_args = json!({"action": "anything"});
1101
1102        let result = middleware
1103            .before_tool_execution("unlisted_tool", &tool_args, "call_789")
1104            .await
1105            .unwrap();
1106
1107        assert!(result.is_none());
1108    }
1109
1110    #[tokio::test]
1111    async fn hitl_interrupt_includes_correct_details() {
1112        let mut policies = HashMap::new();
1113        policies.insert(
1114            "critical_tool".to_string(),
1115            HitlPolicy {
1116                allow_auto: false,
1117                note: Some("Critical operation - requires approval".to_string()),
1118            },
1119        );
1120
1121        let middleware = HumanInLoopMiddleware::new(policies);
1122        let tool_args = json!({
1123            "database": "production",
1124            "operation": "drop_table"
1125        });
1126
1127        let result = middleware
1128            .before_tool_execution("critical_tool", &tool_args, "call_critical_1")
1129            .await
1130            .unwrap();
1131
1132        assert!(result.is_some());
1133        let interrupt = result.unwrap();
1134
1135        match interrupt {
1136            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1137                assert_eq!(hitl.tool_name, "critical_tool");
1138                assert_eq!(hitl.tool_args["database"], "production");
1139                assert_eq!(hitl.tool_args["operation"], "drop_table");
1140                assert_eq!(hitl.call_id, "call_critical_1");
1141                assert!(hitl.policy_note.is_some());
1142                assert!(hitl.policy_note.unwrap().contains("Critical operation"));
1143                // Verify timestamp exists (created_at field is populated)
1144                // The actual timestamp value is tested in agents-core/hitl.rs
1145            }
1146        }
1147    }
1148
1149    #[tokio::test]
1150    async fn hitl_interrupt_without_policy_note() {
1151        let mut policies = HashMap::new();
1152        policies.insert(
1153            "tool_no_note".to_string(),
1154            HitlPolicy {
1155                allow_auto: false,
1156                note: None,
1157            },
1158        );
1159
1160        let middleware = HumanInLoopMiddleware::new(policies);
1161        let tool_args = json!({"param": "value"});
1162
1163        let result = middleware
1164            .before_tool_execution("tool_no_note", &tool_args, "call_no_note")
1165            .await
1166            .unwrap();
1167
1168        assert!(result.is_some());
1169        let interrupt = result.unwrap();
1170
1171        match interrupt {
1172            agents_core::hitl::AgentInterrupt::HumanInLoop(hitl) => {
1173                assert_eq!(hitl.tool_name, "tool_no_note");
1174                assert_eq!(hitl.policy_note, None);
1175            }
1176        }
1177    }
1178}