1use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6
7#[derive(Clone, Debug)]
9pub enum FieldConstraint {
10 Enum(Vec<Json>),
12
13 Range {
15 minimum: Option<Json>,
16 maximum: Option<Json>,
17 },
18
19 Pattern(String),
21
22 MergePatch(Json),
24}
25
26pub trait SchemaTransform: Send + Sync {
28 fn apply(&self, tool: &str, schema: &mut Json);
30}
31
32#[derive(Default)]
41pub struct SchemaEngine {
42 per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
43 global_strict: bool,
44 custom_transforms: Vec<Box<dyn SchemaTransform>>,
45}
46
47impl Clone for SchemaEngine {
48 fn clone(&self) -> Self {
49 Self {
51 per_tool: self.per_tool.clone(),
52 global_strict: self.global_strict,
53 custom_transforms: Vec::new(), }
55 }
56}
57
58impl std::fmt::Debug for SchemaEngine {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("SchemaEngine")
61 .field("per_tool", &self.per_tool)
62 .field("global_strict", &self.global_strict)
63 .field(
64 "custom_transforms",
65 &format!("[{} transforms]", self.custom_transforms.len()),
66 )
67 .finish()
68 }
69}
70
71impl SchemaEngine {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn with_strict(mut self, strict: bool) -> Self {
79 self.global_strict = strict;
80 self
81 }
82
83 pub fn is_strict(&self) -> bool {
85 self.global_strict
86 }
87
88 pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
93 self.per_tool
94 .entry(tool.to_string())
95 .or_default()
96 .push((json_path, c));
97 }
98
99 pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
101 self.custom_transforms.push(Box::new(transform));
102 }
103
104 pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
106 let mut v = serde_json::to_value(&schema).expect("serialize schema");
107
108 if self.global_strict
110 && let Some(obj) = v.as_object_mut()
111 {
112 obj.insert("additionalProperties".to_string(), Json::Bool(false));
113 }
114
115 if let Some(entries) = self.per_tool.get(tool) {
117 for (path, constraint) in entries {
118 Self::apply_constraint(&mut v, path, constraint);
119 }
120 }
121
122 for transform in &self.custom_transforms {
124 transform.apply(tool, &mut v);
125 }
126
127 Schema::try_from(v).expect("schema transform must produce a valid schema")
132 }
133
134 fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
135 let Some(node) = Self::find_node_mut(root, path) else {
136 return;
137 };
138 let Some(obj) = node.as_object_mut() else {
139 return;
140 };
141 match constraint {
142 FieldConstraint::Enum(vals) => {
143 obj.insert("enum".into(), Json::Array(vals.clone()));
144 }
145 FieldConstraint::Range { minimum, maximum } => {
146 if let Some(m) = minimum {
147 obj.insert("minimum".into(), m.clone());
148 }
149 if let Some(m) = maximum {
150 obj.insert("maximum".into(), m.clone());
151 }
152 }
153 FieldConstraint::Pattern(p) => {
154 obj.insert("pattern".into(), Json::String(p.clone()));
155 }
156 FieldConstraint::MergePatch(patch) => {
157 json_patch::merge(node, patch);
158 }
159 }
160 }
161
162 fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
163 let mut cur = root;
164 for seg in path {
165 cur = cur.as_object_mut()?.get_mut(seg)?;
166 }
167 Some(cur)
168 }
169}
170
171pub mod mcp_schema {
183 use schemars::JsonSchema;
184 use schemars::Schema;
185 use schemars::generate::SchemaSettings;
186 use schemars::transform::RestrictFormats;
187 use std::any::TypeId;
188 use std::cell::RefCell;
189 use std::collections::HashMap;
190 use std::sync::Arc;
191
192 thread_local! {
193 static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
194 static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
195 }
196
197 fn settings() -> SchemaSettings {
198 SchemaSettings::draft2020_12().with_transform(RestrictFormats::default())
199 }
200
201 pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
203 CACHE_FOR_TYPE.with(|cache| {
204 let mut cache = cache.borrow_mut();
205 if let Some(x) = cache.get(&TypeId::of::<T>()) {
206 return x.clone();
207 }
208 let generator = settings().into_generator();
209 let root = generator.into_root_schema_for::<T>();
210 let arc = Arc::new(root);
211 cache.insert(TypeId::of::<T>(), arc.clone());
212 arc
213 })
214 }
215
216 pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
219 CACHE_FOR_OUTPUT.with(|cache| {
220 let mut cache = cache.borrow_mut();
221 if let Some(r) = cache.get(&TypeId::of::<T>()) {
222 return r.clone();
223 }
224 let root = cached_schema_for::<T>();
225 let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
226 let result = match json.get("type") {
227 Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
228 Some(serde_json::Value::String(t)) => Err(format!(
229 "MCP requires output_schema root type 'object', found '{}'",
230 t
231 )),
232 None => {
233 if json.get("properties").is_some() {
236 Ok(root.clone())
237 } else {
238 Err(
239 "Schema missing 'type' — output_schema must have root type 'object'"
240 .to_string(),
241 )
242 }
243 }
244 Some(other) => Err(format!(
245 "Unexpected 'type' format: {:?} — expected string 'object'",
246 other
247 )),
248 };
249 cache.insert(TypeId::of::<T>(), result.clone());
250 result
251 })
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use serde::Serialize;
259
260 #[derive(schemars::JsonSchema, Serialize)]
261 struct TestInput {
262 count: i32,
263 name: String,
264 }
265
266 #[test]
267 fn test_strict_mode() {
268 let engine = SchemaEngine::new().with_strict(true);
269 let schema = schemars::schema_for!(TestInput);
270 let transformed = engine.transform("test", schema);
271
272 let json = serde_json::to_value(&transformed).unwrap();
273 assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
274 }
275
276 #[test]
277 fn test_is_strict_getter() {
278 let e = SchemaEngine::new();
279 assert!(!e.is_strict());
280 let e2 = SchemaEngine::new().with_strict(true);
281 assert!(e2.is_strict());
282 }
283
284 #[test]
285 fn test_enum_constraint() {
286 let mut engine = SchemaEngine::new();
287
288 let test_schema: Json = serde_json::json!({
290 "type": "object",
291 "properties": {
292 "name": {
293 "type": "string"
294 }
295 }
296 });
297
298 engine.constrain_field(
299 "test",
300 vec!["properties".into(), "name".into()],
301 FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
302 );
303
304 let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
305 let transformed = engine.transform("test", schema);
306
307 let json = serde_json::to_value(&transformed).unwrap();
308 let name_schema = &json["properties"]["name"];
309 assert!(name_schema.get("enum").is_some());
310 }
311
312 #[test]
313 fn test_range_constraint() {
314 let mut engine = SchemaEngine::new();
316 engine.constrain_field(
317 "test",
318 vec!["properties".into(), "count".into()],
319 FieldConstraint::Range {
320 minimum: Some(Json::Number(0.into())),
321 maximum: Some(Json::Number(100.into())),
322 },
323 );
324
325 let schema = schemars::schema_for!(TestInput);
327
328 let transformed = engine.transform("test", schema);
330
331 let json = serde_json::to_value(&transformed).unwrap();
333 let count_schema = &json["properties"]["count"];
334
335 let min = count_schema.get("minimum").and_then(|v| v.as_f64());
337 let max = count_schema.get("maximum").and_then(|v| v.as_f64());
338
339 assert_eq!(min, Some(0.0), "minimum constraint should be applied");
340 assert_eq!(max, Some(100.0), "maximum constraint should be applied");
341 }
342
343 mod mcp_schema_tests {
348 use super::mcp_schema;
349 use serde::Serialize;
350
351 #[derive(schemars::JsonSchema, Serialize)]
352 struct WithOption {
353 a: Option<String>,
354 }
355
356 #[test]
357 fn test_option_generates_type_array() {
358 let root = mcp_schema::cached_schema_for::<WithOption>();
359 let v = serde_json::to_value(root.as_ref()).unwrap();
360 let a = &v["properties"]["a"];
361 let ty = a
363 .get("type")
364 .and_then(|v| v.as_array())
365 .expect("Option<T> should emit a type array");
366 assert!(ty.contains(&serde_json::json!("string")));
367 assert!(ty.contains(&serde_json::json!("null")));
368 assert_eq!(ty.len(), 2, "Option<T> should contain only string|null");
369 }
370
371 #[derive(schemars::JsonSchema, Serialize)]
372 struct OutputObj {
373 x: i32,
374 }
375
376 #[test]
377 fn test_output_schema_validation_object() {
378 let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
379 assert!(
380 ok.is_ok(),
381 "Object types should pass output schema validation"
382 );
383 }
384
385 #[test]
386 fn test_output_schema_validation_non_object() {
387 let bad = mcp_schema::cached_output_schema_for::<String>();
389 assert!(
390 bad.is_err(),
391 "Non-object types should fail output schema validation"
392 );
393 }
394
395 #[test]
396 fn test_draft_2020_12_uses_defs() {
397 let root = mcp_schema::cached_schema_for::<WithOption>();
398 let v = serde_json::to_value(root.as_ref()).unwrap();
399 assert!(v.is_object(), "Schema should be an object");
403 assert!(
404 v.get("$schema")
405 .and_then(|s| s.as_str())
406 .is_some_and(|s| s.contains("2020-12")),
407 "Schema should reference Draft 2020-12"
408 );
409 }
410
411 #[test]
412 fn test_caching_returns_same_arc() {
413 let first = mcp_schema::cached_schema_for::<OutputObj>();
414 let second = mcp_schema::cached_schema_for::<OutputObj>();
415 assert!(
416 std::sync::Arc::ptr_eq(&first, &second),
417 "Cached schemas should return the same Arc"
418 );
419 }
420
421 #[allow(dead_code)]
426 #[derive(schemars::JsonSchema, Serialize)]
427 enum TestEnum {
428 A,
429 B,
430 }
431
432 #[derive(schemars::JsonSchema, Serialize)]
433 struct HasOptEnum {
434 e: Option<TestEnum>,
435 }
436
437 #[test]
438 fn test_option_enum_anyof_null_branch_has_type() {
439 let root = mcp_schema::cached_schema_for::<HasOptEnum>();
440 let v = serde_json::to_value(root.as_ref()).unwrap();
441 let any_of = v["properties"]["e"]["anyOf"]
442 .as_array()
443 .expect("Option<Enum> should generate anyOf");
444
445 assert!(
447 any_of
448 .iter()
449 .any(|b| b.get("type") == Some(&serde_json::json!("null"))),
450 "anyOf for Option<Enum> must include a branch with type:\"null\""
451 );
452
453 for branch in any_of {
455 let has_nullable = branch.get("nullable") == Some(&serde_json::json!(true));
456 let has_type = branch.get("type").is_some() || branch.get("$ref").is_some();
457 assert!(
458 !has_nullable || has_type,
459 "No branch may contain nullable:true without a type or $ref"
460 );
461 }
462 }
463
464 #[derive(schemars::JsonSchema, Serialize)]
465 struct Unsigneds {
466 a: u32,
467 b: u64,
468 }
469
470 #[test]
471 fn test_strip_uint_formats() {
472 let root = mcp_schema::cached_schema_for::<Unsigneds>();
473 let v = serde_json::to_value(root.as_ref()).unwrap();
474 let pa = &v["properties"]["a"];
475 let pb = &v["properties"]["b"];
476
477 assert!(
478 pa.get("format").is_none(),
479 "u32 should not include non-standard 'format'"
480 );
481 assert!(
482 pb.get("format").is_none(),
483 "u64 should not include non-standard 'format'"
484 );
485 assert_eq!(
486 pa.get("minimum").and_then(|x| x.as_u64()),
487 Some(0),
488 "u32 minimum must be preserved"
489 );
490 assert_eq!(
491 pb.get("minimum").and_then(|x| x.as_u64()),
492 Some(0),
493 "u64 minimum must be preserved"
494 );
495 }
496
497 #[derive(schemars::JsonSchema, Serialize)]
498 struct HasOptString {
499 s: Option<String>,
500 }
501
502 #[test]
503 fn test_option_string_uses_type_array() {
504 let root = mcp_schema::cached_schema_for::<HasOptString>();
505 let v = serde_json::to_value(root.as_ref()).unwrap();
506 let s = &v["properties"]["s"];
507
508 let ty = s
510 .get("type")
511 .and_then(|v| v.as_array())
512 .expect("Option<String> should emit a type array");
513 assert!(ty.contains(&serde_json::json!("string")));
514 assert!(ty.contains(&serde_json::json!("null")));
515 assert_eq!(
516 ty.len(),
517 2,
518 "Option<String> should contain only string|null"
519 );
520 assert!(
522 s.get("nullable").is_none(),
523 "Option<String> should not have nullable keyword"
524 );
525 }
526 }
527}