1use schemars::{
2 JsonSchema,
3 generate::SchemaSettings,
4 transform::{self, Transform},
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::borrow::Cow;
9use thiserror::Error;
10
11#[cfg(feature = "pegboard")]
12mod pegboard;
13#[cfg(feature = "pegboard")]
14pub use pegboard::{prefix_tool_name, PegBoard, PegBoardError};
15
16#[cfg(feature = "pegboard")]
18pub use rmcp::model::CallToolResult;
19
20#[derive(Debug, Error)]
22pub enum ToolError {
23 #[error("Failed to serialize JSON schema: {0}")]
24 SchemaSerialization(#[from] serde_json::Error),
25 #[error("Schema type field is required")]
26 MissingSchemaType,
27 #[error("Schema must be of type 'object', got: {0}")]
28 InvalidSchemaType(String),
29}
30
31#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
34pub struct Tool {
35 pub name: Cow<'static, str>,
36 pub description: Option<Cow<'static, str>>,
37 pub input_schema: serde_json::Value,
38}
39
40#[derive(Debug, Clone)]
41#[non_exhaustive]
42pub struct AddNullable {
43 pub remove_null_type: bool,
45}
46
47impl Default for AddNullable {
48 fn default() -> Self {
49 Self {
50 remove_null_type: true,
51 }
52 }
53}
54
55impl AddNullable {
56 fn has_type(schema: &schemars::Schema, ty: &str) -> bool {
57 match schema.get("type") {
58 Some(Value::Array(values)) => values.iter().any(|v| v.as_str() == Some(ty)),
59 Some(Value::String(s)) => s == ty,
60 _ => false,
61 }
62 }
63}
64
65impl Transform for AddNullable {
66 fn transform(&mut self, schema: &mut schemars::Schema) {
67 if Self::has_type(schema, "null") {
68 if let Some(ty) = schema.get_mut("type")
70 && self.remove_null_type
71 {
72 if let Value::Array(array) = ty {
74 array.retain(|t| t.as_str() != Some("null"));
75 if array.len() == 1 {
76 *ty = array[0].clone();
77 }
78 }
79
80 if let Some(Value::Array(enum_array)) = schema.get_mut("enum") {
82 enum_array.retain(|v| !v.is_null());
83 }
84 }
85 }
86
87 transform::transform_subschemas(self, schema);
88 }
89}
90
91pub fn get_tool<T: JsonSchema, S1, S2>(name: S1, desc: Option<S2>) -> Result<Tool, ToolError>
92where
93 S1: Into<Cow<'static, str>>,
94 S2: Into<Cow<'static, str>>,
95{
96 let json_value = parse_input_schema::<T>()?;
97 Ok(Tool {
98 name: name.into(),
99 description: desc.map(Into::into),
100 input_schema: json_value,
101 })
102}
103
104#[macro_export]
105macro_rules! define_tool {
106 ($tool_name:ident, $function_name:expr, $description:expr, $param_type:ty) => {
107 paste::paste! {
108 static [<$tool_name _ONCE_LOCK>]: std::sync::OnceLock<Result<$crate::Tool, $crate::ToolError>> = ::std::sync::OnceLock::new();
109
110 pub fn [<get_ $tool_name:lower>]() -> Result<&'static $crate::Tool, &'static $crate::ToolError> {
111 [<$tool_name _ONCE_LOCK>].get_or_init(|| {
112 $crate::get_tool::<$param_type, _, _>(
113 $function_name,
114 Some($description),
115 )
116 }).as_ref()
117 }
118 }
119 };
120}
121
122pub fn parse_input_schema<T: JsonSchema>() -> Result<serde_json::Value, ToolError> {
123 let settings = SchemaSettings::draft2019_09()
124 .with(|s| {
125 s.inline_subschemas = true;
128 })
129 .with_transform(AddNullable::default());
130 let schema = settings.into_generator().into_root_schema_for::<T>();
131 let mut json_value = serde_json::to_value(schema)?;
132 let schema_type = json_value.get("type").ok_or(ToolError::MissingSchemaType)?;
133
134 match schema_type {
135 Value::String(s) if s == "object" => {
136 }
138 Value::String(s) => {
139 return Err(ToolError::InvalidSchemaType(s.clone()));
140 }
141 other => {
142 return Err(ToolError::InvalidSchemaType(format!("{:?}", other)));
143 }
144 }
145
146 if let Some(obj) = json_value.as_object_mut() {
147 obj.remove("$schema");
148 obj.remove("title");
149 obj.remove("definitions");
150 }
151 Ok(json_value)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use schemars::JsonSchema;
158
159 #[derive(JsonSchema, serde::Deserialize)]
160 #[allow(dead_code)]
161 pub struct WeatherParams {
162 pub location: String,
164 pub unit: Option<UnitEnum>,
165 pub arr: Option<Vec<String>>,
166 }
167
168 #[derive(JsonSchema, serde::Deserialize, PartialEq, Debug)]
169 #[serde(rename_all = "lowercase")]
170 pub enum UnitEnum {
171 Celsius,
172 Fahrenheit,
173 }
174
175 #[test]
176 fn test_parse_input_schema() {
177 let schema = parse_input_schema::<WeatherParams>().unwrap();
178
179 let got = serde_json::to_string(&schema).unwrap();
180 let want = serde_json::json!({
181 "type": "object",
182 "properties": {
183 "location": {
184 "type": "string",
185 "description": "The city and state, e.g. San Francisco, CA"
186 },
187 "unit": {
188 "type": "string",
189 "enum": ["celsius", "fahrenheit"]
190 },
191 "arr": {
192 "type": "array",
193 "items": {
194 "type": "string"
195 }
196 }
197 },
198 "required": ["location"]
199 })
200 .to_string();
201 assert_eq!(want, got, "Expected: {} Got: {}", want, got);
202 }
203
204 #[test]
205 fn test_get_tool() {
206 let tool = get_tool::<WeatherParams, _, _>(
207 "get_weather",
208 Some("Get the current weather in a given location"),
209 )
210 .unwrap();
211
212 assert_eq!(tool.name, "get_weather");
213 assert_eq!(
214 tool.description,
215 Some(Cow::Borrowed("Get the current weather in a given location"))
216 );
217
218 let schema = &tool.input_schema;
220 assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
221 assert!(schema.get("properties").is_some());
222 assert!(schema.get("required").is_some());
223 }
224
225 #[test]
226 fn test_deserialize_params() {
227 let s = r#"{"location":"Boston, MA","unit":"celsius"}"#;
228 let got: WeatherParams = serde_json::from_str(s).unwrap();
229 assert_eq!(
230 "Boston, MA", got.location,
231 "Expected: Boston, MA Got: {}",
232 got.location
233 );
234 assert_eq!(
235 Some(UnitEnum::Celsius),
236 got.unit,
237 "Expected: celsius got: {:?}",
238 got.unit
239 );
240 }
241
242 #[test]
243 fn test_define_tool_macro() {
244 define_tool!(
245 WEATHER_TOOL,
246 "get_weather",
247 "Get the current weather",
248 WeatherParams
249 );
250
251 let tool = get_weather_tool().unwrap();
252 assert_eq!(tool.name, "get_weather");
253 assert_eq!(
254 tool.description,
255 Some(Cow::Borrowed("Get the current weather"))
256 );
257 assert!(tool.input_schema.get("properties").is_some());
258 assert!(tool.input_schema.get("required").is_some());
259 }
260
261 #[test]
262 fn test_nullable_handling() {
263 #[derive(JsonSchema)]
264 #[allow(dead_code)]
265 struct TestNullable {
266 required_field: String,
267 optional_field: Option<String>,
268 }
269
270 let schema = parse_input_schema::<TestNullable>().unwrap();
271
272 let props = schema.get("properties").unwrap().as_object().unwrap();
274 let optional = props.get("optional_field").unwrap();
275
276 if let Some(type_val) = optional.get("type") {
278 assert_ne!(type_val.as_str(), Some("null"));
279 }
280
281 let required = schema.get("required").unwrap().as_array().unwrap();
283 assert_eq!(required.len(), 1);
284 assert_eq!(required[0].as_str(), Some("required_field"));
285 }
286}