Skip to main content

neuron_tool/
builtin.rs

1//! Built-in middleware implementations.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use neuron_types::{
8    ContentItem, PermissionDecision, PermissionPolicy, ToolContext, ToolError, ToolOutput,
9    WasmBoxedFuture,
10};
11
12use crate::middleware::{Next, ToolCall, ToolMiddleware};
13use crate::registry::ToolRegistry;
14
15/// Middleware that checks tool call permissions against a [`PermissionPolicy`].
16///
17/// If the policy returns `Deny`, the tool call is rejected with `ToolError::PermissionDenied`.
18/// If the policy returns `Ask`, the tool call is rejected (external confirmation not handled here).
19pub struct PermissionChecker {
20    policy: Arc<dyn PermissionPolicy>,
21}
22
23impl PermissionChecker {
24    /// Create a new permission checker with the given policy.
25    #[must_use]
26    pub fn new(policy: impl PermissionPolicy + 'static) -> Self {
27        Self {
28            policy: Arc::new(policy),
29        }
30    }
31}
32
33impl ToolMiddleware for PermissionChecker {
34    fn process<'a>(
35        &'a self,
36        call: &'a ToolCall,
37        ctx: &'a ToolContext,
38        next: Next<'a>,
39    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
40        Box::pin(async move {
41            match self.policy.check(&call.name, &call.input) {
42                PermissionDecision::Allow => next.run(call, ctx).await,
43                PermissionDecision::Deny(reason) => Err(ToolError::PermissionDenied(reason)),
44                PermissionDecision::Ask(reason) => Err(ToolError::PermissionDenied(format!(
45                    "requires confirmation: {reason}"
46                ))),
47            }
48        })
49    }
50}
51
52/// Middleware that truncates tool output to a maximum character length.
53///
54/// Long tool outputs can consume excessive tokens in the context window.
55/// This middleware truncates text content items that exceed the limit.
56pub struct OutputFormatter {
57    max_chars: usize,
58}
59
60impl OutputFormatter {
61    /// Create a new output formatter with the given character limit.
62    #[must_use]
63    pub fn new(max_chars: usize) -> Self {
64        Self { max_chars }
65    }
66}
67
68impl ToolMiddleware for OutputFormatter {
69    fn process<'a>(
70        &'a self,
71        call: &'a ToolCall,
72        ctx: &'a ToolContext,
73        next: Next<'a>,
74    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
75        Box::pin(async move {
76            let mut output = next.run(call, ctx).await?;
77
78            // Truncate text content items that exceed the limit
79            output.content = output
80                .content
81                .into_iter()
82                .map(|item| match item {
83                    ContentItem::Text(text) if text.len() > self.max_chars => {
84                        // Find the nearest char boundary at or before max_chars
85                        // to avoid slicing in the middle of a multi-byte UTF-8
86                        // character. This is a stable polyfill for
87                        // str::floor_char_boundary (stabilized in 1.93).
88                        let mut boundary = self.max_chars;
89                        while boundary > 0 && !text.is_char_boundary(boundary) {
90                            boundary -= 1;
91                        }
92                        ContentItem::Text(format!(
93                            "{}... [truncated, {} chars total]",
94                            &text[..boundary],
95                            text.len()
96                        ))
97                    }
98                    other => other,
99                })
100                .collect();
101
102            Ok(output)
103        })
104    }
105}
106
107/// Middleware that validates tool call input against the tool's JSON Schema.
108///
109/// Performs lightweight structural validation: checks that the input is an
110/// object, required fields are present, and property types match the schema.
111/// This catches obvious input errors before the tool executes, without
112/// depending on a full JSON Schema validation library.
113pub struct SchemaValidator {
114    /// Map of tool name to its input_schema JSON value.
115    schemas: HashMap<String, serde_json::Value>,
116}
117
118impl SchemaValidator {
119    /// Create a new schema validator from the current tool registry.
120    ///
121    /// Snapshots all tool definitions at construction time. Tools registered
122    /// after this call will not be validated.
123    #[must_use]
124    pub fn new(registry: &ToolRegistry) -> Self {
125        let schemas = registry
126            .definitions()
127            .into_iter()
128            .map(|def| (def.name, def.input_schema))
129            .collect();
130        Self { schemas }
131    }
132}
133
134impl ToolMiddleware for SchemaValidator {
135    fn process<'a>(
136        &'a self,
137        call: &'a ToolCall,
138        ctx: &'a ToolContext,
139        next: Next<'a>,
140    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
141        Box::pin(async move {
142            if let Some(schema) = self.schemas.get(&call.name) {
143                validate_input(&call.input, schema)?;
144            }
145            next.run(call, ctx).await
146        })
147    }
148}
149
150/// Validate a JSON input value against a JSON Schema object.
151///
152/// Performs lightweight structural checks:
153/// - Input must be an object (if schema says `"type": "object"`)
154/// - All `"required"` fields must be present
155/// - Property types must match the schema's `"type"` declarations
156fn validate_input(input: &serde_json::Value, schema: &serde_json::Value) -> Result<(), ToolError> {
157    let schema_obj = match schema.as_object() {
158        Some(obj) => obj,
159        None => return Ok(()), // No schema object to validate against
160    };
161
162    // Check that the input is an object if schema declares type: "object"
163    if let Some(serde_json::Value::String(ty)) = schema_obj.get("type")
164        && ty == "object"
165        && !input.is_object()
166    {
167        return Err(ToolError::InvalidInput("expected object input".to_string()));
168    }
169
170    let input_obj = match input.as_object() {
171        Some(obj) => obj,
172        None => return Ok(()), // Non-object input, nothing more to validate
173    };
174
175    // Check required fields
176    if let Some(serde_json::Value::Array(required)) = schema_obj.get("required") {
177        for field in required {
178            if let Some(field_name) = field.as_str()
179                && !input_obj.contains_key(field_name)
180            {
181                return Err(ToolError::InvalidInput(format!(
182                    "missing required field: {field_name}"
183                )));
184            }
185        }
186    }
187
188    // Check property types
189    if let Some(serde_json::Value::Object(properties)) = schema_obj.get("properties") {
190        for (field_name, prop_schema) in properties {
191            if let Some(value) = input_obj.get(field_name)
192                && let Some(serde_json::Value::String(expected_type)) = prop_schema.get("type")
193                && !json_type_matches(value, expected_type)
194            {
195                return Err(ToolError::InvalidInput(format!(
196                    "field '{field_name}' expected type '{expected_type}', \
197                     got {}",
198                    json_type_name(value)
199                )));
200            }
201        }
202    }
203
204    Ok(())
205}
206
207/// Check if a JSON value matches the expected JSON Schema type string.
208fn json_type_matches(value: &serde_json::Value, expected: &str) -> bool {
209    match expected {
210        "string" => value.is_string(),
211        "number" => value.is_number(),
212        "integer" => value.is_i64() || value.is_u64(),
213        "boolean" => value.is_boolean(),
214        "array" => value.is_array(),
215        "object" => value.is_object(),
216        "null" => value.is_null(),
217        _ => true, // Unknown type, pass through
218    }
219}
220
221/// Return the JSON type name for a value (for error messages).
222fn json_type_name(value: &serde_json::Value) -> &'static str {
223    match value {
224        serde_json::Value::Null => "null",
225        serde_json::Value::Bool(_) => "boolean",
226        serde_json::Value::Number(_) => "number",
227        serde_json::Value::String(_) => "string",
228        serde_json::Value::Array(_) => "array",
229        serde_json::Value::Object(_) => "object",
230    }
231}
232
233/// Middleware that enforces a timeout on tool execution.
234///
235/// Wraps the downstream tool call in [`tokio::time::timeout`]. If the tool
236/// does not complete within the configured duration, returns
237/// `ToolError::ExecutionFailed` with a descriptive message so the model
238/// can adapt.
239///
240/// Per-tool overrides allow different timeouts for tools with known
241/// different latency profiles (e.g., web scraping vs. simple computation).
242pub struct TimeoutMiddleware {
243    default_timeout: Duration,
244    per_tool: HashMap<String, Duration>,
245}
246
247impl TimeoutMiddleware {
248    /// Create a new timeout middleware with the given default timeout.
249    #[must_use]
250    pub fn new(default_timeout: Duration) -> Self {
251        Self {
252            default_timeout,
253            per_tool: HashMap::new(),
254        }
255    }
256
257    /// Set a per-tool timeout override.
258    #[must_use]
259    pub fn with_tool_timeout(mut self, tool_name: impl Into<String>, timeout: Duration) -> Self {
260        self.per_tool.insert(tool_name.into(), timeout);
261        self
262    }
263}
264
265impl ToolMiddleware for TimeoutMiddleware {
266    fn process<'a>(
267        &'a self,
268        call: &'a ToolCall,
269        ctx: &'a ToolContext,
270        next: Next<'a>,
271    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
272        Box::pin(async move {
273            let timeout = self
274                .per_tool
275                .get(&call.name)
276                .unwrap_or(&self.default_timeout);
277            match tokio::time::timeout(*timeout, next.run(call, ctx)).await {
278                Ok(result) => result,
279                Err(_elapsed) => Err(ToolError::ExecutionFailed(Box::new(std::io::Error::new(
280                    std::io::ErrorKind::TimedOut,
281                    format!(
282                        "tool '{}' timed out after {:.1}s",
283                        call.name,
284                        timeout.as_secs_f64()
285                    ),
286                )))),
287            }
288        })
289    }
290}
291
292/// Middleware that validates structured output from a tool against a JSON Schema.
293///
294/// When attached to a "result" tool, this validates the model's JSON input
295/// against the expected schema. On validation failure, returns
296/// [`ToolError::ModelRetry`] with a description of what went wrong so the
297/// model can self-correct.
298///
299/// This implements the tool-based structured output pattern used by
300/// instructor, Pydantic AI, and Rig: inject a tool with the output schema,
301/// force the model to call it, and validate.
302pub struct StructuredOutputValidator {
303    schema: serde_json::Value,
304    max_retries: usize,
305}
306
307impl StructuredOutputValidator {
308    /// Create a new structured output validator.
309    ///
310    /// The `schema` should be a JSON Schema object describing the expected
311    /// output shape. `max_retries` limits how many times the model can
312    /// retry on validation failure (0 means fail immediately on first error).
313    #[must_use]
314    pub fn new(schema: serde_json::Value, max_retries: usize) -> Self {
315        Self {
316            schema,
317            max_retries,
318        }
319    }
320}
321
322impl ToolMiddleware for StructuredOutputValidator {
323    fn process<'a>(
324        &'a self,
325        call: &'a ToolCall,
326        ctx: &'a ToolContext,
327        next: Next<'a>,
328    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
329        Box::pin(async move {
330            // Validate the input (which IS the structured output from the model)
331            // against the schema before passing to the tool
332            if let Err(e) = validate_input(&call.input, &self.schema) {
333                // Return ModelRetry so the model can self-correct
334                return Err(ToolError::ModelRetry(format!(
335                    "Output validation failed: {e}. Please fix the output to match the schema."
336                )));
337            }
338            next.run(call, ctx).await
339        })
340    }
341}
342
343/// Tracks retry count for structured output validation.
344///
345/// Wraps [`StructuredOutputValidator`] and enforces a maximum number of
346/// retries. After `max_retries` validation failures, converts the error
347/// to `ToolError::InvalidInput` (non-retryable).
348pub struct RetryLimitedValidator {
349    inner: StructuredOutputValidator,
350    attempts: std::sync::atomic::AtomicUsize,
351}
352
353impl RetryLimitedValidator {
354    /// Create a new retry-limited validator wrapping a [`StructuredOutputValidator`].
355    #[must_use]
356    pub fn new(validator: StructuredOutputValidator) -> Self {
357        Self {
358            inner: validator,
359            attempts: std::sync::atomic::AtomicUsize::new(0),
360        }
361    }
362}
363
364impl ToolMiddleware for RetryLimitedValidator {
365    fn process<'a>(
366        &'a self,
367        call: &'a ToolCall,
368        ctx: &'a ToolContext,
369        next: Next<'a>,
370    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
371        Box::pin(async move {
372            if let Err(e) = validate_input(&call.input, &self.inner.schema) {
373                let attempt = self
374                    .attempts
375                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
376                if attempt >= self.inner.max_retries {
377                    return Err(ToolError::InvalidInput(format!(
378                        "Output validation failed after {} retries: {e}",
379                        self.inner.max_retries
380                    )));
381                }
382                return Err(ToolError::ModelRetry(format!(
383                    "Output validation failed (attempt {}/{}): {e}. \
384                     Please fix the output to match the schema.",
385                    attempt + 1,
386                    self.inner.max_retries
387                )));
388            }
389            // Reset attempt counter on success
390            self.attempts.store(0, std::sync::atomic::Ordering::Relaxed);
391            next.run(call, ctx).await
392        })
393    }
394}