Skip to main content

cairo_lang_lowering/analysis/
equality_analysis.rs

1//! Equality analysis for lowered IR.
2//!
3//! This module tracks semantic equivalence between variables as information flows through the
4//! program. Two variables are equivalent if they hold the same value. Additionally, the analysis
5//! tracks `Box`/unbox, snapshot/desnap, and struct/array construct relationships between
6//! equivalence classes via unified forward/reverse maps. Arrays reuse the struct construct
7//! representation since both map `(TypeId, Vec<VariableId>)` — array pop operations act as
8//! destructures.
9
10use cairo_lang_debug::DebugWithDb;
11use cairo_lang_defs::ids::{ExternFunctionId, NamedLanguageElementId};
12use cairo_lang_semantic::corelib::option_some_variant;
13use cairo_lang_semantic::helper::ModuleHelper;
14use cairo_lang_semantic::{ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId};
15use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
16use salsa::Database;
17
18use crate::analysis::core::Edge;
19use crate::analysis::{DataflowAnalyzer, Direction, ForwardDataflowAnalysis};
20use crate::{
21    BlockEnd, BlockId, Lowered, MatchArm, MatchExternInfo, MatchInfo, Statement, VariableId,
22};
23
24/// A relationship between equivalence classes, carrying its payload data.
25/// Hashable so it can be used as a forward map key.
26#[derive(Clone, Debug, Hash, PartialEq, Eq)]
27enum Relation<'db> {
28    Box(VariableId),
29    Snapshot(VariableId),
30    EnumConstruct(ConcreteVariant<'db>, VariableId),
31    StructConstruct(TypeId<'db>, Vec<VariableId>),
32}
33
34impl<'db> Relation<'db> {
35    /// Returns an iterator over all variables referenced by this relation.
36    fn referenced_vars(&self) -> impl Iterator<Item = VariableId> + '_ {
37        let (single, fields): (Option<VariableId>, &[VariableId]) = match self {
38            Relation::Box(v) | Relation::Snapshot(v) | Relation::EnumConstruct(_, v) => {
39                (Some(*v), &[])
40            }
41            Relation::StructConstruct(_, vs) => (None, vs),
42        };
43        single.into_iter().chain(fields.iter().copied())
44    }
45
46    /// Returns a new Relation with all input variables resolved to their current representatives.
47    fn with_fresh_reps(self, state: &mut EqualityState<'_>) -> Self {
48        match self {
49            Relation::Box(v) => Relation::Box(state.find(v)),
50            Relation::Snapshot(v) => Relation::Snapshot(state.find(v)),
51            Relation::EnumConstruct(variant, v) => Relation::EnumConstruct(variant, state.find(v)),
52            Relation::StructConstruct(ty, fields) => {
53                Relation::StructConstruct(ty, fields.into_iter().map(|v| state.find(v)).collect())
54            }
55        }
56    }
57
58    /// Merges two reverse relationships (self and other are proven equal).
59    /// When both exist with the same kind, propagates equality through inputs.
60    /// Always returns a valid relation (self if no merge needed).
61    fn union_equal_relations(self, other: Option<Self>, uf: &mut EqualityState<'_>) -> Self {
62        let Some(other_rel) = other else { return self };
63        match (&self, &other_rel) {
64            (Relation::Box(a), Relation::Box(b)) if a != b => Relation::Box(uf.union(*a, *b)),
65            (Relation::Snapshot(a), Relation::Snapshot(b)) if a != b => {
66                Relation::Snapshot(uf.union(*a, *b))
67            }
68            (Relation::EnumConstruct(v1, a), Relation::EnumConstruct(v2, b))
69                if v1 == v2 && a != b =>
70            {
71                Relation::EnumConstruct(*v1, uf.union(*a, *b))
72            }
73            (Relation::StructConstruct(t1, a), Relation::StructConstruct(t2, b))
74                if t1 == t2 && a.len() == b.len() =>
75            {
76                Relation::StructConstruct(
77                    *t1,
78                    a.iter().zip(b).map(|(x1, x2)| uf.union(*x1, *x2)).collect(),
79                )
80            }
81            // Same values or different kinds: keep self.
82            _ => self,
83        }
84    }
85}
86
87/// State for the equality analysis, tracking variable equivalences.
88///
89/// This is the `Info` type for the dataflow analysis. Each block gets its own
90/// `EqualityState` representing what we know at that point in the program.
91///
92/// Uses sparse HashMaps for efficiency - only variables that have been touched
93/// by the analysis are stored.
94#[derive(Clone, Debug, Default)]
95pub struct EqualityState<'db> {
96    /// Union-find parent map. If a variable is not in the map, it is its own representative.
97    union_find: OrderedHashMap<VariableId, VariableId>,
98
99    /// Forward map: Relation(...) -> output representative.
100    ///
101    /// Keys use representatives at insertion time. In SSA form, representatives are generally
102    /// stable within a block, so keys stay valid without migration during `union`. A union
103    /// *can* change a representative to a lower ID, which may cause a subsequent identical
104    /// construct to miss the earlier entry — this is a known imprecision (conservative, not
105    /// unsound). At merge points the maps are rebuilt from scratch.
106    forward: OrderedHashMap<Relation<'db>, VariableId>,
107
108    /// Reverse map: output representative -> Relation.
109    /// Records how each class was produced. A class has at most one reverse relationship.
110    reverse: OrderedHashMap<VariableId, Relation<'db>>,
111}
112
113impl<'db> EqualityState<'db> {
114    /// Gets the parent of a variable, defaulting to itself (root) if not in the map.
115    fn get_parent(&self, var: VariableId) -> VariableId {
116        self.union_find.get(&var).copied().unwrap_or(var)
117    }
118
119    /// Finds the representative of a variable's equivalence class.
120    /// Uses path splitting for efficiency: each node is redirected to its grandparent.
121    fn find(&mut self, mut var: VariableId) -> VariableId {
122        let mut parent = self.get_parent(var);
123        while parent != var {
124            let grandparent = self.get_parent(parent);
125            self.union_find.insert(var, grandparent);
126            var = parent;
127            parent = grandparent;
128        }
129        var
130    }
131
132    /// Finds the representative without modifying the structure.
133    pub(crate) fn find_immut(&self, mut var: VariableId) -> VariableId {
134        let mut parent = self.get_parent(var);
135        while parent != var {
136            var = parent;
137            parent = self.get_parent(var);
138        }
139        var
140    }
141
142    /// Unions two variables into the same equivalence class.
143    /// Returns the representative of the merged class.
144    /// Always chooses the lower ID as the representative to maintain canonical form.
145    fn union(&mut self, a: VariableId, b: VariableId) -> VariableId {
146        let root_a = self.find(a);
147        let root_b = self.find(b);
148
149        if root_a == root_b {
150            return root_a;
151        }
152
153        // Always choose the lower ID as the new root to maintain canonical form.
154        // This ensures forward map keys remain valid since lower IDs are defined earlier.
155        let (new_root, old_root) =
156            if root_a.index() < root_b.index() { (root_a, root_b) } else { (root_b, root_a) };
157
158        // Ensure new_root is in the map (as its own parent).
159        self.union_find.entry(new_root).or_insert(new_root);
160        // Update old_root to point to new_root.
161        self.union_find.insert(old_root, new_root);
162
163        // Merge reverse entries for both roots.
164        let old_reverse = self.reverse.swap_remove(&old_root);
165        let new_reverse = self.reverse.swap_remove(&new_root);
166        let merged_reverse = match (new_reverse, old_reverse) {
167            (Some(new_rev), old) => Some(new_rev.union_equal_relations(old, self)),
168            (None, old) => old,
169        };
170
171        // TODO(eytan-starkware): Forward Struct/Enum entries aren't re-keyed during union.
172        //    Doing so would require congruence closure (updating all entries whose inputs
173        //    changed). The consequence is missed hashcons hits — conservative, not unsound.
174        //    We also don't invalidate forward entries that become inconsistent (e.g., two
175        //    different enum variants mapping to the now-merged class). These stale entries
176        //    are harmless as this code should be unreachable.
177        // Merge forward Box/Snapshot entries for both roots.
178        let constructors = [Relation::Box, Relation::Snapshot];
179        for ctor in constructors {
180            let old_fwd = self.forward.swap_remove(&ctor(old_root));
181            let new_fwd = self.forward.swap_remove(&ctor(new_root));
182            let merged = match (new_fwd, old_fwd) {
183                (Some(t1), Some(t2)) => Some(self.union(t1, t2)),
184                (Some(t), None) | (None, Some(t)) => Some(t),
185                (None, None) => None,
186            };
187            if let Some(target) = merged {
188                let final_root = self.find(new_root);
189                let target_rep = self.find(target);
190                self.forward.insert(ctor(final_root), target_rep);
191            }
192        }
193
194        let final_root = self.find(new_root);
195        if let Some(merged_reverse) = merged_reverse {
196            self.reverse.insert(final_root, merged_reverse);
197        }
198
199        self.find(new_root)
200    }
201
202    /// Records a relation: forward maps `relation -> output`, reverse maps `output -> relation`.
203    /// If the same relation already maps to an existing output, unions them.
204    /// If the output already has a reverse, merges inputs via `union_equal_relations`.
205    fn set_relation(&mut self, relation: Relation<'db>, output: VariableId) {
206        // Refresh reps — callers may pass stale IDs, and this maximizes forward hits.
207        let relation = relation.with_fresh_reps(self);
208
209        // Forward dedup: if this exact relation already maps to an output, union them.
210        if let Some(&existing_output) = self.forward.get(&relation) {
211            self.union(existing_output, output);
212        }
213
214        // Reverse merge: if output already has a reverse, merge inputs and use the result.
215        let output_rep = self.find(output);
216        let existing = self.reverse.swap_remove(&output_rep);
217        let relation = relation.union_equal_relations(existing, self);
218
219        // Insert with current reps (may be slightly stale after unions above).
220        let output_rep = self.find(output);
221        self.forward.insert(relation.clone(), output_rep);
222        self.reverse.insert(output_rep, relation);
223    }
224
225    /// Looks up the struct construct info for a representative (immutable).
226    fn get_struct_construct_immut(
227        &self,
228        rep: VariableId,
229    ) -> Option<(TypeId<'db>, Vec<VariableId>)> {
230        match self.reverse.get(&rep)? {
231            Relation::StructConstruct(ty, fields) => Some((*ty, fields.clone())),
232            _ => None,
233        }
234    }
235
236    /// Looks up the struct construct info for a variable (mutable, uses find for path compression).
237    fn get_struct_construct(&mut self, var: VariableId) -> Option<(TypeId<'db>, Vec<VariableId>)> {
238        let rep = self.find(var);
239        self.get_struct_construct_immut(rep)
240    }
241
242    /// Looks up the enum construct info for a representative (immutable).
243    fn get_enum_construct_immut(
244        &self,
245        rep: VariableId,
246    ) -> Option<(ConcreteVariant<'db>, VariableId)> {
247        match self.reverse.get(&rep)? {
248            Relation::EnumConstruct(variant, input) => Some((*variant, *input)),
249            _ => None,
250        }
251    }
252}
253
254impl<'db> DebugWithDb<'db> for EqualityState<'db> {
255    type Db = dyn Database;
256
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'db Self::Db) -> std::fmt::Result {
258        let v = |id: VariableId| format!("v{}", self.find_immut(id).index());
259        let mut lines = Vec::<String>::new();
260        for (relation, &output) in self.forward.iter() {
261            match relation {
262                Relation::Snapshot(source) => {
263                    lines.push(format!("@{} = {}", v(*source), v(output)));
264                }
265                Relation::Box(source) => {
266                    lines.push(format!("Box({}) = {}", v(*source), v(output)));
267                }
268                Relation::EnumConstruct(variant, input) => {
269                    let name = variant.id.name(db).to_string(db);
270                    lines.push(format!("{name}({}) = {}", v(*input), v(output)));
271                }
272                Relation::StructConstruct(ty, inputs) => {
273                    let type_name = ty.format(db);
274                    let fields = inputs.iter().map(|&id| v(id)).collect::<Vec<_>>().join(", ");
275                    lines.push(format!("{type_name}({fields}) = {}", v(output)));
276                }
277            }
278        }
279        for &var in self.union_find.keys() {
280            let rep = self.find_immut(var);
281            if var != rep {
282                lines.push(format!("v{} = v{}", rep.index(), var.index()));
283            }
284        }
285        lines.sort();
286        if lines.is_empty() { write!(f, "(empty)") } else { write!(f, "{}", lines.join(", ")) }
287    }
288}
289
290/// Variable equality analysis.
291///
292/// This analyzer tracks snapshot/desnap, box/unbox, and array construct relationships as data
293/// flows through the program. At merge points (after match arms converge), we conservatively
294/// intersect the equivalence classes, keeping only equalities that hold on all paths.
295pub struct EqualityAnalysis<'a, 'db> {
296    db: &'db dyn Database,
297    lowered: &'a Lowered<'db>,
298    /// The `array_new` extern function id.
299    array_new: ExternFunctionId<'db>,
300    /// The `array_append` extern function id.
301    array_append: ExternFunctionId<'db>,
302    /// The `array_pop_front` extern function id.
303    array_pop_front: ExternFunctionId<'db>,
304    /// The `array_pop_front_consume` extern function id.
305    array_pop_front_consume: ExternFunctionId<'db>,
306    /// The `array_snapshot_pop_front` extern function id.
307    array_snapshot_pop_front: ExternFunctionId<'db>,
308    /// The `array_snapshot_pop_back` extern function id.
309    array_snapshot_pop_back: ExternFunctionId<'db>,
310}
311
312impl<'a, 'db> EqualityAnalysis<'a, 'db> {
313    /// Creates a new equality analysis instance.
314    pub fn new(db: &'db dyn Database, lowered: &'a Lowered<'db>) -> Self {
315        let array_module = ModuleHelper::core(db).submodule("array");
316        Self {
317            db,
318            lowered,
319            array_new: array_module.extern_function_id("array_new"),
320            array_append: array_module.extern_function_id("array_append"),
321            array_pop_front: array_module.extern_function_id("array_pop_front"),
322            array_pop_front_consume: array_module.extern_function_id("array_pop_front_consume"),
323            array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
324            array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
325        }
326    }
327
328    /// Runs equality analysis on a lowered function.
329    /// Returns the equality state at the exit of each block.
330    pub fn analyze(
331        db: &'db dyn Database,
332        lowered: &'a Lowered<'db>,
333    ) -> Vec<Option<EqualityState<'db>>> {
334        ForwardDataflowAnalysis::new(lowered, EqualityAnalysis::new(db, lowered)).run()
335    }
336
337    /// Handles extern match arms for array operations.
338    ///
339    /// Array pop operations act as "destructures" on the struct construct representation:
340    /// - `array_pop_front` / `array_pop_front_consume`: On the Some arm, if the input array was
341    ///   tracked as `[e0, e1, ..., eN]`, the popped element (boxed) is `Box(e0)` and the remaining
342    ///   array is `[e1, ..., eN]`.
343    /// - `array_snapshot_pop_front`: Same as above but through snapshot/box-of-snapshot wrappers.
344    /// - `array_snapshot_pop_back`: Like pop_front but pops from the back: element is `Box(eN)`,
345    ///   remaining is `[e0, ..., eN-1]`.
346    fn transfer_extern_match_arm(
347        &self,
348        info: &mut EqualityState<'db>,
349        extern_info: &MatchExternInfo<'db>,
350        arm: &MatchArm<'db>,
351    ) {
352        let Some((id, _)) = extern_info.function.get_extern(self.db) else { return };
353        // TODO(eytan-starkware): Add support for multipop.
354        let MatchArmSelector::VariantId(variant) = arm.arm_selector else { return };
355        if id == self.array_pop_front
356            || id == self.array_pop_front_consume
357            || id == self.array_snapshot_pop_front
358            || id == self.array_snapshot_pop_back
359        {
360            let [GenericArgumentId::Type(option_ty)] =
361                variant.concrete_enum_id.long(self.db).generic_args[..]
362            else {
363                panic!("Expected Option<T> with a single type argument");
364            };
365            let some_variant = option_some_variant(self.db, option_ty);
366            assert_eq!(
367                variant.concrete_enum_id.enum_id(self.db),
368                some_variant.concrete_enum_id.enum_id(self.db),
369                "Expected match to be on an Option<T>"
370            );
371            self.transfer_array_pop_arm(info, extern_info, arm, id, variant == some_variant);
372        }
373    }
374
375    /// Handles the actual array pop arm transfer after validating the Option variant.
376    fn transfer_array_pop_arm(
377        &self,
378        info: &mut EqualityState<'db>,
379        extern_info: &MatchExternInfo<'db>,
380        arm: &MatchArm<'db>,
381        id: ExternFunctionId<'db>,
382        is_some: bool,
383    ) {
384        if id == self.array_pop_front || id == self.array_pop_front_consume {
385            if is_some {
386                // Some arm: var_ids = [remaining_arr, boxed_elem]
387                let input_arr = extern_info.inputs[0].var_id;
388                let remaining_arr = arm.var_ids[0];
389                let boxed_elem = arm.var_ids[1];
390                if let Some((ty, elems)) = info.get_struct_construct(input_arr)
391                    && let Some((&first, rest)) = elems.split_first()
392                {
393                    info.set_relation(Relation::Box(first), boxed_elem);
394                    let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect();
395                    info.set_relation(Relation::StructConstruct(ty, rest_reps), remaining_arr);
396                }
397            } else {
398                // None arm for array_pop_front: var_ids = [original_arr]. Union with input.
399                // None arm for array_pop_front_consume: var_ids = [].
400                let old_array_var = extern_info.inputs[0].var_id;
401                let ty = self.lowered.variables[old_array_var].ty;
402                // TODO(eytan-starkware): This introduces a backedge to our forward map updates,
403                //    so we might need to support updating the structures accordingly.
404                //    For example, if this is empty after a pop, then we know previous array
405                //    was a singleton.
406                info.set_relation(Relation::StructConstruct(ty, vec![]), old_array_var);
407                if let [original_arr] = arm.var_ids[..] {
408                    info.union(original_arr, old_array_var);
409                }
410            }
411        } else if id == self.array_snapshot_pop_front || id == self.array_snapshot_pop_back {
412            if is_some {
413                // Some arm: var_ids = [remaining_snap_arr, boxed_snap_elem]
414                let input_snap_arr = extern_info.inputs[0].var_id;
415                let remaining_snap_arr = arm.var_ids[0];
416                let boxed_snap_elem = arm.var_ids[1];
417
418                // Look up tracked elements via snapshot reverse relationship or direct lookup.
419                let snap_rep = info.find(input_snap_arr);
420                let original_rep = match info.reverse.get(&snap_rep) {
421                    Some(Relation::Snapshot(v)) => Some(*v),
422                    _ => None,
423                };
424                let elems_opt = original_rep
425                    .and_then(|orig| {
426                        let orig = info.find_immut(orig);
427                        info.get_struct_construct_immut(orig)
428                    })
429                    .or_else(|| info.get_struct_construct_immut(snap_rep));
430
431                if let Some((_orig_ty, elems)) = elems_opt {
432                    let pop_front = id == self.array_snapshot_pop_front;
433                    let (elem, rest) = if pop_front {
434                        let Some((&first, tail)) = elems.split_first() else { return };
435                        (first, tail)
436                    } else {
437                        let Some((&last, init)) = elems.split_last() else { return };
438                        (last, init)
439                    };
440
441                    // The popped element is `Box<@T>`. Record the box relationship against
442                    // the snapshot class of `elem` if it exists.
443                    // TODO(eytan-starkware): Support relationships even when no variable
444                    // represents `@elem` yet.
445                    let elem_rep = info.find(elem);
446                    if let Some(&snap_of_elem) = info.forward.get(&Relation::Snapshot(elem_rep)) {
447                        info.set_relation(Relation::Box(snap_of_elem), boxed_snap_elem);
448                    }
449
450                    // Record the remaining snapshot array under its snapshot type
451                    // (`@Array<T>`). This is a hack: `@Array<T>` is not a struct, but we
452                    // reuse struct construct tracking to store element info for it. This
453                    // also requires the two-path lookup above (path 2).
454                    // TODO(eytan-starkware): Once placeholder vars are supported, store
455                    //    this as a proper `Array<T>` linked via the reverse map instead,
456                    //    since we may not have a var representing the non-snapshot array
457                    //    at the moment.
458                    let snap_ty = self.lowered.variables[remaining_snap_arr].ty;
459                    let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect();
460                    info.set_relation(
461                        Relation::StructConstruct(snap_ty, rest_reps),
462                        remaining_snap_arr,
463                    );
464                }
465            } else {
466                // None arm: record empty snapshot array and union output with input.
467                let old_snap_arr = extern_info.inputs[0].var_id;
468                let snap_ty = self.lowered.variables[old_snap_arr].ty;
469                info.set_relation(Relation::StructConstruct(snap_ty, vec![]), old_snap_arr);
470                if let [original_snap_arr] = arm.var_ids[..] {
471                    info.union(original_snap_arr, old_snap_arr);
472                }
473            }
474        }
475    }
476}
477
478/// Returns an iterator over all variables with equality or relationship information in the equality
479/// states.
480fn merge_referenced_vars<'db, 'a>(
481    info1: &'a EqualityState<'db>,
482    info2: &'a EqualityState<'db>,
483) -> impl Iterator<Item = VariableId> + 'a {
484    let union_find_vars = info1.union_find.keys().chain(info2.union_find.keys()).copied();
485
486    let forward_vars =
487        info1.forward.iter().chain(info2.forward.iter()).flat_map(|(relation, &output)| {
488            relation.referenced_vars().chain(std::iter::once(output))
489        });
490
491    let reverse_vars = info1
492        .reverse
493        .iter()
494        .chain(info2.reverse.iter())
495        .flat_map(|(&rep, relation)| std::iter::once(rep).chain(relation.referenced_vars()));
496
497    union_find_vars.chain(forward_vars).chain(reverse_vars)
498}
499
500/// Finds an intersection representative: given a rep in info1 and a rep in info2,
501/// returns the intersection representative in the result if one exists.
502fn find_intersection_rep(
503    intersections: &OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>>,
504    rep1: VariableId,
505    rep2: VariableId,
506) -> Option<VariableId> {
507    intersections.get(&rep1)?.iter().find_map(|(intersection_r2, intersection_rep)| {
508        (*intersection_r2 == rep2).then_some(*intersection_rep)
509    })
510}
511
512/// Preserves relations that exist in both branches.
513/// All relation types are looked up via the forward map, which retains all entries.
514fn merge_relations<'db>(
515    info1: &EqualityState<'db>,
516    info2: &EqualityState<'db>,
517    intersections: &OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>>,
518    result: &mut EqualityState<'db>,
519) {
520    // Iterate all forward entries from info1 and find matching entries in info2.
521    // We use forward (not reverse) because reverse holds only the latest relation per output,
522    // while forward retains all entries and is the authoritative source.
523    for (relation, &output1) in info1.forward.iter() {
524        match relation {
525            Relation::Box(source1) | Relation::Snapshot(source1) => {
526                for &(source2, intersection_var) in intersections.get(source1).unwrap_or(&vec![]) {
527                    let relation2 = match relation {
528                        Relation::Box(_) => Relation::Box(source2),
529                        Relation::Snapshot(_) => Relation::Snapshot(source2),
530                        _ => unreachable!(),
531                    };
532                    let Some(&output2) = info2.forward.get(&relation2) else { continue };
533                    if let Some(output_intersection) = find_intersection_rep(
534                        intersections,
535                        info1.find_immut(output1),
536                        info2.find_immut(output2),
537                    ) {
538                        let result_relation = match relation {
539                            Relation::Box(_) => Relation::Box(result.find(intersection_var)),
540                            Relation::Snapshot(_) => {
541                                Relation::Snapshot(result.find(intersection_var))
542                            }
543                            _ => unreachable!(),
544                        };
545                        result.set_relation(result_relation, output_intersection);
546                    }
547                }
548            }
549            Relation::EnumConstruct(variant, input1) => {
550                for &(input2, input_intersection) in
551                    intersections.get(&info1.find_immut(*input1)).unwrap_or(&vec![])
552                {
553                    let relation2 = Relation::EnumConstruct(*variant, input2);
554                    let Some(&output2) = info2.forward.get(&relation2) else { continue };
555                    if let Some(output_intersection) = find_intersection_rep(
556                        intersections,
557                        info1.find_immut(output1),
558                        info2.find_immut(output2),
559                    ) {
560                        result.set_relation(
561                            Relation::EnumConstruct(*variant, input_intersection),
562                            output_intersection,
563                        );
564                    }
565                }
566            }
567            Relation::StructConstruct(ty, fields1) => {
568                let fields2: Vec<_> = fields1.iter().map(|&v| info2.find_immut(v)).collect();
569                let Some(&output2) =
570                    info2.forward.get(&Relation::StructConstruct(*ty, fields2.clone()))
571                else {
572                    continue;
573                };
574                let result_fields: Option<Vec<_>> = fields1
575                    .iter()
576                    .zip(&fields2)
577                    .map(|(&v1, &v2)| {
578                        find_intersection_rep(intersections, info1.find_immut(v1), v2)
579                    })
580                    .collect();
581                let Some(result_fields) = result_fields else { continue };
582                if let Some(output_intersection) = find_intersection_rep(
583                    intersections,
584                    info1.find_immut(output1),
585                    info2.find_immut(output2),
586                ) {
587                    result.set_relation(
588                        Relation::StructConstruct(*ty, result_fields),
589                        output_intersection,
590                    );
591                }
592            }
593        }
594    }
595}
596
597impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> {
598    type Info = EqualityState<'db>;
599
600    const DIRECTION: Direction = Direction::Forward;
601
602    fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
603        EqualityState::default()
604    }
605
606    fn merge(
607        &mut self,
608        _lowered: &Lowered<'db>,
609        _statement_location: super::StatementLocation,
610        info1: Self::Info,
611        info2: Self::Info,
612    ) -> Self::Info {
613        // Intersection-based merge: keep only equalities that hold in BOTH branches.
614        let mut result = EqualityState::default();
615
616        // Group variables by (rep1, rep2) - for variables present in either state.
617        let mut groups: OrderedHashMap<(VariableId, VariableId), Vec<VariableId>> =
618            OrderedHashMap::default();
619
620        // Group by (rep1, rep2). Duplicates are fine - they'll just be added to the same group.
621        for var in merge_referenced_vars(&info1, &info2) {
622            let key = (info1.find_immut(var), info2.find_immut(var));
623            groups.entry(key).or_default().push(var);
624        }
625
626        // Union all variables within each group
627        for members in groups.values() {
628            if members.len() > 1 {
629                let first = members[0];
630                for &var in &members[1..] {
631                    result.union(first, var);
632                }
633            }
634        }
635
636        // An important point in this merge is to retain relationships.
637        // Consider:
638        //  info1 [equality class[1] = 1, 2, 4] and 6 is Box(1).
639        //  info2 [equality class[2] = 3, 5, 4] and 6 is Box(3).
640        // To detect we can keep 6 is Box(4), as it is true in all branches, we need intersection of
641        // eclasses (a single eclass can split in the result into multiple eclasses).
642        // Build secondary index: rep1 -> Vec<(rep2, intersection_rep)>.
643        let mut intersections: OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>> =
644            OrderedHashMap::default();
645        for (&(rep1, rep2), vars) in groups.iter() {
646            intersections.entry(rep1).or_default().push((rep2, result.find(vars[0])));
647        }
648
649        merge_relations(&info1, &info2, &intersections, &mut result);
650
651        result
652    }
653
654    fn transfer_stmt(
655        &mut self,
656        info: &mut Self::Info,
657        _statement_location: super::StatementLocation,
658        stmt: &'a Statement<'db>,
659    ) {
660        match stmt {
661            Statement::Snapshot(snapshot_stmt) => {
662                info.union(snapshot_stmt.original(), snapshot_stmt.input.var_id);
663                info.set_relation(
664                    Relation::Snapshot(snapshot_stmt.input.var_id),
665                    snapshot_stmt.snapshot(),
666                );
667            }
668
669            Statement::Desnap(desnap_stmt) => {
670                info.set_relation(Relation::Snapshot(desnap_stmt.output), desnap_stmt.input.var_id);
671            }
672
673            Statement::IntoBox(into_box_stmt) => {
674                info.set_relation(Relation::Box(into_box_stmt.input.var_id), into_box_stmt.output);
675            }
676
677            Statement::Unbox(unbox_stmt) => {
678                info.set_relation(Relation::Box(unbox_stmt.output), unbox_stmt.input.var_id);
679            }
680
681            Statement::EnumConstruct(enum_stmt) => {
682                // output = Variant(input): track via forward map
683                // If we've already seen this variant with an equivalent input, the outputs are
684                // equal.
685                info.set_relation(
686                    Relation::EnumConstruct(enum_stmt.variant, enum_stmt.input.var_id),
687                    enum_stmt.output,
688                );
689            }
690
691            Statement::StructConstruct(struct_stmt) => {
692                // output = StructType(inputs...): track via forward map
693                // If we've already seen the same struct type with equivalent inputs, the outputs
694                // are equal.
695                let ty = self.lowered.variables[struct_stmt.output].ty;
696                let input_reps = struct_stmt.inputs.iter().map(|i| info.find(i.var_id)).collect();
697                info.set_relation(Relation::StructConstruct(ty, input_reps), struct_stmt.output);
698            }
699
700            Statement::StructDestructure(struct_stmt) => {
701                // (outputs...) = struct_destructure(input)
702                // 1. If input was previously constructed, union outputs with original fields.
703                if let Some((_, field_reps)) = info.get_struct_construct(struct_stmt.input.var_id) {
704                    for (&output, &field_rep) in struct_stmt.outputs.iter().zip(field_reps.iter()) {
705                        info.union(output, field_rep);
706                    }
707                }
708                // 2. Record: struct_construct(outputs) == input (for future constructs).
709                let ty = self.lowered.variables[struct_stmt.input.var_id].ty;
710                let output_reps = struct_stmt.outputs.iter().map(|&v| info.find(v)).collect();
711                info.set_relation(
712                    Relation::StructConstruct(ty, output_reps),
713                    struct_stmt.input.var_id,
714                );
715            }
716
717            Statement::Call(call_stmt) => {
718                let Some((id, _)) = call_stmt.function.get_extern(self.db) else { return };
719                if id == self.array_new {
720                    let ty = self.lowered.variables[call_stmt.outputs[0]].ty;
721                    info.set_relation(Relation::StructConstruct(ty, vec![]), call_stmt.outputs[0]);
722                } else if id == self.array_append
723                    && let Some((ty, elems)) = info.get_struct_construct(call_stmt.inputs[0].var_id)
724                {
725                    // Only track append if the input array is already tracked. Arrays from
726                    // function parameters or external calls are conservatively ignored.
727                    let mut new_elems = elems;
728                    new_elems.push(info.find(call_stmt.inputs[1].var_id));
729                    info.set_relation(
730                        Relation::StructConstruct(ty, new_elems),
731                        call_stmt.outputs[0],
732                    );
733                }
734            }
735
736            Statement::Const(_) => {}
737        }
738    }
739
740    fn transfer_edge(&mut self, info: &Self::Info, edge: &Edge<'db, 'a>) -> Self::Info {
741        let mut new_info = info.clone();
742        match edge {
743            Edge::Goto { remapping, .. } => {
744                // Union remapped variables: dst and src should be in the same equivalence class
745                for (dst, src_usage) in remapping.iter() {
746                    new_info.union(*dst, src_usage.var_id);
747                }
748            }
749            Edge::MatchArm { arm, match_info } => {
750                // For enum matches, track that matched_var = Variant(arm_var).
751                if let MatchInfo::Enum(enum_info) = match_info
752                    && let MatchArmSelector::VariantId(variant) = arm.arm_selector
753                    && let [arm_var] = arm.var_ids[..]
754                {
755                    let matched_var = enum_info.input.var_id;
756
757                    // If we previously saw this enum constructed with the same variant,
758                    // union with the original input. Skip if variants differ — this can
759                    // happen after optimizations merge states from different branches.
760                    let output_rep = new_info.find(matched_var);
761                    if let Some((old_variant, input)) =
762                        new_info.get_enum_construct_immut(output_rep)
763                        && variant == old_variant
764                    {
765                        new_info.union(arm_var, input);
766                    }
767
768                    // Record the relationship: matched_var = Variant(arm_var)
769                    new_info.set_relation(Relation::EnumConstruct(variant, arm_var), matched_var);
770                }
771
772                // For extern matches on array operations, track pop/destructure relationships.
773                if let MatchInfo::Extern(extern_info) = match_info {
774                    self.transfer_extern_match_arm(&mut new_info, extern_info, arm);
775                }
776            }
777            Edge::Return { .. } | Edge::Panic { .. } => {}
778        }
779        new_info
780    }
781}