Skip to main content

apcore_cli/
ref_resolver.rs

1// apcore-cli — JSON Schema $ref inliner.
2// Protocol spec: FE-08 (resolve_refs)
3
4use serde_json::{Map, Value};
5use std::collections::HashSet;
6use thiserror::Error;
7
8/// Maximum recursion depth for $ref resolution.
9pub const MAX_REF_DEPTH: usize = 32;
10
11// ---------------------------------------------------------------------------
12// Error type
13// ---------------------------------------------------------------------------
14
15/// Errors produced during `$ref` resolution.
16#[derive(Debug, Error)]
17pub enum RefResolverError {
18    /// A `$ref` target could not be found in the schema's `$defs`.
19    #[error("unresolvable $ref '{reference}' in module '{module_id}' (exit 45)")]
20    Unresolvable {
21        reference: String,
22        module_id: String,
23    },
24
25    /// A circular reference chain was detected (exit 48).
26    #[error("circular $ref detected in module '{module_id}' (exit 48)")]
27    Circular { module_id: String },
28
29    /// The maximum recursion depth was exceeded.
30    #[error("$ref resolution exceeded max depth {max_depth} in module '{module_id}'")]
31    MaxDepthExceeded { max_depth: usize, module_id: String },
32}
33
34// ---------------------------------------------------------------------------
35// resolve_refs
36// ---------------------------------------------------------------------------
37
38/// Inline all `$ref` pointers in a JSON Schema value.
39///
40/// Resolves `$ref` values by looking them up in `schema["$defs"]` and
41/// substituting the referenced schema in-place. Handles nested schemas
42/// recursively up to `max_depth`.
43///
44/// # Arguments
45/// * `schema`    — JSON Schema value (deep-copy is used internally)
46/// * `max_depth` — maximum recursion depth before raising `MaxDepthExceeded`
47/// * `module_id` — module identifier for error messages
48///
49/// # Errors
50/// * `RefResolverError::Unresolvable` — unknown `$ref` target (exit 45)
51/// * `RefResolverError::Circular`     — circular reference (exit 48)
52/// * `RefResolverError::MaxDepthExceeded` — depth limit reached
53pub fn resolve_refs(
54    schema: &Value,
55    max_depth: usize,
56    module_id: &str,
57) -> Result<Value, RefResolverError> {
58    // Deep-copy; do not modify the caller's value.
59    let copy = schema.clone();
60
61    // Extract $defs / definitions ($defs takes precedence).
62    let defs: Map<String, Value> = copy
63        .get("$defs")
64        .or_else(|| copy.get("definitions"))
65        .and_then(|v| v.as_object())
66        .cloned()
67        .unwrap_or_default();
68
69    let mut visiting: HashSet<String> = HashSet::new();
70    let resolved = resolve_node(copy, &defs, 0, max_depth, &mut visiting, module_id)?;
71
72    // Strip definition keys from result.
73    let mut result = resolved;
74    if let Some(obj) = result.as_object_mut() {
75        obj.remove("$defs");
76        obj.remove("definitions");
77    }
78    Ok(result)
79}
80
81// ---------------------------------------------------------------------------
82// Composition helpers
83// ---------------------------------------------------------------------------
84
85/// Merge all branches for allOf: union properties (later wins on conflict),
86/// concatenate required arrays.
87fn merge_allof(branches: Vec<Value>) -> Value {
88    let mut merged_props = Map::new();
89    let mut merged_required: Vec<Value> = Vec::new();
90
91    for branch in branches {
92        if let Some(props) = branch.get("properties").and_then(|v| v.as_object()) {
93            for (k, v) in props {
94                merged_props.insert(k.clone(), v.clone());
95            }
96        }
97        if let Some(req) = branch.get("required").and_then(|v| v.as_array()) {
98            merged_required.extend(req.iter().cloned());
99        }
100    }
101
102    let mut result = Map::new();
103    result.insert("properties".to_string(), Value::Object(merged_props));
104    result.insert("required".to_string(), Value::Array(merged_required));
105    Value::Object(result)
106}
107
108/// Compute the intersection of required field sets across branches.
109fn intersect_required_sets(sets: Vec<HashSet<String>>) -> Vec<Value> {
110    if sets.is_empty() {
111        return Vec::new();
112    }
113    let mut iter = sets.into_iter();
114    let first = iter.next().unwrap();
115    iter.fold(first, |acc, set| acc.intersection(&set).cloned().collect())
116        .into_iter()
117        .map(Value::String)
118        .collect()
119}
120
121/// Merge all branches for anyOf/oneOf: union properties, required = intersection.
122fn merge_anyof(branches: Vec<Value>) -> Value {
123    let mut merged_props = Map::new();
124    let mut all_required_sets: Vec<HashSet<String>> = Vec::new();
125
126    for branch in branches {
127        if let Some(props) = branch.get("properties").and_then(|v| v.as_object()) {
128            for (k, v) in props {
129                merged_props.insert(k.clone(), v.clone());
130            }
131        }
132        let set: HashSet<String> = branch
133            .get("required")
134            .and_then(|v| v.as_array())
135            .map(|arr| {
136                arr.iter()
137                    .filter_map(|v| v.as_str().map(str::to_string))
138                    .collect()
139            })
140            .unwrap_or_default();
141        all_required_sets.push(set);
142    }
143
144    let intersection = intersect_required_sets(all_required_sets);
145
146    let mut result = Map::new();
147    result.insert("properties".to_string(), Value::Object(merged_props));
148    result.insert("required".to_string(), Value::Array(intersection));
149    Value::Object(result)
150}
151
152// ---------------------------------------------------------------------------
153// resolve_node (private helper)
154// ---------------------------------------------------------------------------
155
156fn resolve_node(
157    node: Value,
158    defs: &Map<String, Value>,
159    depth: usize,
160    max_depth: usize,
161    visiting: &mut HashSet<String>,
162    module_id: &str,
163) -> Result<Value, RefResolverError> {
164    let obj = match node {
165        Value::Object(map) => map,
166        other => return Ok(other),
167    };
168
169    // Handle $ref substitution.
170    if let Some(ref_val) = obj.get("$ref") {
171        let ref_path = ref_val.as_str().unwrap_or("").to_string();
172
173        if depth >= max_depth {
174            return Err(RefResolverError::MaxDepthExceeded {
175                max_depth,
176                module_id: module_id.to_string(),
177            });
178        }
179
180        if visiting.contains(&ref_path) {
181            return Err(RefResolverError::Circular {
182                module_id: module_id.to_string(),
183            });
184        }
185
186        // Extract key: "#/$defs/Address" → "Address"
187        let key = ref_path.split('/').next_back().unwrap_or("").to_string();
188
189        let def = defs
190            .get(&key)
191            .cloned()
192            .ok_or_else(|| RefResolverError::Unresolvable {
193                reference: ref_path.clone(),
194                module_id: module_id.to_string(),
195            })?;
196
197        visiting.insert(ref_path.clone());
198        let result = resolve_node(def, defs, depth + 1, max_depth, visiting, module_id)?;
199        // Keep ref_path in visiting for the duration of this chain to detect cycles.
200        // It remains in visiting intentionally — siblings go through a fresh chain
201        // because we only remove entries when unwinding past the insertion point.
202        // However, for sibling $refs (two different properties referencing the same def),
203        // we must remove the entry after resolving so they don't block each other.
204        visiting.remove(&ref_path);
205        return Ok(result);
206    }
207
208    // Handle allOf: merge properties (later wins), concatenate required.
209    if obj.contains_key("allOf") {
210        let sub_schemas = obj
211            .get("allOf")
212            .and_then(|v| v.as_array())
213            .cloned()
214            .unwrap_or_default();
215
216        // Resolve each branch first (handles nested $refs).
217        let mut resolved_branches = Vec::with_capacity(sub_schemas.len());
218        for sub in sub_schemas {
219            let resolved_sub = resolve_node(sub, defs, depth + 1, max_depth, visiting, module_id)?;
220            resolved_branches.push(resolved_sub);
221        }
222
223        let merged = merge_allof(resolved_branches);
224        let merged_map = match merged {
225            Value::Object(m) => m,
226            _ => Map::new(),
227        };
228
229        // Carry over non-composition keys from the parent node.
230        let mut result_map = merged_map;
231        for (k, v) in &obj {
232            if k != "allOf" && !result_map.contains_key(k) {
233                result_map.insert(k.clone(), v.clone());
234            }
235        }
236        return Ok(Value::Object(result_map));
237    }
238
239    // Handle anyOf / oneOf (same merge logic, intersection of required).
240    for keyword in &["anyOf", "oneOf"] {
241        if obj.contains_key(*keyword) {
242            let sub_schemas = obj
243                .get(*keyword)
244                .and_then(|v| v.as_array())
245                .cloned()
246                .unwrap_or_default();
247
248            let mut resolved_branches = Vec::with_capacity(sub_schemas.len());
249            for sub in sub_schemas {
250                let resolved_sub =
251                    resolve_node(sub, defs, depth + 1, max_depth, visiting, module_id)?;
252                resolved_branches.push(resolved_sub);
253            }
254
255            let merged = merge_anyof(resolved_branches);
256            let merged_map = match merged {
257                Value::Object(m) => m,
258                _ => Map::new(),
259            };
260
261            let mut result_map = merged_map;
262            for (k, v) in &obj {
263                if k != *keyword && !result_map.contains_key(k) {
264                    result_map.insert(k.clone(), v.clone());
265                }
266            }
267            return Ok(Value::Object(result_map));
268        }
269    }
270
271    // Recursively resolve all values in the object map.
272    let mut resolved_map = Map::with_capacity(obj.len());
273    for (k, v) in obj {
274        let resolved_v = resolve_node(v, defs, depth, max_depth, visiting, module_id)?;
275        resolved_map.insert(k, resolved_v);
276    }
277
278    Ok(Value::Object(resolved_map))
279}
280
281// ---------------------------------------------------------------------------
282// Unit tests
283// ---------------------------------------------------------------------------
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use serde_json::json;
289
290    #[test]
291    fn test_resolve_refs_no_refs_unchanged() {
292        // A schema without any $ref must be returned unchanged.
293        let schema = json!({
294            "type": "object",
295            "properties": {
296                "name": {"type": "string"}
297            }
298        });
299        let result = resolve_refs(&schema, 32, "test.module");
300        assert!(result.is_ok());
301        let resolved = result.unwrap();
302        assert_eq!(resolved["properties"]["name"]["type"], "string");
303    }
304
305    #[test]
306    fn test_resolve_refs_simple_ref() {
307        // A single $ref must be inlined from $defs.
308        let schema = json!({
309            "$defs": {
310                "MyString": {"type": "string", "description": "A name"}
311            },
312            "type": "object",
313            "properties": {
314                "name": {"$ref": "#/$defs/MyString"}
315            }
316        });
317        let result = resolve_refs(&schema, 32, "test.module");
318        assert!(result.is_ok());
319        let resolved = result.unwrap();
320        assert_eq!(resolved["properties"]["name"]["type"], "string");
321        assert_eq!(resolved["properties"]["name"]["description"], "A name");
322        // $defs must be stripped from result.
323        assert!(resolved.get("$defs").is_none());
324    }
325
326    #[test]
327    fn test_resolve_refs_definitions_key_also_supported() {
328        // Some schemas use "definitions" instead of "$defs".
329        let schema = json!({
330            "definitions": {
331                "Addr": {"type": "string"}
332            },
333            "properties": {
334                "city": {"$ref": "#/definitions/Addr"}
335            }
336        });
337        let result = resolve_refs(&schema, 32, "test.module");
338        assert!(result.is_ok());
339        let resolved = result.unwrap();
340        assert_eq!(resolved["properties"]["city"]["type"], "string");
341        assert!(resolved.get("definitions").is_none());
342    }
343
344    #[test]
345    fn test_resolve_refs_unresolvable_returns_error() {
346        // An unknown $ref must yield RefResolverError::Unresolvable.
347        let schema = json!({
348            "type": "object",
349            "properties": {
350                "x": {"$ref": "#/$defs/DoesNotExist"}
351            }
352        });
353        let result = resolve_refs(&schema, 32, "test.module");
354        assert!(
355            matches!(result, Err(RefResolverError::Unresolvable { .. })),
356            "expected Unresolvable, got: {result:?}"
357        );
358    }
359
360    #[test]
361    fn test_resolve_refs_circular_returns_error() {
362        // A circular $ref chain must yield RefResolverError::Circular or MaxDepthExceeded.
363        let schema = json!({
364            "$defs": {
365                "A": {"$ref": "#/$defs/B"},
366                "B": {"$ref": "#/$defs/A"}
367            },
368            "properties": {
369                "x": {"$ref": "#/$defs/A"}
370            }
371        });
372        let result = resolve_refs(&schema, 32, "test.module");
373        assert!(
374            matches!(
375                result,
376                Err(RefResolverError::Circular { .. })
377                    | Err(RefResolverError::MaxDepthExceeded { .. })
378            ),
379            "expected Circular or MaxDepthExceeded, got: {result:?}"
380        );
381    }
382
383    #[test]
384    fn test_resolve_refs_max_depth_exceeded() {
385        // max_depth=0 means the first $ref hit immediately fails.
386        let schema = json!({
387            "$defs": {
388                "Inner": {"type": "string"}
389            },
390            "properties": {
391                "x": {"$ref": "#/$defs/Inner"}
392            }
393        });
394        let result = resolve_refs(&schema, 0, "test.module");
395        assert!(
396            matches!(result, Err(RefResolverError::MaxDepthExceeded { .. })),
397            "expected MaxDepthExceeded, got: {result:?}"
398        );
399    }
400
401    #[test]
402    fn test_resolve_refs_nested_defs() {
403        // $refs inside nested object properties must all be resolved.
404        let schema = json!({
405            "$defs": {
406                "City": {"type": "string"}
407            },
408            "properties": {
409                "address": {
410                    "type": "object",
411                    "properties": {
412                        "city": {"$ref": "#/$defs/City"}
413                    }
414                }
415            }
416        });
417        let result = resolve_refs(&schema, 32, "test.module");
418        assert!(result.is_ok());
419        let resolved = result.unwrap();
420        assert_eq!(
421            resolved["properties"]["address"]["properties"]["city"]["type"],
422            "string"
423        );
424    }
425
426    #[test]
427    fn test_resolve_refs_does_not_mutate_input() {
428        // The original schema must not be modified.
429        let schema = json!({
430            "$defs": {"T": {"type": "integer"}},
431            "properties": {"x": {"$ref": "#/$defs/T"}}
432        });
433        let _ = resolve_refs(&schema, 32, "test.module");
434        // Input schema still has $ref (not mutated).
435        assert_eq!(schema["properties"]["x"]["$ref"], "#/$defs/T");
436    }
437
438    #[test]
439    fn test_resolve_refs_sibling_refs_same_def() {
440        // Two different properties referencing the same $def must both resolve correctly.
441        let schema = json!({
442            "$defs": {
443                "Str": {"type": "string"}
444            },
445            "properties": {
446                "a": {"$ref": "#/$defs/Str"},
447                "b": {"$ref": "#/$defs/Str"}
448            }
449        });
450        let result = resolve_refs(&schema, 32, "test.module");
451        assert!(result.is_ok(), "sibling refs failed: {result:?}");
452        let resolved = result.unwrap();
453        assert_eq!(resolved["properties"]["a"]["type"], "string");
454        assert_eq!(resolved["properties"]["b"]["type"], "string");
455    }
456
457    // --- Schema composition tests ---
458
459    #[test]
460    fn test_allof_merges_properties() {
461        let schema = json!({
462            "allOf": [
463                {
464                    "properties": {"a": {"type": "string"}},
465                    "required": ["a"]
466                },
467                {
468                    "properties": {"b": {"type": "integer"}},
469                    "required": ["b"]
470                }
471            ]
472        });
473        let result = resolve_refs(&schema, 32, "mod").unwrap();
474        assert_eq!(result["properties"]["a"]["type"], "string");
475        assert_eq!(result["properties"]["b"]["type"], "integer");
476        let required: Vec<&str> = result["required"]
477            .as_array()
478            .unwrap()
479            .iter()
480            .filter_map(|v| v.as_str())
481            .collect();
482        assert!(required.contains(&"a"));
483        assert!(required.contains(&"b"));
484    }
485
486    #[test]
487    fn test_allof_later_schema_wins_on_conflict() {
488        let schema = json!({
489            "allOf": [
490                {"properties": {"x": {"type": "string"}}},
491                {"properties": {"x": {"type": "integer"}}}
492            ]
493        });
494        let result = resolve_refs(&schema, 32, "mod").unwrap();
495        // Later sub-schema wins: x must be integer.
496        assert_eq!(result["properties"]["x"]["type"], "integer");
497    }
498
499    #[test]
500    fn test_allof_copies_non_composition_keys() {
501        let schema = json!({
502            "description": "My type",
503            "allOf": [
504                {"properties": {"a": {"type": "string"}}}
505            ]
506        });
507        let result = resolve_refs(&schema, 32, "mod").unwrap();
508        // "description" must survive in the merged result.
509        assert_eq!(result["description"], "My type");
510    }
511
512    #[test]
513    fn test_anyof_unions_properties() {
514        let schema = json!({
515            "anyOf": [
516                {"properties": {"a": {"type": "string"}}, "required": ["a"]},
517                {"properties": {"b": {"type": "integer"}}, "required": ["b"]}
518            ]
519        });
520        let result = resolve_refs(&schema, 32, "mod").unwrap();
521        // Both properties must appear.
522        assert!(result["properties"].get("a").is_some());
523        assert!(result["properties"].get("b").is_some());
524    }
525
526    #[test]
527    fn test_anyof_required_is_intersection() {
528        let schema = json!({
529            "anyOf": [
530                {"properties": {"a": {"type": "string"}, "b": {"type": "string"}}, "required": ["a", "b"]},
531                {"properties": {"a": {"type": "string"}, "c": {"type": "string"}}, "required": ["a", "c"]}
532            ]
533        });
534        let result = resolve_refs(&schema, 32, "mod").unwrap();
535        let required: Vec<&str> = result["required"]
536            .as_array()
537            .unwrap()
538            .iter()
539            .filter_map(|v| v.as_str())
540            .collect();
541        // Only "a" appears in both branches — it is the intersection.
542        assert!(
543            required.contains(&"a"),
544            "a must be required (in both branches)"
545        );
546        assert!(
547            !required.contains(&"b"),
548            "b must not be required (only in first branch)"
549        );
550        assert!(
551            !required.contains(&"c"),
552            "c must not be required (only in second branch)"
553        );
554    }
555
556    #[test]
557    fn test_anyof_empty_required_when_no_overlap() {
558        let schema = json!({
559            "anyOf": [
560                {"properties": {"a": {"type": "string"}}, "required": ["a"]},
561                {"properties": {"b": {"type": "integer"}}, "required": ["b"]}
562            ]
563        });
564        let result = resolve_refs(&schema, 32, "mod").unwrap();
565        let required = result["required"].as_array().unwrap();
566        assert!(
567            required.is_empty(),
568            "no fields are required in both branches"
569        );
570    }
571
572    #[test]
573    fn test_oneof_behaves_like_anyof() {
574        let schema = json!({
575            "oneOf": [
576                {"properties": {"x": {"type": "string"}}, "required": ["x"]},
577                {"properties": {"y": {"type": "integer"}}, "required": ["y"]}
578            ]
579        });
580        let result = resolve_refs(&schema, 32, "mod").unwrap();
581        assert!(result["properties"].get("x").is_some());
582        assert!(result["properties"].get("y").is_some());
583        assert!(result["required"].as_array().unwrap().is_empty());
584    }
585
586    #[test]
587    fn test_allof_with_nested_ref() {
588        // allOf sub-schema that itself contains a $ref.
589        let schema = json!({
590            "$defs": {
591                "Base": {"properties": {"id": {"type": "integer"}}, "required": ["id"]}
592            },
593            "allOf": [
594                {"$ref": "#/$defs/Base"},
595                {"properties": {"name": {"type": "string"}}}
596            ]
597        });
598        let result = resolve_refs(&schema, 32, "mod").unwrap();
599        assert_eq!(result["properties"]["id"]["type"], "integer");
600        assert_eq!(result["properties"]["name"]["type"], "string");
601        let required: Vec<&str> = result["required"]
602            .as_array()
603            .unwrap()
604            .iter()
605            .filter_map(|v| v.as_str())
606            .collect();
607        assert!(required.contains(&"id"));
608    }
609}