1use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::{Map, Value};
15
16use crate::registry::KernelError;
17
18pub type ToolName = String;
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub struct ToolSchema {
29 pub name: ToolName,
30 pub description: String,
31 pub args_schema: Value,
32 pub result_schema: Value,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub struct ToolResultEnvelopeConfig {
39 pub max_string_chars: usize,
41 pub max_array_items: usize,
43 pub max_total_bytes: usize,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub redaction: Option<RedactionPolicy>,
48}
49
50impl Default for ToolResultEnvelopeConfig {
51 fn default() -> Self {
52 Self {
53 max_string_chars: 4_000,
54 max_array_items: 64,
55 max_total_bytes: 256_000,
56 redaction: None,
57 }
58 }
59}
60
61impl ToolResultEnvelopeConfig {
62 #[must_use]
65 pub fn new(max_string_chars: usize) -> Self {
66 Self {
67 max_string_chars,
68 ..Self::default()
69 }
70 }
71
72 #[must_use]
74 pub fn with_max_array_items(mut self, max_array_items: usize) -> Self {
75 self.max_array_items = max_array_items;
76 self
77 }
78
79 #[must_use]
81 pub fn with_max_total_bytes(mut self, max_total_bytes: usize) -> Self {
82 self.max_total_bytes = max_total_bytes;
83 self
84 }
85
86 #[must_use]
88 pub fn with_redaction_policy(mut self, redaction: RedactionPolicy) -> Self {
89 self.redaction = Some(redaction);
90 self
91 }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
96pub struct RedactionRule {
97 pub pointer: String,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub replacement: Option<String>,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
106pub struct RedactionPolicy {
107 pub deny: Vec<RedactionRule>,
109 #[serde(skip_serializing_if = "Option::is_none")]
112 pub allow: Option<Vec<String>>,
113 pub default_replacement: String,
115}
116
117impl RedactionPolicy {
118 #[must_use]
120 pub fn deny_pointers<I, S>(pointers: I) -> Self
121 where
122 I: IntoIterator<Item = S>,
123 S: Into<String>,
124 {
125 Self {
126 deny: pointers
127 .into_iter()
128 .map(|pointer| RedactionRule {
129 pointer: pointer.into(),
130 replacement: None,
131 })
132 .collect(),
133 allow: None,
134 default_replacement: "[redacted]".to_string(),
135 }
136 }
137
138 #[must_use]
140 pub fn allow_pointers<I, S>(pointers: I) -> Self
141 where
142 I: IntoIterator<Item = S>,
143 S: Into<String>,
144 {
145 Self {
146 deny: Vec::new(),
147 allow: Some(pointers.into_iter().map(Into::into).collect()),
148 default_replacement: "[redacted]".to_string(),
149 }
150 }
151
152 #[must_use]
154 pub fn with_default_replacement(mut self, replacement: impl Into<String>) -> Self {
155 self.default_replacement = replacement.into();
156 self
157 }
158
159 #[must_use]
161 pub fn with_replacement(
162 mut self,
163 pointer: impl Into<String>,
164 replacement: impl Into<String>,
165 ) -> Self {
166 let pointer = pointer.into();
167 let replacement = replacement.into();
168 for rule in &mut self.deny {
169 if rule.pointer == pointer {
170 rule.replacement = Some(replacement);
171 return self;
172 }
173 }
174 self.deny.push(RedactionRule {
175 pointer,
176 replacement: Some(replacement),
177 });
178 self
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
184#[serde(rename_all = "snake_case")]
185pub enum ToolResultOmissionReason {
186 StringChars,
188 ArrayItems,
190 TotalBytes,
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct OmittedSegment {
197 pub pointer: String,
199 pub reason: ToolResultOmissionReason,
201 pub page_token: String,
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
207pub struct ToolResultPageToken {
208 pub pointer: String,
210 pub reason: ToolResultOmissionReason,
212 pub limit: usize,
214}
215
216#[must_use]
218pub fn decode_tool_result_page_token(token: &str) -> Option<ToolResultPageToken> {
219 let payload = token.strip_prefix("v1:")?;
220 let bytes = decode_hex(payload)?;
221 let text = String::from_utf8(bytes).ok()?;
222 serde_json::from_str(&text).ok()
223}
224
225#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
227pub struct ToolResultEnvelope {
228 pub payload: Value,
230 pub truncated: bool,
232 pub omitted_chars: usize,
234 pub omitted_items: usize,
236 #[serde(default, skip_serializing_if = "is_zero")]
238 pub omitted_values: usize,
239 #[serde(default, skip_serializing_if = "is_zero")]
241 pub redacted_values: usize,
242 #[serde(default, skip_serializing_if = "Vec::is_empty")]
244 pub omitted_segments: Vec<OmittedSegment>,
245 #[serde(skip_serializing_if = "Option::is_none")]
247 pub page_token: Option<String>,
248}
249
250impl ToolResultEnvelope {
251 #[must_use]
253 pub fn bound(payload: Value, config: &ToolResultEnvelopeConfig) -> Self {
254 let mut state = ToolResultEnvelopeState::default();
255 let payload = bound_value(payload, config, &mut state, "");
256 let payload = bound_total_bytes(payload, config, &mut state, "");
257 Self {
258 payload,
259 truncated: state.omitted_chars > 0
260 || state.omitted_items > 0
261 || state.omitted_values > 0,
262 omitted_chars: state.omitted_chars,
263 omitted_items: state.omitted_items,
264 omitted_values: state.omitted_values,
265 redacted_values: state.redacted_values,
266 omitted_segments: state.omitted_segments,
267 page_token: state.page_token,
268 }
269 }
270}
271
272fn is_zero(value: &usize) -> bool {
273 *value == 0
274}
275
276#[must_use]
278pub fn bound_tool_result(payload: Value) -> ToolResultEnvelope {
279 ToolResultEnvelope::bound(payload, &ToolResultEnvelopeConfig::default())
280}
281
282#[derive(Default)]
283struct ToolResultEnvelopeState {
284 omitted_chars: usize,
285 omitted_items: usize,
286 omitted_values: usize,
287 redacted_values: usize,
288 omitted_segments: Vec<OmittedSegment>,
289 page_token: Option<String>,
290}
291
292impl ToolResultEnvelopeState {
293 fn record_omission(&mut self, pointer: &str, reason: ToolResultOmissionReason, limit: usize) {
294 let page_token = page_token(pointer, reason, limit);
295 if self.page_token.is_none() {
296 self.page_token = Some(page_token.clone());
297 }
298 self.omitted_segments.push(OmittedSegment {
299 pointer: pointer.to_string(),
300 reason,
301 page_token,
302 });
303 }
304}
305
306fn bound_value(
307 value: Value,
308 config: &ToolResultEnvelopeConfig,
309 state: &mut ToolResultEnvelopeState,
310 pointer: &str,
311) -> Value {
312 if let Some(redaction) = &config.redaction
313 && let Some(replacement) = redaction.replacement_for(pointer)
314 {
315 state.redacted_values = state.redacted_values.saturating_add(1);
316 return Value::String(replacement);
317 }
318
319 match value {
320 Value::String(text) => bound_string(text, config, state, pointer),
321 Value::Array(items) => bound_array(items, config, state, pointer),
322 Value::Object(fields) => bound_object(fields, config, state, pointer),
323 scalar => scalar,
324 }
325}
326
327fn bound_string(
328 text: String,
329 config: &ToolResultEnvelopeConfig,
330 state: &mut ToolResultEnvelopeState,
331 pointer: &str,
332) -> Value {
333 let total_chars = text.chars().count();
334 if total_chars <= config.max_string_chars {
335 return Value::String(text);
336 }
337 state.omitted_chars = state
338 .omitted_chars
339 .saturating_add(total_chars.saturating_sub(config.max_string_chars));
340 state.record_omission(
341 pointer,
342 ToolResultOmissionReason::StringChars,
343 config.max_string_chars,
344 );
345 Value::String(text.chars().take(config.max_string_chars).collect())
346}
347
348fn bound_array(
349 items: Vec<Value>,
350 config: &ToolResultEnvelopeConfig,
351 state: &mut ToolResultEnvelopeState,
352 pointer: &str,
353) -> Value {
354 let total_items = items.len();
355 if total_items > config.max_array_items {
356 state.omitted_items = state
357 .omitted_items
358 .saturating_add(total_items.saturating_sub(config.max_array_items));
359 state.record_omission(
360 pointer,
361 ToolResultOmissionReason::ArrayItems,
362 config.max_array_items,
363 );
364 }
365 Value::Array(
366 items
367 .into_iter()
368 .enumerate()
369 .take(config.max_array_items)
370 .map(|(index, item)| {
371 let child = child_pointer(pointer, &index.to_string());
372 bound_value(item, config, state, &child)
373 })
374 .collect(),
375 )
376}
377
378fn bound_object(
379 fields: Map<String, Value>,
380 config: &ToolResultEnvelopeConfig,
381 state: &mut ToolResultEnvelopeState,
382 pointer: &str,
383) -> Value {
384 Value::Object(
385 fields
386 .into_iter()
387 .map(|(key, value)| {
388 let child = child_pointer(pointer, &key);
389 (key, bound_value(value, config, state, &child))
390 })
391 .collect(),
392 )
393}
394
395impl RedactionPolicy {
396 fn replacement_for(&self, pointer: &str) -> Option<String> {
397 for rule in &self.deny {
398 if rule.pointer == pointer {
399 return Some(
400 rule.replacement
401 .clone()
402 .unwrap_or_else(|| self.default_replacement.clone()),
403 );
404 }
405 }
406
407 let Some(allow) = &self.allow else {
408 return None;
409 };
410 if allow
411 .iter()
412 .any(|allowed| pointer_matches(pointer, allowed))
413 {
414 return None;
415 }
416 Some(self.default_replacement.clone())
417 }
418}
419
420fn pointer_matches(pointer: &str, allowed: &str) -> bool {
421 pointer == allowed || is_descendant(pointer, allowed) || is_descendant(allowed, pointer)
422}
423
424fn is_descendant(pointer: &str, ancestor: &str) -> bool {
425 if ancestor.is_empty() {
426 return !pointer.is_empty();
427 }
428 let prefix = format!("{ancestor}/");
429 pointer.starts_with(&prefix)
430}
431
432fn child_pointer(parent: &str, child: &str) -> String {
433 let escaped = escape_pointer_segment(child);
434 if parent.is_empty() {
435 format!("/{escaped}")
436 } else {
437 format!("{parent}/{escaped}")
438 }
439}
440
441fn escape_pointer_segment(segment: &str) -> String {
442 segment.replace('~', "~0").replace('/', "~1")
443}
444
445fn page_token(pointer: &str, reason: ToolResultOmissionReason, limit: usize) -> String {
446 let payload = ToolResultPageToken {
447 pointer: pointer.to_string(),
448 reason,
449 limit,
450 };
451 match serde_json::to_string(&payload) {
452 Ok(serialized) => format!("v1:{}", encode_hex(serialized.as_bytes())),
453 Err(_) => "v1:".to_string(),
454 }
455}
456
457fn encode_hex(bytes: &[u8]) -> String {
458 bytes.iter().map(|byte| format!("{byte:02x}")).collect()
459}
460
461fn decode_hex(input: &str) -> Option<Vec<u8>> {
462 let mut chars = input.chars();
463 let mut bytes = Vec::new();
464 loop {
465 let Some(high) = chars.next() else {
466 return Some(bytes);
467 };
468 let low = chars.next()?;
469 let high = hex_value(high)?;
470 let low = hex_value(low)?;
471 bytes.push(high.saturating_mul(16).saturating_add(low));
472 }
473}
474
475fn hex_value(character: char) -> Option<u8> {
476 match character {
477 '0'..='9' => Some(character as u8 - b'0'),
478 'a'..='f' => Some(character as u8 - b'a' + 10),
479 'A'..='F' => Some(character as u8 - b'A' + 10),
480 _ => None,
481 }
482}
483
484fn bound_total_bytes(
485 value: Value,
486 config: &ToolResultEnvelopeConfig,
487 state: &mut ToolResultEnvelopeState,
488 pointer: &str,
489) -> Value {
490 if serialized_len(&value) <= config.max_total_bytes {
491 return value;
492 }
493
494 match value {
495 Value::Object(fields) => bound_object_total_bytes(fields, config, state, pointer),
496 Value::Array(items) => bound_array_total_bytes(items, config, state, pointer),
497 Value::String(text) => bound_string_total_bytes(text, config, state, pointer),
498 scalar => {
499 state.omitted_values = state.omitted_values.saturating_add(1);
500 state.record_omission(
501 pointer,
502 ToolResultOmissionReason::TotalBytes,
503 config.max_total_bytes,
504 );
505 scalar
506 }
507 }
508}
509
510fn bound_object_total_bytes(
511 fields: Map<String, Value>,
512 config: &ToolResultEnvelopeConfig,
513 state: &mut ToolResultEnvelopeState,
514 pointer: &str,
515) -> Value {
516 let mut retained = Map::new();
517 for (key, value) in fields {
518 let child = child_pointer(pointer, &key);
519 let mut candidate = retained.clone();
520 candidate.insert(key.clone(), value.clone());
521 if serialized_len(&Value::Object(candidate)) <= config.max_total_bytes {
522 retained.insert(key, value);
523 } else {
524 state.omitted_values = state.omitted_values.saturating_add(1);
525 state.record_omission(
526 &child,
527 ToolResultOmissionReason::TotalBytes,
528 config.max_total_bytes,
529 );
530 }
531 }
532 Value::Object(retained)
533}
534
535fn bound_array_total_bytes(
536 items: Vec<Value>,
537 config: &ToolResultEnvelopeConfig,
538 state: &mut ToolResultEnvelopeState,
539 pointer: &str,
540) -> Value {
541 let mut retained = Vec::new();
542 for (index, item) in items.into_iter().enumerate() {
543 let mut candidate = retained.clone();
544 candidate.push(item.clone());
545 if serialized_len(&Value::Array(candidate)) <= config.max_total_bytes {
546 retained.push(item);
547 } else {
548 state.omitted_items = state.omitted_items.saturating_add(1);
549 let child = child_pointer(pointer, &index.to_string());
550 state.record_omission(
551 &child,
552 ToolResultOmissionReason::TotalBytes,
553 config.max_total_bytes,
554 );
555 }
556 }
557 Value::Array(retained)
558}
559
560fn bound_string_total_bytes(
561 text: String,
562 config: &ToolResultEnvelopeConfig,
563 state: &mut ToolResultEnvelopeState,
564 pointer: &str,
565) -> Value {
566 let mut retained = String::new();
567 for character in text.chars() {
568 let mut candidate = retained.clone();
569 candidate.push(character);
570 if serialized_len(&Value::String(candidate)) <= config.max_total_bytes {
571 retained.push(character);
572 } else {
573 state.omitted_chars = state.omitted_chars.saturating_add(1);
574 }
575 }
576 if retained.chars().count() < text.chars().count() {
577 state.record_omission(
578 pointer,
579 ToolResultOmissionReason::TotalBytes,
580 config.max_total_bytes,
581 );
582 }
583 Value::String(retained)
584}
585
586fn serialized_len(value: &Value) -> usize {
587 match serde_json::to_string(value) {
588 Ok(serialized) => serialized.len(),
589 Err(_) => usize::MAX,
590 }
591}
592
593#[async_trait]
599pub trait Tool: Send + Sync {
600 fn schema(&self) -> ToolSchema;
602
603 fn name(&self) -> ToolName {
605 self.schema().name
606 }
607
608 async fn invoke(&self, args: Value) -> Result<Value, KernelError>;
610}
611
612pub struct LocalTool {
616 schema: ToolSchema,
617 #[allow(clippy::type_complexity)]
618 f: Arc<
619 dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, KernelError>> + Send>>
620 + Send
621 + Sync,
622 >,
623}
624
625impl LocalTool {
626 pub fn new<F, Fut>(schema: ToolSchema, f: F) -> Self
627 where
628 F: Fn(Value) -> Fut + Send + Sync + 'static,
629 Fut: Future<Output = Result<Value, KernelError>> + Send + 'static,
630 {
631 Self {
632 schema,
633 f: Arc::new(move |v| Box::pin(f(v))),
634 }
635 }
636}
637
638#[async_trait]
639impl Tool for LocalTool {
640 fn schema(&self) -> ToolSchema {
641 self.schema.clone()
642 }
643
644 fn name(&self) -> ToolName {
645 self.schema.name.clone()
646 }
647
648 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
649 (self.f)(args).await
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use crate::*;
656 use serde_json::json;
657
658 #[tokio::test]
659 async fn local_tool_roundtrip() {
660 let schema = ToolSchema {
661 name: "test.echo".into(),
662 description: "echoes the input".into(),
663 args_schema: json!({"type": "object"}),
664 result_schema: json!({"type": "object"}),
665 };
666 let tool = LocalTool::new(schema, |v| async move { Ok(v) });
667 let out = tool.invoke(json!({"hello": "world"})).await.unwrap();
668 assert_eq!(out, json!({"hello": "world"}));
669 assert_eq!(tool.name(), "test.echo");
670 }
671
672 #[test]
673 fn tool_result_envelope_bounds_large_strings() {
674 let envelope =
675 ToolResultEnvelope::bound(json!({"body": "abcdef"}), &ToolResultEnvelopeConfig::new(3));
676
677 assert_eq!(envelope.payload, json!({"body": "abc"}));
678 assert!(envelope.truncated);
679 assert_eq!(envelope.omitted_chars, 3);
680 assert_eq!(
681 decode_tool_result_page_token(envelope.page_token.as_deref().unwrap()),
682 Some(ToolResultPageToken {
683 pointer: "/body".to_string(),
684 reason: ToolResultOmissionReason::StringChars,
685 limit: 3,
686 })
687 );
688 assert_eq!(envelope.omitted_segments.len(), 1);
689 assert!(
690 envelope
691 .omitted_segments
692 .iter()
693 .any(|segment| segment.pointer == "/body")
694 );
695 }
696
697 #[test]
698 fn tool_result_envelope_bounds_arrays() {
699 let envelope = ToolResultEnvelope::bound(
700 json!({"rows": [1, 2, 3, 4]}),
701 &ToolResultEnvelopeConfig::new(100).with_max_array_items(2),
702 );
703
704 assert_eq!(envelope.payload, json!({"rows": [1, 2]}));
705 assert!(envelope.truncated);
706 assert_eq!(envelope.omitted_items, 2);
707 assert_eq!(
708 decode_tool_result_page_token(envelope.page_token.as_deref().unwrap()),
709 Some(ToolResultPageToken {
710 pointer: "/rows".to_string(),
711 reason: ToolResultOmissionReason::ArrayItems,
712 limit: 2,
713 })
714 );
715 assert_eq!(envelope.omitted_segments.len(), 1);
716 assert!(
717 envelope
718 .omitted_segments
719 .iter()
720 .any(|segment| segment.pointer == "/rows")
721 );
722 }
723
724 #[test]
725 fn tool_result_envelope_leaves_small_payloads_unchanged() {
726 let payload = json!({"ok": true, "rows": ["a"]});
727 let envelope = ToolResultEnvelope::bound(
728 payload.clone(),
729 &ToolResultEnvelopeConfig::new(100).with_max_array_items(10),
730 );
731
732 assert_eq!(envelope.payload, payload);
733 assert!(!envelope.truncated);
734 assert_eq!(envelope.omitted_chars, 0);
735 assert_eq!(envelope.omitted_items, 0);
736 assert_eq!(envelope.page_token, None);
737 }
738
739 #[test]
740 fn tool_result_envelope_redacts_before_truncation() {
741 let config = ToolResultEnvelopeConfig::new(4).with_redaction_policy(
742 RedactionPolicy::deny_pointers(["/secret"]).with_replacement("/secret", "safe"),
743 );
744 let envelope = ToolResultEnvelope::bound(
745 json!({"public": "abcdef", "secret": "should-not-leak"}),
746 &config,
747 );
748
749 assert_eq!(
750 envelope.payload,
751 json!({"public": "abcd", "secret": "safe"})
752 );
753 assert_eq!(envelope.redacted_values, 1);
754 assert_eq!(envelope.omitted_chars, 2);
755 assert_eq!(
756 decode_tool_result_page_token(envelope.page_token.as_deref().unwrap()),
757 Some(ToolResultPageToken {
758 pointer: "/public".to_string(),
759 reason: ToolResultOmissionReason::StringChars,
760 limit: 4,
761 })
762 );
763 }
764
765 #[test]
766 fn tool_result_envelope_total_budget_drops_fields_with_path_tokens() {
767 let config = ToolResultEnvelopeConfig::new(100).with_max_total_bytes(24);
768 let envelope = ToolResultEnvelope::bound(
769 json!({"a": "small", "b": "also-small", "c": "extra"}),
770 &config,
771 );
772
773 assert!(envelope.truncated);
774 assert!(envelope.omitted_values > 0);
775 assert!(envelope.payload.get("a").is_some());
776 assert!(envelope.omitted_segments.iter().any(|segment| {
777 segment.reason == ToolResultOmissionReason::TotalBytes
778 && decode_tool_result_page_token(&segment.page_token).is_some_and(|token| {
779 token.reason == ToolResultOmissionReason::TotalBytes && token.limit == 24
780 })
781 }));
782 }
783}