Skip to main content

react_compiler_ssa/
enter_ssa.rs

1use std::collections::{HashMap, HashSet};
2
3use indexmap::IndexMap;
4use react_compiler_diagnostics::{CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory};
5use react_compiler_hir::environment::Environment;
6use react_compiler_hir::*;
7use react_compiler_hir::visitors;
8
9// =============================================================================
10// SSABuilder
11// =============================================================================
12
13struct IncompletePhi {
14    old_place: Place,
15    new_place: Place,
16}
17
18struct State {
19    defs: HashMap<IdentifierId, IdentifierId>,
20    incomplete_phis: Vec<IncompletePhi>,
21}
22
23struct SSABuilder {
24    states: HashMap<BlockId, State>,
25    current: Option<BlockId>,
26    unsealed_preds: HashMap<BlockId, u32>,
27    block_preds: HashMap<BlockId, Vec<BlockId>>,
28    unknown: HashSet<IdentifierId>,
29    context: HashSet<IdentifierId>,
30    pending_phis: HashMap<BlockId, Vec<Phi>>,
31    processed_functions: Vec<FunctionId>,
32}
33
34impl SSABuilder {
35    fn new(blocks: &IndexMap<BlockId, BasicBlock>) -> Self {
36        let mut block_preds = HashMap::new();
37        for (id, block) in blocks {
38            block_preds.insert(*id, block.preds.iter().copied().collect());
39        }
40        SSABuilder {
41            states: HashMap::new(),
42            current: None,
43            unsealed_preds: HashMap::new(),
44            block_preds,
45            unknown: HashSet::new(),
46            context: HashSet::new(),
47            pending_phis: HashMap::new(),
48            processed_functions: Vec::new(),
49        }
50    }
51
52    fn define_function(&mut self, func: &HirFunction) {
53        for (id, block) in &func.body.blocks {
54            self.block_preds
55                .insert(*id, block.preds.iter().copied().collect());
56        }
57    }
58
59    fn state_mut(&mut self) -> &mut State {
60        let current = self.current.expect("we need to be in a block to access state!");
61        self.states
62            .get_mut(&current)
63            .expect("state not found for current block")
64    }
65
66    fn make_id(&mut self, old_id: IdentifierId, env: &mut Environment) -> IdentifierId {
67        let new_id = env.next_identifier_id();
68        let old = &env.identifiers[old_id.0 as usize];
69        let declaration_id = old.declaration_id;
70        let name = old.name.clone();
71        let loc = old.loc;
72        let new_ident = &mut env.identifiers[new_id.0 as usize];
73        new_ident.declaration_id = declaration_id;
74        new_ident.name = name;
75        new_ident.loc = loc;
76        new_id
77    }
78
79    fn define_place(&mut self, old_place: &Place, env: &mut Environment) -> Result<Place, CompilerDiagnostic> {
80        let old_id = old_place.identifier;
81
82        if self.unknown.contains(&old_id) {
83            let ident = &env.identifiers[old_id.0 as usize];
84            let name = match &ident.name {
85                Some(name) => format!("{}${}", name.value(), old_id.0),
86                None => format!("${}", old_id.0),
87            };
88            return Err(CompilerDiagnostic::new(
89                ErrorCategory::Todo,
90                "[hoisting] EnterSSA: Expected identifier to be defined before being used",
91                Some(format!("Identifier {} is undefined", name)),
92            ).with_detail(CompilerDiagnosticDetail::Error {
93                loc: old_place.loc,
94                message: None,
95                identifier_name: None,
96            }));
97        }
98
99        // Do not redefine context references.
100        if self.context.contains(&old_id) {
101            return Ok(self.get_place(old_place, env));
102        }
103
104        let new_id = self.make_id(old_id, env);
105        self.state_mut().defs.insert(old_id, new_id);
106        Ok(Place {
107            identifier: new_id,
108            effect: old_place.effect,
109            reactive: old_place.reactive,
110            loc: old_place.loc,
111        })
112    }
113
114    #[allow(dead_code)]
115    fn define_context(&mut self, old_place: &Place, env: &mut Environment) -> Result<Place, CompilerDiagnostic> {
116        let old_id = old_place.identifier;
117        let new_place = self.define_place(old_place, env)?;
118        self.context.insert(old_id);
119        Ok(new_place)
120    }
121
122    /// A function's context places capture a *binding*, not a value: the
123    /// variable is only read when the function is later called, so a context
124    /// place may reference a binding that is declared after the function
125    /// expression itself (eg `const colgroup = useMemo(() => <colgroup>...)`,
126    /// where the JSX tag name resolves to the variable being assigned). Unmark
127    /// such identifiers so the later declaration doesn't error; if the function
128    /// body actually *reads* the variable before it is defined, visiting the
129    /// body re-marks it and the hoisting bailout in define_place still applies.
130    fn unmark_unknown(&mut self, id: IdentifierId) {
131        self.unknown.remove(&id);
132    }
133
134    fn get_place(&mut self, old_place: &Place, env: &mut Environment) -> Place {
135        let current_id = self.current.expect("must be in a block");
136        let new_id = self.get_id_at(old_place, current_id, env);
137        Place {
138            identifier: new_id,
139            effect: old_place.effect,
140            reactive: old_place.reactive,
141            loc: old_place.loc,
142        }
143    }
144
145    fn get_id_at(
146        &mut self,
147        old_place: &Place,
148        block_id: BlockId,
149        env: &mut Environment,
150    ) -> IdentifierId {
151        if let Some(state) = self.states.get(&block_id) {
152            if let Some(&new_id) = state.defs.get(&old_place.identifier) {
153                return new_id;
154            }
155        }
156
157        let preds = self
158            .block_preds
159            .get(&block_id)
160            .cloned()
161            .unwrap_or_default();
162
163        if preds.is_empty() {
164            self.unknown.insert(old_place.identifier);
165            return old_place.identifier;
166        }
167
168        let unsealed = self.unsealed_preds.get(&block_id).copied().unwrap_or(0);
169        if unsealed > 0 {
170            let new_id = self.make_id(old_place.identifier, env);
171            let new_place = Place {
172                identifier: new_id,
173                effect: old_place.effect,
174                reactive: old_place.reactive,
175                loc: old_place.loc,
176            };
177            let state = self.states.get_mut(&block_id).unwrap();
178            state.incomplete_phis.push(IncompletePhi {
179                old_place: old_place.clone(),
180                new_place,
181            });
182            state.defs.insert(old_place.identifier, new_id);
183            return new_id;
184        }
185
186        if preds.len() == 1 {
187            let pred = preds[0];
188            let new_id = self.get_id_at(old_place, pred, env);
189            self.states
190                .get_mut(&block_id)
191                .unwrap()
192                .defs
193                .insert(old_place.identifier, new_id);
194            return new_id;
195        }
196
197        let new_id = self.make_id(old_place.identifier, env);
198        self.states
199            .get_mut(&block_id)
200            .unwrap()
201            .defs
202            .insert(old_place.identifier, new_id);
203        let new_place = Place {
204            identifier: new_id,
205            effect: old_place.effect,
206            reactive: old_place.reactive,
207            loc: old_place.loc,
208        };
209        self.add_phi(block_id, old_place, &new_place, env);
210        new_id
211    }
212
213    fn add_phi(
214        &mut self,
215        block_id: BlockId,
216        old_place: &Place,
217        new_place: &Place,
218        env: &mut Environment,
219    ) {
220        let preds = self
221            .block_preds
222            .get(&block_id)
223            .cloned()
224            .unwrap_or_default();
225
226        let mut pred_defs: IndexMap<BlockId, Place> = IndexMap::new();
227        for pred_block_id in &preds {
228            let pred_id = self.get_id_at(old_place, *pred_block_id, env);
229            pred_defs.insert(
230                *pred_block_id,
231                Place {
232                    identifier: pred_id,
233                    effect: old_place.effect,
234                    reactive: old_place.reactive,
235                    loc: old_place.loc,
236                },
237            );
238        }
239
240        let phi = Phi {
241            place: new_place.clone(),
242            operands: pred_defs,
243        };
244
245        self.pending_phis
246            .entry(block_id)
247            .or_default()
248            .push(phi);
249    }
250
251    fn fix_incomplete_phis(&mut self, block_id: BlockId, env: &mut Environment) {
252        let incomplete_phis: Vec<IncompletePhi> = self
253            .states
254            .get_mut(&block_id)
255            .unwrap()
256            .incomplete_phis
257            .drain(..)
258            .collect();
259        for phi in &incomplete_phis {
260            self.add_phi(block_id, &phi.old_place, &phi.new_place, env);
261        }
262    }
263
264    fn start_block(&mut self, block_id: BlockId) {
265        self.current = Some(block_id);
266        self.states.insert(
267            block_id,
268            State {
269                defs: HashMap::new(),
270                incomplete_phis: Vec::new(),
271            },
272        );
273    }
274}
275
276// =============================================================================
277// Public entry point
278// =============================================================================
279
280pub fn enter_ssa(
281    func: &mut HirFunction,
282    env: &mut Environment,
283) -> Result<(), CompilerDiagnostic> {
284    let mut builder = SSABuilder::new(&func.body.blocks);
285    let root_entry = func.body.entry;
286    enter_ssa_impl(func, &mut builder, env, root_entry)?;
287
288    // Apply all pending phis to the actual blocks
289    apply_pending_phis(func, env, &mut builder);
290
291    Ok(())
292}
293
294fn apply_pending_phis(
295    func: &mut HirFunction,
296    env: &mut Environment,
297    builder: &mut SSABuilder,
298) {
299    for (block_id, block) in func.body.blocks.iter_mut() {
300        if let Some(phis) = builder.pending_phis.remove(block_id) {
301            block.phis.extend(phis);
302        }
303    }
304    for fid in &builder.processed_functions.clone() {
305        let inner_func = &mut env.functions[fid.0 as usize];
306        for (block_id, block) in inner_func.body.blocks.iter_mut() {
307            if let Some(phis) = builder.pending_phis.remove(block_id) {
308                block.phis.extend(phis);
309            }
310        }
311    }
312}
313
314fn enter_ssa_impl(
315    func: &mut HirFunction,
316    builder: &mut SSABuilder,
317    env: &mut Environment,
318    root_entry: BlockId,
319) -> Result<(), CompilerDiagnostic> {
320    let mut visited_blocks: HashSet<BlockId> = HashSet::new();
321    let block_ids: Vec<BlockId> = func.body.blocks.keys().copied().collect();
322
323    for block_id in &block_ids {
324        let block_id = *block_id;
325
326        if visited_blocks.contains(&block_id) {
327            return Err(CompilerDiagnostic::new(
328                ErrorCategory::Invariant,
329                format!("found a cycle! visiting bb{} again", block_id.0),
330                None,
331            ));
332        }
333
334        visited_blocks.insert(block_id);
335        builder.start_block(block_id);
336
337        // Handle params at the root entry
338        if block_id == root_entry {
339            if !func.context.is_empty() {
340                return Err(CompilerDiagnostic::new(
341                    ErrorCategory::Invariant,
342                    "Expected function context to be empty for outer function declarations",
343                    None,
344                ));
345            }
346            let params = std::mem::take(&mut func.params);
347            let mut new_params = Vec::with_capacity(params.len());
348            for param in params {
349                new_params.push(match param {
350                    ParamPattern::Place(p) => ParamPattern::Place(builder.define_place(&p, env)?),
351                    ParamPattern::Spread(s) => ParamPattern::Spread(SpreadPattern {
352                        place: builder.define_place(&s.place, env)?,
353                    }),
354                });
355            }
356            func.params = new_params;
357        }
358
359        // Process instructions
360        let instruction_ids: Vec<InstructionId> = func
361            .body
362            .blocks
363            .get(&block_id)
364            .unwrap()
365            .instructions
366            .clone();
367
368        for instr_id in &instruction_ids {
369            let instr_idx = instr_id.0 as usize;
370            let instr = &mut func.instructions[instr_idx];
371
372            // For FunctionExpression/ObjectMethod, we need to handle context
373            // mapping specially because env.functions is borrowed by the closure.
374            // First, check if this is a FunctionExpression/ObjectMethod and handle
375            // context mapping separately.
376            let func_expr_id = match &instr.value {
377                InstructionValue::FunctionExpression { lowered_func, .. }
378                | InstructionValue::ObjectMethod { lowered_func, .. } => Some(lowered_func.func),
379                _ => None,
380            };
381
382            // Map context places for function expressions before other operands
383            if let Some(fid) = func_expr_id {
384                let context = std::mem::take(&mut env.functions[fid.0 as usize].context);
385                env.functions[fid.0 as usize].context = context
386                    .into_iter()
387                    .map(|place| builder.get_place(&place, env))
388                    .collect();
389            }
390
391            // Map non-context operands
392            visitors::for_each_instruction_value_operand_mut(&mut instr.value, &mut |place| {
393                *place = builder.get_place(place, env);
394            });
395
396            // Map lvalues (skip DeclareContext/StoreContext — context variables
397            // don't participate in SSA renaming)
398            let instr = &mut func.instructions[instr_idx];
399            let mut lvalue_err: Option<CompilerDiagnostic> = None;
400            visitors::for_each_instruction_lvalue_mut(instr, &mut |place| {
401                if lvalue_err.is_none() {
402                    match builder.define_place(place, env) {
403                        Ok(new_place) => *place = new_place,
404                        Err(e) => lvalue_err = Some(e),
405                    }
406                }
407            });
408            if let Some(e) = lvalue_err {
409                return Err(e);
410            }
411
412            // Handle inner function SSA
413            if let Some(fid) = func_expr_id {
414                let context_ids: Vec<IdentifierId> = env.functions[fid.0 as usize]
415                    .context
416                    .iter()
417                    .map(|place| place.identifier)
418                    .collect();
419                for id in context_ids {
420                    builder.unmark_unknown(id);
421                }
422                builder.processed_functions.push(fid);
423                let inner_func = &mut env.functions[fid.0 as usize];
424                let inner_entry = inner_func.body.entry;
425                let entry_block = inner_func.body.blocks.get_mut(&inner_entry).unwrap();
426
427                if !entry_block.preds.is_empty() {
428                    return Err(CompilerDiagnostic::new(
429                        ErrorCategory::Invariant,
430                        "Expected function expression entry block to have zero predecessors",
431                        None,
432                    ));
433                }
434                entry_block.preds.insert(block_id);
435
436                builder.define_function(inner_func);
437
438                let saved_current = builder.current;
439
440                // Map inner function params
441                let inner_params = std::mem::take(&mut env.functions[fid.0 as usize].params);
442                let mut new_inner_params = Vec::with_capacity(inner_params.len());
443                for param in inner_params {
444                    new_inner_params.push(match param {
445                        ParamPattern::Place(p) => ParamPattern::Place(builder.define_place(&p, env)?),
446                        ParamPattern::Spread(s) => ParamPattern::Spread(SpreadPattern {
447                            place: builder.define_place(&s.place, env)?,
448                        }),
449                    });
450                }
451                env.functions[fid.0 as usize].params = new_inner_params;
452
453                // Take the inner function out of the arena to process it
454                let mut inner_func = std::mem::replace(
455                    &mut env.functions[fid.0 as usize],
456                    placeholder_function(),
457                );
458
459                enter_ssa_impl(&mut inner_func, builder, env, root_entry)?;
460
461                // Put it back
462                env.functions[fid.0 as usize] = inner_func;
463
464                builder.current = saved_current;
465
466                // Clear entry preds
467                env.functions[fid.0 as usize]
468                    .body
469                    .blocks
470                    .get_mut(&inner_entry)
471                    .unwrap()
472                    .preds
473                    .clear();
474                builder.block_preds.insert(inner_entry, Vec::new());
475            }
476        }
477
478        // Map terminal operands
479        let terminal = &mut func.body.blocks.get_mut(&block_id).unwrap().terminal;
480        visitors::for_each_terminal_operand_mut(terminal, &mut |place| {
481            *place = builder.get_place(place, env);
482        });
483
484        // Handle successors
485        let terminal_ref = &func.body.blocks.get(&block_id).unwrap().terminal;
486        let successors = visitors::each_terminal_successor(terminal_ref);
487        for output_id in successors {
488            let output_preds_len = builder
489                .block_preds
490                .get(&output_id)
491                .map(|p| p.len() as u32)
492                .unwrap_or(0);
493
494            let count = if builder.unsealed_preds.contains_key(&output_id) {
495                builder.unsealed_preds[&output_id] - 1
496            } else {
497                output_preds_len - 1
498            };
499            builder.unsealed_preds.insert(output_id, count);
500
501            if count == 0 && visited_blocks.contains(&output_id) {
502                builder.fix_incomplete_phis(output_id, env);
503            }
504        }
505    }
506
507    Ok(())
508}
509
510/// Create a placeholder HirFunction for temporarily swapping an inner function
511/// out of `env.functions` via `std::mem::replace`. The placeholder is never
512/// read — the real function is swapped back immediately after processing.
513pub fn placeholder_function() -> HirFunction {
514    HirFunction {
515        loc: None,
516        id: None,
517        name_hint: None,
518        fn_type: ReactFunctionType::Other,
519        params: Vec::new(),
520        return_type_annotation: None,
521        returns: Place {
522            identifier: IdentifierId(0),
523            effect: Effect::Unknown,
524            reactive: false,
525            loc: None,
526        },
527        context: Vec::new(),
528        body: HIR {
529            entry: BlockId(0),
530            blocks: IndexMap::new(),
531        },
532        instructions: Vec::new(),
533        generator: false,
534        is_async: false,
535        directives: Vec::new(),
536        aliasing_effects: None,
537    }
538}