1use schemars::{
2    JsonSchema,
3    generate::SchemaSettings,
4    transform::{self, Transform},
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::borrow::Cow;
9use thiserror::Error;
10
11#[cfg(feature = "pegboard")]
12mod pegboard;
13#[cfg(feature = "pegboard")]
14pub use pegboard::{prefix_tool_name, PegBoard, PegBoardError};
15
16#[cfg(feature = "pegboard")]
18pub use rmcp::model::CallToolResult;
19
20#[derive(Debug, Error)]
22pub enum ToolError {
23    #[error("Failed to serialize JSON schema: {0}")]
24    SchemaSerialization(#[from] serde_json::Error),
25    #[error("Schema type field is required")]
26    MissingSchemaType,
27    #[error("Schema must be of type 'object', got: {0}")]
28    InvalidSchemaType(String),
29}
30
31#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
34pub struct Tool {
35    pub name: Cow<'static, str>,
36    pub description: Option<Cow<'static, str>>,
37    pub input_schema: serde_json::Value,
38}
39
40#[derive(Debug, Clone)]
41#[non_exhaustive]
42pub struct AddNullable {
43    pub remove_null_type: bool,
45}
46
47impl Default for AddNullable {
48    fn default() -> Self {
49        Self {
50            remove_null_type: true,
51        }
52    }
53}
54
55impl AddNullable {
56    fn has_type(schema: &schemars::Schema, ty: &str) -> bool {
57        match schema.get("type") {
58            Some(Value::Array(values)) => values.iter().any(|v| v.as_str() == Some(ty)),
59            Some(Value::String(s)) => s == ty,
60            _ => false,
61        }
62    }
63}
64
65impl Transform for AddNullable {
66    fn transform(&mut self, schema: &mut schemars::Schema) {
67        if Self::has_type(schema, "null") {
68            if let Some(ty) = schema.get_mut("type")
70                && self.remove_null_type
71            {
72                if let Value::Array(array) = ty {
74                    array.retain(|t| t.as_str() != Some("null"));
75                    if array.len() == 1 {
76                        *ty = array[0].clone();
77                    }
78                }
79
80                if let Some(Value::Array(enum_array)) = schema.get_mut("enum") {
82                    enum_array.retain(|v| !v.is_null());
83                }
84            }
85        }
86
87        transform::transform_subschemas(self, schema);
88    }
89}
90
91pub fn get_tool<T: JsonSchema, S1, S2>(name: S1, desc: Option<S2>) -> Result<Tool, ToolError>
92where
93    S1: Into<Cow<'static, str>>,
94    S2: Into<Cow<'static, str>>,
95{
96    let json_value = parse_input_schema::<T>()?;
97    Ok(Tool {
98        name: name.into(),
99        description: desc.map(Into::into),
100        input_schema: json_value,
101    })
102}
103
104#[macro_export]
105macro_rules! define_tool {
106    ($tool_name:ident, $function_name:expr, $description:expr, $param_type:ty) => {
107        paste::paste! {
108            static [<$tool_name _ONCE_LOCK>]: std::sync::OnceLock<Result<$crate::Tool, $crate::ToolError>> = ::std::sync::OnceLock::new();
109
110            pub fn [<get_ $tool_name:lower>]() -> Result<&'static $crate::Tool, &'static $crate::ToolError> {
111                [<$tool_name _ONCE_LOCK>].get_or_init(|| {
112                    $crate::get_tool::<$param_type, _, _>(
113                        $function_name,
114                        Some($description),
115                    )
116                }).as_ref()
117            }
118        }
119    };
120}
121
122pub fn parse_input_schema<T: JsonSchema>() -> Result<serde_json::Value, ToolError> {
123    let settings = SchemaSettings::draft2019_09()
124        .with(|s| {
125            s.inline_subschemas = true;
128        })
129        .with_transform(AddNullable::default());
130    let schema = settings.into_generator().into_root_schema_for::<T>();
131    let mut json_value = serde_json::to_value(schema)?;
132    let schema_type = json_value.get("type").ok_or(ToolError::MissingSchemaType)?;
133
134    match schema_type {
135        Value::String(s) if s == "object" => {
136            }
138        Value::String(s) => {
139            return Err(ToolError::InvalidSchemaType(s.clone()));
140        }
141        other => {
142            return Err(ToolError::InvalidSchemaType(format!("{:?}", other)));
143        }
144    }
145
146    if let Some(obj) = json_value.as_object_mut() {
147        obj.remove("$schema");
148        obj.remove("title");
149        obj.remove("definitions");
150    }
151    Ok(json_value)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use schemars::JsonSchema;
158
159    #[derive(JsonSchema, serde::Deserialize)]
160    #[allow(dead_code)]
161    pub struct WeatherParams {
162        pub location: String,
164        pub unit: Option<UnitEnum>,
165        pub arr: Option<Vec<String>>,
166    }
167
168    #[derive(JsonSchema, serde::Deserialize, PartialEq, Debug)]
169    #[serde(rename_all = "lowercase")]
170    pub enum UnitEnum {
171        Celsius,
172        Fahrenheit,
173    }
174
175    #[test]
176    fn test_parse_input_schema() {
177        let schema = parse_input_schema::<WeatherParams>().unwrap();
178
179        let got = serde_json::to_string(&schema).unwrap();
180        let want = serde_json::json!({
181            "type": "object",
182            "properties": {
183                "location": {
184                    "type": "string",
185                    "description": "The city and state, e.g. San Francisco, CA"
186                },
187                "unit": {
188                    "type": "string",
189                    "enum": ["celsius", "fahrenheit"]
190                },
191                "arr": {
192                    "type": "array",
193                    "items": {
194                        "type": "string"
195                    }
196                }
197            },
198            "required": ["location"]
199        })
200        .to_string();
201        assert_eq!(want, got, "Expected: {} Got: {}", want, got);
202    }
203
204    #[test]
205    fn test_get_tool() {
206        let tool = get_tool::<WeatherParams, _, _>(
207            "get_weather",
208            Some("Get the current weather in a given location"),
209        )
210        .unwrap();
211
212        assert_eq!(tool.name, "get_weather");
213        assert_eq!(
214            tool.description,
215            Some(Cow::Borrowed("Get the current weather in a given location"))
216        );
217
218        let schema = &tool.input_schema;
220        assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
221        assert!(schema.get("properties").is_some());
222        assert!(schema.get("required").is_some());
223    }
224
225    #[test]
226    fn test_deserialize_params() {
227        let s = r#"{"location":"Boston, MA","unit":"celsius"}"#;
228        let got: WeatherParams = serde_json::from_str(s).unwrap();
229        assert_eq!(
230            "Boston, MA", got.location,
231            "Expected: Boston, MA Got: {}",
232            got.location
233        );
234        assert_eq!(
235            Some(UnitEnum::Celsius),
236            got.unit,
237            "Expected: celsius got: {:?}",
238            got.unit
239        );
240    }
241
242    #[test]
243    fn test_define_tool_macro() {
244        define_tool!(
245            WEATHER_TOOL,
246            "get_weather",
247            "Get the current weather",
248            WeatherParams
249        );
250
251        let tool = get_weather_tool().unwrap();
252        assert_eq!(tool.name, "get_weather");
253        assert_eq!(
254            tool.description,
255            Some(Cow::Borrowed("Get the current weather"))
256        );
257        assert!(tool.input_schema.get("properties").is_some());
258        assert!(tool.input_schema.get("required").is_some());
259    }
260
261    #[test]
262    fn test_nullable_handling() {
263        #[derive(JsonSchema)]
264        #[allow(dead_code)]
265        struct TestNullable {
266            required_field: String,
267            optional_field: Option<String>,
268        }
269
270        let schema = parse_input_schema::<TestNullable>().unwrap();
271
272        let props = schema.get("properties").unwrap().as_object().unwrap();
274        let optional = props.get("optional_field").unwrap();
275
276        if let Some(type_val) = optional.get("type") {
278            assert_ne!(type_val.as_str(), Some("null"));
279        }
280
281        let required = schema.get("required").unwrap().as_array().unwrap();
283        assert_eq!(required.len(), 1);
284        assert_eq!(required[0].as_str(), Some("required_field"));
285    }
286}