1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use neuron_types::{
8 ContentItem, PermissionDecision, PermissionPolicy, ToolContext, ToolError, ToolOutput,
9 WasmBoxedFuture,
10};
11
12use crate::middleware::{Next, ToolCall, ToolMiddleware};
13use crate::registry::ToolRegistry;
14
15pub struct PermissionChecker {
20 policy: Arc<dyn PermissionPolicy>,
21}
22
23impl PermissionChecker {
24 #[must_use]
26 pub fn new(policy: impl PermissionPolicy + 'static) -> Self {
27 Self {
28 policy: Arc::new(policy),
29 }
30 }
31}
32
33impl ToolMiddleware for PermissionChecker {
34 fn process<'a>(
35 &'a self,
36 call: &'a ToolCall,
37 ctx: &'a ToolContext,
38 next: Next<'a>,
39 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
40 Box::pin(async move {
41 match self.policy.check(&call.name, &call.input) {
42 PermissionDecision::Allow => next.run(call, ctx).await,
43 PermissionDecision::Deny(reason) => Err(ToolError::PermissionDenied(reason)),
44 PermissionDecision::Ask(reason) => Err(ToolError::PermissionDenied(format!(
45 "requires confirmation: {reason}"
46 ))),
47 }
48 })
49 }
50}
51
52pub struct OutputFormatter {
57 max_chars: usize,
58}
59
60impl OutputFormatter {
61 #[must_use]
63 pub fn new(max_chars: usize) -> Self {
64 Self { max_chars }
65 }
66}
67
68impl ToolMiddleware for OutputFormatter {
69 fn process<'a>(
70 &'a self,
71 call: &'a ToolCall,
72 ctx: &'a ToolContext,
73 next: Next<'a>,
74 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
75 Box::pin(async move {
76 let mut output = next.run(call, ctx).await?;
77
78 output.content = output
80 .content
81 .into_iter()
82 .map(|item| match item {
83 ContentItem::Text(text) if text.len() > self.max_chars => {
84 let mut boundary = self.max_chars;
89 while boundary > 0 && !text.is_char_boundary(boundary) {
90 boundary -= 1;
91 }
92 ContentItem::Text(format!(
93 "{}... [truncated, {} chars total]",
94 &text[..boundary],
95 text.len()
96 ))
97 }
98 other => other,
99 })
100 .collect();
101
102 Ok(output)
103 })
104 }
105}
106
107pub struct SchemaValidator {
114 schemas: HashMap<String, serde_json::Value>,
116}
117
118impl SchemaValidator {
119 #[must_use]
124 pub fn new(registry: &ToolRegistry) -> Self {
125 let schemas = registry
126 .definitions()
127 .into_iter()
128 .map(|def| (def.name, def.input_schema))
129 .collect();
130 Self { schemas }
131 }
132}
133
134impl ToolMiddleware for SchemaValidator {
135 fn process<'a>(
136 &'a self,
137 call: &'a ToolCall,
138 ctx: &'a ToolContext,
139 next: Next<'a>,
140 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
141 Box::pin(async move {
142 if let Some(schema) = self.schemas.get(&call.name) {
143 validate_input(&call.input, schema)?;
144 }
145 next.run(call, ctx).await
146 })
147 }
148}
149
150fn validate_input(input: &serde_json::Value, schema: &serde_json::Value) -> Result<(), ToolError> {
157 let schema_obj = match schema.as_object() {
158 Some(obj) => obj,
159 None => return Ok(()), };
161
162 if let Some(serde_json::Value::String(ty)) = schema_obj.get("type")
164 && ty == "object"
165 && !input.is_object()
166 {
167 return Err(ToolError::InvalidInput("expected object input".to_string()));
168 }
169
170 let input_obj = match input.as_object() {
171 Some(obj) => obj,
172 None => return Ok(()), };
174
175 if let Some(serde_json::Value::Array(required)) = schema_obj.get("required") {
177 for field in required {
178 if let Some(field_name) = field.as_str()
179 && !input_obj.contains_key(field_name)
180 {
181 return Err(ToolError::InvalidInput(format!(
182 "missing required field: {field_name}"
183 )));
184 }
185 }
186 }
187
188 if let Some(serde_json::Value::Object(properties)) = schema_obj.get("properties") {
190 for (field_name, prop_schema) in properties {
191 if let Some(value) = input_obj.get(field_name)
192 && let Some(serde_json::Value::String(expected_type)) = prop_schema.get("type")
193 && !json_type_matches(value, expected_type)
194 {
195 return Err(ToolError::InvalidInput(format!(
196 "field '{field_name}' expected type '{expected_type}', \
197 got {}",
198 json_type_name(value)
199 )));
200 }
201 }
202 }
203
204 Ok(())
205}
206
207fn json_type_matches(value: &serde_json::Value, expected: &str) -> bool {
209 match expected {
210 "string" => value.is_string(),
211 "number" => value.is_number(),
212 "integer" => value.is_i64() || value.is_u64(),
213 "boolean" => value.is_boolean(),
214 "array" => value.is_array(),
215 "object" => value.is_object(),
216 "null" => value.is_null(),
217 _ => true, }
219}
220
221fn json_type_name(value: &serde_json::Value) -> &'static str {
223 match value {
224 serde_json::Value::Null => "null",
225 serde_json::Value::Bool(_) => "boolean",
226 serde_json::Value::Number(_) => "number",
227 serde_json::Value::String(_) => "string",
228 serde_json::Value::Array(_) => "array",
229 serde_json::Value::Object(_) => "object",
230 }
231}
232
233pub struct TimeoutMiddleware {
243 default_timeout: Duration,
244 per_tool: HashMap<String, Duration>,
245}
246
247impl TimeoutMiddleware {
248 #[must_use]
250 pub fn new(default_timeout: Duration) -> Self {
251 Self {
252 default_timeout,
253 per_tool: HashMap::new(),
254 }
255 }
256
257 #[must_use]
259 pub fn with_tool_timeout(mut self, tool_name: impl Into<String>, timeout: Duration) -> Self {
260 self.per_tool.insert(tool_name.into(), timeout);
261 self
262 }
263}
264
265impl ToolMiddleware for TimeoutMiddleware {
266 fn process<'a>(
267 &'a self,
268 call: &'a ToolCall,
269 ctx: &'a ToolContext,
270 next: Next<'a>,
271 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
272 Box::pin(async move {
273 let timeout = self
274 .per_tool
275 .get(&call.name)
276 .unwrap_or(&self.default_timeout);
277 match tokio::time::timeout(*timeout, next.run(call, ctx)).await {
278 Ok(result) => result,
279 Err(_elapsed) => Err(ToolError::ExecutionFailed(Box::new(std::io::Error::new(
280 std::io::ErrorKind::TimedOut,
281 format!(
282 "tool '{}' timed out after {:.1}s",
283 call.name,
284 timeout.as_secs_f64()
285 ),
286 )))),
287 }
288 })
289 }
290}
291
292pub struct StructuredOutputValidator {
303 schema: serde_json::Value,
304 max_retries: usize,
305}
306
307impl StructuredOutputValidator {
308 #[must_use]
314 pub fn new(schema: serde_json::Value, max_retries: usize) -> Self {
315 Self {
316 schema,
317 max_retries,
318 }
319 }
320}
321
322impl ToolMiddleware for StructuredOutputValidator {
323 fn process<'a>(
324 &'a self,
325 call: &'a ToolCall,
326 ctx: &'a ToolContext,
327 next: Next<'a>,
328 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
329 Box::pin(async move {
330 if let Err(e) = validate_input(&call.input, &self.schema) {
333 return Err(ToolError::ModelRetry(format!(
335 "Output validation failed: {e}. Please fix the output to match the schema."
336 )));
337 }
338 next.run(call, ctx).await
339 })
340 }
341}
342
343pub struct RetryLimitedValidator {
349 inner: StructuredOutputValidator,
350 attempts: std::sync::atomic::AtomicUsize,
351}
352
353impl RetryLimitedValidator {
354 #[must_use]
356 pub fn new(validator: StructuredOutputValidator) -> Self {
357 Self {
358 inner: validator,
359 attempts: std::sync::atomic::AtomicUsize::new(0),
360 }
361 }
362}
363
364impl ToolMiddleware for RetryLimitedValidator {
365 fn process<'a>(
366 &'a self,
367 call: &'a ToolCall,
368 ctx: &'a ToolContext,
369 next: Next<'a>,
370 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
371 Box::pin(async move {
372 if let Err(e) = validate_input(&call.input, &self.inner.schema) {
373 let attempt = self
374 .attempts
375 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
376 if attempt >= self.inner.max_retries {
377 return Err(ToolError::InvalidInput(format!(
378 "Output validation failed after {} retries: {e}",
379 self.inner.max_retries
380 )));
381 }
382 return Err(ToolError::ModelRetry(format!(
383 "Output validation failed (attempt {}/{}): {e}. \
384 Please fix the output to match the schema.",
385 attempt + 1,
386 self.inner.max_retries
387 )));
388 }
389 self.attempts.store(0, std::sync::atomic::Ordering::Relaxed);
391 next.run(call, ctx).await
392 })
393 }
394}