1use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6
7#[derive(Clone, Debug)]
9pub enum FieldConstraint {
10 Enum(Vec<Json>),
12
13 Range {
15 minimum: Option<Json>,
16 maximum: Option<Json>,
17 },
18
19 Pattern(String),
21
22 MergePatch(Json),
24}
25
26pub trait SchemaTransform: Send + Sync {
28 fn apply(&self, tool: &str, schema: &mut Json);
30}
31
32#[derive(Default)]
41pub struct SchemaEngine {
42 per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
43 global_strict: bool,
44 custom_transforms: Vec<Box<dyn SchemaTransform>>,
45}
46
47impl Clone for SchemaEngine {
48 fn clone(&self) -> Self {
49 Self {
51 per_tool: self.per_tool.clone(),
52 global_strict: self.global_strict,
53 custom_transforms: Vec::new(), }
55 }
56}
57
58impl std::fmt::Debug for SchemaEngine {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("SchemaEngine")
61 .field("per_tool", &self.per_tool)
62 .field("global_strict", &self.global_strict)
63 .field(
64 "custom_transforms",
65 &format!("[{} transforms]", self.custom_transforms.len()),
66 )
67 .finish()
68 }
69}
70
71impl SchemaEngine {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn with_strict(mut self, strict: bool) -> Self {
79 self.global_strict = strict;
80 self
81 }
82
83 pub fn is_strict(&self) -> bool {
85 self.global_strict
86 }
87
88 pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
93 self.per_tool
94 .entry(tool.to_string())
95 .or_default()
96 .push((json_path, c));
97 }
98
99 pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
101 self.custom_transforms.push(Box::new(transform));
102 }
103
104 pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
106 let mut v = serde_json::to_value(&schema).expect("serialize schema");
107
108 if self.global_strict
110 && let Some(obj) = v.as_object_mut()
111 {
112 obj.insert("additionalProperties".to_string(), Json::Bool(false));
113 }
114
115 if let Some(entries) = self.per_tool.get(tool) {
117 for (path, constraint) in entries {
118 Self::apply_constraint(&mut v, path, constraint);
119 }
120 }
121
122 for transform in &self.custom_transforms {
124 transform.apply(tool, &mut v);
125 }
126
127 Schema::try_from(v).expect("schema transform must produce a valid schema")
132 }
133
134 fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
135 let Some(node) = Self::find_node_mut(root, path) else {
136 return;
137 };
138 let Some(obj) = node.as_object_mut() else {
139 return;
140 };
141 match constraint {
142 FieldConstraint::Enum(vals) => {
143 obj.insert("enum".into(), Json::Array(vals.clone()));
144 }
145 FieldConstraint::Range { minimum, maximum } => {
146 if let Some(m) = minimum {
147 obj.insert("minimum".into(), m.clone());
148 }
149 if let Some(m) = maximum {
150 obj.insert("maximum".into(), m.clone());
151 }
152 }
153 FieldConstraint::Pattern(p) => {
154 obj.insert("pattern".into(), Json::String(p.clone()));
155 }
156 FieldConstraint::MergePatch(patch) => {
157 json_patch::merge(node, patch);
158 }
159 }
160 }
161
162 fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
163 let mut cur = root;
164 for seg in path {
165 cur = cur.as_object_mut()?.get_mut(seg)?;
166 }
167 Some(cur)
168 }
169}
170
171pub mod mcp_schema {
182 use schemars::generate::SchemaSettings;
183 use schemars::transform::{AddNullable, RestrictFormats, Transform};
184 use schemars::{JsonSchema, Schema};
185 use std::any::TypeId;
186 use std::cell::RefCell;
187 use std::collections::HashMap;
188 use std::sync::Arc;
189
190 thread_local! {
191 static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
192 static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
193 }
194
195 #[derive(Clone, Copy, Default)]
198 struct SanitizeNullBranches;
199
200 impl Transform for SanitizeNullBranches {
201 fn transform(&mut self, schema: &mut Schema) {
202 let mut v = serde_json::to_value(&*schema).expect("serialize schema for sanitize");
204 sanitize_null_branches_recursive(&mut v);
205 *schema = Schema::try_from(v).expect("rebuild sanitized schema");
206 }
207 }
208
209 fn sanitize_null_branches_recursive(node: &mut serde_json::Value) {
210 use serde_json::Value as Json;
211 match node {
212 Json::Object(map) => {
213 let has_nullable_true = map
215 .get("nullable")
216 .and_then(|v| v.as_bool())
217 .unwrap_or(false);
218 let const_is_null = map.get("const").map(|v| v.is_null()).unwrap_or(false);
219 let has_type = map.contains_key("type");
220
221 if has_nullable_true && const_is_null && !has_type {
222 map.remove("const");
223 map.remove("nullable");
224 map.insert("type".to_string(), Json::String("null".to_string()));
225 }
226
227 for value in map.values_mut() {
229 sanitize_null_branches_recursive(value);
230 }
231 }
232 Json::Array(arr) => {
233 for elem in arr {
234 sanitize_null_branches_recursive(elem);
235 }
236 }
237 _ => {}
238 }
239 }
240
241 fn settings() -> SchemaSettings {
242 SchemaSettings::draft2020_12()
243 .with_transform(AddNullable::default())
244 .with_transform(RestrictFormats::default())
245 .with_transform(SanitizeNullBranches)
246 }
247
248 pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
250 CACHE_FOR_TYPE.with(|cache| {
251 let mut cache = cache.borrow_mut();
252 if let Some(x) = cache.get(&TypeId::of::<T>()) {
253 return x.clone();
254 }
255 let generator = settings().into_generator();
256 let root = generator.into_root_schema_for::<T>();
257 let arc = Arc::new(root);
258 cache.insert(TypeId::of::<T>(), arc.clone());
259 arc
260 })
261 }
262
263 pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
266 CACHE_FOR_OUTPUT.with(|cache| {
267 let mut cache = cache.borrow_mut();
268 if let Some(r) = cache.get(&TypeId::of::<T>()) {
269 return r.clone();
270 }
271 let root = cached_schema_for::<T>();
272 let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
273 let result = match json.get("type") {
274 Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
275 Some(serde_json::Value::String(t)) => Err(format!(
276 "MCP requires output_schema root type 'object', found '{}'",
277 t
278 )),
279 None => {
280 if json.get("properties").is_some() {
283 Ok(root.clone())
284 } else {
285 Err(
286 "Schema missing 'type' — output_schema must have root type 'object'"
287 .to_string(),
288 )
289 }
290 }
291 Some(other) => Err(format!(
292 "Unexpected 'type' format: {:?} — expected string 'object'",
293 other
294 )),
295 };
296 cache.insert(TypeId::of::<T>(), result.clone());
297 result
298 })
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use serde::Serialize;
306
307 #[derive(schemars::JsonSchema, Serialize)]
308 struct TestInput {
309 count: i32,
310 name: String,
311 }
312
313 #[test]
314 fn test_strict_mode() {
315 let engine = SchemaEngine::new().with_strict(true);
316 let schema = schemars::schema_for!(TestInput);
317 let transformed = engine.transform("test", schema);
318
319 let json = serde_json::to_value(&transformed).unwrap();
320 assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
321 }
322
323 #[test]
324 fn test_is_strict_getter() {
325 let e = SchemaEngine::new();
326 assert!(!e.is_strict());
327 let e2 = SchemaEngine::new().with_strict(true);
328 assert!(e2.is_strict());
329 }
330
331 #[test]
332 fn test_enum_constraint() {
333 let mut engine = SchemaEngine::new();
334
335 let test_schema: Json = serde_json::json!({
337 "type": "object",
338 "properties": {
339 "name": {
340 "type": "string"
341 }
342 }
343 });
344
345 engine.constrain_field(
346 "test",
347 vec!["properties".into(), "name".into()],
348 FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
349 );
350
351 let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
352 let transformed = engine.transform("test", schema);
353
354 let json = serde_json::to_value(&transformed).unwrap();
355 let name_schema = &json["properties"]["name"];
356 assert!(name_schema.get("enum").is_some());
357 }
358
359 #[test]
360 fn test_range_constraint() {
361 let mut engine = SchemaEngine::new();
363 engine.constrain_field(
364 "test",
365 vec!["properties".into(), "count".into()],
366 FieldConstraint::Range {
367 minimum: Some(Json::Number(0.into())),
368 maximum: Some(Json::Number(100.into())),
369 },
370 );
371
372 let schema = schemars::schema_for!(TestInput);
374
375 let transformed = engine.transform("test", schema);
377
378 let json = serde_json::to_value(&transformed).unwrap();
380 let count_schema = &json["properties"]["count"];
381
382 let min = count_schema.get("minimum").and_then(|v| v.as_f64());
384 let max = count_schema.get("maximum").and_then(|v| v.as_f64());
385
386 assert_eq!(min, Some(0.0), "minimum constraint should be applied");
387 assert_eq!(max, Some(100.0), "maximum constraint should be applied");
388 }
389
390 mod mcp_schema_tests {
395 use super::mcp_schema;
396 use serde::Serialize;
397
398 #[derive(schemars::JsonSchema, Serialize)]
399 struct WithOption {
400 a: Option<String>,
401 }
402
403 #[test]
404 fn test_central_generator_addnullable() {
405 let root = mcp_schema::cached_schema_for::<WithOption>();
406 let v = serde_json::to_value(root.as_ref()).unwrap();
407 let a = &v["properties"]["a"];
408 assert_eq!(
410 a.get("nullable"),
411 Some(&serde_json::Value::Bool(true)),
412 "Option<T> fields should have nullable: true"
413 );
414 }
415
416 #[derive(schemars::JsonSchema, Serialize)]
417 struct OutputObj {
418 x: i32,
419 }
420
421 #[test]
422 fn test_output_schema_validation_object() {
423 let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
424 assert!(
425 ok.is_ok(),
426 "Object types should pass output schema validation"
427 );
428 }
429
430 #[test]
431 fn test_output_schema_validation_non_object() {
432 let bad = mcp_schema::cached_output_schema_for::<String>();
434 assert!(
435 bad.is_err(),
436 "Non-object types should fail output schema validation"
437 );
438 }
439
440 #[test]
441 fn test_draft_2020_12_uses_defs() {
442 let root = mcp_schema::cached_schema_for::<WithOption>();
443 let v = serde_json::to_value(root.as_ref()).unwrap();
444 assert!(v.is_object(), "Schema should be an object");
448 assert!(
449 v.get("$schema")
450 .and_then(|s| s.as_str())
451 .is_some_and(|s| s.contains("2020-12")),
452 "Schema should reference Draft 2020-12"
453 );
454 }
455
456 #[test]
457 fn test_caching_returns_same_arc() {
458 let first = mcp_schema::cached_schema_for::<OutputObj>();
459 let second = mcp_schema::cached_schema_for::<OutputObj>();
460 assert!(
461 std::sync::Arc::ptr_eq(&first, &second),
462 "Cached schemas should return the same Arc"
463 );
464 }
465
466 #[allow(dead_code)]
471 #[derive(schemars::JsonSchema, Serialize)]
472 enum TestEnum {
473 A,
474 B,
475 }
476
477 #[derive(schemars::JsonSchema, Serialize)]
478 struct HasOptEnum {
479 e: Option<TestEnum>,
480 }
481
482 #[test]
483 fn test_option_enum_anyof_null_branch_has_type() {
484 let root = mcp_schema::cached_schema_for::<HasOptEnum>();
485 let v = serde_json::to_value(root.as_ref()).unwrap();
486 let any_of = v["properties"]["e"]["anyOf"]
487 .as_array()
488 .expect("Option<Enum> should generate anyOf");
489
490 assert!(
492 any_of
493 .iter()
494 .any(|b| b.get("type") == Some(&serde_json::json!("null"))),
495 "anyOf for Option<Enum> must include a branch with type:\"null\""
496 );
497
498 for branch in any_of {
500 let has_nullable = branch.get("nullable") == Some(&serde_json::json!(true));
501 let has_type = branch.get("type").is_some() || branch.get("$ref").is_some();
502 assert!(
503 !has_nullable || has_type,
504 "No branch may contain nullable:true without a type or $ref"
505 );
506 }
507 }
508
509 #[derive(schemars::JsonSchema, Serialize)]
510 struct Unsigneds {
511 a: u32,
512 b: u64,
513 }
514
515 #[test]
516 fn test_strip_uint_formats() {
517 let root = mcp_schema::cached_schema_for::<Unsigneds>();
518 let v = serde_json::to_value(root.as_ref()).unwrap();
519 let pa = &v["properties"]["a"];
520 let pb = &v["properties"]["b"];
521
522 assert!(
523 pa.get("format").is_none(),
524 "u32 should not include non-standard 'format'"
525 );
526 assert!(
527 pb.get("format").is_none(),
528 "u64 should not include non-standard 'format'"
529 );
530 assert_eq!(
531 pa.get("minimum").and_then(|x| x.as_u64()),
532 Some(0),
533 "u32 minimum must be preserved"
534 );
535 assert_eq!(
536 pb.get("minimum").and_then(|x| x.as_u64()),
537 Some(0),
538 "u64 minimum must be preserved"
539 );
540 }
541
542 #[derive(schemars::JsonSchema, Serialize)]
543 struct HasOptString {
544 s: Option<String>,
545 }
546
547 #[test]
548 fn test_option_string_preserves_nullable() {
549 let root = mcp_schema::cached_schema_for::<HasOptString>();
550 let v = serde_json::to_value(root.as_ref()).unwrap();
551 let s = &v["properties"]["s"];
552
553 assert_eq!(
554 s.get("type"),
555 Some(&serde_json::json!("string")),
556 "Option<String> should have type: string"
557 );
558 assert_eq!(
559 s.get("nullable"),
560 Some(&serde_json::json!(true)),
561 "Option<String> should retain nullable: true"
562 );
563 }
564 }
565}