Skip to main content

neuron_tool/
builtin.rs

1//! Built-in middleware implementations.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use neuron_types::{
7    ContentItem, PermissionDecision, PermissionPolicy, ToolContext, ToolError, ToolOutput,
8    WasmBoxedFuture,
9};
10
11use crate::middleware::{Next, ToolCall, ToolMiddleware};
12use crate::registry::ToolRegistry;
13
14/// Middleware that checks tool call permissions against a [`PermissionPolicy`].
15///
16/// If the policy returns `Deny`, the tool call is rejected with `ToolError::PermissionDenied`.
17/// If the policy returns `Ask`, the tool call is rejected (external confirmation not handled here).
18pub struct PermissionChecker {
19    policy: Arc<dyn PermissionPolicy>,
20}
21
22impl PermissionChecker {
23    /// Create a new permission checker with the given policy.
24    #[must_use]
25    pub fn new(policy: impl PermissionPolicy + 'static) -> Self {
26        Self {
27            policy: Arc::new(policy),
28        }
29    }
30}
31
32impl ToolMiddleware for PermissionChecker {
33    fn process<'a>(
34        &'a self,
35        call: &'a ToolCall,
36        ctx: &'a ToolContext,
37        next: Next<'a>,
38    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
39        Box::pin(async move {
40            match self.policy.check(&call.name, &call.input) {
41                PermissionDecision::Allow => next.run(call, ctx).await,
42                PermissionDecision::Deny(reason) => {
43                    Err(ToolError::PermissionDenied(reason))
44                }
45                PermissionDecision::Ask(reason) => {
46                    Err(ToolError::PermissionDenied(format!(
47                        "requires confirmation: {reason}"
48                    )))
49                }
50            }
51        })
52    }
53}
54
55/// Middleware that truncates tool output to a maximum character length.
56///
57/// Long tool outputs can consume excessive tokens in the context window.
58/// This middleware truncates text content items that exceed the limit.
59pub struct OutputFormatter {
60    max_chars: usize,
61}
62
63impl OutputFormatter {
64    /// Create a new output formatter with the given character limit.
65    #[must_use]
66    pub fn new(max_chars: usize) -> Self {
67        Self { max_chars }
68    }
69}
70
71impl ToolMiddleware for OutputFormatter {
72    fn process<'a>(
73        &'a self,
74        call: &'a ToolCall,
75        ctx: &'a ToolContext,
76        next: Next<'a>,
77    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
78        Box::pin(async move {
79            let mut output = next.run(call, ctx).await?;
80
81            // Truncate text content items that exceed the limit
82            output.content = output
83                .content
84                .into_iter()
85                .map(|item| match item {
86                    ContentItem::Text(text) if text.len() > self.max_chars => {
87                        // Find the nearest char boundary at or before max_chars
88                        // to avoid slicing in the middle of a multi-byte UTF-8
89                        // character. This is a stable polyfill for
90                        // str::floor_char_boundary (stabilized in 1.93).
91                        let mut boundary = self.max_chars;
92                        while boundary > 0 && !text.is_char_boundary(boundary) {
93                            boundary -= 1;
94                        }
95                        ContentItem::Text(format!(
96                            "{}... [truncated, {} chars total]",
97                            &text[..boundary],
98                            text.len()
99                        ))
100                    }
101                    other => other,
102                })
103                .collect();
104
105            Ok(output)
106        })
107    }
108}
109
110/// Middleware that validates tool call input against the tool's JSON Schema.
111///
112/// Performs lightweight structural validation: checks that the input is an
113/// object, required fields are present, and property types match the schema.
114/// This catches obvious input errors before the tool executes, without
115/// depending on a full JSON Schema validation library.
116pub struct SchemaValidator {
117    /// Map of tool name to its input_schema JSON value.
118    schemas: HashMap<String, serde_json::Value>,
119}
120
121impl SchemaValidator {
122    /// Create a new schema validator from the current tool registry.
123    ///
124    /// Snapshots all tool definitions at construction time. Tools registered
125    /// after this call will not be validated.
126    #[must_use]
127    pub fn new(registry: &ToolRegistry) -> Self {
128        let schemas = registry
129            .definitions()
130            .into_iter()
131            .map(|def| (def.name, def.input_schema))
132            .collect();
133        Self { schemas }
134    }
135}
136
137impl ToolMiddleware for SchemaValidator {
138    fn process<'a>(
139        &'a self,
140        call: &'a ToolCall,
141        ctx: &'a ToolContext,
142        next: Next<'a>,
143    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
144        Box::pin(async move {
145            if let Some(schema) = self.schemas.get(&call.name) {
146                validate_input(&call.input, schema)?;
147            }
148            next.run(call, ctx).await
149        })
150    }
151}
152
153/// Validate a JSON input value against a JSON Schema object.
154///
155/// Performs lightweight structural checks:
156/// - Input must be an object (if schema says `"type": "object"`)
157/// - All `"required"` fields must be present
158/// - Property types must match the schema's `"type"` declarations
159fn validate_input(
160    input: &serde_json::Value,
161    schema: &serde_json::Value,
162) -> Result<(), ToolError> {
163    let schema_obj = match schema.as_object() {
164        Some(obj) => obj,
165        None => return Ok(()), // No schema object to validate against
166    };
167
168    // Check that the input is an object if schema declares type: "object"
169    if let Some(serde_json::Value::String(ty)) = schema_obj.get("type")
170        && ty == "object"
171        && !input.is_object()
172    {
173        return Err(ToolError::InvalidInput(
174            "expected object input".to_string(),
175        ));
176    }
177
178    let input_obj = match input.as_object() {
179        Some(obj) => obj,
180        None => return Ok(()), // Non-object input, nothing more to validate
181    };
182
183    // Check required fields
184    if let Some(serde_json::Value::Array(required)) = schema_obj.get("required") {
185        for field in required {
186            if let Some(field_name) = field.as_str()
187                && !input_obj.contains_key(field_name)
188            {
189                return Err(ToolError::InvalidInput(format!(
190                    "missing required field: {field_name}"
191                )));
192            }
193        }
194    }
195
196    // Check property types
197    if let Some(serde_json::Value::Object(properties)) = schema_obj.get("properties") {
198        for (field_name, prop_schema) in properties {
199            if let Some(value) = input_obj.get(field_name)
200                && let Some(serde_json::Value::String(expected_type)) =
201                    prop_schema.get("type")
202                && !json_type_matches(value, expected_type)
203            {
204                return Err(ToolError::InvalidInput(format!(
205                    "field '{field_name}' expected type '{expected_type}', \
206                     got {}",
207                    json_type_name(value)
208                )));
209            }
210        }
211    }
212
213    Ok(())
214}
215
216/// Check if a JSON value matches the expected JSON Schema type string.
217fn json_type_matches(value: &serde_json::Value, expected: &str) -> bool {
218    match expected {
219        "string" => value.is_string(),
220        "number" => value.is_number(),
221        "integer" => value.is_i64() || value.is_u64(),
222        "boolean" => value.is_boolean(),
223        "array" => value.is_array(),
224        "object" => value.is_object(),
225        "null" => value.is_null(),
226        _ => true, // Unknown type, pass through
227    }
228}
229
230/// Return the JSON type name for a value (for error messages).
231fn json_type_name(value: &serde_json::Value) -> &'static str {
232    match value {
233        serde_json::Value::Null => "null",
234        serde_json::Value::Bool(_) => "boolean",
235        serde_json::Value::Number(_) => "number",
236        serde_json::Value::String(_) => "string",
237        serde_json::Value::Array(_) => "array",
238        serde_json::Value::Object(_) => "object",
239    }
240}