1use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6
7#[derive(Clone, Debug)]
9pub enum FieldConstraint {
10 Enum(Vec<Json>),
12
13 Range {
15 minimum: Option<Json>,
16 maximum: Option<Json>,
17 },
18
19 Pattern(String),
21
22 MergePatch(Json),
24}
25
26pub trait SchemaTransform: Send + Sync {
28 fn apply(&self, tool: &str, schema: &mut Json);
30}
31
32#[derive(Default)]
41pub struct SchemaEngine {
42 per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
43 global_strict: bool,
44 custom_transforms: Vec<Box<dyn SchemaTransform>>,
45}
46
47impl Clone for SchemaEngine {
48 fn clone(&self) -> Self {
49 Self {
51 per_tool: self.per_tool.clone(),
52 global_strict: self.global_strict,
53 custom_transforms: Vec::new(), }
55 }
56}
57
58impl std::fmt::Debug for SchemaEngine {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("SchemaEngine")
61 .field("per_tool", &self.per_tool)
62 .field("global_strict", &self.global_strict)
63 .field(
64 "custom_transforms",
65 &format!("[{} transforms]", self.custom_transforms.len()),
66 )
67 .finish()
68 }
69}
70
71impl SchemaEngine {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn with_strict(mut self, strict: bool) -> Self {
79 self.global_strict = strict;
80 self
81 }
82
83 pub fn is_strict(&self) -> bool {
85 self.global_strict
86 }
87
88 pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
93 self.per_tool
94 .entry(tool.to_string())
95 .or_default()
96 .push((json_path, c));
97 }
98
99 pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
101 self.custom_transforms.push(Box::new(transform));
102 }
103
104 pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
106 let mut v = serde_json::to_value(&schema).expect("serialize schema");
107
108 if self.global_strict
110 && let Some(obj) = v.as_object_mut()
111 {
112 obj.insert("additionalProperties".to_string(), Json::Bool(false));
113 }
114
115 if let Some(entries) = self.per_tool.get(tool) {
117 for (path, constraint) in entries {
118 Self::apply_constraint(&mut v, path, constraint);
119 }
120 }
121
122 for transform in &self.custom_transforms {
124 transform.apply(tool, &mut v);
125 }
126
127 Schema::try_from(v).expect("schema transform must produce a valid schema")
132 }
133
134 fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
135 let Some(node) = Self::find_node_mut(root, path) else {
136 return;
137 };
138 let Some(obj) = node.as_object_mut() else {
139 return;
140 };
141 match constraint {
142 FieldConstraint::Enum(vals) => {
143 obj.insert("enum".into(), Json::Array(vals.clone()));
144 }
145 FieldConstraint::Range { minimum, maximum } => {
146 if let Some(m) = minimum {
147 obj.insert("minimum".into(), m.clone());
148 }
149 if let Some(m) = maximum {
150 obj.insert("maximum".into(), m.clone());
151 }
152 }
153 FieldConstraint::Pattern(p) => {
154 obj.insert("pattern".into(), Json::String(p.clone()));
155 }
156 FieldConstraint::MergePatch(patch) => {
157 json_patch::merge(node, patch);
158 }
159 }
160 }
161
162 fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
163 let mut cur = root;
164 for seg in path {
165 cur = cur.as_object_mut()?.get_mut(seg)?;
166 }
167 Some(cur)
168 }
169}
170
171pub mod mcp_schema {
182 use schemars::JsonSchema;
183 use schemars::Schema;
184 use schemars::generate::SchemaSettings;
185 use schemars::transform::AddNullable;
186 use schemars::transform::RestrictFormats;
187 use schemars::transform::Transform;
188 use std::any::TypeId;
189 use std::cell::RefCell;
190 use std::collections::HashMap;
191 use std::sync::Arc;
192
193 thread_local! {
194 static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
195 static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
196 }
197
198 #[derive(Clone, Copy, Default)]
201 struct SanitizeNullBranches;
202
203 impl Transform for SanitizeNullBranches {
204 fn transform(&mut self, schema: &mut Schema) {
205 let mut v = serde_json::to_value(&*schema).expect("serialize schema for sanitize");
207 sanitize_null_branches_recursive(&mut v);
208 *schema = Schema::try_from(v).expect("rebuild sanitized schema");
209 }
210 }
211
212 fn sanitize_null_branches_recursive(node: &mut serde_json::Value) {
213 use serde_json::Value as Json;
214 match node {
215 Json::Object(map) => {
216 let has_nullable_true = map
218 .get("nullable")
219 .and_then(|v| v.as_bool())
220 .unwrap_or(false);
221 let const_is_null = map.get("const").map(|v| v.is_null()).unwrap_or(false);
222 let has_type = map.contains_key("type");
223
224 if has_nullable_true && const_is_null && !has_type {
225 map.remove("const");
226 map.remove("nullable");
227 map.insert("type".to_string(), Json::String("null".to_string()));
228 }
229
230 for value in map.values_mut() {
232 sanitize_null_branches_recursive(value);
233 }
234 }
235 Json::Array(arr) => {
236 for elem in arr {
237 sanitize_null_branches_recursive(elem);
238 }
239 }
240 _ => {}
241 }
242 }
243
244 fn settings() -> SchemaSettings {
245 SchemaSettings::draft2020_12()
246 .with_transform(AddNullable::default())
247 .with_transform(RestrictFormats::default())
248 .with_transform(SanitizeNullBranches)
249 }
250
251 pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
253 CACHE_FOR_TYPE.with(|cache| {
254 let mut cache = cache.borrow_mut();
255 if let Some(x) = cache.get(&TypeId::of::<T>()) {
256 return x.clone();
257 }
258 let generator = settings().into_generator();
259 let root = generator.into_root_schema_for::<T>();
260 let arc = Arc::new(root);
261 cache.insert(TypeId::of::<T>(), arc.clone());
262 arc
263 })
264 }
265
266 pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
269 CACHE_FOR_OUTPUT.with(|cache| {
270 let mut cache = cache.borrow_mut();
271 if let Some(r) = cache.get(&TypeId::of::<T>()) {
272 return r.clone();
273 }
274 let root = cached_schema_for::<T>();
275 let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
276 let result = match json.get("type") {
277 Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
278 Some(serde_json::Value::String(t)) => Err(format!(
279 "MCP requires output_schema root type 'object', found '{}'",
280 t
281 )),
282 None => {
283 if json.get("properties").is_some() {
286 Ok(root.clone())
287 } else {
288 Err(
289 "Schema missing 'type' — output_schema must have root type 'object'"
290 .to_string(),
291 )
292 }
293 }
294 Some(other) => Err(format!(
295 "Unexpected 'type' format: {:?} — expected string 'object'",
296 other
297 )),
298 };
299 cache.insert(TypeId::of::<T>(), result.clone());
300 result
301 })
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use serde::Serialize;
309
310 #[derive(schemars::JsonSchema, Serialize)]
311 struct TestInput {
312 count: i32,
313 name: String,
314 }
315
316 #[test]
317 fn test_strict_mode() {
318 let engine = SchemaEngine::new().with_strict(true);
319 let schema = schemars::schema_for!(TestInput);
320 let transformed = engine.transform("test", schema);
321
322 let json = serde_json::to_value(&transformed).unwrap();
323 assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
324 }
325
326 #[test]
327 fn test_is_strict_getter() {
328 let e = SchemaEngine::new();
329 assert!(!e.is_strict());
330 let e2 = SchemaEngine::new().with_strict(true);
331 assert!(e2.is_strict());
332 }
333
334 #[test]
335 fn test_enum_constraint() {
336 let mut engine = SchemaEngine::new();
337
338 let test_schema: Json = serde_json::json!({
340 "type": "object",
341 "properties": {
342 "name": {
343 "type": "string"
344 }
345 }
346 });
347
348 engine.constrain_field(
349 "test",
350 vec!["properties".into(), "name".into()],
351 FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
352 );
353
354 let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
355 let transformed = engine.transform("test", schema);
356
357 let json = serde_json::to_value(&transformed).unwrap();
358 let name_schema = &json["properties"]["name"];
359 assert!(name_schema.get("enum").is_some());
360 }
361
362 #[test]
363 fn test_range_constraint() {
364 let mut engine = SchemaEngine::new();
366 engine.constrain_field(
367 "test",
368 vec!["properties".into(), "count".into()],
369 FieldConstraint::Range {
370 minimum: Some(Json::Number(0.into())),
371 maximum: Some(Json::Number(100.into())),
372 },
373 );
374
375 let schema = schemars::schema_for!(TestInput);
377
378 let transformed = engine.transform("test", schema);
380
381 let json = serde_json::to_value(&transformed).unwrap();
383 let count_schema = &json["properties"]["count"];
384
385 let min = count_schema.get("minimum").and_then(|v| v.as_f64());
387 let max = count_schema.get("maximum").and_then(|v| v.as_f64());
388
389 assert_eq!(min, Some(0.0), "minimum constraint should be applied");
390 assert_eq!(max, Some(100.0), "maximum constraint should be applied");
391 }
392
393 mod mcp_schema_tests {
398 use super::mcp_schema;
399 use serde::Serialize;
400
401 #[derive(schemars::JsonSchema, Serialize)]
402 struct WithOption {
403 a: Option<String>,
404 }
405
406 #[test]
407 fn test_central_generator_addnullable() {
408 let root = mcp_schema::cached_schema_for::<WithOption>();
409 let v = serde_json::to_value(root.as_ref()).unwrap();
410 let a = &v["properties"]["a"];
411 assert_eq!(
413 a.get("nullable"),
414 Some(&serde_json::Value::Bool(true)),
415 "Option<T> fields should have nullable: true"
416 );
417 }
418
419 #[derive(schemars::JsonSchema, Serialize)]
420 struct OutputObj {
421 x: i32,
422 }
423
424 #[test]
425 fn test_output_schema_validation_object() {
426 let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
427 assert!(
428 ok.is_ok(),
429 "Object types should pass output schema validation"
430 );
431 }
432
433 #[test]
434 fn test_output_schema_validation_non_object() {
435 let bad = mcp_schema::cached_output_schema_for::<String>();
437 assert!(
438 bad.is_err(),
439 "Non-object types should fail output schema validation"
440 );
441 }
442
443 #[test]
444 fn test_draft_2020_12_uses_defs() {
445 let root = mcp_schema::cached_schema_for::<WithOption>();
446 let v = serde_json::to_value(root.as_ref()).unwrap();
447 assert!(v.is_object(), "Schema should be an object");
451 assert!(
452 v.get("$schema")
453 .and_then(|s| s.as_str())
454 .is_some_and(|s| s.contains("2020-12")),
455 "Schema should reference Draft 2020-12"
456 );
457 }
458
459 #[test]
460 fn test_caching_returns_same_arc() {
461 let first = mcp_schema::cached_schema_for::<OutputObj>();
462 let second = mcp_schema::cached_schema_for::<OutputObj>();
463 assert!(
464 std::sync::Arc::ptr_eq(&first, &second),
465 "Cached schemas should return the same Arc"
466 );
467 }
468
469 #[allow(dead_code)]
474 #[derive(schemars::JsonSchema, Serialize)]
475 enum TestEnum {
476 A,
477 B,
478 }
479
480 #[derive(schemars::JsonSchema, Serialize)]
481 struct HasOptEnum {
482 e: Option<TestEnum>,
483 }
484
485 #[test]
486 fn test_option_enum_anyof_null_branch_has_type() {
487 let root = mcp_schema::cached_schema_for::<HasOptEnum>();
488 let v = serde_json::to_value(root.as_ref()).unwrap();
489 let any_of = v["properties"]["e"]["anyOf"]
490 .as_array()
491 .expect("Option<Enum> should generate anyOf");
492
493 assert!(
495 any_of
496 .iter()
497 .any(|b| b.get("type") == Some(&serde_json::json!("null"))),
498 "anyOf for Option<Enum> must include a branch with type:\"null\""
499 );
500
501 for branch in any_of {
503 let has_nullable = branch.get("nullable") == Some(&serde_json::json!(true));
504 let has_type = branch.get("type").is_some() || branch.get("$ref").is_some();
505 assert!(
506 !has_nullable || has_type,
507 "No branch may contain nullable:true without a type or $ref"
508 );
509 }
510 }
511
512 #[derive(schemars::JsonSchema, Serialize)]
513 struct Unsigneds {
514 a: u32,
515 b: u64,
516 }
517
518 #[test]
519 fn test_strip_uint_formats() {
520 let root = mcp_schema::cached_schema_for::<Unsigneds>();
521 let v = serde_json::to_value(root.as_ref()).unwrap();
522 let pa = &v["properties"]["a"];
523 let pb = &v["properties"]["b"];
524
525 assert!(
526 pa.get("format").is_none(),
527 "u32 should not include non-standard 'format'"
528 );
529 assert!(
530 pb.get("format").is_none(),
531 "u64 should not include non-standard 'format'"
532 );
533 assert_eq!(
534 pa.get("minimum").and_then(|x| x.as_u64()),
535 Some(0),
536 "u32 minimum must be preserved"
537 );
538 assert_eq!(
539 pb.get("minimum").and_then(|x| x.as_u64()),
540 Some(0),
541 "u64 minimum must be preserved"
542 );
543 }
544
545 #[derive(schemars::JsonSchema, Serialize)]
546 struct HasOptString {
547 s: Option<String>,
548 }
549
550 #[test]
551 fn test_option_string_preserves_nullable() {
552 let root = mcp_schema::cached_schema_for::<HasOptString>();
553 let v = serde_json::to_value(root.as_ref()).unwrap();
554 let s = &v["properties"]["s"];
555
556 assert_eq!(
557 s.get("type"),
558 Some(&serde_json::json!("string")),
559 "Option<String> should have type: string"
560 );
561 assert_eq!(
562 s.get("nullable"),
563 Some(&serde_json::json!(true)),
564 "Option<String> should retain nullable: true"
565 );
566 }
567 }
568}