use schemars::{
JsonSchema,
generate::SchemaSettings,
transform::{self, Transform},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use thiserror::Error;
#[cfg(feature = "pegboard")]
mod pegboard;
#[cfg(feature = "pegboard")]
pub use pegboard::{PegBoard, PegBoardError};
#[cfg(feature = "pegboard")]
pub use rmcp::model::CallToolResult;
#[derive(Debug, Error)]
pub enum ToolError {
#[error("Failed to serialize JSON schema: {0}")]
SchemaSerialization(#[from] serde_json::Error),
#[error("Schema type field is required")]
MissingSchemaType,
#[error("Schema must be of type 'object', got: {0}")]
InvalidSchemaType(String),
}
#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
pub struct Tool {
pub name: Cow<'static, str>,
pub description: Option<Cow<'static, str>>,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct AddNullable {
pub remove_null_type: bool,
}
impl Default for AddNullable {
fn default() -> Self {
Self {
remove_null_type: true,
}
}
}
impl AddNullable {
fn has_type(schema: &schemars::Schema, ty: &str) -> bool {
match schema.get("type") {
Some(Value::Array(values)) => values.iter().any(|v| v.as_str() == Some(ty)),
Some(Value::String(s)) => s == ty,
_ => false,
}
}
}
impl Transform for AddNullable {
fn transform(&mut self, schema: &mut schemars::Schema) {
if Self::has_type(schema, "null") {
if let Some(ty) = schema.get_mut("type")
&& self.remove_null_type
{
if let Value::Array(array) = ty {
array.retain(|t| t.as_str() != Some("null"));
if array.len() == 1 {
*ty = array[0].clone();
}
}
if let Some(Value::Array(enum_array)) = schema.get_mut("enum") {
enum_array.retain(|v| !v.is_null());
}
}
}
transform::transform_subschemas(self, schema);
}
}
pub fn get_tool<T: JsonSchema, S1, S2>(name: S1, desc: Option<S2>) -> Result<Tool, ToolError>
where
S1: Into<Cow<'static, str>>,
S2: Into<Cow<'static, str>>,
{
let json_value = parse_input_schema::<T>()?;
Ok(Tool {
name: name.into(),
description: desc.map(Into::into),
input_schema: json_value,
})
}
#[macro_export]
macro_rules! define_tool {
($tool_name:ident, $function_name:expr, $description:expr, $param_type:ty) => {
paste::paste! {
static [<$tool_name _ONCE_LOCK>]: std::sync::OnceLock<Result<$crate::Tool, $crate::ToolError>> = ::std::sync::OnceLock::new();
pub fn [<get_ $tool_name:lower>]() -> Result<&'static $crate::Tool, &'static $crate::ToolError> {
[<$tool_name _ONCE_LOCK>].get_or_init(|| {
$crate::get_tool::<$param_type, _, _>(
$function_name,
Some($description),
)
}).as_ref()
}
}
};
}
pub fn parse_input_schema<T: JsonSchema>() -> Result<serde_json::Value, ToolError> {
let settings = SchemaSettings::draft2019_09()
.with(|s| {
s.inline_subschemas = true;
})
.with_transform(AddNullable::default());
let schema = settings.into_generator().into_root_schema_for::<T>();
let mut json_value = serde_json::to_value(schema)?;
let schema_type = json_value.get("type").ok_or(ToolError::MissingSchemaType)?;
match schema_type {
Value::String(s) if s == "object" => {
}
Value::String(s) => {
return Err(ToolError::InvalidSchemaType(s.clone()));
}
other => {
return Err(ToolError::InvalidSchemaType(format!("{:?}", other)));
}
}
if let Some(obj) = json_value.as_object_mut() {
obj.remove("$schema");
obj.remove("title");
obj.remove("definitions");
}
Ok(json_value)
}
#[cfg(test)]
mod tests {
use super::*;
use schemars::JsonSchema;
#[derive(JsonSchema, serde::Deserialize)]
#[allow(dead_code)]
pub struct WeatherParams {
pub location: String,
pub unit: Option<UnitEnum>,
pub arr: Option<Vec<String>>,
}
#[derive(JsonSchema, serde::Deserialize, PartialEq, Debug)]
#[serde(rename_all = "lowercase")]
pub enum UnitEnum {
Celsius,
Fahrenheit,
}
#[test]
fn test_parse_input_schema() {
let schema = parse_input_schema::<WeatherParams>().unwrap();
let got = serde_json::to_string(&schema).unwrap();
let want = serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
},
"arr": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": ["location"]
})
.to_string();
assert_eq!(want, got, "Expected: {} Got: {}", want, got);
}
#[test]
fn test_get_tool() {
let tool = get_tool::<WeatherParams, _, _>(
"get_weather",
Some("Get the current weather in a given location"),
)
.unwrap();
assert_eq!(tool.name, "get_weather");
assert_eq!(
tool.description,
Some(Cow::Borrowed("Get the current weather in a given location"))
);
let schema = &tool.input_schema;
assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
assert!(schema.get("properties").is_some());
assert!(schema.get("required").is_some());
}
#[test]
fn test_deserialize_params() {
let s = r#"{"location":"Boston, MA","unit":"celsius"}"#;
let got: WeatherParams = serde_json::from_str(s).unwrap();
assert_eq!(
"Boston, MA", got.location,
"Expected: Boston, MA Got: {}",
got.location
);
assert_eq!(
Some(UnitEnum::Celsius),
got.unit,
"Expected: celsius got: {:?}",
got.unit
);
}
#[test]
fn test_define_tool_macro() {
define_tool!(
WEATHER_TOOL,
"get_weather",
"Get the current weather",
WeatherParams
);
let tool = get_weather_tool().unwrap();
assert_eq!(tool.name, "get_weather");
assert_eq!(
tool.description,
Some(Cow::Borrowed("Get the current weather"))
);
assert!(tool.input_schema.get("properties").is_some());
assert!(tool.input_schema.get("required").is_some());
}
#[test]
fn test_nullable_handling() {
#[derive(JsonSchema)]
#[allow(dead_code)]
struct TestNullable {
required_field: String,
optional_field: Option<String>,
}
let schema = parse_input_schema::<TestNullable>().unwrap();
let props = schema.get("properties").unwrap().as_object().unwrap();
let optional = props.get("optional_field").unwrap();
if let Some(type_val) = optional.get("type") {
assert_ne!(type_val.as_str(), Some("null"));
}
let required = schema.get("required").unwrap().as_array().unwrap();
assert_eq!(required.len(), 1);
assert_eq!(required[0].as_str(), Some("required_field"));
}
}