Skip to main content

neuron_tool/
middleware.rs

1//! Middleware types for the tool execution pipeline.
2//!
3//! Middleware wraps tool execution with cross-cutting concerns like
4//! validation, permissions, logging, and output formatting.
5//!
6//! The pattern is identical to axum's `from_fn` — each middleware
7//! receives a `Next` that it can call to continue the chain, or
8//! skip to short-circuit.
9
10use std::sync::Arc;
11
12use neuron_types::{
13    ToolContext, ToolDyn, ToolError, ToolOutput, WasmBoxedFuture, WasmCompatSend, WasmCompatSync,
14};
15
16/// A tool call in flight through the middleware pipeline.
17#[derive(Debug, Clone)]
18pub struct ToolCall {
19    /// Unique identifier for this tool call (from the model).
20    pub id: String,
21    /// Name of the tool being called.
22    pub name: String,
23    /// JSON input arguments.
24    pub input: serde_json::Value,
25}
26
27/// Middleware that wraps tool execution.
28///
29/// Each middleware receives the call, context, and a [`Next`] to continue the chain.
30/// Middleware can:
31/// - Inspect/modify the call before passing it on
32/// - Short-circuit by returning without calling `next.run()`
33/// - Inspect/modify the result after the tool executes
34///
35/// Uses boxed futures for dyn-compatibility (heterogeneous middleware collections).
36pub trait ToolMiddleware: WasmCompatSend + WasmCompatSync {
37    /// Process a tool call, optionally delegating to the next middleware/tool.
38    fn process<'a>(
39        &'a self,
40        call: &'a ToolCall,
41        ctx: &'a ToolContext,
42        next: Next<'a>,
43    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
44}
45
46/// The remaining middleware chain plus the underlying tool.
47///
48/// Consumed on call to prevent double-invoke.
49pub struct Next<'a> {
50    tool: &'a dyn ToolDyn,
51    middleware: &'a [Arc<dyn ToolMiddleware>],
52}
53
54impl<'a> Next<'a> {
55    /// Create a new Next from a tool and middleware slice.
56    pub(crate) fn new(tool: &'a dyn ToolDyn, middleware: &'a [Arc<dyn ToolMiddleware>]) -> Self {
57        Self { tool, middleware }
58    }
59
60    /// Continue the middleware chain, eventually calling the tool.
61    pub async fn run(self, call: &'a ToolCall, ctx: &'a ToolContext) -> Result<ToolOutput, ToolError> {
62        if let Some((head, tail)) = self.middleware.split_first() {
63            let next = Next::new(self.tool, tail);
64            head.process(call, ctx, next).await
65        } else {
66            // End of chain — call the actual tool
67            self.tool.call_dyn(call.input.clone(), ctx).await
68        }
69    }
70}
71
72/// Wrapper that implements `ToolMiddleware` for a closure returning a boxed future.
73struct MiddlewareFn<F> {
74    f: F,
75}
76
77impl<F> ToolMiddleware for MiddlewareFn<F>
78where
79    F: for<'a> Fn(
80            &'a ToolCall,
81            &'a ToolContext,
82            Next<'a>,
83        ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
84        + Send
85        + Sync,
86{
87    fn process<'a>(
88        &'a self,
89        call: &'a ToolCall,
90        ctx: &'a ToolContext,
91        next: Next<'a>,
92    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
93        (self.f)(call, ctx, next)
94    }
95}
96
97/// Create middleware from a closure (like axum's `from_fn`).
98///
99/// The closure must return a `Box::pin(async move { ... })` future.
100///
101/// # Example
102///
103/// ```ignore
104/// use neuron_tool::*;
105///
106/// let logging = tool_middleware_fn(|call, ctx, next| {
107///     Box::pin(async move {
108///         println!("calling {}", call.name);
109///         let result = next.run(call, ctx).await;
110///         println!("done");
111///         result
112///     })
113/// });
114/// ```
115#[must_use]
116pub fn tool_middleware_fn<F>(f: F) -> impl ToolMiddleware
117where
118    F: for<'a> Fn(
119            &'a ToolCall,
120            &'a ToolContext,
121            Next<'a>,
122        ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>
123        + Send
124        + Sync,
125{
126    MiddlewareFn { f }
127}