use serde_json::Value;
use super::tool::ToolError;
pub fn generate_tool_schema<T: schemars::JsonSchema>() -> Value {
let settings = schemars::generate::SchemaSettings::default().with(|s| {
s.inline_subschemas = true;
s.meta_schema = None;
});
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();
let mut value = serde_json::to_value(schema).expect("schema serialization cannot fail");
if let Some(obj) = value.as_object_mut() {
obj.remove("$schema");
obj.remove("$defs");
obj.remove("definitions");
}
sanitize_for_llm(&mut value);
value
}
pub fn sanitize_for_llm(schema: &mut Value) {
simplify_any_of(schema);
fix_missing_fields(schema);
}
pub fn validate_against_schema(schema: &Value, args: &Value) -> Result<(), ToolError> {
if schema.is_null() {
return Ok(());
}
if let Some(obj) = schema.as_object() {
let has_properties = obj
.get("properties")
.and_then(|v| v.as_object())
.is_some_and(|p| !p.is_empty());
let has_required = obj
.get("required")
.and_then(|v| v.as_array())
.is_some_and(|r| !r.is_empty());
if !has_properties && !has_required {
return Ok(());
}
}
let validator = jsonschema::validator_for(schema)
.map_err(|e| ToolError::Internal(format!("invalid tool schema: {e}")))?;
let errors: Vec<String> = validator.iter_errors(args).map(|e| e.to_string()).collect();
if errors.is_empty() {
Ok(())
} else {
Err(ToolError::InvalidArguments(errors.join("; ")))
}
}
fn simplify_any_of(value: &mut Value) {
if let Some(obj) = value.as_object_mut() {
for v in obj.values_mut() {
simplify_any_of(v);
}
if let Some(type_val) = obj.get_mut("type")
&& let Some(arr) = type_val.as_array().cloned()
{
let non_null: Vec<&Value> = arr.iter().filter(|v| v.as_str() != Some("null")).collect();
if non_null.len() == 1 && non_null.len() < arr.len() {
*type_val = non_null[0].clone();
}
}
if let Some(any_of) = obj.remove("anyOf") {
if let Some(arr) = any_of.as_array()
&& arr.len() == 2
{
let (null_idx, non_null_idx) = if is_null_schema(&arr[0]) {
(0, 1)
} else if is_null_schema(&arr[1]) {
(1, 0)
} else {
obj.insert("anyOf".to_string(), any_of);
return;
};
let _ = null_idx;
if let Some(non_null_obj) = arr[non_null_idx].as_object() {
for (k, v) in non_null_obj {
obj.insert(k.clone(), v.clone());
}
}
return;
}
obj.insert("anyOf".to_string(), any_of);
}
} else if let Some(arr) = value.as_array_mut() {
for item in arr.iter_mut() {
simplify_any_of(item);
}
}
}
fn is_null_schema(value: &Value) -> bool {
value
.as_object()
.and_then(|o| o.get("type"))
.and_then(|t| t.as_str())
.is_some_and(|s| s == "null")
}
fn fix_missing_fields(value: &mut Value) {
if let Some(obj) = value.as_object_mut() {
for v in obj.values_mut() {
fix_missing_fields(v);
}
if let Some(ty) = obj.get("type").and_then(|t| t.as_str()).map(String::from) {
if ty == "array" && !obj.contains_key("items") {
obj.insert("items".to_string(), Value::Object(Default::default()));
}
if ty == "object" && !obj.contains_key("properties") {
obj.insert("properties".to_string(), Value::Object(Default::default()));
}
}
} else if let Some(arr) = value.as_array_mut() {
for item in arr.iter_mut() {
fix_missing_fields(item);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::json;
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct SimpleArgs {
query: String,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct OptionalArgs {
query: String,
limit: Option<u32>,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct ListArgs {
items: Vec<String>,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct Inner {
x: i32,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct Outer {
inner: Inner,
}
#[derive(Deserialize, JsonSchema)]
#[allow(dead_code)]
enum Format {
Json,
Yaml,
Xml,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct FormatArgs {
format: Format,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct ComplexArgs {
name: String,
tags: Vec<String>,
count: Option<u32>,
inner: Inner,
format: Option<Format>,
}
#[test]
fn generate_simple_struct() {
let schema = generate_tool_schema::<SimpleArgs>();
assert_eq!(schema["type"], "object");
assert_eq!(schema["properties"]["query"]["type"], "string");
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&json!("query")));
}
#[test]
fn option_field_simplified() {
let schema = generate_tool_schema::<OptionalArgs>();
assert!(schema["properties"]["limit"].get("anyOf").is_none());
assert_eq!(schema["properties"]["limit"]["type"], "integer");
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&json!("query")));
assert!(!required.contains(&json!("limit")));
}
#[test]
fn vec_field_has_items() {
let schema = generate_tool_schema::<ListArgs>();
assert_eq!(schema["properties"]["items"]["type"], "array");
assert!(schema["properties"]["items"].get("items").is_some());
}
#[test]
fn nested_struct_inlined() {
let schema = generate_tool_schema::<Outer>();
let output = serde_json::to_string(&schema).unwrap();
assert!(!output.contains("$ref"));
assert_eq!(
schema["properties"]["inner"]["properties"]["x"]["type"],
"integer"
);
}
#[test]
fn enum_field_generates_enum_values() {
let schema = generate_tool_schema::<FormatArgs>();
let output = serde_json::to_string(&schema).unwrap();
assert!(output.contains("Json"));
assert!(output.contains("Yaml"));
}
#[test]
fn no_meta_fields_in_output() {
let schema = generate_tool_schema::<ComplexArgs>();
assert!(schema.get("$schema").is_none());
assert!(schema.get("$defs").is_none());
assert!(schema.get("definitions").is_none());
}
#[test]
fn sanitize_is_idempotent() {
let mut schema = generate_tool_schema::<ComplexArgs>();
let before = schema.clone();
sanitize_for_llm(&mut schema);
assert_eq!(schema, before);
}
#[test]
fn complex_struct_is_llm_friendly() {
let schema = generate_tool_schema::<ComplexArgs>();
let output = serde_json::to_string(&schema).unwrap();
assert!(!output.contains("$ref"));
assert!(!output.contains("$defs"));
assert!(!output.contains("$schema"));
assert!(!output.contains("anyOf"));
assert!(schema["properties"]["tags"].get("items").is_some());
assert!(schema["properties"]["inner"].get("properties").is_some());
}
#[test]
fn gemini_compatible_no_any_of() {
let schema = generate_tool_schema::<ComplexArgs>();
let output = serde_json::to_string(&schema).unwrap();
assert!(!output.contains("anyOf"));
assert!(!output.contains("oneOf"));
assert!(!output.contains("allOf"));
}
#[test]
fn openai_anthropic_valid_json_schema() {
let schema = generate_tool_schema::<ComplexArgs>();
assert!(jsonschema::validator_for(&schema).is_ok());
}
#[test]
fn validate_accepts_valid_input() {
let schema = generate_tool_schema::<SimpleArgs>();
let args = json!({"query": "hello"});
assert!(validate_against_schema(&schema, &args).is_ok());
}
#[test]
fn validate_rejects_wrong_type() {
let schema = generate_tool_schema::<SimpleArgs>();
let args = json!({"query": 42});
assert!(validate_against_schema(&schema, &args).is_err());
}
#[test]
fn validate_rejects_missing_required() {
let schema = generate_tool_schema::<SimpleArgs>();
let args = json!({});
assert!(validate_against_schema(&schema, &args).is_err());
}
#[test]
fn validate_skips_empty_schema() {
let schema = json!({"type": "object"});
let args = json!({"anything": true});
assert!(validate_against_schema(&schema, &args).is_ok());
}
#[test]
fn validate_skips_null_schema() {
let schema = Value::Null;
let args = json!({"anything": true});
assert!(validate_against_schema(&schema, &args).is_ok());
}
#[test]
fn validate_optional_field_accepts_absence() {
let schema = generate_tool_schema::<OptionalArgs>();
let args = json!({"query": "hello"});
assert!(validate_against_schema(&schema, &args).is_ok());
}
#[test]
fn validate_optional_field_accepts_value() {
let schema = generate_tool_schema::<OptionalArgs>();
let args = json!({"query": "hello", "limit": 10});
assert!(validate_against_schema(&schema, &args).is_ok());
}
#[test]
fn validate_optional_field_rejects_wrong_type() {
let schema = generate_tool_schema::<OptionalArgs>();
let args = json!({"query": "hello", "limit": "not a number"});
assert!(validate_against_schema(&schema, &args).is_err());
}
}