1use indexmap::IndexMap;
13use serde_json::Value;
14use std::sync::Arc;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum Shape {
18 Null,
19 Bool,
20 Int,
21 Float,
22 Str,
23 Array(Box<Shape>),
25 Object(IndexMap<Arc<str>, Shape>),
27 Union(Vec<Shape>),
29 Unknown,
30}
31
32impl Shape {
33 pub fn of(v: &Value) -> Shape {
35 match v {
36 Value::Null => Shape::Null,
37 Value::Bool(_) => Shape::Bool,
38 Value::Number(n) => if n.is_f64() { Shape::Float } else { Shape::Int },
39 Value::String(_) => Shape::Str,
40 Value::Array(a) => {
41 let mut elem = Shape::Unknown;
42 let mut first = true;
43 for item in a {
44 let s = Shape::of(item);
45 elem = if first { first = false; s } else { elem.merge(s) };
46 }
47 Shape::Array(Box::new(elem))
48 }
49 Value::Object(o) => {
50 let mut map = IndexMap::new();
51 for (k, v) in o {
52 map.insert(Arc::from(k.as_str()), Shape::of(v));
53 }
54 Shape::Object(map)
55 }
56 }
57 }
58
59 pub fn merge(self, other: Shape) -> Shape {
61 match (self, other) {
62 (a, b) if a == b => a,
63 (Shape::Unknown, x) | (x, Shape::Unknown) => x,
64 (Shape::Array(a), Shape::Array(b)) => Shape::Array(Box::new(a.merge(*b))),
65 (Shape::Object(mut a), Shape::Object(b)) => {
66 for (k, v) in b {
67 if let Some(existing) = a.shift_remove(&k) {
68 a.insert(k, existing.merge(v));
69 } else {
70 a.insert(k, v);
71 }
72 }
73 Shape::Object(a)
74 }
75 (Shape::Int, Shape::Float) | (Shape::Float, Shape::Int) => Shape::Float,
76 (Shape::Union(mut xs), y) | (y, Shape::Union(mut xs)) => {
77 if !xs.contains(&y) { xs.push(y); }
78 Shape::Union(xs)
79 }
80 (a, b) => Shape::Union(vec![a, b]),
81 }
82 }
83
84 pub fn has_field(&self, name: &str) -> bool {
86 match self {
87 Shape::Object(m) => m.contains_key(name),
88 Shape::Union(xs) => xs.iter().all(|s| s.has_field(name)),
89 _ => false,
90 }
91 }
92
93 pub fn field(&self, name: &str) -> Option<&Shape> {
95 match self {
96 Shape::Object(m) => m.get(name),
97 _ => None,
98 }
99 }
100
101 pub fn element(&self) -> Option<&Shape> {
103 match self {
104 Shape::Array(b) => Some(b),
105 _ => None,
106 }
107 }
108}
109
110use super::vm::{Program, Opcode};
113
114pub fn specialize(program: &Program, shape: &Shape) -> Program {
121 let new_ops: Vec<Opcode> = specialize_ops(&program.ops, shape);
122 Program {
123 ops: new_ops.into(),
124 source: program.source.clone(),
125 id: program.id,
126 is_structural: program.is_structural,
127 }
128}
129
130fn specialize_ops(ops: &[Opcode], shape: &Shape) -> Vec<Opcode> {
131 let mut out = Vec::with_capacity(ops.len());
132 let mut cur: Shape = shape.clone();
133 for op in ops {
134 match op {
135 Opcode::PushRoot => { cur = shape.clone(); out.push(op.clone()); }
136 Opcode::OptField(k) => {
137 if cur.has_field(k) {
138 out.push(Opcode::GetField(k.clone()));
139 } else {
140 out.push(op.clone());
141 }
142 cur = cur.field(k).cloned().unwrap_or(Shape::Unknown);
143 }
144 Opcode::GetField(k) => {
145 cur = cur.field(k).cloned().unwrap_or(Shape::Unknown);
146 out.push(op.clone());
147 }
148 Opcode::RootChain(ks) => {
149 let mut c = shape.clone();
150 for k in ks.iter() {
151 c = c.field(k).cloned().unwrap_or(Shape::Unknown);
152 }
153 cur = c;
154 out.push(op.clone());
155 }
156 Opcode::GetIndex(_) | Opcode::GetSlice(..) => {
157 cur = cur.element().cloned().unwrap_or(Shape::Unknown);
158 out.push(op.clone());
159 }
160 _ => { cur = Shape::Unknown; out.push(op.clone()); }
161 }
162 }
163 out
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serde_json::json;
170
171 #[test]
172 fn infer_scalar() {
173 assert_eq!(Shape::of(&json!(1)), Shape::Int);
174 assert_eq!(Shape::of(&json!(1.5)), Shape::Float);
175 assert_eq!(Shape::of(&json!("x")), Shape::Str);
176 assert_eq!(Shape::of(&json!(null)), Shape::Null);
177 }
178
179 #[test]
180 fn infer_object() {
181 let s = Shape::of(&json!({"a": 1, "b": "x"}));
182 assert!(s.has_field("a"));
183 assert!(s.has_field("b"));
184 assert!(!s.has_field("c"));
185 assert_eq!(s.field("a"), Some(&Shape::Int));
186 }
187
188 #[test]
189 fn infer_homogeneous_array() {
190 let s = Shape::of(&json!([1, 2, 3]));
191 assert_eq!(s.element(), Some(&Shape::Int));
192 }
193
194 #[test]
195 fn infer_heterogeneous_array() {
196 let s = Shape::of(&json!([1, "two", 3]));
197 match s.element().unwrap() {
198 Shape::Union(xs) => {
199 assert!(xs.contains(&Shape::Int));
200 assert!(xs.contains(&Shape::Str));
201 }
202 other => panic!("expected union, got {:?}", other),
203 }
204 }
205
206 #[test]
207 fn merge_int_float_to_float() {
208 let a = Shape::Int;
209 let b = Shape::Float;
210 assert_eq!(a.merge(b), Shape::Float);
211 }
212
213 #[test]
214 fn specialize_opt_field_to_get_field() {
215 use crate::vm::{Compiler, Opcode};
216 let prog = Compiler::compile_str("$.a?.b").unwrap();
217 let shape = Shape::of(&json!({"a": {"b": 1}}));
218 let spec = specialize(&prog, &shape);
219 let has_opt = spec.ops.iter().any(|o| matches!(o, Opcode::OptField(_)));
220 assert!(!has_opt, "OptField should specialize to GetField");
221 }
222
223 #[test]
224 fn specialize_preserves_opt_when_missing() {
225 use crate::vm::{Compiler, Opcode};
226 let prog = Compiler::compile_str("$.a?.missing").unwrap();
227 let shape = Shape::of(&json!({"a": {"b": 1}}));
228 let spec = specialize(&prog, &shape);
229 let has_opt = spec.ops.iter().any(|o| matches!(o, Opcode::OptField(k) if k.as_ref() == "missing"));
230 assert!(has_opt, "OptField for absent field should remain");
231 }
232}