neuron_tool/middleware.rs
1//! Middleware types for the tool execution pipeline.
2//!
3//! Middleware wraps tool execution with cross-cutting concerns like
4//! validation, permissions, logging, and output formatting.
5//!
6//! The pattern is identical to axum's `from_fn` — each middleware
7//! receives a `Next` that it can call to continue the chain, or
8//! skip to short-circuit.
9
10use std::sync::Arc;
11
12use neuron_types::{
13 ToolContext, ToolDyn, ToolError, ToolOutput, WasmBoxedFuture, WasmCompatSend, WasmCompatSync,
14};
15
16/// A tool call in flight through the middleware pipeline.
17#[derive(Debug, Clone)]
18pub struct ToolCall {
19 /// Unique identifier for this tool call (from the model).
20 pub id: String,
21 /// Name of the tool being called.
22 pub name: String,
23 /// JSON input arguments.
24 pub input: serde_json::Value,
25}
26
27/// Middleware that wraps tool execution.
28///
29/// Each middleware receives the call, context, and a [`Next`] to continue the chain.
30/// Middleware can:
31/// - Inspect/modify the call before passing it on
32/// - Short-circuit by returning without calling `next.run()`
33/// - Inspect/modify the result after the tool executes
34///
35/// Uses boxed futures for dyn-compatibility (heterogeneous middleware collections).
36pub trait ToolMiddleware: WasmCompatSend + WasmCompatSync {
37 /// Process a tool call, optionally delegating to the next middleware/tool.
38 fn process<'a>(
39 &'a self,
40 call: &'a ToolCall,
41 ctx: &'a ToolContext,
42 next: Next<'a>,
43 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
44}
45
46/// The remaining middleware chain plus the underlying tool.
47///
48/// Consumed on call to prevent double-invoke.
49pub struct Next<'a> {
50 tool: &'a dyn ToolDyn,
51 middleware: &'a [Arc<dyn ToolMiddleware>],
52}
53
54impl<'a> Next<'a> {
55 /// Create a new Next from a tool and middleware slice.
56 pub(crate) fn new(tool: &'a dyn ToolDyn, middleware: &'a [Arc<dyn ToolMiddleware>]) -> Self {
57 Self { tool, middleware }
58 }
59
60 /// Continue the middleware chain, eventually calling the tool.
61 pub async fn run(self, call: &'a ToolCall, ctx: &'a ToolContext) -> Result<ToolOutput, ToolError> {
62 if let Some((head, tail)) = self.middleware.split_first() {
63 let next = Next::new(self.tool, tail);
64 head.process(call, ctx, next).await
65 } else {
66 // End of chain — call the actual tool
67 self.tool.call_dyn(call.input.clone(), ctx).await
68 }
69 }
70}
71
72/// Wrapper that implements `ToolMiddleware` for a closure returning a boxed future.
73struct MiddlewareFn<F> {
74 f: F,
75}
76
77impl<F> ToolMiddleware for MiddlewareFn<F>
78where
79 F: for<'a> Fn(
80 &'a ToolCall,
81 &'a ToolContext,
82 Next<'a>,
83 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
84 + Send
85 + Sync,
86{
87 fn process<'a>(
88 &'a self,
89 call: &'a ToolCall,
90 ctx: &'a ToolContext,
91 next: Next<'a>,
92 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
93 (self.f)(call, ctx, next)
94 }
95}
96
97/// Create middleware from a closure (like axum's `from_fn`).
98///
99/// The closure must return a `Box::pin(async move { ... })` future.
100///
101/// # Example
102///
103/// ```ignore
104/// use neuron_tool::*;
105///
106/// let logging = tool_middleware_fn(|call, ctx, next| {
107/// Box::pin(async move {
108/// println!("calling {}", call.name);
109/// let result = next.run(call, ctx).await;
110/// println!("done");
111/// result
112/// })
113/// });
114/// ```
115#[must_use]
116pub fn tool_middleware_fn<F>(f: F) -> impl ToolMiddleware
117where
118 F: for<'a> Fn(
119 &'a ToolCall,
120 &'a ToolContext,
121 Next<'a>,
122 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
123 + Send
124 + Sync,
125{
126 MiddlewareFn { f }
127}