1use serde_json::{Map, Value, json};
55use std::collections::{HashMap, HashSet};
56
57pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
59 "$ref",
61 "$schema",
62 "$id",
63 "$defs",
64 "definitions",
65 "additionalProperties",
67 "patternProperties",
68 "minLength",
70 "maxLength",
71 "pattern",
72 "format",
73 "minimum",
75 "maximum",
76 "multipleOf",
77 "minItems",
79 "maxItems",
80 "uniqueItems",
81 "minProperties",
83 "maxProperties",
84 "examples", ];
87
88const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"];
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum CleaningStrategy {
94 Gemini,
96 Anthropic,
98 OpenAI,
100 Conservative,
102}
103
104impl CleaningStrategy {
105 pub fn unsupported_keywords(self) -> &'static [&'static str] {
107 match self {
108 Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
109 Self::Anthropic => &["$ref", "$defs", "definitions"], Self::OpenAI => &[], Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
112 }
113 }
114}
115
116pub struct SchemaCleanr;
118
119impl SchemaCleanr {
120 pub fn clean_for_gemini(schema: Value) -> Value {
125 Self::clean(schema, CleaningStrategy::Gemini)
126 }
127
128 pub fn clean_for_anthropic(schema: Value) -> Value {
130 Self::clean(schema, CleaningStrategy::Anthropic)
131 }
132
133 pub fn clean_for_openai(schema: Value) -> Value {
135 Self::clean(schema, CleaningStrategy::OpenAI)
136 }
137
138 pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
140 let defs = if let Some(obj) = schema.as_object() {
142 Self::extract_defs(obj)
143 } else {
144 HashMap::new()
145 };
146
147 Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
148 }
149
150 pub fn validate(schema: &Value) -> anyhow::Result<()> {
154 let obj = schema
155 .as_object()
156 .ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
157
158 if !obj.contains_key("type") {
160 anyhow::bail!("Schema missing required 'type' field");
161 }
162
163 if let Some(Value::String(t)) = obj.get("type") {
165 if t == "object" && !obj.contains_key("properties") {
166 tracing::warn!("Object schema without 'properties' field may cause issues");
167 }
168 }
169
170 Ok(())
171 }
172
173 fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
179 let mut defs = HashMap::new();
180
181 if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
183 for (key, value) in defs_obj {
184 defs.insert(key.clone(), value.clone());
185 }
186 }
187
188 if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
190 for (key, value) in defs_obj {
191 defs.insert(key.clone(), value.clone());
192 }
193 }
194
195 defs
196 }
197
198 fn clean_with_defs(
200 schema: Value,
201 defs: &HashMap<String, Value>,
202 strategy: CleaningStrategy,
203 ref_stack: &mut HashSet<String>,
204 ) -> Value {
205 match schema {
206 Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
207 Value::Array(arr) => Value::Array(
208 arr.into_iter()
209 .map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
210 .collect(),
211 ),
212 other => other,
213 }
214 }
215
216 fn clean_object(
218 obj: Map<String, Value>,
219 defs: &HashMap<String, Value>,
220 strategy: CleaningStrategy,
221 ref_stack: &mut HashSet<String>,
222 ) -> Value {
223 if let Some(Value::String(ref_value)) = obj.get("$ref") {
225 return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
226 }
227
228 if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
230 if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
231 return simplified;
232 }
233 }
234
235 let mut cleaned = Map::new();
237 let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
238 let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
239
240 for (key, value) in obj {
241 if unsupported.contains(key.as_str()) {
243 continue;
244 }
245
246 match key.as_str() {
248 "const" => {
250 cleaned.insert("enum".to_string(), json!([value]));
251 }
252 "type" if has_union => {
254 }
256 "type" if matches!(value, Value::Array(_)) => {
258 let cleaned_value = Self::clean_type_array(value);
259 cleaned.insert(key, cleaned_value);
260 }
261 "properties" => {
263 let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
264 cleaned.insert(key, cleaned_value);
265 }
266 "items" => {
267 let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
268 cleaned.insert(key, cleaned_value);
269 }
270 "anyOf" | "oneOf" | "allOf" => {
271 let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
272 cleaned.insert(key, cleaned_value);
273 }
274 _ => {
276 let cleaned_value = match value {
277 Value::Object(_) | Value::Array(_) => {
278 Self::clean_with_defs(value, defs, strategy, ref_stack)
279 }
280 other => other,
281 };
282 cleaned.insert(key, cleaned_value);
283 }
284 }
285 }
286
287 Value::Object(cleaned)
288 }
289
290 fn resolve_ref(
292 ref_value: &str,
293 obj: &Map<String, Value>,
294 defs: &HashMap<String, Value>,
295 strategy: CleaningStrategy,
296 ref_stack: &mut HashSet<String>,
297 ) -> Value {
298 if ref_stack.contains(ref_value) {
300 tracing::warn!("Circular $ref detected: {}", ref_value);
301 return Self::preserve_meta(obj, Value::Object(Map::new()));
302 }
303
304 if let Some(def_name) = Self::parse_local_ref(ref_value) {
306 if let Some(definition) = defs.get(def_name.as_str()) {
307 ref_stack.insert(ref_value.to_string());
308 let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
309 ref_stack.remove(ref_value);
310 return Self::preserve_meta(obj, cleaned);
311 }
312 }
313
314 tracing::warn!("Cannot resolve $ref: {}", ref_value);
316 Self::preserve_meta(obj, Value::Object(Map::new()))
317 }
318
319 fn parse_local_ref(ref_value: &str) -> Option<String> {
321 ref_value
322 .strip_prefix("#/$defs/")
323 .or_else(|| ref_value.strip_prefix("#/definitions/"))
324 .map(Self::decode_json_pointer)
325 }
326
327 fn decode_json_pointer(segment: &str) -> String {
329 if !segment.contains('~') {
330 return segment.to_string();
331 }
332
333 let mut decoded = String::with_capacity(segment.len());
334 let mut chars = segment.chars().peekable();
335
336 while let Some(ch) = chars.next() {
337 if ch == '~' {
338 match chars.peek().copied() {
339 Some('0') => {
340 chars.next();
341 decoded.push('~');
342 }
343 Some('1') => {
344 chars.next();
345 decoded.push('/');
346 }
347 _ => decoded.push('~'),
348 }
349 } else {
350 decoded.push(ch);
351 }
352 }
353
354 decoded
355 }
356
357 fn try_simplify_union(
359 obj: &Map<String, Value>,
360 defs: &HashMap<String, Value>,
361 strategy: CleaningStrategy,
362 ref_stack: &mut HashSet<String>,
363 ) -> Option<Value> {
364 let union_key = if obj.contains_key("anyOf") {
365 "anyOf"
366 } else if obj.contains_key("oneOf") {
367 "oneOf"
368 } else {
369 return None;
370 };
371
372 let variants = obj.get(union_key)?.as_array()?;
373
374 let cleaned_variants: Vec<Value> = variants
376 .iter()
377 .map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
378 .collect();
379
380 let non_null: Vec<Value> = cleaned_variants
382 .into_iter()
383 .filter(|v| !Self::is_null_schema(v))
384 .collect();
385
386 if non_null.len() == 1 {
388 return Some(Self::preserve_meta(obj, non_null[0].clone()));
389 }
390
391 if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
393 return Some(Self::preserve_meta(obj, enum_value));
394 }
395
396 None
397 }
398
399 fn is_null_schema(value: &Value) -> bool {
401 if let Some(obj) = value.as_object() {
402 if let Some(Value::Null) = obj.get("const") {
404 return true;
405 }
406 if let Some(Value::Array(arr)) = obj.get("enum") {
408 if arr.len() == 1 && matches!(arr[0], Value::Null) {
409 return true;
410 }
411 }
412 if let Some(Value::String(t)) = obj.get("type") {
414 if t == "null" {
415 return true;
416 }
417 }
418 }
419 false
420 }
421
422 fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
426 if variants.is_empty() {
427 return None;
428 }
429
430 let mut all_values = Vec::new();
431 let mut common_type: Option<String> = None;
432
433 for variant in variants {
434 let obj = variant.as_object()?;
435
436 let literal_value = if let Some(const_val) = obj.get("const") {
438 const_val.clone()
439 } else if let Some(Value::Array(arr)) = obj.get("enum") {
440 if arr.len() == 1 {
441 arr[0].clone()
442 } else {
443 return None;
444 }
445 } else {
446 return None;
447 };
448
449 let variant_type = obj.get("type")?.as_str()?;
451 match &common_type {
452 None => common_type = Some(variant_type.to_string()),
453 Some(t) if t != variant_type => return None,
454 _ => {}
455 }
456
457 all_values.push(literal_value);
458 }
459
460 common_type.map(|t| {
461 json!({
462 "type": t,
463 "enum": all_values
464 })
465 })
466 }
467
468 fn clean_type_array(value: Value) -> Value {
470 if let Value::Array(types) = value {
471 let non_null: Vec<Value> = types
472 .into_iter()
473 .filter(|v| v.as_str() != Some("null"))
474 .collect();
475
476 match non_null.len() {
477 0 => Value::String("null".to_string()),
478 1 => non_null
479 .into_iter()
480 .next()
481 .unwrap_or(Value::String("null".to_string())),
482 _ => Value::Array(non_null),
483 }
484 } else {
485 value
486 }
487 }
488
489 fn clean_properties(
491 value: Value,
492 defs: &HashMap<String, Value>,
493 strategy: CleaningStrategy,
494 ref_stack: &mut HashSet<String>,
495 ) -> Value {
496 if let Value::Object(props) = value {
497 let cleaned: Map<String, Value> = props
498 .into_iter()
499 .map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
500 .collect();
501 Value::Object(cleaned)
502 } else {
503 value
504 }
505 }
506
507 fn clean_union(
509 value: Value,
510 defs: &HashMap<String, Value>,
511 strategy: CleaningStrategy,
512 ref_stack: &mut HashSet<String>,
513 ) -> Value {
514 if let Value::Array(variants) = value {
515 let cleaned: Vec<Value> = variants
516 .into_iter()
517 .map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
518 .collect();
519 Value::Array(cleaned)
520 } else {
521 value
522 }
523 }
524
525 fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
527 if let Value::Object(target_obj) = &mut target {
528 for &key in SCHEMA_META_KEYS {
529 if let Some(value) = source.get(key) {
530 target_obj.insert(key.to_string(), value.clone());
531 }
532 }
533 }
534 target
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_remove_unsupported_keywords() {
544 let schema = json!({
545 "type": "string",
546 "minLength": 1,
547 "maxLength": 100,
548 "pattern": "^[a-z]+$",
549 "description": "A lowercase string"
550 });
551
552 let cleaned = SchemaCleanr::clean_for_gemini(schema);
553
554 assert_eq!(cleaned["type"], "string");
555 assert_eq!(cleaned["description"], "A lowercase string");
556 assert!(cleaned.get("minLength").is_none());
557 assert!(cleaned.get("maxLength").is_none());
558 assert!(cleaned.get("pattern").is_none());
559 }
560
561 #[test]
562 fn test_resolve_ref() {
563 let schema = json!({
564 "type": "object",
565 "properties": {
566 "age": {
567 "$ref": "#/$defs/Age"
568 }
569 },
570 "$defs": {
571 "Age": {
572 "type": "integer",
573 "minimum": 0
574 }
575 }
576 });
577
578 let cleaned = SchemaCleanr::clean_for_gemini(schema);
579
580 assert_eq!(cleaned["properties"]["age"]["type"], "integer");
581 assert!(cleaned["properties"]["age"].get("minimum").is_none()); assert!(cleaned.get("$defs").is_none());
583 }
584
585 #[test]
586 fn test_flatten_literal_union() {
587 let schema = json!({
588 "anyOf": [
589 { "const": "admin", "type": "string" },
590 { "const": "user", "type": "string" },
591 { "const": "guest", "type": "string" }
592 ]
593 });
594
595 let cleaned = SchemaCleanr::clean_for_gemini(schema);
596
597 assert_eq!(cleaned["type"], "string");
598 assert!(cleaned["enum"].is_array());
599 let enum_values = cleaned["enum"].as_array().unwrap();
600 assert_eq!(enum_values.len(), 3);
601 assert!(enum_values.contains(&json!("admin")));
602 assert!(enum_values.contains(&json!("user")));
603 assert!(enum_values.contains(&json!("guest")));
604 }
605
606 #[test]
607 fn test_strip_null_from_union() {
608 let schema = json!({
609 "oneOf": [
610 { "type": "string" },
611 { "type": "null" }
612 ]
613 });
614
615 let cleaned = SchemaCleanr::clean_for_gemini(schema);
616
617 assert_eq!(cleaned["type"], "string");
619 assert!(cleaned.get("oneOf").is_none());
620 }
621
622 #[test]
623 fn test_const_to_enum() {
624 let schema = json!({
625 "const": "fixed_value",
626 "description": "A constant"
627 });
628
629 let cleaned = SchemaCleanr::clean_for_gemini(schema);
630
631 assert_eq!(cleaned["enum"], json!(["fixed_value"]));
632 assert_eq!(cleaned["description"], "A constant");
633 assert!(cleaned.get("const").is_none());
634 }
635
636 #[test]
637 fn test_preserve_metadata() {
638 let schema = json!({
639 "$ref": "#/$defs/Name",
640 "description": "User's name",
641 "title": "Name Field",
642 "default": "Anonymous",
643 "$defs": {
644 "Name": {
645 "type": "string"
646 }
647 }
648 });
649
650 let cleaned = SchemaCleanr::clean_for_gemini(schema);
651
652 assert_eq!(cleaned["type"], "string");
653 assert_eq!(cleaned["description"], "User's name");
654 assert_eq!(cleaned["title"], "Name Field");
655 assert_eq!(cleaned["default"], "Anonymous");
656 }
657
658 #[test]
659 fn test_circular_ref_prevention() {
660 let schema = json!({
661 "type": "object",
662 "properties": {
663 "parent": {
664 "$ref": "#/$defs/Node"
665 }
666 },
667 "$defs": {
668 "Node": {
669 "type": "object",
670 "properties": {
671 "child": {
672 "$ref": "#/$defs/Node"
673 }
674 }
675 }
676 }
677 });
678
679 let cleaned = SchemaCleanr::clean_for_gemini(schema);
681
682 assert_eq!(cleaned["properties"]["parent"]["type"], "object");
683 }
685
686 #[test]
687 fn test_validate_schema() {
688 let valid = json!({
689 "type": "object",
690 "properties": {
691 "name": { "type": "string" }
692 }
693 });
694
695 assert!(SchemaCleanr::validate(&valid).is_ok());
696
697 let invalid = json!({
698 "properties": {
699 "name": { "type": "string" }
700 }
701 });
702
703 assert!(SchemaCleanr::validate(&invalid).is_err());
704 }
705
706 #[test]
707 fn test_strategy_differences() {
708 let schema = json!({
709 "type": "string",
710 "minLength": 1,
711 "description": "A string field"
712 });
713
714 let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
716 assert!(gemini.get("minLength").is_none());
717 assert_eq!(gemini["type"], "string");
718 assert_eq!(gemini["description"], "A string field");
719
720 let openai = SchemaCleanr::clean_for_openai(schema.clone());
722 assert_eq!(openai["minLength"], 1); assert_eq!(openai["type"], "string");
724 }
725
726 #[test]
727 fn test_nested_properties() {
728 let schema = json!({
729 "type": "object",
730 "properties": {
731 "user": {
732 "type": "object",
733 "properties": {
734 "name": {
735 "type": "string",
736 "minLength": 1
737 }
738 },
739 "additionalProperties": false
740 }
741 }
742 });
743
744 let cleaned = SchemaCleanr::clean_for_gemini(schema);
745
746 assert!(
747 cleaned["properties"]["user"]["properties"]["name"]
748 .get("minLength")
749 .is_none()
750 );
751 assert!(
752 cleaned["properties"]["user"]
753 .get("additionalProperties")
754 .is_none()
755 );
756 }
757
758 #[test]
759 fn test_type_array_null_removal() {
760 let schema = json!({
761 "type": ["string", "null"]
762 });
763
764 let cleaned = SchemaCleanr::clean_for_gemini(schema);
765
766 assert_eq!(cleaned["type"], "string");
768 }
769
770 #[test]
771 fn test_type_array_only_null_preserved() {
772 let schema = json!({
773 "type": ["null"]
774 });
775
776 let cleaned = SchemaCleanr::clean_for_gemini(schema);
777
778 assert_eq!(cleaned["type"], "null");
779 }
780
781 #[test]
782 fn test_ref_with_json_pointer_escape() {
783 let schema = json!({
784 "$ref": "#/$defs/Foo~1Bar",
785 "$defs": {
786 "Foo/Bar": {
787 "type": "string"
788 }
789 }
790 });
791
792 let cleaned = SchemaCleanr::clean_for_gemini(schema);
793
794 assert_eq!(cleaned["type"], "string");
795 }
796
797 #[test]
798 fn test_skip_type_when_non_simplifiable_union_exists() {
799 let schema = json!({
800 "type": "object",
801 "oneOf": [
802 {
803 "type": "object",
804 "properties": {
805 "a": { "type": "string" }
806 }
807 },
808 {
809 "type": "object",
810 "properties": {
811 "b": { "type": "number" }
812 }
813 }
814 ]
815 });
816
817 let cleaned = SchemaCleanr::clean_for_gemini(schema);
818
819 assert!(cleaned.get("type").is_none());
820 assert!(cleaned.get("oneOf").is_some());
821 }
822
823 #[test]
824 fn test_clean_nested_unknown_schema_keyword() {
825 let schema = json!({
826 "not": {
827 "$ref": "#/$defs/Age"
828 },
829 "$defs": {
830 "Age": {
831 "type": "integer",
832 "minimum": 0
833 }
834 }
835 });
836
837 let cleaned = SchemaCleanr::clean_for_gemini(schema);
838
839 assert_eq!(cleaned["not"]["type"], "integer");
840 assert!(cleaned["not"].get("minimum").is_none());
841 }
842}