Skip to main content

jetro_core/
ssa.rs

1//! SSA-style numbering + data-flow graph for v2 programs.
2//!
3//! v2's stack machine has no named intermediate values.  To express a
4//! data-flow graph, we assign each stack push a fresh `ValueId` and
5//! record which earlier values each opcode consumes.
6//!
7//! Limitations:
8//! - Sub-programs (branches) are opaque to this pass — they are walked
9//!   separately but their values live in their own namespace.
10//! - Phi nodes are not synthesised; there are no merge points within
11//!   a single block.
12//!
13//! Useful for: def-use queries, live-range analysis, value-numbering CSE.
14
15use std::sync::Arc;
16use std::collections::HashMap;
17use super::vm::{Program, Opcode};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct ValueId(pub u32);
21
22#[derive(Debug, Clone)]
23pub struct SsaInstr {
24    pub id:   ValueId,
25    pub op:   Opcode,
26    /// Value ids consumed from the stack (pops) in order.
27    pub uses: Vec<ValueId>,
28}
29
30/// A phi node: at a merge point, value takes one of several incoming
31/// values depending on which predecessor path was taken.
32#[derive(Debug, Clone)]
33pub struct Phi {
34    pub id:       ValueId,
35    /// Incoming (predecessor label, value) pairs.
36    pub incoming: Vec<(PhiEdge, ValueId)>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub enum PhiEdge { AndLhs, AndRhs, OrLhs, OrRhs, CoalesceLhs, CoalesceRhs }
41
42#[derive(Debug, Clone, Default)]
43pub struct SsaGraph {
44    pub instrs: Vec<SsaInstr>,
45    pub phis:   Vec<Phi>,
46    /// Last pushed value id — the program's result.
47    pub result: Option<ValueId>,
48}
49
50impl SsaGraph {
51    pub fn build(program: &Program) -> SsaGraph {
52        let mut g = SsaGraph::default();
53        let mut stack: Vec<ValueId> = Vec::new();
54        for op in program.ops.iter() {
55            let arity = op_arity(op);
56            let mut uses = Vec::with_capacity(arity.pops);
57            for _ in 0..arity.pops {
58                if let Some(v) = stack.pop() { uses.push(v); }
59            }
60            let id = ValueId(g.instrs.len() as u32);
61            // Synthesise phi at short-circuit / coalesce merge points.
62            // At an AndOp, the lhs is on the stack; rhs sub-program produces
63            // a new value conditionally.  Result is a phi(lhs, rhs).
64            let phi_edge = match op {
65                Opcode::AndOp(_)      => Some((PhiEdge::AndLhs, PhiEdge::AndRhs)),
66                Opcode::OrOp(_)       => Some((PhiEdge::OrLhs, PhiEdge::OrRhs)),
67                Opcode::CoalesceOp(_) => Some((PhiEdge::CoalesceLhs, PhiEdge::CoalesceRhs)),
68                _ => None,
69            };
70            if let Some((lhs_edge, rhs_edge)) = phi_edge {
71                // uses[0] is the lhs consumed; the "rhs" value is synthetic — we
72                // model it as the instruction's own id (sub-program result).
73                let lhs = uses.first().copied().unwrap_or(id);
74                g.phis.push(Phi {
75                    id,
76                    incoming: vec![(lhs_edge, lhs), (rhs_edge, id)],
77                });
78            }
79            g.instrs.push(SsaInstr { id, op: op.clone(), uses });
80            if arity.pushes { stack.push(id); }
81        }
82        g.result = stack.pop();
83        g
84    }
85
86    /// Use-list (consumers) for each value id.
87    pub fn use_list(&self) -> HashMap<ValueId, Vec<ValueId>> {
88        let mut m: HashMap<ValueId, Vec<ValueId>> = HashMap::new();
89        for instr in &self.instrs {
90            for u in &instr.uses {
91                m.entry(*u).or_default().push(instr.id);
92            }
93        }
94        m
95    }
96
97    /// Dead values: pushed but never consumed and not the final result.
98    pub fn dead_values(&self) -> Vec<ValueId> {
99        let uses = self.use_list();
100        self.instrs.iter()
101            .filter(|i| {
102                op_arity(&i.op).pushes
103                && !uses.contains_key(&i.id)
104                && Some(i.id) != self.result
105            })
106            .map(|i| i.id)
107            .collect()
108    }
109}
110
111#[derive(Debug, Clone, Copy)]
112struct Arity { pops: usize, pushes: bool }
113
114fn op_arity(op: &Opcode) -> Arity {
115    match op {
116        Opcode::PushNull | Opcode::PushBool(_) | Opcode::PushInt(_)
117            | Opcode::PushFloat(_) | Opcode::PushStr(_)
118            | Opcode::PushRoot | Opcode::PushCurrent | Opcode::LoadIdent(_)
119            | Opcode::RootChain(_) | Opcode::GetPointer(_)
120            | Opcode::MakeObj(_) | Opcode::MakeArr(_) | Opcode::FString(_)
121            | Opcode::ListComp(_) | Opcode::DictComp(_) | Opcode::SetComp(_)
122            | Opcode::PatchEval(_) =>
123            Arity { pops: 0, pushes: true },
124
125        Opcode::GetField(_) | Opcode::OptField(_) | Opcode::GetIndex(_)
126            | Opcode::FieldChain(_)
127            | Opcode::GetSlice(..) | Opcode::Descendant(_) | Opcode::DescendAll
128            | Opcode::DynIndex(_) | Opcode::InlineFilter(_)
129            | Opcode::Quantifier(_) | Opcode::FilterCount(_)
130            | Opcode::FindFirst(_) | Opcode::FindOne(_)
131            | Opcode::FilterMap { .. } | Opcode::MapFilter { .. }
132            | Opcode::FilterMapSum { .. } | Opcode::FilterMapAvg { .. }
133            | Opcode::FilterMapFirst { .. }
134            | Opcode::FilterMapMin { .. } | Opcode::FilterMapMax { .. }
135            | Opcode::FilterLast { .. }
136            | Opcode::FilterFilter { .. }
137            | Opcode::MapMap { .. } | Opcode::MapSum(_) | Opcode::MapAvg(_)
138            | Opcode::MapToJsonJoin { .. }
139            | Opcode::StrTrimUpper | Opcode::StrTrimLower
140            | Opcode::StrUpperTrim | Opcode::StrLowerTrim
141            | Opcode::StrSplitReverseJoin { .. }
142            | Opcode::MapReplaceLit { .. }
143            | Opcode::MapUpperReplaceLit { .. }
144            | Opcode::MapLowerReplaceLit { .. }
145            | Opcode::MapStrConcat { .. }
146            | Opcode::MapSplitLenSum { .. }
147            | Opcode::MapProject { .. }
148            | Opcode::MapStrSlice { .. }
149            | Opcode::MapFString(_)
150            | Opcode::MapSplitCount { .. }
151            | Opcode::MapSplitFirst { .. }
152            | Opcode::MapSplitNth   { .. }
153            | Opcode::MapSplitCountSum { .. }
154            | Opcode::MapMin(_) | Opcode::MapMax(_)
155            | Opcode::MapFieldSum(_) | Opcode::MapFieldAvg(_)
156            | Opcode::MapFieldMin(_) | Opcode::MapFieldMax(_)
157            | Opcode::MapField(_) | Opcode::MapFieldChain(_) | Opcode::MapFieldUnique(_)
158            | Opcode::MapFieldChainUnique(_)
159            | Opcode::FlatMapChain(_)
160            | Opcode::FilterFieldEqLit(_, _) | Opcode::FilterFieldCmpLit(_, _, _)
161            | Opcode::FilterCurrentCmpLit(_, _)
162            | Opcode::FilterStrVecStartsWith(_)
163            | Opcode::FilterStrVecEndsWith(_)
164            | Opcode::FilterStrVecContains(_)
165            | Opcode::MapStrVecUpper
166            | Opcode::MapStrVecLower
167            | Opcode::MapStrVecTrim
168            | Opcode::MapNumVecArith { .. }
169            | Opcode::MapNumVecNeg
170            | Opcode::FilterFieldCmpField(_, _, _)
171            | Opcode::FilterFieldEqLitMapField(_, _, _)
172            | Opcode::FilterFieldCmpLitMapField(_, _, _, _)
173            | Opcode::FilterFieldEqLitCount(_, _) | Opcode::FilterFieldCmpLitCount(_, _, _)
174            | Opcode::FilterFieldCmpFieldCount(_, _, _)
175            | Opcode::FilterFieldsAllEqLitCount(_)
176            | Opcode::FilterFieldsAllCmpLitCount(_)
177            | Opcode::GroupByField(_)
178            | Opcode::CountByField(_)
179            | Opcode::UniqueByField(_)
180            | Opcode::MapFlatten(_)
181            | Opcode::MapFirst(_) | Opcode::MapLast(_)
182            | Opcode::FilterTakeWhile { .. }
183            | Opcode::FilterDropWhile { .. } | Opcode::MapUnique(_)
184            | Opcode::EquiJoin { .. }
185            | Opcode::TopN { .. } | Opcode::UniqueCount
186            | Opcode::ArgExtreme { .. } | Opcode::KindCheck { .. }
187            | Opcode::Not | Opcode::Neg
188            | Opcode::CallMethod(_) | Opcode::CallOptMethod(_)
189            | Opcode::AndOp(_) | Opcode::OrOp(_) | Opcode::CoalesceOp(_)
190            | Opcode::IfElse { .. }
191            | Opcode::CastOp(_) =>
192            Arity { pops: 1, pushes: true },
193
194        Opcode::Add | Opcode::Sub | Opcode::Mul | Opcode::Div | Opcode::Mod
195            | Opcode::Eq | Opcode::Neq | Opcode::Lt | Opcode::Lte
196            | Opcode::Gt | Opcode::Gte | Opcode::Fuzzy =>
197            Arity { pops: 2, pushes: true },
198
199        Opcode::StoreVar(_) => Arity { pops: 1, pushes: false },
200        Opcode::SetCurrent | Opcode::BindVar(_)
201            | Opcode::BindObjDestructure(_) | Opcode::BindArrDestructure(_)
202            | Opcode::LetExpr { .. } => Arity { pops: 0, pushes: false },
203    }
204}
205
206/// Value-numbering CSE on top of SSA: two instructions with identical
207/// opcode and matching `uses` are mapped to the same canonical id.
208pub fn value_number(g: &SsaGraph) -> HashMap<ValueId, ValueId> {
209    let mut canon: HashMap<ValueId, ValueId> = HashMap::new();
210    let mut seen: HashMap<(u64, Vec<ValueId>), ValueId> = HashMap::new();
211    for instr in &g.instrs {
212        let canon_uses: Vec<ValueId> = instr.uses.iter()
213            .map(|u| *canon.get(u).unwrap_or(u)).collect();
214        let key = (op_hash(&instr.op), canon_uses);
215        match seen.get(&key) {
216            Some(&existing) => { canon.insert(instr.id, existing); }
217            None            => { seen.insert(key, instr.id); canon.insert(instr.id, instr.id); }
218        }
219    }
220    canon
221}
222
223fn op_hash(op: &Opcode) -> u64 {
224    use std::collections::hash_map::DefaultHasher;
225    use std::hash::{Hash, Hasher};
226    let mut h = DefaultHasher::new();
227    std::mem::discriminant(op).hash(&mut h);
228    match op {
229        Opcode::PushInt(n) => n.hash(&mut h),
230        Opcode::PushStr(s) => s.as_bytes().hash(&mut h),
231        Opcode::PushBool(b) => b.hash(&mut h),
232        Opcode::GetField(k) | Opcode::OptField(k) | Opcode::LoadIdent(k) =>
233            k.as_bytes().hash(&mut h),
234        Opcode::GetIndex(i) => i.hash(&mut h),
235        _ => {}
236    }
237    h.finish()
238}
239
240// Silence unused-Arc warning.
241#[allow(dead_code)]
242fn _use_arc<T>(_: Arc<T>) {}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::vm::Compiler;
248
249    #[test]
250    fn ssa_builds_graph() {
251        let p = Compiler::compile_str("$.a + $.b").unwrap();
252        let g = SsaGraph::build(&p);
253        assert!(g.result.is_some());
254        let add = g.instrs.last().unwrap();
255        assert_eq!(add.uses.len(), 2);
256    }
257
258    #[test]
259    fn ssa_use_list() {
260        let p = Compiler::compile_str("$.a + $.b").unwrap();
261        let g = SsaGraph::build(&p);
262        let uses = g.use_list();
263        assert_eq!(uses.values().map(|v| v.len()).sum::<usize>(), 2);
264    }
265
266    #[test]
267    fn value_numbering_dedups_identical() {
268        // Use GetField sequences that don't const-fold.
269        let p = Compiler::compile_str("[$.a, $.a]").unwrap();
270        let g = SsaGraph::build(&p);
271        let canon = value_number(&g);
272        // Both root-chain loads share canonical id.
273        let load_ids: Vec<ValueId> = g.instrs.iter()
274            .filter(|i| matches!(i.op, crate::vm::Opcode::RootChain(_)))
275            .map(|i| canon[&i.id]).collect();
276        if load_ids.len() >= 2 {
277            assert_eq!(load_ids[0], load_ids[1]);
278        }
279    }
280}