1use indexmap::IndexMap;
13use serde_json::Value;
14use std::sync::Arc;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum Shape {
18 Null,
19 Bool,
20 Int,
21 Float,
22 Str,
23 Array(Box<Shape>),
25 Object(IndexMap<Arc<str>, Shape>),
27 Union(Vec<Shape>),
29 Unknown,
30}
31
32impl Shape {
33 pub fn of(v: &Value) -> Shape {
35 match v {
36 Value::Null => Shape::Null,
37 Value::Bool(_) => Shape::Bool,
38 Value::Number(n) => if n.is_f64() { Shape::Float } else { Shape::Int },
39 Value::String(_) => Shape::Str,
40 Value::Array(a) => {
41 let mut elem = Shape::Unknown;
42 let mut first = true;
43 for item in a {
44 let s = Shape::of(item);
45 elem = if first { first = false; s } else { elem.merge(s) };
46 }
47 Shape::Array(Box::new(elem))
48 }
49 Value::Object(o) => {
50 let mut map = IndexMap::new();
51 for (k, v) in o {
52 map.insert(Arc::from(k.as_str()), Shape::of(v));
53 }
54 Shape::Object(map)
55 }
56 }
57 }
58
59 pub fn merge(self, other: Shape) -> Shape {
61 match (self, other) {
62 (a, b) if a == b => a,
63 (Shape::Unknown, x) | (x, Shape::Unknown) => x,
64 (Shape::Array(a), Shape::Array(b)) => Shape::Array(Box::new(a.merge(*b))),
65 (Shape::Object(mut a), Shape::Object(b)) => {
66 for (k, v) in b {
67 if let Some(existing) = a.shift_remove(&k) {
68 a.insert(k, existing.merge(v));
69 } else {
70 a.insert(k, v);
71 }
72 }
73 Shape::Object(a)
74 }
75 (Shape::Int, Shape::Float) | (Shape::Float, Shape::Int) => Shape::Float,
76 (Shape::Union(mut xs), y) | (y, Shape::Union(mut xs)) => {
77 if !xs.contains(&y) { xs.push(y); }
78 Shape::Union(xs)
79 }
80 (a, b) => Shape::Union(vec![a, b]),
81 }
82 }
83
84 pub fn has_field(&self, name: &str) -> bool {
86 match self {
87 Shape::Object(m) => m.contains_key(name),
88 Shape::Union(xs) => xs.iter().all(|s| s.has_field(name)),
89 _ => false,
90 }
91 }
92
93 pub fn field(&self, name: &str) -> Option<&Shape> {
95 match self {
96 Shape::Object(m) => m.get(name),
97 _ => None,
98 }
99 }
100
101 pub fn element(&self) -> Option<&Shape> {
103 match self {
104 Shape::Array(b) => Some(b),
105 _ => None,
106 }
107 }
108}
109
110use super::vm::{Program, Opcode};
113
114pub fn specialize(program: &Program, shape: &Shape) -> Program {
121 let new_ops: Vec<Opcode> = specialize_ops(&program.ops, shape);
122 let ics = crate::vm::fresh_ics(new_ops.len());
123 Program {
124 ops: new_ops.into(),
125 source: program.source.clone(),
126 id: program.id,
127 is_structural: program.is_structural,
128 ics,
129 }
130}
131
132fn specialize_ops(ops: &[Opcode], shape: &Shape) -> Vec<Opcode> {
133 let mut out = Vec::with_capacity(ops.len());
134 let mut cur: Shape = shape.clone();
135 for op in ops {
136 match op {
137 Opcode::PushRoot => { cur = shape.clone(); out.push(op.clone()); }
138 Opcode::OptField(k) => {
139 if cur.has_field(k) {
140 out.push(Opcode::GetField(k.clone()));
141 } else {
142 out.push(op.clone());
143 }
144 cur = cur.field(k).cloned().unwrap_or(Shape::Unknown);
145 }
146 Opcode::GetField(k) => {
147 cur = cur.field(k).cloned().unwrap_or(Shape::Unknown);
148 out.push(op.clone());
149 }
150 Opcode::RootChain(ks) => {
151 let mut c = shape.clone();
152 for k in ks.iter() {
153 c = c.field(k).cloned().unwrap_or(Shape::Unknown);
154 }
155 cur = c;
156 out.push(op.clone());
157 }
158 Opcode::GetIndex(_) | Opcode::GetSlice(..) => {
159 cur = cur.element().cloned().unwrap_or(Shape::Unknown);
160 out.push(op.clone());
161 }
162 _ => { cur = Shape::Unknown; out.push(op.clone()); }
163 }
164 }
165 out
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use serde_json::json;
172
173 #[test]
174 fn infer_scalar() {
175 assert_eq!(Shape::of(&json!(1)), Shape::Int);
176 assert_eq!(Shape::of(&json!(1.5)), Shape::Float);
177 assert_eq!(Shape::of(&json!("x")), Shape::Str);
178 assert_eq!(Shape::of(&json!(null)), Shape::Null);
179 }
180
181 #[test]
182 fn infer_object() {
183 let s = Shape::of(&json!({"a": 1, "b": "x"}));
184 assert!(s.has_field("a"));
185 assert!(s.has_field("b"));
186 assert!(!s.has_field("c"));
187 assert_eq!(s.field("a"), Some(&Shape::Int));
188 }
189
190 #[test]
191 fn infer_homogeneous_array() {
192 let s = Shape::of(&json!([1, 2, 3]));
193 assert_eq!(s.element(), Some(&Shape::Int));
194 }
195
196 #[test]
197 fn infer_heterogeneous_array() {
198 let s = Shape::of(&json!([1, "two", 3]));
199 match s.element().unwrap() {
200 Shape::Union(xs) => {
201 assert!(xs.contains(&Shape::Int));
202 assert!(xs.contains(&Shape::Str));
203 }
204 other => panic!("expected union, got {:?}", other),
205 }
206 }
207
208 #[test]
209 fn merge_int_float_to_float() {
210 let a = Shape::Int;
211 let b = Shape::Float;
212 assert_eq!(a.merge(b), Shape::Float);
213 }
214
215 #[test]
216 fn specialize_opt_field_to_get_field() {
217 use crate::vm::{Compiler, Opcode};
218 let prog = Compiler::compile_str("$.a?.b").unwrap();
221 let shape = Shape::of(&json!({"a": {"b": 1}}));
222 let spec = specialize(&prog, &shape);
223 let has_opt = spec.ops.iter().any(|o| matches!(o, Opcode::OptField(_)));
224 assert!(!has_opt, "OptField should specialize to GetField");
225 }
226
227 #[test]
228 fn specialize_preserves_opt_when_missing() {
229 use crate::vm::{Compiler, Opcode};
230 let prog = Compiler::compile_str("$.a.missing?").unwrap();
234 let shape = Shape::of(&json!({"a": {"b": 1}}));
235 let spec = specialize(&prog, &shape);
236 let has_opt = spec.ops.iter().any(|o| matches!(o, Opcode::OptField(k) if k.as_ref() == "missing"));
237 assert!(has_opt, "OptField for absent field should remain");
238 }
239}