neuron_tool/middleware.rs
1//! Middleware types for the tool execution pipeline.
2//!
3//! Middleware wraps tool execution with cross-cutting concerns like
4//! validation, permissions, logging, and output formatting.
5//!
6//! The pattern is identical to axum's `from_fn` — each middleware
7//! receives a `Next` that it can call to continue the chain, or
8//! skip to short-circuit.
9
10use std::sync::Arc;
11
12use neuron_types::{
13 ToolContext, ToolDyn, ToolError, ToolOutput, WasmBoxedFuture, WasmCompatSend, WasmCompatSync,
14};
15
16/// A tool call in flight through the middleware pipeline.
17#[derive(Debug, Clone)]
18pub struct ToolCall {
19 /// Unique identifier for this tool call (from the model).
20 pub id: String,
21 /// Name of the tool being called.
22 pub name: String,
23 /// JSON input arguments.
24 pub input: serde_json::Value,
25}
26
27/// Middleware that wraps tool execution.
28///
29/// Each middleware receives the call, context, and a [`Next`] to continue the chain.
30/// Middleware can:
31/// - Inspect/modify the call before passing it on
32/// - Short-circuit by returning without calling `next.run()`
33/// - Inspect/modify the result after the tool executes
34///
35/// Uses boxed futures for dyn-compatibility (heterogeneous middleware collections).
36pub trait ToolMiddleware: WasmCompatSend + WasmCompatSync {
37 /// Process a tool call, optionally delegating to the next middleware/tool.
38 fn process<'a>(
39 &'a self,
40 call: &'a ToolCall,
41 ctx: &'a ToolContext,
42 next: Next<'a>,
43 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
44}
45
46/// The remaining middleware chain plus the underlying tool.
47///
48/// Consumed on call to prevent double-invoke.
49pub struct Next<'a> {
50 tool: &'a dyn ToolDyn,
51 middleware: &'a [Arc<dyn ToolMiddleware>],
52}
53
54impl<'a> Next<'a> {
55 /// Create a new Next from a tool and middleware slice.
56 pub(crate) fn new(tool: &'a dyn ToolDyn, middleware: &'a [Arc<dyn ToolMiddleware>]) -> Self {
57 Self { tool, middleware }
58 }
59
60 /// Continue the middleware chain, eventually calling the tool.
61 pub async fn run(
62 self,
63 call: &'a ToolCall,
64 ctx: &'a ToolContext,
65 ) -> Result<ToolOutput, ToolError> {
66 if let Some((head, tail)) = self.middleware.split_first() {
67 let next = Next::new(self.tool, tail);
68 head.process(call, ctx, next).await
69 } else {
70 // End of chain — call the actual tool
71 self.tool.call_dyn(call.input.clone(), ctx).await
72 }
73 }
74}
75
76/// Wrapper that implements `ToolMiddleware` for a closure returning a boxed future.
77struct MiddlewareFn<F> {
78 f: F,
79}
80
81impl<F> ToolMiddleware for MiddlewareFn<F>
82where
83 F: for<'a> Fn(
84 &'a ToolCall,
85 &'a ToolContext,
86 Next<'a>,
87 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
88 + Send
89 + Sync,
90{
91 fn process<'a>(
92 &'a self,
93 call: &'a ToolCall,
94 ctx: &'a ToolContext,
95 next: Next<'a>,
96 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
97 (self.f)(call, ctx, next)
98 }
99}
100
101/// Create middleware from a closure (like axum's `from_fn`).
102///
103/// The closure must return a `Box::pin(async move { ... })` future.
104///
105/// # Example
106///
107/// ```ignore
108/// use neuron_tool::*;
109///
110/// let logging = tool_middleware_fn(|call, ctx, next| {
111/// Box::pin(async move {
112/// println!("calling {}", call.name);
113/// let result = next.run(call, ctx).await;
114/// println!("done");
115/// result
116/// })
117/// });
118/// ```
119#[must_use]
120pub fn tool_middleware_fn<F>(f: F) -> impl ToolMiddleware
121where
122 F: for<'a> Fn(
123 &'a ToolCall,
124 &'a ToolContext,
125 Next<'a>,
126 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
127 + Send
128 + Sync,
129{
130 MiddlewareFn { f }
131}