1use schemars::Schema;
4use serde_json::Value as Json;
5use std::collections::HashMap;
6use std::collections::HashSet;
7
8#[derive(Clone, Debug)]
10pub enum FieldConstraint {
11 Enum(Vec<Json>),
13
14 Range {
16 minimum: Option<Json>,
17 maximum: Option<Json>,
18 },
19
20 Pattern(String),
22
23 MergePatch(Json),
25}
26
27pub trait SchemaTransform: Send + Sync {
29 fn apply(&self, tool: &str, schema: &mut Json);
31}
32
33#[derive(Default)]
42pub struct SchemaEngine {
43 per_tool: HashMap<String, Vec<(Vec<String>, FieldConstraint)>>,
44 global_strict: bool,
45 custom_transforms: Vec<Box<dyn SchemaTransform>>,
46}
47
48impl Clone for SchemaEngine {
49 fn clone(&self) -> Self {
50 Self {
52 per_tool: self.per_tool.clone(),
53 global_strict: self.global_strict,
54 custom_transforms: Vec::new(), }
56 }
57}
58
59impl std::fmt::Debug for SchemaEngine {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("SchemaEngine")
62 .field("per_tool", &self.per_tool)
63 .field("global_strict", &self.global_strict)
64 .field(
65 "custom_transforms",
66 &format!("[{} transforms]", self.custom_transforms.len()),
67 )
68 .finish()
69 }
70}
71
72impl SchemaEngine {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn with_strict(mut self, strict: bool) -> Self {
80 self.global_strict = strict;
81 self
82 }
83
84 pub fn is_strict(&self) -> bool {
86 self.global_strict
87 }
88
89 pub fn constrain_field(&mut self, tool: &str, json_path: Vec<String>, c: FieldConstraint) {
94 self.per_tool
95 .entry(tool.to_string())
96 .or_default()
97 .push((json_path, c));
98 }
99
100 pub fn add_transform<T: SchemaTransform + 'static>(&mut self, transform: T) {
102 self.custom_transforms.push(Box::new(transform));
103 }
104
105 pub fn transform(&self, tool: &str, schema: Schema) -> Schema {
107 let mut v = serde_json::to_value(&schema).expect("serialize schema");
108
109 if self.global_strict
111 && let Some(obj) = v.as_object_mut()
112 {
113 obj.insert("additionalProperties".to_string(), Json::Bool(false));
114 }
115
116 if let Some(entries) = self.per_tool.get(tool) {
118 for (path, constraint) in entries {
119 Self::apply_constraint(&mut v, path, constraint);
120 }
121 }
122
123 for transform in &self.custom_transforms {
125 transform.apply(tool, &mut v);
126 }
127
128 Schema::try_from(v).expect("schema transform must produce a valid schema")
133 }
134
135 fn apply_constraint(root: &mut Json, path: &[String], constraint: &FieldConstraint) {
136 let Some(node) = Self::find_node_mut(root, path) else {
137 return;
138 };
139 let Some(obj) = node.as_object_mut() else {
140 return;
141 };
142 match constraint {
143 FieldConstraint::Enum(vals) => {
144 obj.insert("enum".into(), Json::Array(vals.clone()));
145 }
146 FieldConstraint::Range { minimum, maximum } => {
147 if let Some(m) = minimum {
148 obj.insert("minimum".into(), m.clone());
149 }
150 if let Some(m) = maximum {
151 obj.insert("maximum".into(), m.clone());
152 }
153 }
154 FieldConstraint::Pattern(p) => {
155 obj.insert("pattern".into(), Json::String(p.clone()));
156 }
157 FieldConstraint::MergePatch(patch) => {
158 json_patch::merge(node, patch);
159 }
160 }
161 }
162
163 fn find_node_mut<'a>(root: &'a mut Json, path: &[String]) -> Option<&'a mut Json> {
164 let mut cur = root;
165 for seg in path {
166 cur = cur.as_object_mut()?.get_mut(seg)?;
167 }
168 Some(cur)
169 }
170}
171
172const OPTIONAL_PROPERTY_GUIDANCE: &str = "Optional; omit or use null.";
173
174#[derive(Clone, Default)]
175struct NullFirstOptional;
176
177impl schemars::transform::Transform for NullFirstOptional {
178 fn transform(&mut self, schema: &mut Schema) {
179 let mut value = serde_json::to_value(&*schema).expect("serialize schema");
180 normalize_optional_properties(&mut value);
181 *schema = Schema::try_from(value).expect("NullFirstOptional must preserve schema validity");
182 }
183}
184
185fn normalize_optional_properties(node: &mut Json) {
186 let Some(obj) = node.as_object_mut() else {
187 return;
188 };
189
190 recurse_object_entries(obj, "$defs");
191 recurse_object_entries(obj, "definitions");
192
193 let required = required_property_names(obj.get("required"));
194 if let Some(properties) = obj.get_mut("properties").and_then(Json::as_object_mut) {
195 for (property_name, property_schema) in properties {
196 if !required.contains(property_name.as_str()) {
197 normalize_known_nullable_shapes(property_schema);
198 if explicitly_allows_null(property_schema) {
199 annotate_optional_property(property_schema);
200 }
201 }
202 normalize_optional_properties(property_schema);
203 }
204 }
205
206 recurse_object_entries(obj, "dependentSchemas");
207 recurse_object_entries(obj, "patternProperties");
208 recurse_schema_entry(obj, "additionalProperties");
209 recurse_schema_entry(obj, "propertyNames");
210 recurse_schema_entry(obj, "unevaluatedProperties");
211 recurse_schema_entry(obj, "items");
212 recurse_schema_entry(obj, "unevaluatedItems");
213 recurse_schema_entry(obj, "contains");
214 recurse_schema_array_entry(obj, "prefixItems");
215 recurse_schema_array_entry(obj, "allOf");
216 recurse_schema_array_entry(obj, "anyOf");
217 recurse_schema_array_entry(obj, "oneOf");
218 recurse_schema_entry(obj, "if");
219 recurse_schema_entry(obj, "then");
220 recurse_schema_entry(obj, "else");
221 recurse_schema_entry(obj, "not");
222}
223
224fn recurse_object_entries(obj: &mut serde_json::Map<String, Json>, key: &str) {
225 let Some(entries) = obj.get_mut(key).and_then(Json::as_object_mut) else {
226 return;
227 };
228
229 for value in entries.values_mut() {
230 normalize_optional_properties(value);
231 }
232}
233
234fn recurse_schema_entry(obj: &mut serde_json::Map<String, Json>, key: &str) {
235 let Some(value) = obj.get_mut(key) else {
236 return;
237 };
238
239 normalize_optional_properties(value);
240}
241
242fn recurse_schema_array_entry(obj: &mut serde_json::Map<String, Json>, key: &str) {
243 let Some(values) = obj.get_mut(key).and_then(Json::as_array_mut) else {
244 return;
245 };
246
247 for value in values {
248 normalize_optional_properties(value);
249 }
250}
251
252fn required_property_names(required: Option<&Json>) -> HashSet<String> {
253 required
254 .and_then(Json::as_array)
255 .into_iter()
256 .flatten()
257 .filter_map(Json::as_str)
258 .map(str::to_owned)
259 .collect()
260}
261
262fn normalize_known_nullable_shapes(node: &mut Json) {
263 move_null_to_front_in_type_array(node);
264 move_null_to_front_in_enum_values(node);
265 move_null_to_front_in_any_of(node);
266}
267
268fn explicitly_allows_null(node: &Json) -> bool {
269 type_array_contains_null(node)
270 || enum_values_contain_null(node)
271 || any_of_contains_explicit_null_branch(node)
272}
273
274fn type_array_contains_null(node: &Json) -> bool {
275 node.as_object()
276 .and_then(|obj| obj.get("type"))
277 .and_then(Json::as_array)
278 .is_some_and(|type_values| {
279 type_values
280 .iter()
281 .any(|value| value == &Json::String("null".into()))
282 })
283}
284
285fn enum_values_contain_null(node: &Json) -> bool {
286 node.as_object()
287 .and_then(|obj| obj.get("enum"))
288 .and_then(Json::as_array)
289 .is_some_and(|enum_values| enum_values.iter().any(Json::is_null))
290}
291
292fn any_of_contains_explicit_null_branch(node: &Json) -> bool {
293 node.as_object()
294 .and_then(|obj| obj.get("anyOf"))
295 .and_then(Json::as_array)
296 .is_some_and(|any_of| any_of.iter().any(is_explicit_null_branch))
297}
298
299fn move_null_to_front_in_type_array(node: &mut Json) {
300 let Some(obj) = node.as_object_mut() else {
301 return;
302 };
303
304 let Some(type_values) = obj.get_mut("type").and_then(Json::as_array_mut) else {
305 return;
306 };
307
308 move_values_to_front(type_values, |value| value == &Json::String("null".into()));
309}
310
311fn move_null_to_front_in_enum_values(node: &mut Json) {
312 let Some(obj) = node.as_object_mut() else {
313 return;
314 };
315
316 let Some(enum_values) = obj.get_mut("enum").and_then(Json::as_array_mut) else {
317 return;
318 };
319
320 move_values_to_front(enum_values, Json::is_null);
321}
322
323fn move_null_to_front_in_any_of(node: &mut Json) {
324 let Some(obj) = node.as_object_mut() else {
325 return;
326 };
327
328 let Some(any_of) = obj.get_mut("anyOf").and_then(Json::as_array_mut) else {
329 return;
330 };
331
332 move_values_to_front(any_of, is_explicit_null_branch);
333}
334
335fn annotate_optional_property(node: &mut Json) {
336 let Some(obj) = node.as_object_mut() else {
337 return;
338 };
339
340 match obj.get_mut("description") {
341 Some(Json::String(description)) => {
342 if !description.contains(OPTIONAL_PROPERTY_GUIDANCE) {
343 description.push_str("\n\n");
344 description.push_str(OPTIONAL_PROPERTY_GUIDANCE);
345 }
346 }
347 Some(_) => {
348 }
350 None => {
351 obj.insert(
352 "description".to_string(),
353 Json::String(OPTIONAL_PROPERTY_GUIDANCE.to_string()),
354 );
355 }
356 }
357}
358
359fn move_values_to_front<F>(values: &mut Vec<Json>, predicate: F)
360where
361 F: Fn(&Json) -> bool,
362{
363 let mut matching = Vec::new();
364 let mut non_matching = Vec::new();
365
366 for value in values.drain(..) {
367 if predicate(&value) {
368 matching.push(value);
369 } else {
370 non_matching.push(value);
371 }
372 }
373
374 if matching.is_empty() {
375 *values = non_matching;
376 return;
377 }
378
379 matching.extend(non_matching);
380 *values = matching;
381}
382
383fn is_explicit_null_branch(node: &Json) -> bool {
384 matches!(
385 node,
386 Json::Object(obj) if obj.get("type") == Some(&Json::String("null".into()))
387 )
388}
389
390pub mod mcp_schema {
402 use super::NullFirstOptional;
403 use schemars::JsonSchema;
404 use schemars::Schema;
405 use schemars::generate::SchemaSettings;
406 use schemars::transform::RestrictFormats;
407 use std::any::TypeId;
408 use std::cell::RefCell;
409 use std::collections::HashMap;
410 use std::sync::Arc;
411
412 thread_local! {
413 static CACHE_FOR_TYPE: RefCell<HashMap<TypeId, Arc<Schema>>> = RefCell::new(HashMap::new());
414 static CACHE_FOR_OUTPUT: RefCell<HashMap<TypeId, Result<Arc<Schema>, String>>> = RefCell::new(HashMap::new());
415 }
416
417 fn settings() -> SchemaSettings {
418 SchemaSettings::draft2020_12()
419 .with_transform(RestrictFormats::default())
420 .with_transform(NullFirstOptional)
421 }
422
423 pub fn cached_schema_for<T: JsonSchema + 'static>() -> Arc<Schema> {
425 CACHE_FOR_TYPE.with(|cache| {
426 let mut cache = cache.borrow_mut();
427 if let Some(x) = cache.get(&TypeId::of::<T>()) {
428 return x.clone();
429 }
430 let generator = settings().into_generator();
431 let root = generator.into_root_schema_for::<T>();
432 let arc = Arc::new(root);
433 cache.insert(TypeId::of::<T>(), arc.clone());
434 arc
435 })
436 }
437
438 pub fn cached_output_schema_for<T: JsonSchema + 'static>() -> Result<Arc<Schema>, String> {
441 CACHE_FOR_OUTPUT.with(|cache| {
442 let mut cache = cache.borrow_mut();
443 if let Some(r) = cache.get(&TypeId::of::<T>()) {
444 return r.clone();
445 }
446 let root = cached_schema_for::<T>();
447 let json = serde_json::to_value(root.as_ref()).expect("serialize output schema");
448 let result = match json.get("type") {
449 Some(serde_json::Value::String(t)) if t == "object" => Ok(root.clone()),
450 Some(serde_json::Value::String(t)) => Err(format!(
451 "MCP requires output_schema root type 'object', found '{}'",
452 t
453 )),
454 None => {
455 if json.get("properties").is_some() {
458 Ok(root.clone())
459 } else {
460 Err(
461 "Schema missing 'type' — output_schema must have root type 'object'"
462 .to_string(),
463 )
464 }
465 }
466 Some(other) => Err(format!(
467 "Unexpected 'type' format: {:?} — expected string 'object'",
468 other
469 )),
470 };
471 cache.insert(TypeId::of::<T>(), result.clone());
472 result
473 })
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use serde::Serialize;
481
482 #[derive(schemars::JsonSchema, Serialize)]
483 struct TestInput {
484 count: i32,
485 name: String,
486 }
487
488 #[test]
489 fn test_strict_mode() {
490 let engine = SchemaEngine::new().with_strict(true);
491 let schema = schemars::schema_for!(TestInput);
492 let transformed = engine.transform("test", schema);
493
494 let json = serde_json::to_value(&transformed).unwrap();
495 assert_eq!(json.get("additionalProperties"), Some(&Json::Bool(false)));
496 }
497
498 #[test]
499 fn test_is_strict_getter() {
500 let e = SchemaEngine::new();
501 assert!(!e.is_strict());
502 let e2 = SchemaEngine::new().with_strict(true);
503 assert!(e2.is_strict());
504 }
505
506 #[test]
507 fn test_enum_constraint() {
508 let mut engine = SchemaEngine::new();
509
510 let test_schema: Json = serde_json::json!({
512 "type": "object",
513 "properties": {
514 "name": {
515 "type": "string"
516 }
517 }
518 });
519
520 engine.constrain_field(
521 "test",
522 vec!["properties".into(), "name".into()],
523 FieldConstraint::Enum(vec![Json::String("a".into()), Json::String("b".into())]),
524 );
525
526 let schema: Schema = Schema::try_from(test_schema.clone()).unwrap();
527 let transformed = engine.transform("test", schema);
528
529 let json = serde_json::to_value(&transformed).unwrap();
530 let name_schema = &json["properties"]["name"];
531 assert!(name_schema.get("enum").is_some());
532 }
533
534 #[test]
535 fn test_range_constraint() {
536 let mut engine = SchemaEngine::new();
538 engine.constrain_field(
539 "test",
540 vec!["properties".into(), "count".into()],
541 FieldConstraint::Range {
542 minimum: Some(Json::Number(0.into())),
543 maximum: Some(Json::Number(100.into())),
544 },
545 );
546
547 let schema = schemars::schema_for!(TestInput);
549
550 let transformed = engine.transform("test", schema);
552
553 let json = serde_json::to_value(&transformed).unwrap();
555 let count_schema = &json["properties"]["count"];
556
557 let min = count_schema.get("minimum").and_then(|v| v.as_f64());
559 let max = count_schema.get("maximum").and_then(|v| v.as_f64());
560
561 assert_eq!(min, Some(0.0), "minimum constraint should be applied");
562 assert_eq!(max, Some(100.0), "maximum constraint should be applied");
563 }
564
565 mod mcp_schema_tests {
570 use super::Json;
571 use super::NullFirstOptional;
572 use super::OPTIONAL_PROPERTY_GUIDANCE;
573 use super::Schema;
574 use super::mcp_schema;
575 use schemars::transform::Transform;
576 use serde::Serialize;
577
578 fn property<'a>(schema: &'a Json, name: &str) -> &'a Json {
579 &schema["properties"][name]
580 }
581
582 fn required_names(schema: &Json) -> Vec<&str> {
583 schema["required"]
584 .as_array()
585 .into_iter()
586 .flatten()
587 .filter_map(Json::as_str)
588 .collect()
589 }
590
591 fn assert_optional_guidance(schema: &Json, name: &str) {
592 assert_eq!(
593 property(schema, name).get("description"),
594 Some(&Json::String(OPTIONAL_PROPERTY_GUIDANCE.to_string()))
595 );
596 }
597
598 #[derive(schemars::JsonSchema, Serialize)]
599 struct WithOption {
600 a: Option<String>,
601 }
602
603 #[test]
604 fn test_option_string_is_optional_nullable_with_null_first() {
605 let root = mcp_schema::cached_schema_for::<WithOption>();
606 let v = serde_json::to_value(root.as_ref()).unwrap();
607 let a = property(&v, "a");
608
609 assert_eq!(a.get("type"), Some(&serde_json::json!(["null", "string"])));
610 assert!(a.get("nullable").is_none());
611 assert!(required_names(&v).is_empty());
612 assert_optional_guidance(&v, "a");
613 }
614
615 #[derive(schemars::JsonSchema, Serialize)]
616 struct OutputObj {
617 x: i32,
618 }
619
620 #[test]
621 fn test_output_schema_validation_object() {
622 let ok = mcp_schema::cached_output_schema_for::<OutputObj>();
623 assert!(
624 ok.is_ok(),
625 "Object types should pass output schema validation"
626 );
627 }
628
629 #[test]
630 fn test_output_schema_validation_non_object() {
631 let bad = mcp_schema::cached_output_schema_for::<String>();
633 assert!(
634 bad.is_err(),
635 "Non-object types should fail output schema validation"
636 );
637 }
638
639 #[test]
640 fn test_draft_2020_12_uses_defs() {
641 let root = mcp_schema::cached_schema_for::<WithOption>();
642 let v = serde_json::to_value(root.as_ref()).unwrap();
643 assert!(v.is_object(), "Schema should be an object");
647 assert!(
648 v.get("$schema")
649 .and_then(|s| s.as_str())
650 .is_some_and(|s| s.contains("2020-12")),
651 "Schema should reference Draft 2020-12"
652 );
653 }
654
655 #[test]
656 fn test_caching_returns_same_arc() {
657 let first = mcp_schema::cached_schema_for::<OutputObj>();
658 let second = mcp_schema::cached_schema_for::<OutputObj>();
659 assert!(
660 std::sync::Arc::ptr_eq(&first, &second),
661 "Cached schemas should return the same Arc"
662 );
663 }
664
665 #[allow(dead_code)]
670 #[derive(schemars::JsonSchema, Serialize)]
671 enum TestEnum {
672 A,
673 B,
674 }
675
676 #[derive(schemars::JsonSchema, Serialize)]
677 struct HasOptEnum {
678 e: Option<TestEnum>,
679 }
680
681 #[test]
682 fn test_option_enum_keeps_any_of_with_null_first() {
683 let root = mcp_schema::cached_schema_for::<HasOptEnum>();
684 let v = serde_json::to_value(root.as_ref()).unwrap();
685 let e = property(&v, "e");
686 let any_of = e["anyOf"].as_array().expect("Option enum should use anyOf");
687
688 assert_eq!(any_of.len(), 2);
689 assert_eq!(any_of[0], serde_json::json!({ "type": "null" }));
690 assert!(any_of[1].get("$ref").is_some());
691 assert_optional_guidance(&v, "e");
692 }
693
694 #[derive(schemars::JsonSchema, Serialize)]
695 struct Unsigneds {
696 a: u32,
697 b: u64,
698 }
699
700 #[test]
701 fn test_strip_uint_formats() {
702 let root = mcp_schema::cached_schema_for::<Unsigneds>();
703 let v = serde_json::to_value(root.as_ref()).unwrap();
704 let pa = &v["properties"]["a"];
705 let pb = &v["properties"]["b"];
706
707 assert!(
708 pa.get("format").is_none(),
709 "u32 should not include non-standard 'format'"
710 );
711 assert!(
712 pb.get("format").is_none(),
713 "u64 should not include non-standard 'format'"
714 );
715 assert_eq!(
716 pa.get("minimum").and_then(|x| x.as_u64()),
717 Some(0),
718 "u32 minimum must be preserved"
719 );
720 assert_eq!(
721 pb.get("minimum").and_then(|x| x.as_u64()),
722 Some(0),
723 "u64 minimum must be preserved"
724 );
725 }
726
727 #[derive(schemars::JsonSchema, Serialize)]
728 struct HasOptString {
729 s: Option<String>,
730 }
731
732 #[test]
733 fn test_option_string_uses_null_first_without_nullable_keyword() {
734 let root = mcp_schema::cached_schema_for::<HasOptString>();
735 let v = serde_json::to_value(root.as_ref()).unwrap();
736 let s = property(&v, "s");
737
738 assert_eq!(s.get("type"), Some(&serde_json::json!(["null", "string"])));
739 assert!(
740 s.get("nullable").is_none(),
741 "Option<String> should not have nullable keyword"
742 );
743 assert_optional_guidance(&v, "s");
744 }
745
746 #[derive(schemars::JsonSchema, Serialize)]
747 struct NestedInner {
748 leaf: Option<String>,
749 }
750
751 #[derive(schemars::JsonSchema, Serialize)]
752 struct NestedOuter {
753 nested: Option<NestedInner>,
754 }
755
756 #[test]
757 fn test_nested_optional_properties_are_normalized_recursively() {
758 let root = mcp_schema::cached_schema_for::<NestedOuter>();
759 let v = serde_json::to_value(root.as_ref()).unwrap();
760 let nested = property(&v, "nested");
761 let nested_any_of = nested["anyOf"]
762 .as_array()
763 .expect("Nested option should keep anyOf branches");
764
765 assert_eq!(nested_any_of[0], serde_json::json!({ "type": "null" }));
766 assert!(nested_any_of[1].get("$ref").is_some());
767 assert_optional_guidance(&v, "nested");
768
769 let defs = v["$defs"]
770 .as_object()
771 .expect("Nested type should use $defs");
772 let inner = defs
773 .values()
774 .find(|schema| schema["properties"].get("leaf").is_some())
775 .expect("NestedInner schema should exist in $defs");
776
777 assert_eq!(
778 inner["properties"]["leaf"]["type"],
779 serde_json::json!(["null", "string"])
780 );
781 assert_eq!(
782 inner["properties"]["leaf"]["description"],
783 serde_json::json!(OPTIONAL_PROPERTY_GUIDANCE)
784 );
785 }
786
787 #[derive(schemars::JsonSchema, Serialize)]
788 struct HasOptVec {
789 values: Option<Vec<String>>,
790 }
791
792 #[test]
793 fn test_option_vec_property_keeps_outer_nullability_with_null_first() {
794 let root = mcp_schema::cached_schema_for::<HasOptVec>();
795 let v = serde_json::to_value(root.as_ref()).unwrap();
796 let values = property(&v, "values");
797
798 assert_eq!(
799 values.get("type"),
800 Some(&serde_json::json!(["null", "array"]))
801 );
802 assert_eq!(values["items"]["type"], serde_json::json!("string"));
803 assert_optional_guidance(&v, "values");
804 }
805
806 #[derive(schemars::JsonSchema, Serialize)]
807 struct HasNestedOptionalItems {
808 values: Option<Vec<Option<String>>>,
809 }
810
811 #[test]
812 fn test_inner_nullability_is_preserved() {
813 let root = mcp_schema::cached_schema_for::<HasNestedOptionalItems>();
814 let v = serde_json::to_value(root.as_ref()).unwrap();
815 let values = property(&v, "values");
816 let item_type = values["items"]["type"]
817 .as_array()
818 .expect("Inner Option<String> should remain nullable");
819
820 assert_eq!(
821 values.get("type"),
822 Some(&serde_json::json!(["null", "array"]))
823 );
824 assert!(item_type.contains(&serde_json::json!("string")));
825 assert!(item_type.contains(&serde_json::json!("null")));
826 assert_optional_guidance(&v, "values");
827 }
828
829 #[test]
830 fn test_required_fields_remain_unchanged() {
831 let mut schema = Schema::try_from(serde_json::json!({
832 "type": "object",
833 "properties": {
834 "required_field": { "type": ["string", "null"] },
835 "optional_field": { "type": ["string", "null"] }
836 },
837 "required": ["required_field"]
838 }))
839 .unwrap();
840
841 NullFirstOptional.transform(&mut schema);
842
843 let v = serde_json::to_value(&schema).unwrap();
844 let required_type = v["properties"]["required_field"]["type"]
845 .as_array()
846 .expect("Required field should keep nullable type array");
847
848 assert!(required_type.contains(&serde_json::json!("string")));
849 assert!(required_type.contains(&serde_json::json!("null")));
850 assert_eq!(
851 v["properties"]["optional_field"]["type"],
852 serde_json::json!(["null", "string"])
853 );
854 assert_eq!(
855 v["properties"]["optional_field"]["description"],
856 serde_json::json!(OPTIONAL_PROPERTY_GUIDANCE)
857 );
858 assert!(
859 v["properties"]["required_field"]
860 .get("description")
861 .is_none()
862 );
863 }
864
865 #[test]
866 fn test_manual_any_of_null_branch_moves_to_front() {
867 let mut schema = Schema::try_from(serde_json::json!({
868 "type": "object",
869 "properties": {
870 "optional_field": {
871 "anyOf": [
872 { "type": "string" },
873 { "type": "integer" },
874 { "type": "null" }
875 ]
876 }
877 }
878 }))
879 .unwrap();
880
881 NullFirstOptional.transform(&mut schema);
882
883 let v = serde_json::to_value(&schema).unwrap();
884 assert_eq!(
885 v["properties"]["optional_field"]["anyOf"],
886 serde_json::json!([
887 { "type": "null" },
888 { "type": "string" },
889 { "type": "integer" }
890 ])
891 );
892 assert_eq!(
893 v["properties"]["optional_field"]["description"],
894 serde_json::json!(OPTIONAL_PROPERTY_GUIDANCE)
895 );
896 }
897
898 #[test]
899 fn test_manual_enum_null_moves_to_front() {
900 let mut schema = Schema::try_from(serde_json::json!({
901 "type": "object",
902 "properties": {
903 "optional_field": {
904 "enum": ["alpha", null, "beta"]
905 }
906 }
907 }))
908 .unwrap();
909
910 NullFirstOptional.transform(&mut schema);
911
912 let v = serde_json::to_value(&schema).unwrap();
913 assert_eq!(
914 v["properties"]["optional_field"]["enum"],
915 serde_json::json!([null, "alpha", "beta"])
916 );
917 assert_eq!(
918 v["properties"]["optional_field"]["description"],
919 serde_json::json!(OPTIONAL_PROPERTY_GUIDANCE)
920 );
921 }
922
923 #[test]
924 fn test_existing_description_appends_guidance_once() {
925 let mut schema = Schema::try_from(serde_json::json!({
926 "type": "object",
927 "properties": {
928 "optional_field": {
929 "description": "Existing description.",
930 "type": ["string", "null"]
931 }
932 }
933 }))
934 .unwrap();
935
936 NullFirstOptional.transform(&mut schema);
937 NullFirstOptional.transform(&mut schema);
938
939 let v = serde_json::to_value(&schema).unwrap();
940 assert_eq!(
941 v["properties"]["optional_field"]["description"],
942 serde_json::json!("Existing description.\n\nOptional; omit or use null.")
943 );
944 assert_eq!(
945 v["properties"]["optional_field"]["type"],
946 serde_json::json!(["null", "string"])
947 );
948 }
949
950 #[test]
951 fn test_non_nullable_optional_property_does_not_get_null_guidance() {
952 let mut schema = Schema::try_from(serde_json::json!({
953 "type": "object",
954 "properties": {
955 "optional_field": {
956 "type": "string"
957 }
958 }
959 }))
960 .unwrap();
961
962 NullFirstOptional.transform(&mut schema);
963
964 let v = serde_json::to_value(&schema).unwrap();
965 assert!(
966 v["properties"]["optional_field"]
967 .get("description")
968 .is_none()
969 );
970 }
971
972 #[test]
973 fn test_dependent_schemas_are_normalized_recursively() {
974 let mut schema = Schema::try_from(serde_json::json!({
975 "type": "object",
976 "properties": {
977 "trigger": { "type": "boolean" }
978 },
979 "dependentSchemas": {
980 "trigger": {
981 "type": "object",
982 "properties": {
983 "nested_optional": {
984 "type": ["string", "null"]
985 }
986 }
987 }
988 }
989 }))
990 .unwrap();
991
992 NullFirstOptional.transform(&mut schema);
993
994 let v = serde_json::to_value(&schema).unwrap();
995 assert_eq!(
996 v["dependentSchemas"]["trigger"]["properties"]["nested_optional"]["type"],
997 serde_json::json!(["null", "string"])
998 );
999 assert_eq!(
1000 v["dependentSchemas"]["trigger"]["properties"]["nested_optional"]["description"],
1001 serde_json::json!(OPTIONAL_PROPERTY_GUIDANCE)
1002 );
1003 }
1004
1005 #[test]
1006 fn test_unevaluated_items_are_normalized_recursively() {
1007 let mut schema = Schema::try_from(serde_json::json!({
1008 "type": "array",
1009 "unevaluatedItems": {
1010 "type": "object",
1011 "properties": {
1012 "nested_optional": {
1013 "type": ["string", "null"]
1014 }
1015 }
1016 }
1017 }))
1018 .unwrap();
1019
1020 NullFirstOptional.transform(&mut schema);
1021
1022 let v = serde_json::to_value(&schema).unwrap();
1023 assert_eq!(
1024 v["unevaluatedItems"]["properties"]["nested_optional"]["type"],
1025 serde_json::json!(["null", "string"])
1026 );
1027 assert_eq!(
1028 v["unevaluatedItems"]["properties"]["nested_optional"]["description"],
1029 serde_json::json!(OPTIONAL_PROPERTY_GUIDANCE)
1030 );
1031 }
1032 }
1033}