1use cairo_lang_debug::DebugWithDb;
11use cairo_lang_defs::ids::{ExternFunctionId, NamedLanguageElementId};
12use cairo_lang_semantic::corelib::option_some_variant;
13use cairo_lang_semantic::helper::ModuleHelper;
14use cairo_lang_semantic::{ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId};
15use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
16use salsa::Database;
17
18use crate::analysis::core::Edge;
19use crate::analysis::{DataflowAnalyzer, Direction, ForwardDataflowAnalysis};
20use crate::{
21 BlockEnd, BlockId, Lowered, MatchArm, MatchExternInfo, MatchInfo, Statement, VariableId,
22};
23
24#[derive(Clone, Debug, Hash, PartialEq, Eq)]
27enum Relation<'db> {
28 Box(VariableId),
29 Snapshot(VariableId),
30 EnumConstruct(ConcreteVariant<'db>, VariableId),
31 StructConstruct(TypeId<'db>, Vec<VariableId>),
32}
33
34impl<'db> Relation<'db> {
35 fn referenced_vars(&self) -> impl Iterator<Item = VariableId> + '_ {
37 let (single, fields): (Option<VariableId>, &[VariableId]) = match self {
38 Relation::Box(v) | Relation::Snapshot(v) | Relation::EnumConstruct(_, v) => {
39 (Some(*v), &[])
40 }
41 Relation::StructConstruct(_, vs) => (None, vs),
42 };
43 single.into_iter().chain(fields.iter().copied())
44 }
45
46 fn with_fresh_reps(self, state: &mut EqualityState<'_>) -> Self {
48 match self {
49 Relation::Box(v) => Relation::Box(state.find(v)),
50 Relation::Snapshot(v) => Relation::Snapshot(state.find(v)),
51 Relation::EnumConstruct(variant, v) => Relation::EnumConstruct(variant, state.find(v)),
52 Relation::StructConstruct(ty, fields) => {
53 Relation::StructConstruct(ty, fields.into_iter().map(|v| state.find(v)).collect())
54 }
55 }
56 }
57
58 fn union_equal_relations(self, other: Option<Self>, uf: &mut EqualityState<'_>) -> Self {
62 let Some(other_rel) = other else { return self };
63 match (&self, &other_rel) {
64 (Relation::Box(a), Relation::Box(b)) if a != b => Relation::Box(uf.union(*a, *b)),
65 (Relation::Snapshot(a), Relation::Snapshot(b)) if a != b => {
66 Relation::Snapshot(uf.union(*a, *b))
67 }
68 (Relation::EnumConstruct(v1, a), Relation::EnumConstruct(v2, b))
69 if v1 == v2 && a != b =>
70 {
71 Relation::EnumConstruct(*v1, uf.union(*a, *b))
72 }
73 (Relation::StructConstruct(t1, a), Relation::StructConstruct(t2, b))
74 if t1 == t2 && a.len() == b.len() =>
75 {
76 Relation::StructConstruct(
77 *t1,
78 a.iter().zip(b).map(|(x1, x2)| uf.union(*x1, *x2)).collect(),
79 )
80 }
81 _ => self,
83 }
84 }
85}
86
87#[derive(Clone, Debug, Default)]
95pub struct EqualityState<'db> {
96 union_find: OrderedHashMap<VariableId, VariableId>,
98
99 forward: OrderedHashMap<Relation<'db>, VariableId>,
107
108 reverse: OrderedHashMap<VariableId, Relation<'db>>,
111}
112
113impl<'db> EqualityState<'db> {
114 fn get_parent(&self, var: VariableId) -> VariableId {
116 self.union_find.get(&var).copied().unwrap_or(var)
117 }
118
119 fn find(&mut self, mut var: VariableId) -> VariableId {
122 let mut parent = self.get_parent(var);
123 while parent != var {
124 let grandparent = self.get_parent(parent);
125 self.union_find.insert(var, grandparent);
126 var = parent;
127 parent = grandparent;
128 }
129 var
130 }
131
132 pub(crate) fn find_immut(&self, mut var: VariableId) -> VariableId {
134 let mut parent = self.get_parent(var);
135 while parent != var {
136 var = parent;
137 parent = self.get_parent(var);
138 }
139 var
140 }
141
142 fn union(&mut self, a: VariableId, b: VariableId) -> VariableId {
146 let root_a = self.find(a);
147 let root_b = self.find(b);
148
149 if root_a == root_b {
150 return root_a;
151 }
152
153 let (new_root, old_root) =
156 if root_a.index() < root_b.index() { (root_a, root_b) } else { (root_b, root_a) };
157
158 self.union_find.entry(new_root).or_insert(new_root);
160 self.union_find.insert(old_root, new_root);
162
163 let old_reverse = self.reverse.swap_remove(&old_root);
165 let new_reverse = self.reverse.swap_remove(&new_root);
166 let merged_reverse = match (new_reverse, old_reverse) {
167 (Some(new_rev), old) => Some(new_rev.union_equal_relations(old, self)),
168 (None, old) => old,
169 };
170
171 let constructors = [Relation::Box, Relation::Snapshot];
179 for ctor in constructors {
180 let old_fwd = self.forward.swap_remove(&ctor(old_root));
181 let new_fwd = self.forward.swap_remove(&ctor(new_root));
182 let merged = match (new_fwd, old_fwd) {
183 (Some(t1), Some(t2)) => Some(self.union(t1, t2)),
184 (Some(t), None) | (None, Some(t)) => Some(t),
185 (None, None) => None,
186 };
187 if let Some(target) = merged {
188 let final_root = self.find(new_root);
189 let target_rep = self.find(target);
190 self.forward.insert(ctor(final_root), target_rep);
191 }
192 }
193
194 let final_root = self.find(new_root);
195 if let Some(merged_reverse) = merged_reverse {
196 self.reverse.insert(final_root, merged_reverse);
197 }
198
199 self.find(new_root)
200 }
201
202 fn set_relation(&mut self, relation: Relation<'db>, output: VariableId) {
206 let relation = relation.with_fresh_reps(self);
208
209 if let Some(&existing_output) = self.forward.get(&relation) {
211 self.union(existing_output, output);
212 }
213
214 let output_rep = self.find(output);
216 let existing = self.reverse.swap_remove(&output_rep);
217 let relation = relation.union_equal_relations(existing, self);
218
219 let output_rep = self.find(output);
221 self.forward.insert(relation.clone(), output_rep);
222 self.reverse.insert(output_rep, relation);
223 }
224
225 fn get_struct_construct_immut(
227 &self,
228 rep: VariableId,
229 ) -> Option<(TypeId<'db>, Vec<VariableId>)> {
230 match self.reverse.get(&rep)? {
231 Relation::StructConstruct(ty, fields) => Some((*ty, fields.clone())),
232 _ => None,
233 }
234 }
235
236 fn get_struct_construct(&mut self, var: VariableId) -> Option<(TypeId<'db>, Vec<VariableId>)> {
238 let rep = self.find(var);
239 self.get_struct_construct_immut(rep)
240 }
241
242 fn get_enum_construct_immut(
244 &self,
245 rep: VariableId,
246 ) -> Option<(ConcreteVariant<'db>, VariableId)> {
247 match self.reverse.get(&rep)? {
248 Relation::EnumConstruct(variant, input) => Some((*variant, *input)),
249 _ => None,
250 }
251 }
252}
253
254impl<'db> DebugWithDb<'db> for EqualityState<'db> {
255 type Db = dyn Database;
256
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'db Self::Db) -> std::fmt::Result {
258 let v = |id: VariableId| format!("v{}", self.find_immut(id).index());
259 let mut lines = Vec::<String>::new();
260 for (relation, &output) in self.forward.iter() {
261 match relation {
262 Relation::Snapshot(source) => {
263 lines.push(format!("@{} = {}", v(*source), v(output)));
264 }
265 Relation::Box(source) => {
266 lines.push(format!("Box({}) = {}", v(*source), v(output)));
267 }
268 Relation::EnumConstruct(variant, input) => {
269 let name = variant.id.name(db).to_string(db);
270 lines.push(format!("{name}({}) = {}", v(*input), v(output)));
271 }
272 Relation::StructConstruct(ty, inputs) => {
273 let type_name = ty.format(db);
274 let fields = inputs.iter().map(|&id| v(id)).collect::<Vec<_>>().join(", ");
275 lines.push(format!("{type_name}({fields}) = {}", v(output)));
276 }
277 }
278 }
279 for &var in self.union_find.keys() {
280 let rep = self.find_immut(var);
281 if var != rep {
282 lines.push(format!("v{} = v{}", rep.index(), var.index()));
283 }
284 }
285 lines.sort();
286 if lines.is_empty() { write!(f, "(empty)") } else { write!(f, "{}", lines.join(", ")) }
287 }
288}
289
290pub struct EqualityAnalysis<'a, 'db> {
296 db: &'db dyn Database,
297 lowered: &'a Lowered<'db>,
298 array_new: ExternFunctionId<'db>,
300 array_append: ExternFunctionId<'db>,
302 array_pop_front: ExternFunctionId<'db>,
304 array_pop_front_consume: ExternFunctionId<'db>,
306 array_snapshot_pop_front: ExternFunctionId<'db>,
308 array_snapshot_pop_back: ExternFunctionId<'db>,
310}
311
312impl<'a, 'db> EqualityAnalysis<'a, 'db> {
313 pub fn new(db: &'db dyn Database, lowered: &'a Lowered<'db>) -> Self {
315 let array_module = ModuleHelper::core(db).submodule("array");
316 Self {
317 db,
318 lowered,
319 array_new: array_module.extern_function_id("array_new"),
320 array_append: array_module.extern_function_id("array_append"),
321 array_pop_front: array_module.extern_function_id("array_pop_front"),
322 array_pop_front_consume: array_module.extern_function_id("array_pop_front_consume"),
323 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
324 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
325 }
326 }
327
328 pub fn analyze(
331 db: &'db dyn Database,
332 lowered: &'a Lowered<'db>,
333 ) -> Vec<Option<EqualityState<'db>>> {
334 ForwardDataflowAnalysis::new(lowered, EqualityAnalysis::new(db, lowered)).run()
335 }
336
337 fn transfer_extern_match_arm(
347 &self,
348 info: &mut EqualityState<'db>,
349 extern_info: &MatchExternInfo<'db>,
350 arm: &MatchArm<'db>,
351 ) {
352 let Some((id, _)) = extern_info.function.get_extern(self.db) else { return };
353 let MatchArmSelector::VariantId(variant) = arm.arm_selector else { return };
355 if id == self.array_pop_front
356 || id == self.array_pop_front_consume
357 || id == self.array_snapshot_pop_front
358 || id == self.array_snapshot_pop_back
359 {
360 let [GenericArgumentId::Type(option_ty)] =
361 variant.concrete_enum_id.long(self.db).generic_args[..]
362 else {
363 panic!("Expected Option<T> with a single type argument");
364 };
365 let some_variant = option_some_variant(self.db, option_ty);
366 assert_eq!(
367 variant.concrete_enum_id.enum_id(self.db),
368 some_variant.concrete_enum_id.enum_id(self.db),
369 "Expected match to be on an Option<T>"
370 );
371 self.transfer_array_pop_arm(info, extern_info, arm, id, variant == some_variant);
372 }
373 }
374
375 fn transfer_array_pop_arm(
377 &self,
378 info: &mut EqualityState<'db>,
379 extern_info: &MatchExternInfo<'db>,
380 arm: &MatchArm<'db>,
381 id: ExternFunctionId<'db>,
382 is_some: bool,
383 ) {
384 if id == self.array_pop_front || id == self.array_pop_front_consume {
385 if is_some {
386 let input_arr = extern_info.inputs[0].var_id;
388 let remaining_arr = arm.var_ids[0];
389 let boxed_elem = arm.var_ids[1];
390 if let Some((ty, elems)) = info.get_struct_construct(input_arr)
391 && let Some((&first, rest)) = elems.split_first()
392 {
393 info.set_relation(Relation::Box(first), boxed_elem);
394 let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect();
395 info.set_relation(Relation::StructConstruct(ty, rest_reps), remaining_arr);
396 }
397 } else {
398 let old_array_var = extern_info.inputs[0].var_id;
401 let ty = self.lowered.variables[old_array_var].ty;
402 info.set_relation(Relation::StructConstruct(ty, vec![]), old_array_var);
407 if let [original_arr] = arm.var_ids[..] {
408 info.union(original_arr, old_array_var);
409 }
410 }
411 } else if id == self.array_snapshot_pop_front || id == self.array_snapshot_pop_back {
412 if is_some {
413 let input_snap_arr = extern_info.inputs[0].var_id;
415 let remaining_snap_arr = arm.var_ids[0];
416 let boxed_snap_elem = arm.var_ids[1];
417
418 let snap_rep = info.find(input_snap_arr);
420 let original_rep = match info.reverse.get(&snap_rep) {
421 Some(Relation::Snapshot(v)) => Some(*v),
422 _ => None,
423 };
424 let elems_opt = original_rep
425 .and_then(|orig| {
426 let orig = info.find_immut(orig);
427 info.get_struct_construct_immut(orig)
428 })
429 .or_else(|| info.get_struct_construct_immut(snap_rep));
430
431 if let Some((_orig_ty, elems)) = elems_opt {
432 let pop_front = id == self.array_snapshot_pop_front;
433 let (elem, rest) = if pop_front {
434 let Some((&first, tail)) = elems.split_first() else { return };
435 (first, tail)
436 } else {
437 let Some((&last, init)) = elems.split_last() else { return };
438 (last, init)
439 };
440
441 let elem_rep = info.find(elem);
446 if let Some(&snap_of_elem) = info.forward.get(&Relation::Snapshot(elem_rep)) {
447 info.set_relation(Relation::Box(snap_of_elem), boxed_snap_elem);
448 }
449
450 let snap_ty = self.lowered.variables[remaining_snap_arr].ty;
459 let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect();
460 info.set_relation(
461 Relation::StructConstruct(snap_ty, rest_reps),
462 remaining_snap_arr,
463 );
464 }
465 } else {
466 let old_snap_arr = extern_info.inputs[0].var_id;
468 let snap_ty = self.lowered.variables[old_snap_arr].ty;
469 info.set_relation(Relation::StructConstruct(snap_ty, vec![]), old_snap_arr);
470 if let [original_snap_arr] = arm.var_ids[..] {
471 info.union(original_snap_arr, old_snap_arr);
472 }
473 }
474 }
475 }
476}
477
478fn merge_referenced_vars<'db, 'a>(
481 info1: &'a EqualityState<'db>,
482 info2: &'a EqualityState<'db>,
483) -> impl Iterator<Item = VariableId> + 'a {
484 let union_find_vars = info1.union_find.keys().chain(info2.union_find.keys()).copied();
485
486 let forward_vars =
487 info1.forward.iter().chain(info2.forward.iter()).flat_map(|(relation, &output)| {
488 relation.referenced_vars().chain(std::iter::once(output))
489 });
490
491 let reverse_vars = info1
492 .reverse
493 .iter()
494 .chain(info2.reverse.iter())
495 .flat_map(|(&rep, relation)| std::iter::once(rep).chain(relation.referenced_vars()));
496
497 union_find_vars.chain(forward_vars).chain(reverse_vars)
498}
499
500fn find_intersection_rep(
503 intersections: &OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>>,
504 rep1: VariableId,
505 rep2: VariableId,
506) -> Option<VariableId> {
507 intersections.get(&rep1)?.iter().find_map(|(intersection_r2, intersection_rep)| {
508 (*intersection_r2 == rep2).then_some(*intersection_rep)
509 })
510}
511
512fn merge_relations<'db>(
515 info1: &EqualityState<'db>,
516 info2: &EqualityState<'db>,
517 intersections: &OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>>,
518 result: &mut EqualityState<'db>,
519) {
520 for (relation, &output1) in info1.forward.iter() {
524 match relation {
525 Relation::Box(source1) | Relation::Snapshot(source1) => {
526 for &(source2, intersection_var) in intersections.get(source1).unwrap_or(&vec![]) {
527 let relation2 = match relation {
528 Relation::Box(_) => Relation::Box(source2),
529 Relation::Snapshot(_) => Relation::Snapshot(source2),
530 _ => unreachable!(),
531 };
532 let Some(&output2) = info2.forward.get(&relation2) else { continue };
533 if let Some(output_intersection) = find_intersection_rep(
534 intersections,
535 info1.find_immut(output1),
536 info2.find_immut(output2),
537 ) {
538 let result_relation = match relation {
539 Relation::Box(_) => Relation::Box(result.find(intersection_var)),
540 Relation::Snapshot(_) => {
541 Relation::Snapshot(result.find(intersection_var))
542 }
543 _ => unreachable!(),
544 };
545 result.set_relation(result_relation, output_intersection);
546 }
547 }
548 }
549 Relation::EnumConstruct(variant, input1) => {
550 for &(input2, input_intersection) in
551 intersections.get(&info1.find_immut(*input1)).unwrap_or(&vec![])
552 {
553 let relation2 = Relation::EnumConstruct(*variant, input2);
554 let Some(&output2) = info2.forward.get(&relation2) else { continue };
555 if let Some(output_intersection) = find_intersection_rep(
556 intersections,
557 info1.find_immut(output1),
558 info2.find_immut(output2),
559 ) {
560 result.set_relation(
561 Relation::EnumConstruct(*variant, input_intersection),
562 output_intersection,
563 );
564 }
565 }
566 }
567 Relation::StructConstruct(ty, fields1) => {
568 let fields2: Vec<_> = fields1.iter().map(|&v| info2.find_immut(v)).collect();
569 let Some(&output2) =
570 info2.forward.get(&Relation::StructConstruct(*ty, fields2.clone()))
571 else {
572 continue;
573 };
574 let result_fields: Option<Vec<_>> = fields1
575 .iter()
576 .zip(&fields2)
577 .map(|(&v1, &v2)| {
578 find_intersection_rep(intersections, info1.find_immut(v1), v2)
579 })
580 .collect();
581 let Some(result_fields) = result_fields else { continue };
582 if let Some(output_intersection) = find_intersection_rep(
583 intersections,
584 info1.find_immut(output1),
585 info2.find_immut(output2),
586 ) {
587 result.set_relation(
588 Relation::StructConstruct(*ty, result_fields),
589 output_intersection,
590 );
591 }
592 }
593 }
594 }
595}
596
597impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> {
598 type Info = EqualityState<'db>;
599
600 const DIRECTION: Direction = Direction::Forward;
601
602 fn initial_info(&mut self, _block_id: BlockId, _block_end: &'a BlockEnd<'db>) -> Self::Info {
603 EqualityState::default()
604 }
605
606 fn merge(
607 &mut self,
608 _lowered: &Lowered<'db>,
609 _statement_location: super::StatementLocation,
610 info1: Self::Info,
611 info2: Self::Info,
612 ) -> Self::Info {
613 let mut result = EqualityState::default();
615
616 let mut groups: OrderedHashMap<(VariableId, VariableId), Vec<VariableId>> =
618 OrderedHashMap::default();
619
620 for var in merge_referenced_vars(&info1, &info2) {
622 let key = (info1.find_immut(var), info2.find_immut(var));
623 groups.entry(key).or_default().push(var);
624 }
625
626 for members in groups.values() {
628 if members.len() > 1 {
629 let first = members[0];
630 for &var in &members[1..] {
631 result.union(first, var);
632 }
633 }
634 }
635
636 let mut intersections: OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>> =
644 OrderedHashMap::default();
645 for (&(rep1, rep2), vars) in groups.iter() {
646 intersections.entry(rep1).or_default().push((rep2, result.find(vars[0])));
647 }
648
649 merge_relations(&info1, &info2, &intersections, &mut result);
650
651 result
652 }
653
654 fn transfer_stmt(
655 &mut self,
656 info: &mut Self::Info,
657 _statement_location: super::StatementLocation,
658 stmt: &'a Statement<'db>,
659 ) {
660 match stmt {
661 Statement::Snapshot(snapshot_stmt) => {
662 info.union(snapshot_stmt.original(), snapshot_stmt.input.var_id);
663 info.set_relation(
664 Relation::Snapshot(snapshot_stmt.input.var_id),
665 snapshot_stmt.snapshot(),
666 );
667 }
668
669 Statement::Desnap(desnap_stmt) => {
670 info.set_relation(Relation::Snapshot(desnap_stmt.output), desnap_stmt.input.var_id);
671 }
672
673 Statement::IntoBox(into_box_stmt) => {
674 info.set_relation(Relation::Box(into_box_stmt.input.var_id), into_box_stmt.output);
675 }
676
677 Statement::Unbox(unbox_stmt) => {
678 info.set_relation(Relation::Box(unbox_stmt.output), unbox_stmt.input.var_id);
679 }
680
681 Statement::EnumConstruct(enum_stmt) => {
682 info.set_relation(
686 Relation::EnumConstruct(enum_stmt.variant, enum_stmt.input.var_id),
687 enum_stmt.output,
688 );
689 }
690
691 Statement::StructConstruct(struct_stmt) => {
692 let ty = self.lowered.variables[struct_stmt.output].ty;
696 let input_reps = struct_stmt.inputs.iter().map(|i| info.find(i.var_id)).collect();
697 info.set_relation(Relation::StructConstruct(ty, input_reps), struct_stmt.output);
698 }
699
700 Statement::StructDestructure(struct_stmt) => {
701 if let Some((_, field_reps)) = info.get_struct_construct(struct_stmt.input.var_id) {
704 for (&output, &field_rep) in struct_stmt.outputs.iter().zip(field_reps.iter()) {
705 info.union(output, field_rep);
706 }
707 }
708 let ty = self.lowered.variables[struct_stmt.input.var_id].ty;
710 let output_reps = struct_stmt.outputs.iter().map(|&v| info.find(v)).collect();
711 info.set_relation(
712 Relation::StructConstruct(ty, output_reps),
713 struct_stmt.input.var_id,
714 );
715 }
716
717 Statement::Call(call_stmt) => {
718 let Some((id, _)) = call_stmt.function.get_extern(self.db) else { return };
719 if id == self.array_new {
720 let ty = self.lowered.variables[call_stmt.outputs[0]].ty;
721 info.set_relation(Relation::StructConstruct(ty, vec![]), call_stmt.outputs[0]);
722 } else if id == self.array_append
723 && let Some((ty, elems)) = info.get_struct_construct(call_stmt.inputs[0].var_id)
724 {
725 let mut new_elems = elems;
728 new_elems.push(info.find(call_stmt.inputs[1].var_id));
729 info.set_relation(
730 Relation::StructConstruct(ty, new_elems),
731 call_stmt.outputs[0],
732 );
733 }
734 }
735
736 Statement::Const(_) => {}
737 }
738 }
739
740 fn transfer_edge(&mut self, info: &Self::Info, edge: &Edge<'db, 'a>) -> Self::Info {
741 let mut new_info = info.clone();
742 match edge {
743 Edge::Goto { remapping, .. } => {
744 for (dst, src_usage) in remapping.iter() {
746 new_info.union(*dst, src_usage.var_id);
747 }
748 }
749 Edge::MatchArm { arm, match_info } => {
750 if let MatchInfo::Enum(enum_info) = match_info
752 && let MatchArmSelector::VariantId(variant) = arm.arm_selector
753 && let [arm_var] = arm.var_ids[..]
754 {
755 let matched_var = enum_info.input.var_id;
756
757 let output_rep = new_info.find(matched_var);
761 if let Some((old_variant, input)) =
762 new_info.get_enum_construct_immut(output_rep)
763 && variant == old_variant
764 {
765 new_info.union(arm_var, input);
766 }
767
768 new_info.set_relation(Relation::EnumConstruct(variant, arm_var), matched_var);
770 }
771
772 if let MatchInfo::Extern(extern_info) = match_info {
774 self.transfer_extern_match_arm(&mut new_info, extern_info, arm);
775 }
776 }
777 Edge::Return { .. } | Edge::Panic { .. } => {}
778 }
779 new_info
780 }
781}