Skip to main content

neuron_tool/
middleware.rs

1//! Middleware types for the tool execution pipeline.
2//!
3//! Middleware wraps tool execution with cross-cutting concerns like
4//! validation, permissions, logging, and output formatting.
5//!
6//! The pattern is identical to axum's `from_fn` — each middleware
7//! receives a `Next` that it can call to continue the chain, or
8//! skip to short-circuit.
9
10use std::sync::Arc;
11
12use neuron_types::{
13    ToolContext, ToolDyn, ToolError, ToolOutput, WasmBoxedFuture, WasmCompatSend, WasmCompatSync,
14};
15
16/// A tool call in flight through the middleware pipeline.
17#[derive(Debug, Clone)]
18pub struct ToolCall {
19    /// Unique identifier for this tool call (from the model).
20    pub id: String,
21    /// Name of the tool being called.
22    pub name: String,
23    /// JSON input arguments.
24    pub input: serde_json::Value,
25}
26
27/// Middleware that wraps tool execution.
28///
29/// Each middleware receives the call, context, and a [`Next`] to continue the chain.
30/// Middleware can:
31/// - Inspect/modify the call before passing it on
32/// - Short-circuit by returning without calling `next.run()`
33/// - Inspect/modify the result after the tool executes
34///
35/// Uses boxed futures for dyn-compatibility (heterogeneous middleware collections).
36pub trait ToolMiddleware: WasmCompatSend + WasmCompatSync {
37    /// Process a tool call, optionally delegating to the next middleware/tool.
38    fn process<'a>(
39        &'a self,
40        call: &'a ToolCall,
41        ctx: &'a ToolContext,
42        next: Next<'a>,
43    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
44}
45
46/// The remaining middleware chain plus the underlying tool.
47///
48/// Consumed on call to prevent double-invoke.
49pub struct Next<'a> {
50    tool: &'a dyn ToolDyn,
51    middleware: &'a [Arc<dyn ToolMiddleware>],
52}
53
54impl<'a> Next<'a> {
55    /// Create a new Next from a tool and middleware slice.
56    pub(crate) fn new(tool: &'a dyn ToolDyn, middleware: &'a [Arc<dyn ToolMiddleware>]) -> Self {
57        Self { tool, middleware }
58    }
59
60    /// Continue the middleware chain, eventually calling the tool.
61    pub async fn run(
62        self,
63        call: &'a ToolCall,
64        ctx: &'a ToolContext,
65    ) -> Result<ToolOutput, ToolError> {
66        if let Some((head, tail)) = self.middleware.split_first() {
67            let next = Next::new(self.tool, tail);
68            head.process(call, ctx, next).await
69        } else {
70            // End of chain — call the actual tool
71            self.tool.call_dyn(call.input.clone(), ctx).await
72        }
73    }
74}
75
76/// Wrapper that implements `ToolMiddleware` for a closure returning a boxed future.
77struct MiddlewareFn<F> {
78    f: F,
79}
80
81impl<F> ToolMiddleware for MiddlewareFn<F>
82where
83    F: for<'a> Fn(
84            &'a ToolCall,
85            &'a ToolContext,
86            Next<'a>,
87        ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
88        + Send
89        + Sync,
90{
91    fn process<'a>(
92        &'a self,
93        call: &'a ToolCall,
94        ctx: &'a ToolContext,
95        next: Next<'a>,
96    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
97        (self.f)(call, ctx, next)
98    }
99}
100
101/// Create middleware from a closure (like axum's `from_fn`).
102///
103/// The closure must return a `Box::pin(async move { ... })` future.
104///
105/// # Example
106///
107/// ```ignore
108/// use neuron_tool::*;
109///
110/// let logging = tool_middleware_fn(|call, ctx, next| {
111///     Box::pin(async move {
112///         println!("calling {}", call.name);
113///         let result = next.run(call, ctx).await;
114///         println!("done");
115///         result
116///     })
117/// });
118/// ```
119#[must_use]
120pub fn tool_middleware_fn<F>(f: F) -> impl ToolMiddleware
121where
122    F: for<'a> Fn(
123            &'a ToolCall,
124            &'a ToolContext,
125            Next<'a>,
126        ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
127        + Send
128        + Sync,
129{
130    MiddlewareFn { f }
131}