1use serde_json::{Map, Value};
5use std::collections::HashSet;
6use thiserror::Error;
7
8pub const MAX_REF_DEPTH: usize = 32;
10
11#[derive(Debug, Error)]
17pub enum RefResolverError {
18 #[error("unresolvable $ref '{reference}' in module '{module_id}' (exit 45)")]
20 Unresolvable {
21 reference: String,
22 module_id: String,
23 },
24
25 #[error("circular $ref detected in module '{module_id}' (exit 48)")]
27 Circular { module_id: String },
28
29 #[error("$ref resolution exceeded max depth {max_depth} in module '{module_id}'")]
31 MaxDepthExceeded { max_depth: usize, module_id: String },
32}
33
34pub fn resolve_refs(
54 schema: &Value,
55 max_depth: usize,
56 module_id: &str,
57) -> Result<Value, RefResolverError> {
58 let copy = schema.clone();
60
61 let defs: Map<String, Value> = copy
63 .get("$defs")
64 .or_else(|| copy.get("definitions"))
65 .and_then(|v| v.as_object())
66 .cloned()
67 .unwrap_or_default();
68
69 let mut visiting: HashSet<String> = HashSet::new();
70 let resolved = resolve_node(copy, &defs, 0, max_depth, &mut visiting, module_id)?;
71
72 let mut result = resolved;
74 if let Some(obj) = result.as_object_mut() {
75 obj.remove("$defs");
76 obj.remove("definitions");
77 }
78 Ok(result)
79}
80
81fn merge_allof(branches: Vec<Value>) -> Value {
88 let mut merged_props = Map::new();
89 let mut merged_required: Vec<Value> = Vec::new();
90
91 for branch in branches {
92 if let Some(props) = branch.get("properties").and_then(|v| v.as_object()) {
93 for (k, v) in props {
94 merged_props.insert(k.clone(), v.clone());
95 }
96 }
97 if let Some(req) = branch.get("required").and_then(|v| v.as_array()) {
98 merged_required.extend(req.iter().cloned());
99 }
100 }
101
102 let mut result = Map::new();
103 result.insert("properties".to_string(), Value::Object(merged_props));
104 result.insert("required".to_string(), Value::Array(merged_required));
105 Value::Object(result)
106}
107
108fn intersect_required_sets(sets: Vec<HashSet<String>>) -> Vec<Value> {
110 if sets.is_empty() {
111 return Vec::new();
112 }
113 let mut iter = sets.into_iter();
114 let first = iter.next().unwrap();
115 iter.fold(first, |acc, set| acc.intersection(&set).cloned().collect())
116 .into_iter()
117 .map(Value::String)
118 .collect()
119}
120
121fn merge_anyof(branches: Vec<Value>) -> Value {
123 let mut merged_props = Map::new();
124 let mut all_required_sets: Vec<HashSet<String>> = Vec::new();
125
126 for branch in branches {
127 if let Some(props) = branch.get("properties").and_then(|v| v.as_object()) {
128 for (k, v) in props {
129 merged_props.insert(k.clone(), v.clone());
130 }
131 }
132 let set: HashSet<String> = branch
133 .get("required")
134 .and_then(|v| v.as_array())
135 .map(|arr| {
136 arr.iter()
137 .filter_map(|v| v.as_str().map(str::to_string))
138 .collect()
139 })
140 .unwrap_or_default();
141 all_required_sets.push(set);
142 }
143
144 let intersection = intersect_required_sets(all_required_sets);
145
146 let mut result = Map::new();
147 result.insert("properties".to_string(), Value::Object(merged_props));
148 result.insert("required".to_string(), Value::Array(intersection));
149 Value::Object(result)
150}
151
152fn resolve_node(
157 node: Value,
158 defs: &Map<String, Value>,
159 depth: usize,
160 max_depth: usize,
161 visiting: &mut HashSet<String>,
162 module_id: &str,
163) -> Result<Value, RefResolverError> {
164 let obj = match node {
165 Value::Object(map) => map,
166 other => return Ok(other),
167 };
168
169 if let Some(ref_val) = obj.get("$ref") {
171 let ref_path = ref_val.as_str().unwrap_or("").to_string();
172
173 if depth >= max_depth {
174 return Err(RefResolverError::MaxDepthExceeded {
175 max_depth,
176 module_id: module_id.to_string(),
177 });
178 }
179
180 if visiting.contains(&ref_path) {
181 return Err(RefResolverError::Circular {
182 module_id: module_id.to_string(),
183 });
184 }
185
186 let key = ref_path.split('/').next_back().unwrap_or("").to_string();
188
189 let def = defs
190 .get(&key)
191 .cloned()
192 .ok_or_else(|| RefResolverError::Unresolvable {
193 reference: ref_path.clone(),
194 module_id: module_id.to_string(),
195 })?;
196
197 visiting.insert(ref_path.clone());
198 let result = resolve_node(def, defs, depth + 1, max_depth, visiting, module_id)?;
199 visiting.remove(&ref_path);
205 return Ok(result);
206 }
207
208 if obj.contains_key("allOf") {
210 let sub_schemas = obj
211 .get("allOf")
212 .and_then(|v| v.as_array())
213 .cloned()
214 .unwrap_or_default();
215
216 let mut resolved_branches = Vec::with_capacity(sub_schemas.len());
218 for sub in sub_schemas {
219 let resolved_sub = resolve_node(sub, defs, depth + 1, max_depth, visiting, module_id)?;
220 resolved_branches.push(resolved_sub);
221 }
222
223 let merged = merge_allof(resolved_branches);
224 let merged_map = match merged {
225 Value::Object(m) => m,
226 _ => Map::new(),
227 };
228
229 let mut result_map = merged_map;
231 for (k, v) in &obj {
232 if k != "allOf" && !result_map.contains_key(k) {
233 result_map.insert(k.clone(), v.clone());
234 }
235 }
236 return Ok(Value::Object(result_map));
237 }
238
239 for keyword in &["anyOf", "oneOf"] {
241 if obj.contains_key(*keyword) {
242 let sub_schemas = obj
243 .get(*keyword)
244 .and_then(|v| v.as_array())
245 .cloned()
246 .unwrap_or_default();
247
248 let mut resolved_branches = Vec::with_capacity(sub_schemas.len());
249 for sub in sub_schemas {
250 let resolved_sub =
251 resolve_node(sub, defs, depth + 1, max_depth, visiting, module_id)?;
252 resolved_branches.push(resolved_sub);
253 }
254
255 let merged = merge_anyof(resolved_branches);
256 let merged_map = match merged {
257 Value::Object(m) => m,
258 _ => Map::new(),
259 };
260
261 let mut result_map = merged_map;
262 for (k, v) in &obj {
263 if k != *keyword && !result_map.contains_key(k) {
264 result_map.insert(k.clone(), v.clone());
265 }
266 }
267 return Ok(Value::Object(result_map));
268 }
269 }
270
271 let mut resolved_map = Map::with_capacity(obj.len());
273 for (k, v) in obj {
274 let resolved_v = resolve_node(v, defs, depth, max_depth, visiting, module_id)?;
275 resolved_map.insert(k, resolved_v);
276 }
277
278 Ok(Value::Object(resolved_map))
279}
280
281#[cfg(test)]
286mod tests {
287 use super::*;
288 use serde_json::json;
289
290 #[test]
291 fn test_resolve_refs_no_refs_unchanged() {
292 let schema = json!({
294 "type": "object",
295 "properties": {
296 "name": {"type": "string"}
297 }
298 });
299 let result = resolve_refs(&schema, 32, "test.module");
300 assert!(result.is_ok());
301 let resolved = result.unwrap();
302 assert_eq!(resolved["properties"]["name"]["type"], "string");
303 }
304
305 #[test]
306 fn test_resolve_refs_simple_ref() {
307 let schema = json!({
309 "$defs": {
310 "MyString": {"type": "string", "description": "A name"}
311 },
312 "type": "object",
313 "properties": {
314 "name": {"$ref": "#/$defs/MyString"}
315 }
316 });
317 let result = resolve_refs(&schema, 32, "test.module");
318 assert!(result.is_ok());
319 let resolved = result.unwrap();
320 assert_eq!(resolved["properties"]["name"]["type"], "string");
321 assert_eq!(resolved["properties"]["name"]["description"], "A name");
322 assert!(resolved.get("$defs").is_none());
324 }
325
326 #[test]
327 fn test_resolve_refs_definitions_key_also_supported() {
328 let schema = json!({
330 "definitions": {
331 "Addr": {"type": "string"}
332 },
333 "properties": {
334 "city": {"$ref": "#/definitions/Addr"}
335 }
336 });
337 let result = resolve_refs(&schema, 32, "test.module");
338 assert!(result.is_ok());
339 let resolved = result.unwrap();
340 assert_eq!(resolved["properties"]["city"]["type"], "string");
341 assert!(resolved.get("definitions").is_none());
342 }
343
344 #[test]
345 fn test_resolve_refs_unresolvable_returns_error() {
346 let schema = json!({
348 "type": "object",
349 "properties": {
350 "x": {"$ref": "#/$defs/DoesNotExist"}
351 }
352 });
353 let result = resolve_refs(&schema, 32, "test.module");
354 assert!(
355 matches!(result, Err(RefResolverError::Unresolvable { .. })),
356 "expected Unresolvable, got: {result:?}"
357 );
358 }
359
360 #[test]
361 fn test_resolve_refs_circular_returns_error() {
362 let schema = json!({
364 "$defs": {
365 "A": {"$ref": "#/$defs/B"},
366 "B": {"$ref": "#/$defs/A"}
367 },
368 "properties": {
369 "x": {"$ref": "#/$defs/A"}
370 }
371 });
372 let result = resolve_refs(&schema, 32, "test.module");
373 assert!(
374 matches!(
375 result,
376 Err(RefResolverError::Circular { .. })
377 | Err(RefResolverError::MaxDepthExceeded { .. })
378 ),
379 "expected Circular or MaxDepthExceeded, got: {result:?}"
380 );
381 }
382
383 #[test]
384 fn test_resolve_refs_max_depth_exceeded() {
385 let schema = json!({
387 "$defs": {
388 "Inner": {"type": "string"}
389 },
390 "properties": {
391 "x": {"$ref": "#/$defs/Inner"}
392 }
393 });
394 let result = resolve_refs(&schema, 0, "test.module");
395 assert!(
396 matches!(result, Err(RefResolverError::MaxDepthExceeded { .. })),
397 "expected MaxDepthExceeded, got: {result:?}"
398 );
399 }
400
401 #[test]
402 fn test_resolve_refs_nested_defs() {
403 let schema = json!({
405 "$defs": {
406 "City": {"type": "string"}
407 },
408 "properties": {
409 "address": {
410 "type": "object",
411 "properties": {
412 "city": {"$ref": "#/$defs/City"}
413 }
414 }
415 }
416 });
417 let result = resolve_refs(&schema, 32, "test.module");
418 assert!(result.is_ok());
419 let resolved = result.unwrap();
420 assert_eq!(
421 resolved["properties"]["address"]["properties"]["city"]["type"],
422 "string"
423 );
424 }
425
426 #[test]
427 fn test_resolve_refs_does_not_mutate_input() {
428 let schema = json!({
430 "$defs": {"T": {"type": "integer"}},
431 "properties": {"x": {"$ref": "#/$defs/T"}}
432 });
433 let _ = resolve_refs(&schema, 32, "test.module");
434 assert_eq!(schema["properties"]["x"]["$ref"], "#/$defs/T");
436 }
437
438 #[test]
439 fn test_resolve_refs_sibling_refs_same_def() {
440 let schema = json!({
442 "$defs": {
443 "Str": {"type": "string"}
444 },
445 "properties": {
446 "a": {"$ref": "#/$defs/Str"},
447 "b": {"$ref": "#/$defs/Str"}
448 }
449 });
450 let result = resolve_refs(&schema, 32, "test.module");
451 assert!(result.is_ok(), "sibling refs failed: {result:?}");
452 let resolved = result.unwrap();
453 assert_eq!(resolved["properties"]["a"]["type"], "string");
454 assert_eq!(resolved["properties"]["b"]["type"], "string");
455 }
456
457 #[test]
460 fn test_allof_merges_properties() {
461 let schema = json!({
462 "allOf": [
463 {
464 "properties": {"a": {"type": "string"}},
465 "required": ["a"]
466 },
467 {
468 "properties": {"b": {"type": "integer"}},
469 "required": ["b"]
470 }
471 ]
472 });
473 let result = resolve_refs(&schema, 32, "mod").unwrap();
474 assert_eq!(result["properties"]["a"]["type"], "string");
475 assert_eq!(result["properties"]["b"]["type"], "integer");
476 let required: Vec<&str> = result["required"]
477 .as_array()
478 .unwrap()
479 .iter()
480 .filter_map(|v| v.as_str())
481 .collect();
482 assert!(required.contains(&"a"));
483 assert!(required.contains(&"b"));
484 }
485
486 #[test]
487 fn test_allof_later_schema_wins_on_conflict() {
488 let schema = json!({
489 "allOf": [
490 {"properties": {"x": {"type": "string"}}},
491 {"properties": {"x": {"type": "integer"}}}
492 ]
493 });
494 let result = resolve_refs(&schema, 32, "mod").unwrap();
495 assert_eq!(result["properties"]["x"]["type"], "integer");
497 }
498
499 #[test]
500 fn test_allof_copies_non_composition_keys() {
501 let schema = json!({
502 "description": "My type",
503 "allOf": [
504 {"properties": {"a": {"type": "string"}}}
505 ]
506 });
507 let result = resolve_refs(&schema, 32, "mod").unwrap();
508 assert_eq!(result["description"], "My type");
510 }
511
512 #[test]
513 fn test_anyof_unions_properties() {
514 let schema = json!({
515 "anyOf": [
516 {"properties": {"a": {"type": "string"}}, "required": ["a"]},
517 {"properties": {"b": {"type": "integer"}}, "required": ["b"]}
518 ]
519 });
520 let result = resolve_refs(&schema, 32, "mod").unwrap();
521 assert!(result["properties"].get("a").is_some());
523 assert!(result["properties"].get("b").is_some());
524 }
525
526 #[test]
527 fn test_anyof_required_is_intersection() {
528 let schema = json!({
529 "anyOf": [
530 {"properties": {"a": {"type": "string"}, "b": {"type": "string"}}, "required": ["a", "b"]},
531 {"properties": {"a": {"type": "string"}, "c": {"type": "string"}}, "required": ["a", "c"]}
532 ]
533 });
534 let result = resolve_refs(&schema, 32, "mod").unwrap();
535 let required: Vec<&str> = result["required"]
536 .as_array()
537 .unwrap()
538 .iter()
539 .filter_map(|v| v.as_str())
540 .collect();
541 assert!(
543 required.contains(&"a"),
544 "a must be required (in both branches)"
545 );
546 assert!(
547 !required.contains(&"b"),
548 "b must not be required (only in first branch)"
549 );
550 assert!(
551 !required.contains(&"c"),
552 "c must not be required (only in second branch)"
553 );
554 }
555
556 #[test]
557 fn test_anyof_empty_required_when_no_overlap() {
558 let schema = json!({
559 "anyOf": [
560 {"properties": {"a": {"type": "string"}}, "required": ["a"]},
561 {"properties": {"b": {"type": "integer"}}, "required": ["b"]}
562 ]
563 });
564 let result = resolve_refs(&schema, 32, "mod").unwrap();
565 let required = result["required"].as_array().unwrap();
566 assert!(
567 required.is_empty(),
568 "no fields are required in both branches"
569 );
570 }
571
572 #[test]
573 fn test_oneof_behaves_like_anyof() {
574 let schema = json!({
575 "oneOf": [
576 {"properties": {"x": {"type": "string"}}, "required": ["x"]},
577 {"properties": {"y": {"type": "integer"}}, "required": ["y"]}
578 ]
579 });
580 let result = resolve_refs(&schema, 32, "mod").unwrap();
581 assert!(result["properties"].get("x").is_some());
582 assert!(result["properties"].get("y").is_some());
583 assert!(result["required"].as_array().unwrap().is_empty());
584 }
585
586 #[test]
587 fn test_allof_with_nested_ref() {
588 let schema = json!({
590 "$defs": {
591 "Base": {"properties": {"id": {"type": "integer"}}, "required": ["id"]}
592 },
593 "allOf": [
594 {"$ref": "#/$defs/Base"},
595 {"properties": {"name": {"type": "string"}}}
596 ]
597 });
598 let result = resolve_refs(&schema, 32, "mod").unwrap();
599 assert_eq!(result["properties"]["id"]["type"], "integer");
600 assert_eq!(result["properties"]["name"]["type"], "string");
601 let required: Vec<&str> = result["required"]
602 .as_array()
603 .unwrap()
604 .iter()
605 .filter_map(|v| v.as_str())
606 .collect();
607 assert!(required.contains(&"id"));
608 }
609}