use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::error::{ToolError, ToolValidationError};
pub use crate::types::ToolResultBlock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ExecutionMode {
Parallel,
Sequential,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
pub const ARG_PARSE_ERROR_MARKER: &str = "__clark_arg_parse_error";
pub const ARG_PARSE_RAW_MARKER: &str = "__clark_arg_raw";
pub fn arg_parse_error_value(error: impl Into<String>, raw: impl Into<String>) -> Value {
serde_json::json!({
ARG_PARSE_ERROR_MARKER: error.into(),
ARG_PARSE_RAW_MARKER: raw.into(),
})
}
pub fn detect_arg_parse_error(args: &Value) -> Option<(&str, &str)> {
let obj = args.as_object()?;
let err = obj.get(ARG_PARSE_ERROR_MARKER)?.as_str()?;
let raw = obj.get(ARG_PARSE_RAW_MARKER)?.as_str()?;
Some((err, raw))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<ToolResultBlock>,
#[serde(default, skip_serializing_if = "is_false")]
pub is_error: bool,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub details: Value,
#[serde(default, skip_serializing_if = "is_false")]
pub terminate: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub narration: Option<String>,
}
fn is_false(b: &bool) -> bool {
!*b
}
impl ToolResult {
pub fn text(text: impl Into<String>) -> Self {
Self {
content: vec![ToolResultBlock::Text(crate::types::TextContent {
text: text.into(),
})],
is_error: false,
details: Value::Null,
terminate: false,
narration: None,
}
}
pub fn terminal(text: impl Into<String>) -> Self {
Self {
content: vec![ToolResultBlock::Text(crate::types::TextContent {
text: text.into(),
})],
is_error: false,
details: Value::Null,
terminate: true,
narration: None,
}
}
pub fn error(text: impl Into<String>) -> Self {
Self {
content: vec![ToolResultBlock::Text(crate::types::TextContent {
text: text.into(),
})],
is_error: true,
details: Value::Null,
terminate: false,
narration: None,
}
}
pub fn with_narration(mut self, narration: impl Into<String>) -> Self {
let raw: String = narration.into();
let trimmed = raw.trim();
if !trimmed.is_empty() {
self.narration = Some(trimmed.to_string());
}
self
}
}
pub type ToolUpdateSink = mpsc::UnboundedSender<ToolResult>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ToolHistoryPolicy {
pub dedup_arg: Option<&'static str>,
pub summary_arg: Option<&'static str>,
pub compactable_result: bool,
pub pins_active_plan: bool,
}
impl ToolHistoryPolicy {
pub const fn new() -> Self {
Self {
dedup_arg: None,
summary_arg: None,
compactable_result: false,
pins_active_plan: false,
}
}
pub const fn dedup_arg(mut self, arg: &'static str) -> Self {
self.dedup_arg = Some(arg);
self
}
pub const fn summary_arg(mut self, arg: &'static str) -> Self {
self.summary_arg = Some(arg);
self
}
pub const fn compactable_result(mut self) -> Self {
self.compactable_result = true;
self
}
pub const fn pins_active_plan(mut self) -> Self {
self.pins_active_plan = true;
self
}
}
impl Default for ToolHistoryPolicy {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
pub trait AgentTool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> Value;
fn requires_exclusive_sandbox(&self) -> bool {
false
}
fn max_result_chars(&self) -> Option<usize> {
None
}
fn history_policy(&self) -> ToolHistoryPolicy {
ToolHistoryPolicy::default()
}
fn identity_policy(&self) -> crate::tool_identity::ToolIdentityPolicy {
crate::tool_identity::ToolIdentityPolicy::default()
}
fn aborts_siblings_on_error(&self) -> bool {
false
}
fn counts_toward_tool_call_limit(&self) -> bool {
true
}
fn parallel_safe_per_turn(&self) -> bool {
false
}
fn counts_toward_termination_vote(&self) -> bool {
true
}
fn prepare_arguments(&self, args: Value) -> Value {
args
}
fn validate(&self, _args: &Value) -> Result<(), ToolValidationError> {
Ok(())
}
async fn execute(
&self,
call_id: &str,
args: Value,
signal: CancellationToken,
update: ToolUpdateSink,
) -> Result<ToolResult, ToolError>;
}
#[async_trait]
pub trait TypedAgentTool: Send + Sync + 'static {
type Args: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn requires_exclusive_sandbox(&self) -> bool {
false
}
fn max_result_chars(&self) -> Option<usize> {
None
}
fn history_policy(&self) -> ToolHistoryPolicy {
ToolHistoryPolicy::default()
}
fn identity_policy(&self) -> crate::tool_identity::ToolIdentityPolicy {
crate::tool_identity::ToolIdentityPolicy::default()
}
fn aborts_siblings_on_error(&self) -> bool {
false
}
fn counts_toward_tool_call_limit(&self) -> bool {
true
}
fn parallel_safe_per_turn(&self) -> bool {
false
}
fn counts_toward_termination_vote(&self) -> bool {
true
}
fn prepare_arguments(&self, args: Value) -> Value {
args
}
async fn run(
&self,
call_id: &str,
args: Self::Args,
signal: CancellationToken,
update: ToolUpdateSink,
) -> Result<ToolResult, ToolError>;
}
#[async_trait]
impl<T: TypedAgentTool> AgentTool for T {
fn name(&self) -> &str {
TypedAgentTool::name(self)
}
fn description(&self) -> &str {
TypedAgentTool::description(self)
}
fn parameters_schema(&self) -> Value {
let settings = schemars::gen::SchemaSettings::draft07().with(|s| {
s.inline_subschemas = true;
});
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T::Args>();
let value = serde_json::to_value(schema).expect("typed-tool schema serializes");
let mut value = flatten_tagged_oneof_schema(value);
normalize_strict_validator_quirks(&mut value);
value
}
fn requires_exclusive_sandbox(&self) -> bool {
TypedAgentTool::requires_exclusive_sandbox(self)
}
fn max_result_chars(&self) -> Option<usize> {
TypedAgentTool::max_result_chars(self)
}
fn history_policy(&self) -> ToolHistoryPolicy {
TypedAgentTool::history_policy(self)
}
fn identity_policy(&self) -> crate::tool_identity::ToolIdentityPolicy {
TypedAgentTool::identity_policy(self)
}
fn aborts_siblings_on_error(&self) -> bool {
TypedAgentTool::aborts_siblings_on_error(self)
}
fn counts_toward_tool_call_limit(&self) -> bool {
TypedAgentTool::counts_toward_tool_call_limit(self)
}
fn parallel_safe_per_turn(&self) -> bool {
TypedAgentTool::parallel_safe_per_turn(self)
}
fn counts_toward_termination_vote(&self) -> bool {
TypedAgentTool::counts_toward_termination_vote(self)
}
fn prepare_arguments(&self, args: Value) -> Value {
TypedAgentTool::prepare_arguments(self, args)
}
async fn execute(
&self,
call_id: &str,
args: Value,
signal: CancellationToken,
update: ToolUpdateSink,
) -> Result<ToolResult, ToolError> {
let prepared = AgentTool::prepare_arguments(self, args);
let stripped = strip_top_level_nulls(prepared);
let schema = AgentTool::parameters_schema(self);
let coerced = coerce_string_scalars_at_top_level(stripped, &schema);
let parsed: T::Args = match serde_json::from_value(coerced) {
Ok(v) => v,
Err(e) => {
return Ok(ToolResult::error(format!(
"{}: invalid arguments: {}",
TypedAgentTool::name(self),
enrich_arg_parse_error_message(&e),
)));
}
};
TypedAgentTool::run(self, call_id, parsed, signal, update).await
}
}
fn coerce_string_scalars_at_top_level(value: Value, schema: &Value) -> Value {
let Value::Object(mut map) = value else {
return value;
};
let Some(properties) = schema.get("properties").and_then(Value::as_object) else {
return Value::Object(map);
};
for (key, val) in map.iter_mut() {
let Some(prop_schema) = properties.get(key) else {
continue;
};
coerce_one_scalar_in_place(val, prop_schema);
}
Value::Object(map)
}
fn coerce_one_scalar_in_place(value: &mut Value, prop_schema: &Value) {
let Some(text) = value.as_str() else {
return;
};
let Some(target) = scalar_target_from_schema(prop_schema) else {
return;
};
match target {
ScalarTarget::Integer => {
let trimmed = text.trim();
if let Ok(n) = trimmed.parse::<i64>() {
*value = Value::Number(serde_json::Number::from(n));
} else if let Ok(n) = trimmed.parse::<u64>() {
*value = Value::Number(serde_json::Number::from(n));
}
}
ScalarTarget::Number => {
let trimmed = text.trim();
if let Ok(n) = trimmed.parse::<f64>() {
if let Some(num) = serde_json::Number::from_f64(n) {
*value = Value::Number(num);
}
}
}
ScalarTarget::Boolean => match text.trim() {
"true" | "True" | "TRUE" => *value = Value::Bool(true),
"false" | "False" | "FALSE" => *value = Value::Bool(false),
_ => {}
},
}
}
#[derive(Debug, Clone, Copy)]
enum ScalarTarget {
Integer,
Number,
Boolean,
}
fn scalar_target_from_schema(prop_schema: &Value) -> Option<ScalarTarget> {
let type_field = prop_schema.get("type")?;
let single = match type_field {
Value::String(s) => Some(s.as_str()),
Value::Array(arr) => {
let non_null: Vec<&str> = arr
.iter()
.filter_map(|v| v.as_str())
.filter(|s| *s != "null")
.collect();
if non_null.len() == 1 {
Some(non_null[0])
} else {
None
}
}
_ => None,
}?;
match single {
"integer" => Some(ScalarTarget::Integer),
"number" => Some(ScalarTarget::Number),
"boolean" => Some(ScalarTarget::Boolean),
_ => None,
}
}
fn enrich_arg_parse_error_message(err: &serde_json::Error) -> String {
let raw = err.to_string();
match arg_parse_hint(&raw) {
Some(hint) => format!("{raw}. {hint}"),
None => raw,
}
}
fn arg_parse_hint(raw: &str) -> Option<String> {
let value = extract_invalid_string_value(raw)?;
if expects_integer(raw) {
let parsed: i128 = value.trim().parse().ok()?;
return Some(format!(
"Did you mean the integer {parsed}? Resend without quotes."
));
}
if expects_number(raw) {
let parsed: f64 = value.trim().parse().ok()?;
return Some(format!(
"Did you mean the number {parsed}? Resend without quotes."
));
}
if expects_boolean(raw) {
return match value.trim() {
"true" | "True" | "TRUE" => Some(
"Did you mean true? Resend as a boolean literal (lowercase, no quotes)."
.to_string(),
),
"false" | "False" | "FALSE" => Some(
"Did you mean false? Resend as a boolean literal (lowercase, no quotes)."
.to_string(),
),
_ => None,
};
}
if expects_sequence(raw) {
return Some(
"Expected a JSON array (e.g. `[{...}, {...}]`); the field cannot be a string. \
Resend the value as an array of structured items, not a string of XML-like markup."
.to_string(),
);
}
None
}
fn extract_invalid_string_value(raw: &str) -> Option<&str> {
let start = raw.find("string \"")? + "string \"".len();
let rest = &raw[start..];
let end = rest.find('\"')?;
Some(&rest[..end])
}
fn expects_integer(raw: &str) -> bool {
raw.contains("expected usize")
|| raw.contains("expected isize")
|| raw.contains("expected u8")
|| raw.contains("expected u16")
|| raw.contains("expected u32")
|| raw.contains("expected u64")
|| raw.contains("expected i8")
|| raw.contains("expected i16")
|| raw.contains("expected i32")
|| raw.contains("expected i64")
|| raw.contains("expected integer")
}
fn expects_number(raw: &str) -> bool {
raw.contains("expected f32")
|| raw.contains("expected f64")
|| raw.contains("expected floating point")
}
fn expects_boolean(raw: &str) -> bool {
raw.contains("expected a boolean") || raw.contains("expected bool")
}
fn expects_sequence(raw: &str) -> bool {
raw.contains("expected a sequence") || raw.contains("expected an array")
}
fn strip_top_level_nulls(value: Value) -> Value {
match value {
Value::Object(map) => {
Value::Object(map.into_iter().filter(|(_, v)| !v.is_null()).collect())
}
other => other,
}
}
fn flatten_tagged_oneof_schema(schema: Value) -> Value {
let Value::Object(mut root) = schema else {
return schema;
};
let Some(Value::Array(variants)) = root.remove("oneOf") else {
if !root.is_empty() {
return Value::Object(root);
}
return Value::Null;
};
struct VariantSpec {
tag_value_str: Option<String>,
own_field_names: Vec<String>,
}
let mut discriminator: Option<String> = None;
let mut variant_specs: Vec<VariantSpec> = Vec::with_capacity(variants.len());
let mut merged_props = serde_json::Map::new();
let mut required_set: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
let mut tag_in_required = true;
let mut tag_values: Vec<Value> = Vec::with_capacity(variants.len());
for variant in &variants {
let Some(obj) = variant.as_object() else {
return reassemble_oneof(root, variants);
};
let Some(Value::Object(props)) = obj.get("properties").cloned() else {
return reassemble_oneof(root, variants);
};
let mut variant_tag: Option<(String, Value)> = None;
for (name, prop) in props.iter() {
let Some(prop_obj) = prop.as_object() else {
continue;
};
let Some(Value::Array(enum_values)) = prop_obj.get("enum").cloned() else {
continue;
};
if enum_values.len() == 1 {
variant_tag = Some((name.clone(), enum_values.into_iter().next().unwrap()));
break;
}
}
let Some((tag_name, tag_value)) = variant_tag else {
return reassemble_oneof(root, variants);
};
match &discriminator {
None => discriminator = Some(tag_name.clone()),
Some(existing) if existing == &tag_name => {}
Some(_) => return reassemble_oneof(root, variants),
}
tag_values.push(tag_value.clone());
let mut own_field_names = Vec::new();
for (name, prop_schema) in props.iter() {
if name == &tag_name {
continue;
}
merged_props
.entry(name.clone())
.or_insert_with(|| prop_schema.clone());
own_field_names.push(name.clone());
}
let mut tag_required_here = false;
if let Some(Value::Array(req)) = obj.get("required") {
for r in req {
if let Some(s) = r.as_str() {
if s == tag_name {
tag_required_here = true;
}
}
}
}
if !tag_required_here {
tag_in_required = false;
}
variant_specs.push(VariantSpec {
tag_value_str: tag_value.as_str().map(str::to_string),
own_field_names,
});
}
let Some(discriminator) = discriminator else {
return reassemble_oneof(root, variants);
};
let total_variants = variant_specs.len();
let all_tags_are_strings = variant_specs.iter().all(|s| s.tag_value_str.is_some());
if all_tags_are_strings && total_variants > 1 {
let mut owners: std::collections::BTreeMap<String, Vec<String>> =
std::collections::BTreeMap::new();
for spec in &variant_specs {
let tag_label = spec.tag_value_str.clone().unwrap_or_default();
for field in &spec.own_field_names {
owners
.entry(field.clone())
.or_default()
.push(tag_label.clone());
}
}
for (field, mut variant_tags) in owners {
if variant_tags.len() == total_variants {
continue;
}
variant_tags.sort();
variant_tags.dedup();
let suffix = format!(
" (applies when {discriminator} in: [{}])",
variant_tags.join(", ")
);
if let Some(Value::Object(prop_map)) = merged_props.get_mut(&field) {
let new_desc = match prop_map.get("description") {
Some(Value::String(existing)) if !existing.is_empty() => {
format!("{existing}{suffix}")
}
_ => suffix.trim_start().to_string(),
};
prop_map.insert("description".to_string(), Value::String(new_desc));
}
}
}
let mut tag_prop = serde_json::Map::new();
tag_prop.insert("type".to_string(), Value::String("string".to_string()));
tag_prop.insert("enum".to_string(), Value::Array(tag_values));
let mut ordered_props = serde_json::Map::new();
ordered_props.insert(discriminator.clone(), Value::Object(tag_prop));
for (name, schema) in merged_props {
ordered_props.insert(name, schema);
}
if tag_in_required {
required_set.insert(discriminator);
}
let mut out = serde_json::Map::new();
if let Some(desc) = root.remove("description") {
out.insert("description".to_string(), desc);
}
if let Some(schema) = root.remove("$schema") {
out.insert("$schema".to_string(), schema);
}
out.insert("type".to_string(), Value::String("object".to_string()));
out.insert("properties".to_string(), Value::Object(ordered_props));
if !required_set.is_empty() {
out.insert(
"required".to_string(),
Value::Array(required_set.into_iter().map(Value::String).collect()),
);
}
Value::Object(out)
}
fn reassemble_oneof(mut root: serde_json::Map<String, Value>, variants: Vec<Value>) -> Value {
root.insert("oneOf".to_string(), Value::Array(variants));
Value::Object(root)
}
fn normalize_strict_validator_quirks(value: &mut Value) {
match value {
Value::Object(map) => {
if let Some(items) = map.get_mut("items") {
if matches!(items, Value::Bool(true)) {
*items = Value::Object(serde_json::Map::new());
}
}
for v in map.values_mut() {
normalize_strict_validator_quirks(v);
}
}
Value::Array(arr) => {
for v in arr {
normalize_strict_validator_quirks(v);
}
}
_ => {}
}
}
#[derive(Default, Clone)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn AgentTool>>,
order: Vec<String>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with(mut self, tool: Arc<dyn AgentTool>) -> Self {
self.register(tool);
self
}
pub fn register(&mut self, tool: Arc<dyn AgentTool>) {
let name = tool.name().to_string();
if !self.tools.contains_key(&name) {
self.order.push(name.clone());
}
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
self.tools.get(name).cloned()
}
pub fn history_policy(&self, name: &str) -> ToolHistoryPolicy {
self.tools
.get(name)
.map(|tool| tool.history_policy())
.unwrap_or_default()
}
pub fn identity_policy(&self, name: &str) -> crate::tool_identity::ToolIdentityPolicy {
self.tools
.get(name)
.map(|tool| tool.identity_policy())
.unwrap_or_default()
}
pub fn identity_policies(
&self,
) -> std::collections::HashMap<String, crate::tool_identity::ToolIdentityPolicy> {
self.tools
.iter()
.map(|(name, tool)| (name.clone(), tool.identity_policy()))
.collect()
}
pub fn names(&self) -> Vec<&str> {
self.order.iter().map(String::as_str).collect()
}
pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn AgentTool>> {
self.order.iter().filter_map(|name| self.tools.get(name))
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn len(&self) -> usize {
self.tools.len()
}
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tools", &self.order)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TextContent;
use schemars::JsonSchema;
use serde::Deserialize;
#[derive(Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[allow(dead_code)]
struct DocVariantArgs {
filename: String,
#[serde(default)]
title: Option<String>,
}
#[derive(Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[allow(dead_code)]
struct ExcelVariantArgs {
filename: String,
#[serde(default)]
rows: Vec<Vec<serde_json::Value>>,
}
#[derive(Deserialize, JsonSchema)]
#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
#[allow(dead_code)]
enum ExampleArgs {
Document(DocVariantArgs),
Excel(ExcelVariantArgs),
}
fn build_example_schema() -> Value {
let settings = schemars::gen::SchemaSettings::draft07().with(|s| {
s.inline_subschemas = true;
});
let g = settings.into_generator();
let s = g.into_root_schema_for::<ExampleArgs>();
let raw = serde_json::to_value(s).unwrap();
flatten_tagged_oneof_schema(raw)
}
#[derive(Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[allow(dead_code)]
struct NonAlphabeticOrderCanaryArgs {
zeta_selector: String,
alpha_payload: String,
middle_payload: String,
}
#[test]
fn schema_runtime_preserves_insertion_order_for_tool_objects() {
let mut object = serde_json::Map::new();
object.insert("zeta_selector".to_string(), Value::String("z".to_string()));
object.insert("alpha_payload".to_string(), Value::String("a".to_string()));
object.insert("middle_payload".to_string(), Value::String("m".to_string()));
let keys = object.keys().map(String::as_str).collect::<Vec<_>>();
assert_eq!(
keys,
["zeta_selector", "alpha_payload", "middle_payload"],
"serde_json::Map must keep insertion order; losing this breaks \
model-facing tool-schema property order"
);
let serialized = serde_json::to_string(&Value::Object(object)).unwrap();
assert_eq!(
serialized, r#"{"zeta_selector":"z","alpha_payload":"a","middle_payload":"m"}"#,
"schema JSON serialization must preserve object insertion order"
);
}
#[test]
fn schemars_preserves_declared_struct_order_for_tool_args() {
let settings = schemars::gen::SchemaSettings::draft07().with(|s| {
s.inline_subschemas = true;
});
let schema = serde_json::to_value(
settings
.into_generator()
.into_root_schema_for::<NonAlphabeticOrderCanaryArgs>(),
)
.expect("schema serializes");
let props = schema
.get("properties")
.and_then(Value::as_object)
.expect("schema must expose properties");
let order = props.keys().map(String::as_str).collect::<Vec<_>>();
assert_eq!(
order,
["zeta_selector", "alpha_payload", "middle_payload"],
"schemars must emit Args fields in declaration order for \
autoregressive tool-call conditioning"
);
}
#[test]
fn flatten_tagged_oneof_produces_flat_object_schema() {
let s = build_example_schema();
assert_eq!(s.get("type").and_then(Value::as_str), Some("object"));
assert!(s.get("oneOf").is_none());
let kind_prop = s.pointer("/properties/kind").expect("kind property");
assert_eq!(
kind_prop.get("type").and_then(Value::as_str),
Some("string")
);
let kind_enum = kind_prop
.get("enum")
.and_then(Value::as_array)
.expect("enum");
let mut kinds: Vec<&str> = kind_enum.iter().filter_map(Value::as_str).collect();
kinds.sort();
assert_eq!(kinds, vec!["document", "excel"]);
let props = s
.get("properties")
.and_then(Value::as_object)
.expect("properties");
let order: Vec<&str> = props.keys().map(String::as_str).collect();
assert_eq!(
order.first().copied(),
Some("kind"),
"discriminator must be emitted before payload fields so \
variant-specific keys are conditioned on the selected kind"
);
assert!(s.pointer("/properties/filename").is_some());
assert!(s.pointer("/properties/title").is_some());
assert!(s.pointer("/properties/rows").is_some());
let req = s
.get("required")
.and_then(Value::as_array)
.expect("required");
assert!(req.iter().any(|v| v.as_str() == Some("kind")));
}
#[test]
fn flatten_tagged_oneof_annotates_variant_specific_property_descriptions() {
let s = build_example_schema();
let filename_desc = s
.pointer("/properties/filename/description")
.and_then(Value::as_str)
.unwrap_or_default();
assert!(
!filename_desc.contains("applies when kind in"),
"shared property `filename` must NOT carry a narrowing \
suffix; got: {filename_desc:?}"
);
let title_desc = s
.pointer("/properties/title/description")
.and_then(Value::as_str)
.expect("title description present");
assert!(
title_desc.contains("applies when kind in: [document]"),
"Document-only `title` must declare its variant scope; \
got: {title_desc:?}"
);
let rows_desc = s
.pointer("/properties/rows/description")
.and_then(Value::as_str)
.expect("rows description present");
assert!(
rows_desc.contains("applies when kind in: [excel]"),
"Excel-only `rows` must declare its variant scope; \
got: {rows_desc:?}"
);
assert!(
s.get("allOf").is_none(),
"top-level allOf would be rejected by Azure's tool validator"
);
assert!(s.get("oneOf").is_none());
assert!(s.get("anyOf").is_none());
}
#[test]
fn normalize_strict_quirks_rewrites_items_true_to_empty_object() {
let mut schema = serde_json::json!({
"type": "object",
"properties": {
"rows": {
"type": "array",
"items": {
"type": "array",
"items": true
}
}
}
});
normalize_strict_validator_quirks(&mut schema);
assert_eq!(
schema.pointer("/properties/rows/items/items"),
Some(&serde_json::json!({})),
);
}
#[test]
fn strip_top_level_nulls_removes_inapplicable_variant_fields() {
let model_payload = serde_json::json!({
"action": "run",
"command": "echo hi",
"workdir": "/home/user/workspace",
"code": null,
"interpreter": null,
"ext": null,
"exec_dir": null,
"max_token": null,
"truncate_from": null,
"run_id": null,
"after_seq": null,
"max_events": null,
"timeout_s": null,
"timeout_ms": null,
"terminal": null,
"force": null,
"timeout_secs": 60,
});
let stripped = strip_top_level_nulls(model_payload);
let obj = stripped.as_object().expect("object");
assert!(!obj.contains_key("code"));
assert!(!obj.contains_key("ext"));
assert!(!obj.contains_key("max_token"));
assert!(!obj.contains_key("force"));
assert_eq!(obj.get("action").and_then(Value::as_str), Some("run"));
assert_eq!(obj.get("command").and_then(Value::as_str), Some("echo hi"));
assert_eq!(obj.get("timeout_secs").and_then(Value::as_i64), Some(60));
}
#[test]
fn strip_top_level_nulls_passes_through_non_object_values() {
assert_eq!(
strip_top_level_nulls(serde_json::json!("text")),
serde_json::json!("text")
);
assert_eq!(strip_top_level_nulls(Value::Null), Value::Null);
}
fn make_schema(properties: Value) -> Value {
serde_json::json!({
"type": "object",
"properties": properties,
})
}
#[test]
fn coerce_string_to_integer_when_schema_says_integer() {
let schema = make_schema(serde_json::json!({
"max_iterations": {"type": "integer"},
}));
let coerced = coerce_string_scalars_at_top_level(
serde_json::json!({"max_iterations": "50"}),
&schema,
);
assert_eq!(coerced, serde_json::json!({"max_iterations": 50}));
}
#[test]
fn coerce_string_to_integer_handles_negative_and_whitespace() {
let schema = make_schema(serde_json::json!({
"offset": {"type": "integer"},
"limit": {"type": "integer"},
}));
let coerced = coerce_string_scalars_at_top_level(
serde_json::json!({"offset": "-7", "limit": " 42 "}),
&schema,
);
assert_eq!(coerced, serde_json::json!({"offset": -7, "limit": 42}));
}
#[test]
fn coerce_string_to_boolean_for_each_case_variant() {
let schema = make_schema(serde_json::json!({
"full_page": {"type": "boolean"},
"headless": {"type": "boolean"},
"verbose": {"type": "boolean"},
"untouched": {"type": "boolean"},
}));
let coerced = coerce_string_scalars_at_top_level(
serde_json::json!({
"full_page": "true",
"headless": "True",
"verbose": "FALSE",
"untouched": "maybe",
}),
&schema,
);
assert_eq!(coerced["full_page"], serde_json::json!(true));
assert_eq!(coerced["headless"], serde_json::json!(true));
assert_eq!(coerced["verbose"], serde_json::json!(false));
assert_eq!(coerced["untouched"], serde_json::json!("maybe"));
}
#[test]
fn coerce_string_to_number_for_float_schema() {
let schema = make_schema(serde_json::json!({
"temperature": {"type": "number"},
}));
let coerced =
coerce_string_scalars_at_top_level(serde_json::json!({"temperature": "0.7"}), &schema);
let n = coerced["temperature"].as_f64().expect("number");
assert!((n - 0.7).abs() < 1e-9);
}
#[test]
fn coerce_leaves_string_fields_alone() {
let schema = make_schema(serde_json::json!({
"query": {"type": "string"},
"count": {"type": "integer"},
}));
let coerced = coerce_string_scalars_at_top_level(
serde_json::json!({"query": "50", "count": "50"}),
&schema,
);
assert_eq!(coerced["query"], serde_json::json!("50"));
assert_eq!(coerced["count"], serde_json::json!(50));
}
#[test]
fn coerce_leaves_unparseable_strings_alone() {
let schema = make_schema(serde_json::json!({
"max_iterations": {"type": "integer"},
}));
let coerced = coerce_string_scalars_at_top_level(
serde_json::json!({"max_iterations": "fifty"}),
&schema,
);
assert_eq!(coerced, serde_json::json!({"max_iterations": "fifty"}));
}
#[test]
fn coerce_treats_nullable_integer_as_integer() {
let schema = make_schema(serde_json::json!({
"max_iterations": {"type": ["integer", "null"]},
}));
let coerced = coerce_string_scalars_at_top_level(
serde_json::json!({"max_iterations": "20"}),
&schema,
);
assert_eq!(coerced, serde_json::json!({"max_iterations": 20}));
}
#[test]
fn coerce_skips_ambiguous_multi_type_schemas() {
let schema = make_schema(serde_json::json!({
"value": {"type": ["integer", "string"]},
}));
let coerced =
coerce_string_scalars_at_top_level(serde_json::json!({"value": "42"}), &schema);
assert_eq!(coerced, serde_json::json!({"value": "42"}));
}
#[test]
fn coerce_passes_through_object_without_properties() {
let schema = serde_json::json!({"type": "object"});
let coerced = coerce_string_scalars_at_top_level(serde_json::json!({"x": "50"}), &schema);
assert_eq!(coerced, serde_json::json!({"x": "50"}));
}
fn hint_for(json: Value, expected_target: &str) -> Option<String> {
#[derive(Debug, Deserialize, JsonSchema)]
#[allow(dead_code)]
struct UsizeField {
n: usize,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[allow(dead_code)]
struct BoolField {
b: bool,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[allow(dead_code)]
struct VecField {
items: Vec<serde_json::Value>,
}
let raw = match expected_target {
"usize" => serde_json::from_value::<UsizeField>(json).unwrap_err(),
"bool" => serde_json::from_value::<BoolField>(json).unwrap_err(),
"sequence" => serde_json::from_value::<VecField>(json).unwrap_err(),
_ => panic!("unknown target {expected_target}"),
};
Some(enrich_arg_parse_error_message(&raw))
}
#[test]
fn enrich_appends_integer_hint_for_string_encoded_int() {
let msg = hint_for(serde_json::json!({"n": "50"}), "usize").unwrap();
assert!(
msg.contains("Did you mean the integer 50"),
"expected integer hint, got: {msg}"
);
assert!(msg.contains("Resend without quotes"));
}
#[test]
fn enrich_appends_boolean_hint_for_string_encoded_bool() {
let msg = hint_for(serde_json::json!({"b": "True"}), "bool").unwrap();
assert!(
msg.contains("Did you mean true"),
"expected boolean hint, got: {msg}"
);
}
#[test]
fn enrich_appends_sequence_hint_for_string_in_array_slot() {
let xml_soup = "\n<ref>{\"kind\":\"file\",\"path\":\"x.md\"}</ref></artifact></file_write>";
let msg = hint_for(serde_json::json!({"items": xml_soup}), "sequence").unwrap();
assert!(
msg.contains("Expected a JSON array"),
"expected sequence hint, got: {msg}"
);
}
#[test]
fn enrich_passes_through_unrecognised_errors_unchanged() {
#[derive(Debug, Deserialize, JsonSchema)]
#[allow(dead_code)]
struct R {
n: usize,
}
let err = serde_json::from_value::<R>(serde_json::json!({})).unwrap_err();
let raw = err.to_string();
let enriched = enrich_arg_parse_error_message(&err);
assert_eq!(enriched, raw);
}
#[test]
fn flatten_tagged_oneof_passes_through_single_struct_schemas() {
let raw = serde_json::json!({
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"],
});
let out = flatten_tagged_oneof_schema(raw.clone());
assert_eq!(out, raw);
}
struct EchoTool;
#[async_trait]
impl AgentTool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echo arguments back as text"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"]
})
}
async fn execute(
&self,
_call_id: &str,
args: Value,
_signal: CancellationToken,
_update: ToolUpdateSink,
) -> Result<ToolResult, ToolError> {
let text = args
.get("text")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
Ok(ToolResult {
content: vec![ToolResultBlock::Text(TextContent { text })],
is_error: false,
details: Value::Null,
terminate: false,
narration: None,
})
}
}
#[test]
fn registry_lookup() {
let registry = ToolRegistry::new().with(Arc::new(EchoTool));
assert!(registry.get("echo").is_some());
assert!(registry.get("missing").is_none());
assert_eq!(registry.len(), 1);
}
struct NamedTool(&'static str);
#[async_trait]
impl AgentTool for NamedTool {
fn name(&self) -> &str {
self.0
}
fn description(&self) -> &str {
"named"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(
&self,
_call_id: &str,
_args: Value,
_signal: CancellationToken,
_update: ToolUpdateSink,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text("ok"))
}
}
#[test]
fn registry_preserves_registration_order() {
let mut registry = ToolRegistry::new()
.with(Arc::new(NamedTool("message_result")))
.with(Arc::new(NamedTool("message_ask")))
.with(Arc::new(NamedTool("plan")));
registry.register(Arc::new(NamedTool("message_result")));
assert_eq!(
registry.names(),
vec!["message_result", "message_ask", "plan"]
);
assert_eq!(
registry.iter().map(|tool| tool.name()).collect::<Vec<_>>(),
vec!["message_result", "message_ask", "plan"]
);
}
#[tokio::test]
async fn echo_tool_executes() {
let tool = EchoTool;
let (tx, _rx) = mpsc::unbounded_channel();
let result = tool
.execute(
"call_1",
serde_json::json!({"text": "hi"}),
CancellationToken::new(),
tx,
)
.await
.unwrap();
let ToolResultBlock::Text(t) = &result.content[0] else {
panic!("expected text")
};
assert_eq!(t.text, "hi");
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
struct CoercibleArgs {
max_iterations: usize,
full_page: bool,
temperature: f32,
label: String,
}
struct CoercibleTool;
#[async_trait]
impl TypedAgentTool for CoercibleTool {
type Args = CoercibleArgs;
fn name(&self) -> &str {
"coercible"
}
fn description(&self) -> &str {
"fixture"
}
async fn run(
&self,
_call_id: &str,
args: Self::Args,
_signal: CancellationToken,
_update: ToolUpdateSink,
) -> Result<ToolResult, ToolError> {
Ok(ToolResult::text(format!(
"max_iterations={} full_page={} temperature={} label={}",
args.max_iterations, args.full_page, args.temperature, args.label
)))
}
}
#[tokio::test]
async fn execute_coerces_string_encoded_scalars_end_to_end() {
let tool = CoercibleTool;
let (tx, _rx) = mpsc::unbounded_channel();
let result = AgentTool::execute(
&tool,
"call_1",
serde_json::json!({
"max_iterations": "50",
"full_page": "True",
"temperature": "0.7",
"label": "actual string",
}),
CancellationToken::new(),
tx,
)
.await
.unwrap();
let ToolResultBlock::Text(t) = &result.content[0] else {
panic!("expected text result");
};
assert!(
t.text.contains("max_iterations=50"),
"integer coercion missing: {}",
t.text
);
assert!(
t.text.contains("full_page=true"),
"boolean coercion missing: {}",
t.text
);
assert!(
t.text.contains("temperature=0.7"),
"float coercion missing: {}",
t.text
);
assert!(
t.text.contains("label=actual string"),
"string field must NOT be coerced: {}",
t.text
);
assert!(!result.is_error, "execute must succeed after coercion");
}
#[tokio::test]
async fn execute_appends_self_correcting_hint_on_unrecoverable_string_int() {
let tool = CoercibleTool;
let (tx, _rx) = mpsc::unbounded_channel();
let result = AgentTool::execute(
&tool,
"call_2",
serde_json::json!({
"max_iterations": "fifty",
"full_page": true,
"temperature": 0.1,
"label": "x",
}),
CancellationToken::new(),
tx,
)
.await
.unwrap();
assert!(result.is_error, "expected validator rejection");
let ToolResultBlock::Text(t) = &result.content[0] else {
panic!("expected text result");
};
assert!(
t.text.starts_with("coercible: invalid arguments:"),
"preserve canonical error prefix: {}",
t.text
);
assert!(
!t.text.contains("Did you mean the integer fifty"),
"must not invent a hint when the value cannot parse: {}",
t.text
);
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(tag = "action", rename_all = "snake_case")]
enum TaggedArgs {
Open { url: String },
Reload {},
}
struct TaggedTool;
#[async_trait]
impl TypedAgentTool for TaggedTool {
type Args = TaggedArgs;
fn name(&self) -> &str {
"tagged_fixture"
}
fn description(&self) -> &str {
"fixture"
}
fn prepare_arguments(&self, args: Value) -> Value {
let Value::Object(mut obj) = args else {
return args;
};
if !obj.contains_key("action") && obj.contains_key("url") {
obj.insert("action".to_string(), Value::String("open".to_string()));
}
Value::Object(obj)
}
async fn run(
&self,
_call_id: &str,
args: Self::Args,
_signal: CancellationToken,
_update: ToolUpdateSink,
) -> Result<ToolResult, ToolError> {
let label = match args {
TaggedArgs::Open { url } => format!("open:{url}"),
TaggedArgs::Reload {} => "reload".to_string(),
};
Ok(ToolResult::text(label))
}
}
#[tokio::test]
async fn execute_runs_prepare_arguments_before_typed_deser() {
let tool = TaggedTool;
let (tx, _rx) = mpsc::unbounded_channel();
let result = AgentTool::execute(
&tool,
"call_1",
serde_json::json!({"url": "https://example.com"}),
CancellationToken::new(),
tx,
)
.await
.unwrap();
let ToolResultBlock::Text(t) = &result.content[0] else {
panic!("expected text result");
};
assert!(
!result.is_error,
"execute must succeed after action inference"
);
assert_eq!(t.text, "open:https://example.com");
}
#[tokio::test]
async fn execute_prepare_arguments_does_not_override_explicit_action() {
let tool = TaggedTool;
let (tx, _rx) = mpsc::unbounded_channel();
let result = AgentTool::execute(
&tool,
"call_2",
serde_json::json!({"action": "reload"}),
CancellationToken::new(),
tx,
)
.await
.unwrap();
let ToolResultBlock::Text(t) = &result.content[0] else {
panic!("expected text result");
};
assert!(!result.is_error);
assert_eq!(t.text, "reload");
}
}