1use std::collections::HashMap;
4use std::sync::Arc;
5
6use neuron_types::{
7 ContentItem, PermissionDecision, PermissionPolicy, ToolContext, ToolError, ToolOutput,
8 WasmBoxedFuture,
9};
10
11use crate::middleware::{Next, ToolCall, ToolMiddleware};
12use crate::registry::ToolRegistry;
13
14pub struct PermissionChecker {
19 policy: Arc<dyn PermissionPolicy>,
20}
21
22impl PermissionChecker {
23 #[must_use]
25 pub fn new(policy: impl PermissionPolicy + 'static) -> Self {
26 Self {
27 policy: Arc::new(policy),
28 }
29 }
30}
31
32impl ToolMiddleware for PermissionChecker {
33 fn process<'a>(
34 &'a self,
35 call: &'a ToolCall,
36 ctx: &'a ToolContext,
37 next: Next<'a>,
38 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
39 Box::pin(async move {
40 match self.policy.check(&call.name, &call.input) {
41 PermissionDecision::Allow => next.run(call, ctx).await,
42 PermissionDecision::Deny(reason) => {
43 Err(ToolError::PermissionDenied(reason))
44 }
45 PermissionDecision::Ask(reason) => {
46 Err(ToolError::PermissionDenied(format!(
47 "requires confirmation: {reason}"
48 )))
49 }
50 }
51 })
52 }
53}
54
55pub struct OutputFormatter {
60 max_chars: usize,
61}
62
63impl OutputFormatter {
64 #[must_use]
66 pub fn new(max_chars: usize) -> Self {
67 Self { max_chars }
68 }
69}
70
71impl ToolMiddleware for OutputFormatter {
72 fn process<'a>(
73 &'a self,
74 call: &'a ToolCall,
75 ctx: &'a ToolContext,
76 next: Next<'a>,
77 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
78 Box::pin(async move {
79 let mut output = next.run(call, ctx).await?;
80
81 output.content = output
83 .content
84 .into_iter()
85 .map(|item| match item {
86 ContentItem::Text(text) if text.len() > self.max_chars => {
87 let mut boundary = self.max_chars;
92 while boundary > 0 && !text.is_char_boundary(boundary) {
93 boundary -= 1;
94 }
95 ContentItem::Text(format!(
96 "{}... [truncated, {} chars total]",
97 &text[..boundary],
98 text.len()
99 ))
100 }
101 other => other,
102 })
103 .collect();
104
105 Ok(output)
106 })
107 }
108}
109
110pub struct SchemaValidator {
117 schemas: HashMap<String, serde_json::Value>,
119}
120
121impl SchemaValidator {
122 #[must_use]
127 pub fn new(registry: &ToolRegistry) -> Self {
128 let schemas = registry
129 .definitions()
130 .into_iter()
131 .map(|def| (def.name, def.input_schema))
132 .collect();
133 Self { schemas }
134 }
135}
136
137impl ToolMiddleware for SchemaValidator {
138 fn process<'a>(
139 &'a self,
140 call: &'a ToolCall,
141 ctx: &'a ToolContext,
142 next: Next<'a>,
143 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
144 Box::pin(async move {
145 if let Some(schema) = self.schemas.get(&call.name) {
146 validate_input(&call.input, schema)?;
147 }
148 next.run(call, ctx).await
149 })
150 }
151}
152
153fn validate_input(
160 input: &serde_json::Value,
161 schema: &serde_json::Value,
162) -> Result<(), ToolError> {
163 let schema_obj = match schema.as_object() {
164 Some(obj) => obj,
165 None => return Ok(()), };
167
168 if let Some(serde_json::Value::String(ty)) = schema_obj.get("type")
170 && ty == "object"
171 && !input.is_object()
172 {
173 return Err(ToolError::InvalidInput(
174 "expected object input".to_string(),
175 ));
176 }
177
178 let input_obj = match input.as_object() {
179 Some(obj) => obj,
180 None => return Ok(()), };
182
183 if let Some(serde_json::Value::Array(required)) = schema_obj.get("required") {
185 for field in required {
186 if let Some(field_name) = field.as_str()
187 && !input_obj.contains_key(field_name)
188 {
189 return Err(ToolError::InvalidInput(format!(
190 "missing required field: {field_name}"
191 )));
192 }
193 }
194 }
195
196 if let Some(serde_json::Value::Object(properties)) = schema_obj.get("properties") {
198 for (field_name, prop_schema) in properties {
199 if let Some(value) = input_obj.get(field_name)
200 && let Some(serde_json::Value::String(expected_type)) =
201 prop_schema.get("type")
202 && !json_type_matches(value, expected_type)
203 {
204 return Err(ToolError::InvalidInput(format!(
205 "field '{field_name}' expected type '{expected_type}', \
206 got {}",
207 json_type_name(value)
208 )));
209 }
210 }
211 }
212
213 Ok(())
214}
215
216fn json_type_matches(value: &serde_json::Value, expected: &str) -> bool {
218 match expected {
219 "string" => value.is_string(),
220 "number" => value.is_number(),
221 "integer" => value.is_i64() || value.is_u64(),
222 "boolean" => value.is_boolean(),
223 "array" => value.is_array(),
224 "object" => value.is_object(),
225 "null" => value.is_null(),
226 _ => true, }
228}
229
230fn json_type_name(value: &serde_json::Value) -> &'static str {
232 match value {
233 serde_json::Value::Null => "null",
234 serde_json::Value::Bool(_) => "boolean",
235 serde_json::Value::Number(_) => "number",
236 serde_json::Value::String(_) => "string",
237 serde_json::Value::Array(_) => "array",
238 serde_json::Value::Object(_) => "object",
239 }
240}