1use std::any::{Any, TypeId, type_name};
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18
19use rmcp::model::JsonObject;
20use schemars::JsonSchema;
21use schemars::Schema;
22use schemars::generate::SchemaSettings;
23use serde_json::Value;
24
25thread_local! {
26 static SCHEMA_CACHE: Mutex<HashMap<(TypeId, bool), Arc<JsonObject>>> = Mutex::new(HashMap::new());
27}
28
29#[must_use]
44pub fn input_schema<T: JsonSchema + Any>(pinned: bool) -> Arc<JsonObject> {
45 SCHEMA_CACHE.with(|cache| {
46 cache
47 .lock()
48 .expect("schema cache poisoned")
49 .entry((TypeId::of::<T>(), pinned))
50 .or_insert_with(|| Arc::new(build::<T>(pinned)))
51 .clone()
52 })
53}
54
55#[must_use]
67pub fn output_schema<T: JsonSchema + Any>() -> Arc<JsonObject> {
68 let schema = input_schema::<T>(false);
69 match schema.get("type") {
70 Some(Value::String(t)) if t == "object" => schema,
71 other => panic!(
72 "Invalid output schema for type `{}`: root `type` must be \"object\", got {:?}",
73 type_name::<T>(),
74 other,
75 ),
76 }
77}
78
79fn build<T: JsonSchema>(pinned: bool) -> JsonObject {
88 let value = SchemaSettings::draft2020_12()
89 .with(|s| {
90 s.inline_subschemas = true;
91 s.meta_schema = None;
92 })
93 .with_transform(strip_root_metadata)
94 .with_transform(move |schema: &mut Schema| {
95 if pinned {
96 strip_root_database(schema);
97 }
98 })
99 .into_generator()
100 .into_root_schema_for::<T>()
101 .to_value();
102
103 let Value::Object(object) = value else {
104 panic!("schema for `{}` did not produce a JSON object", type_name::<T>());
105 };
106 object
107}
108
109fn strip_root_metadata(schema: &mut Schema) {
117 if let Some(object) = schema.as_object_mut() {
118 object.remove("title");
119 object.remove("description");
120 }
121}
122
123fn strip_root_database(schema: &mut Schema) {
129 let Some(object) = schema.as_object_mut() else {
130 return;
131 };
132 if let Some(Value::Object(properties)) = object.get_mut("properties") {
133 properties.remove("database");
134 }
135 if let Some(Value::Array(required)) = object.get_mut("required") {
136 required.retain(|value| value.as_str() != Some("database"));
137 if required.is_empty() {
138 object.remove("required");
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::{build, input_schema, output_schema};
146 use schemars::JsonSchema;
147 use serde::{Deserialize, Serialize};
148 use serde_json::Value;
149
150 #[derive(Deserialize, Serialize, JsonSchema)]
152 #[schemars(title = "FixtureTitle", description = "Fixture root description.")]
153 struct Fixture {
154 name: String,
156 nested: Nested,
157 }
158
159 #[derive(Deserialize, Serialize, JsonSchema)]
160 struct Nested {
161 value: u32,
162 }
163
164 #[derive(Deserialize, Serialize, JsonSchema)]
166 struct OtherFixture {
167 value: u32,
168 }
169
170 fn contains_key(value: &Value, key: &str) -> bool {
171 match value {
172 Value::Object(map) => map.contains_key(key) || map.values().any(|v| contains_key(v, key)),
173 Value::Array(items) => items.iter().any(|v| contains_key(v, key)),
174 _ => false,
175 }
176 }
177
178 #[derive(Deserialize, Serialize, JsonSchema)]
180 struct PinnedFixture {
181 query: String,
182 #[serde(default)]
183 database: Option<String>,
184 }
185
186 #[test]
187 fn input_schema_strips_dollar_schema_title_and_description() {
188 let schema = input_schema::<Fixture>(false);
189 assert!(!schema.contains_key("$schema"), "root $schema not stripped: {schema:?}");
190 assert!(!schema.contains_key("title"), "root title not stripped: {schema:?}");
191 assert!(
192 !schema.contains_key("description"),
193 "root description not stripped: {schema:?}"
194 );
195 assert_eq!(schema.get("type"), Some(&Value::String("object".into())));
196 }
197
198 #[test]
199 fn input_schema_inlines_nested_subschemas() {
200 let schema = input_schema::<Fixture>(false);
201 let value = Value::Object((*schema).clone());
202 assert!(!contains_key(&value, "$defs"), "$defs not inlined: {value}");
203 assert!(!contains_key(&value, "$ref"), "$ref not inlined: {value}");
204 }
205
206 #[test]
207 fn input_schema_caches_by_type_and_pinned() {
208 let first = input_schema::<Fixture>(false);
209 let second = input_schema::<Fixture>(false);
210 assert!(
211 std::sync::Arc::ptr_eq(&first, &second),
212 "same (type, pinned) should return cached Arc"
213 );
214 let other = input_schema::<OtherFixture>(false);
215 assert!(
216 !std::sync::Arc::ptr_eq(&first, &other),
217 "different types must not share cache entry"
218 );
219 let pinned = input_schema::<PinnedFixture>(true);
220 let unpinned = input_schema::<PinnedFixture>(false);
221 assert!(
222 !std::sync::Arc::ptr_eq(&pinned, &unpinned),
223 "same type with different pinned flags must not share cache entry"
224 );
225 }
226
227 #[test]
228 fn output_schema_accepts_object_root() {
229 let schema = output_schema::<Fixture>();
230 assert_eq!(schema.get("type"), Some(&Value::String("object".into())));
231 let again = output_schema::<Fixture>();
232 assert!(std::sync::Arc::ptr_eq(&schema, &again));
233 }
234
235 #[test]
236 #[should_panic(expected = "root `type` must be \"object\"")]
237 fn output_schema_panics_on_non_object_root() {
238 let _ = output_schema::<u32>();
239 }
240
241 #[test]
242 fn build_preserves_properties() {
243 let schema = build::<Fixture>(false);
244 let properties = schema
245 .get("properties")
246 .and_then(Value::as_object)
247 .expect("properties survive generation");
248 assert!(properties.contains_key("name"));
249 assert!(properties.contains_key("nested"));
250 let name = properties.get("name").and_then(Value::as_object).unwrap();
251 assert_eq!(
252 name.get("description").and_then(Value::as_str),
253 Some("Doc-comment kept on the property — must survive schema generation."),
254 "per-property descriptions must survive schema generation"
255 );
256 }
257
258 #[test]
259 fn pinned_input_schema_strips_database_property_and_required() {
260 let unpinned = input_schema::<PinnedFixture>(false);
261 let unpinned_props = unpinned.get("properties").and_then(Value::as_object).unwrap();
262 assert!(unpinned_props.contains_key("database"));
263 assert!(unpinned_props.contains_key("query"));
264
265 let pinned = input_schema::<PinnedFixture>(true);
266 let pinned_props = pinned.get("properties").and_then(Value::as_object).unwrap();
267 assert!(
268 !pinned_props.contains_key("database"),
269 "pinned schema must drop `database`: {pinned:?}"
270 );
271 assert!(
272 pinned_props.contains_key("query"),
273 "pinned schema preserves other props"
274 );
275
276 let required = pinned.get("required").and_then(Value::as_array);
277 if let Some(required) = required {
278 assert!(
279 !required.iter().any(|v| v.as_str() == Some("database")),
280 "pinned schema must drop `database` from required: {required:?}"
281 );
282 }
283 }
284}