use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ObjectJsonSchema {
#[serde(rename = "type")]
pub schema_type: String,
pub properties: IndexMap<String, JsonValue>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub required: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(
rename = "additionalProperties",
skip_serializing_if = "Option::is_none"
)]
pub additional_properties: Option<bool>,
#[serde(flatten)]
pub extra: HashMap<String, JsonValue>,
}
impl ObjectJsonSchema {
#[must_use]
pub fn new() -> Self {
Self {
schema_type: "object".to_string(),
properties: IndexMap::new(),
required: Vec::new(),
description: None,
additional_properties: None,
extra: HashMap::new(),
}
}
#[must_use]
pub fn with_property(mut self, name: &str, schema: JsonValue, required: bool) -> Self {
self.properties.insert(name.to_string(), schema);
if required {
self.required.push(name.to_string());
}
self
}
pub fn add_property(&mut self, name: &str, schema: JsonValue, required: bool) {
self.properties.insert(name.to_string(), schema);
if required && !self.required.contains(&name.to_string()) {
self.required.push(name.to_string());
}
}
#[must_use]
pub fn with_description(mut self, desc: &str) -> Self {
self.description = Some(desc.to_string());
self
}
#[must_use]
pub fn with_additional_properties(mut self, allowed: bool) -> Self {
self.additional_properties = Some(allowed);
self
}
#[must_use]
pub fn with_extra(mut self, key: &str, value: JsonValue) -> Self {
self.extra.insert(key.to_string(), value);
self
}
#[must_use]
pub fn is_required(&self, name: &str) -> bool {
self.required.contains(&name.to_string())
}
#[must_use]
pub fn get_property(&self, name: &str) -> Option<&JsonValue> {
self.properties.get(name)
}
#[must_use]
pub fn property_count(&self) -> usize {
self.properties.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.properties.is_empty()
}
pub fn to_json(&self) -> Result<JsonValue, serde_json::Error> {
serde_json::to_value(self)
}
}
impl Default for ObjectJsonSchema {
fn default() -> Self {
Self::new()
}
}
impl TryFrom<JsonValue> for ObjectJsonSchema {
type Error = serde_json::Error;
fn try_from(value: JsonValue) -> Result<Self, Self::Error> {
serde_json::from_value(value)
}
}
impl From<ObjectJsonSchema> for JsonValue {
fn from(schema: ObjectJsonSchema) -> Self {
serde_json::to_value(schema).unwrap_or(JsonValue::Null)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters_json_schema: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub outer_typed_dict_key: Option<String>,
}
impl ToolDefinition {
#[must_use]
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters_json_schema: crate::schema::SchemaBuilder::new()
.build()
.expect("SchemaBuilder JSON serialization failed"),
strict: None,
outer_typed_dict_key: None,
}
}
#[must_use]
pub fn with_parameters(mut self, schema: impl Into<JsonValue>) -> Self {
self.parameters_json_schema = schema.into();
self
}
#[must_use]
pub fn with_strict(mut self, strict: bool) -> Self {
self.strict = Some(strict);
self
}
#[must_use]
pub fn with_outer_typed_dict_key(mut self, key: impl Into<String>) -> Self {
self.outer_typed_dict_key = Some(key.into());
self
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn description(&self) -> &str {
&self.description
}
#[must_use]
pub fn parameters(&self) -> &JsonValue {
&self.parameters_json_schema
}
#[must_use]
pub fn is_strict(&self) -> bool {
self.strict.unwrap_or(false)
}
#[must_use]
pub fn to_openai_function(&self) -> JsonValue {
let mut func = serde_json::json!({
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters_json_schema.clone()
}
});
if let Some(strict) = self.strict {
func["function"]["strict"] = JsonValue::Bool(strict);
}
func
}
#[must_use]
pub fn to_anthropic_tool(&self) -> JsonValue {
serde_json::json!({
"name": self.name,
"description": self.description,
"input_schema": self.parameters_json_schema.clone()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_object_json_schema_new() {
let schema = ObjectJsonSchema::new();
assert_eq!(schema.schema_type, "object");
assert!(schema.properties.is_empty());
assert!(schema.required.is_empty());
}
#[test]
fn test_object_json_schema_with_property() {
let schema = ObjectJsonSchema::new()
.with_property("name", serde_json::json!({"type": "string"}), true)
.with_property("age", serde_json::json!({"type": "integer"}), false);
assert_eq!(schema.property_count(), 2);
assert!(schema.is_required("name"));
assert!(!schema.is_required("age"));
}
#[test]
fn test_tool_definition_new() {
let def = ToolDefinition::new("get_weather", "Get the current weather");
assert_eq!(def.name(), "get_weather");
assert_eq!(def.description(), "Get the current weather");
let properties = def
.parameters()
.get("properties")
.and_then(|value| value.as_object())
.unwrap();
assert!(properties.is_empty());
}
#[test]
fn test_tool_definition_with_parameters() {
let params = crate::schema::SchemaBuilder::new()
.string("location", "City name", true)
.enum_values(
"unit",
"Temperature unit",
&["celsius", "fahrenheit"],
false,
)
.build()
.expect("SchemaBuilder JSON serialization failed");
let def = ToolDefinition::new("get_weather", "Get weather")
.with_parameters(params)
.with_strict(true);
assert!(def.is_strict());
let properties = def
.parameters()
.get("properties")
.and_then(|value| value.as_object())
.unwrap();
assert_eq!(properties.len(), 2);
}
#[test]
fn test_to_openai_function() {
let def = ToolDefinition::new("test", "Test tool")
.with_parameters(
crate::schema::SchemaBuilder::new()
.string("x", "A value", true)
.build()
.expect("SchemaBuilder JSON serialization failed"),
)
.with_strict(true);
let func = def.to_openai_function();
assert_eq!(func["type"], "function");
assert_eq!(func["function"]["name"], "test");
assert_eq!(func["function"]["strict"], true);
}
#[test]
fn test_to_anthropic_tool() {
let def = ToolDefinition::new("test", "Test tool");
let tool = def.to_anthropic_tool();
assert_eq!(tool["name"], "test");
assert!(tool.get("input_schema").is_some());
}
#[test]
fn test_serde_roundtrip() {
let schema = ObjectJsonSchema::new()
.with_property("x", serde_json::json!({"type": "string"}), true)
.with_description("Test schema");
let json = serde_json::to_string(&schema).unwrap();
let parsed: ObjectJsonSchema = serde_json::from_str(&json).unwrap();
assert_eq!(schema, parsed);
}
}