1use std::sync::Arc;
16use std::collections::HashMap;
17use super::vm::{Program, Opcode};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ValueId(pub u32);
21
22#[derive(Debug, Clone)]
23pub struct SsaInstr {
24 pub id: ValueId,
25 pub op: Opcode,
26 pub uses: Vec<ValueId>,
28}
29
30#[derive(Debug, Clone)]
33pub struct Phi {
34 pub id: ValueId,
35 pub incoming: Vec<(PhiEdge, ValueId)>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub enum PhiEdge { AndLhs, AndRhs, OrLhs, OrRhs, CoalesceLhs, CoalesceRhs }
41
42#[derive(Debug, Clone, Default)]
43pub struct SsaGraph {
44 pub instrs: Vec<SsaInstr>,
45 pub phis: Vec<Phi>,
46 pub result: Option<ValueId>,
48}
49
50impl SsaGraph {
51 pub fn build(program: &Program) -> SsaGraph {
52 let mut g = SsaGraph::default();
53 let mut stack: Vec<ValueId> = Vec::new();
54 for op in program.ops.iter() {
55 let arity = op_arity(op);
56 let mut uses = Vec::with_capacity(arity.pops);
57 for _ in 0..arity.pops {
58 if let Some(v) = stack.pop() { uses.push(v); }
59 }
60 let id = ValueId(g.instrs.len() as u32);
61 let phi_edge = match op {
65 Opcode::AndOp(_) => Some((PhiEdge::AndLhs, PhiEdge::AndRhs)),
66 Opcode::OrOp(_) => Some((PhiEdge::OrLhs, PhiEdge::OrRhs)),
67 Opcode::CoalesceOp(_) => Some((PhiEdge::CoalesceLhs, PhiEdge::CoalesceRhs)),
68 _ => None,
69 };
70 if let Some((lhs_edge, rhs_edge)) = phi_edge {
71 let lhs = uses.first().copied().unwrap_or(id);
74 g.phis.push(Phi {
75 id,
76 incoming: vec![(lhs_edge, lhs), (rhs_edge, id)],
77 });
78 }
79 g.instrs.push(SsaInstr { id, op: op.clone(), uses });
80 if arity.pushes { stack.push(id); }
81 }
82 g.result = stack.pop();
83 g
84 }
85
86 pub fn use_list(&self) -> HashMap<ValueId, Vec<ValueId>> {
88 let mut m: HashMap<ValueId, Vec<ValueId>> = HashMap::new();
89 for instr in &self.instrs {
90 for u in &instr.uses {
91 m.entry(*u).or_default().push(instr.id);
92 }
93 }
94 m
95 }
96
97 pub fn dead_values(&self) -> Vec<ValueId> {
99 let uses = self.use_list();
100 self.instrs.iter()
101 .filter(|i| {
102 op_arity(&i.op).pushes
103 && !uses.contains_key(&i.id)
104 && Some(i.id) != self.result
105 })
106 .map(|i| i.id)
107 .collect()
108 }
109}
110
111#[derive(Debug, Clone, Copy)]
112struct Arity { pops: usize, pushes: bool }
113
114fn op_arity(op: &Opcode) -> Arity {
115 match op {
116 Opcode::PushNull | Opcode::PushBool(_) | Opcode::PushInt(_)
117 | Opcode::PushFloat(_) | Opcode::PushStr(_)
118 | Opcode::PushRoot | Opcode::PushCurrent | Opcode::LoadIdent(_)
119 | Opcode::RootChain(_) | Opcode::GetPointer(_)
120 | Opcode::MakeObj(_) | Opcode::MakeArr(_) | Opcode::FString(_)
121 | Opcode::ListComp(_) | Opcode::DictComp(_) | Opcode::SetComp(_)
122 | Opcode::PatchEval(_) =>
123 Arity { pops: 0, pushes: true },
124
125 Opcode::GetField(_) | Opcode::OptField(_) | Opcode::GetIndex(_)
126 | Opcode::FieldChain(_)
127 | Opcode::GetSlice(..) | Opcode::Descendant(_) | Opcode::DescendAll
128 | Opcode::DynIndex(_) | Opcode::InlineFilter(_)
129 | Opcode::Quantifier(_) | Opcode::FilterCount(_)
130 | Opcode::FindFirst(_) | Opcode::FindOne(_)
131 | Opcode::FilterMap { .. } | Opcode::MapFilter { .. }
132 | Opcode::FilterMapSum { .. } | Opcode::FilterMapAvg { .. }
133 | Opcode::FilterMapFirst { .. }
134 | Opcode::FilterMapMin { .. } | Opcode::FilterMapMax { .. }
135 | Opcode::FilterLast { .. }
136 | Opcode::FilterFilter { .. }
137 | Opcode::MapMap { .. } | Opcode::MapSum(_) | Opcode::MapAvg(_)
138 | Opcode::MapToJsonJoin { .. }
139 | Opcode::StrTrimUpper | Opcode::StrTrimLower
140 | Opcode::StrUpperTrim | Opcode::StrLowerTrim
141 | Opcode::StrSplitReverseJoin { .. }
142 | Opcode::MapReplaceLit { .. }
143 | Opcode::MapUpperReplaceLit { .. }
144 | Opcode::MapLowerReplaceLit { .. }
145 | Opcode::MapStrConcat { .. }
146 | Opcode::MapSplitLenSum { .. }
147 | Opcode::MapProject { .. }
148 | Opcode::MapStrSlice { .. }
149 | Opcode::MapFString(_)
150 | Opcode::MapSplitCount { .. }
151 | Opcode::MapSplitFirst { .. }
152 | Opcode::MapSplitNth { .. }
153 | Opcode::MapSplitCountSum { .. }
154 | Opcode::MapMin(_) | Opcode::MapMax(_)
155 | Opcode::MapFieldSum(_) | Opcode::MapFieldAvg(_)
156 | Opcode::MapFieldMin(_) | Opcode::MapFieldMax(_)
157 | Opcode::MapField(_) | Opcode::MapFieldChain(_) | Opcode::MapFieldUnique(_)
158 | Opcode::MapFieldChainUnique(_)
159 | Opcode::FlatMapChain(_)
160 | Opcode::FilterFieldEqLit(_, _) | Opcode::FilterFieldCmpLit(_, _, _)
161 | Opcode::FilterCurrentCmpLit(_, _)
162 | Opcode::FilterStrVecStartsWith(_)
163 | Opcode::FilterStrVecEndsWith(_)
164 | Opcode::FilterStrVecContains(_)
165 | Opcode::MapStrVecUpper
166 | Opcode::MapStrVecLower
167 | Opcode::MapStrVecTrim
168 | Opcode::MapNumVecArith { .. }
169 | Opcode::MapNumVecNeg
170 | Opcode::FilterFieldCmpField(_, _, _)
171 | Opcode::FilterFieldEqLitMapField(_, _, _)
172 | Opcode::FilterFieldCmpLitMapField(_, _, _, _)
173 | Opcode::FilterFieldEqLitCount(_, _) | Opcode::FilterFieldCmpLitCount(_, _, _)
174 | Opcode::FilterFieldCmpFieldCount(_, _, _)
175 | Opcode::FilterFieldsAllEqLitCount(_)
176 | Opcode::FilterFieldsAllCmpLitCount(_)
177 | Opcode::GroupByField(_)
178 | Opcode::CountByField(_)
179 | Opcode::UniqueByField(_)
180 | Opcode::MapFlatten(_)
181 | Opcode::MapFirst(_) | Opcode::MapLast(_)
182 | Opcode::FilterTakeWhile { .. }
183 | Opcode::FilterDropWhile { .. } | Opcode::MapUnique(_)
184 | Opcode::EquiJoin { .. }
185 | Opcode::TopN { .. } | Opcode::UniqueCount
186 | Opcode::ArgExtreme { .. } | Opcode::KindCheck { .. }
187 | Opcode::Not | Opcode::Neg
188 | Opcode::CallMethod(_) | Opcode::CallOptMethod(_)
189 | Opcode::AndOp(_) | Opcode::OrOp(_) | Opcode::CoalesceOp(_)
190 | Opcode::IfElse { .. }
191 | Opcode::CastOp(_) =>
192 Arity { pops: 1, pushes: true },
193
194 Opcode::Add | Opcode::Sub | Opcode::Mul | Opcode::Div | Opcode::Mod
195 | Opcode::Eq | Opcode::Neq | Opcode::Lt | Opcode::Lte
196 | Opcode::Gt | Opcode::Gte | Opcode::Fuzzy =>
197 Arity { pops: 2, pushes: true },
198
199 Opcode::StoreVar(_) => Arity { pops: 1, pushes: false },
200 Opcode::SetCurrent | Opcode::BindVar(_)
201 | Opcode::BindObjDestructure(_) | Opcode::BindArrDestructure(_)
202 | Opcode::LetExpr { .. } => Arity { pops: 0, pushes: false },
203 }
204}
205
206pub fn value_number(g: &SsaGraph) -> HashMap<ValueId, ValueId> {
209 let mut canon: HashMap<ValueId, ValueId> = HashMap::new();
210 let mut seen: HashMap<(u64, Vec<ValueId>), ValueId> = HashMap::new();
211 for instr in &g.instrs {
212 let canon_uses: Vec<ValueId> = instr.uses.iter()
213 .map(|u| *canon.get(u).unwrap_or(u)).collect();
214 let key = (op_hash(&instr.op), canon_uses);
215 match seen.get(&key) {
216 Some(&existing) => { canon.insert(instr.id, existing); }
217 None => { seen.insert(key, instr.id); canon.insert(instr.id, instr.id); }
218 }
219 }
220 canon
221}
222
223fn op_hash(op: &Opcode) -> u64 {
224 use std::collections::hash_map::DefaultHasher;
225 use std::hash::{Hash, Hasher};
226 let mut h = DefaultHasher::new();
227 std::mem::discriminant(op).hash(&mut h);
228 match op {
229 Opcode::PushInt(n) => n.hash(&mut h),
230 Opcode::PushStr(s) => s.as_bytes().hash(&mut h),
231 Opcode::PushBool(b) => b.hash(&mut h),
232 Opcode::GetField(k) | Opcode::OptField(k) | Opcode::LoadIdent(k) =>
233 k.as_bytes().hash(&mut h),
234 Opcode::GetIndex(i) => i.hash(&mut h),
235 _ => {}
236 }
237 h.finish()
238}
239
240#[allow(dead_code)]
242fn _use_arc<T>(_: Arc<T>) {}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::vm::Compiler;
248
249 #[test]
250 fn ssa_builds_graph() {
251 let p = Compiler::compile_str("$.a + $.b").unwrap();
252 let g = SsaGraph::build(&p);
253 assert!(g.result.is_some());
254 let add = g.instrs.last().unwrap();
255 assert_eq!(add.uses.len(), 2);
256 }
257
258 #[test]
259 fn ssa_use_list() {
260 let p = Compiler::compile_str("$.a + $.b").unwrap();
261 let g = SsaGraph::build(&p);
262 let uses = g.use_list();
263 assert_eq!(uses.values().map(|v| v.len()).sum::<usize>(), 2);
264 }
265
266 #[test]
267 fn value_numbering_dedups_identical() {
268 let p = Compiler::compile_str("[$.a, $.a]").unwrap();
270 let g = SsaGraph::build(&p);
271 let canon = value_number(&g);
272 let load_ids: Vec<ValueId> = g.instrs.iter()
274 .filter(|i| matches!(i.op, crate::vm::Opcode::RootChain(_)))
275 .map(|i| canon[&i.id]).collect();
276 if load_ids.len() >= 2 {
277 assert_eq!(load_ids[0], load_ids[1]);
278 }
279 }
280}