Skip to main content

cairo_lang_lowering/lower/
refs.rs

1use cairo_lang_defs::ids::{LanguageElementId, MemberId};
2use cairo_lang_proc_macros::DebugWithDb;
3use cairo_lang_semantic::expr::fmt::ExprFormatter;
4use cairo_lang_semantic::expr::inference::InferenceError;
5use cairo_lang_semantic::items::structure::StructSemantic;
6use cairo_lang_semantic::usage::MemberPath;
7use cairo_lang_semantic::{self as semantic};
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9use cairo_lang_utils::{extract_matches, try_extract_matches};
10use itertools::{Itertools, chain};
11
12use super::block_builder::BlockStructRecomposer;
13use super::context::VarRequest;
14use crate::VariableId;
15use crate::ids::LocationId;
16
17/// Information about members captured by the closure and their types.
18#[derive(Clone, Debug)]
19pub struct ClosureInfo<'db> {
20    // TODO(TomerStarkware): unite copiable members and snapshots into a single map.
21    /// The members captured by the closure (not as snapshot).
22    pub members: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
23    /// The types of the captured snapshot variables.
24    pub snapshots: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
25}
26
27/// The result of `SemanticLoweringMapping::assemble_value`.
28pub enum AssembleValueError<'db> {
29    /// The variable was moved before.
30    Moved(MovedVar<'db>),
31    /// The variable is missing from `SemanticLoweringMapping::scattered`.
32    Missing,
33}
34
35#[derive(Clone, Default, Debug)]
36pub struct SemanticLoweringMapping<'db> {
37    /// Maps member paths ([MemberPath]) to lowered variable ids or scattered variable ids.
38    scattered: OrderedHashMap<MemberPath<'db>, Value<'db>>,
39}
40impl<'db> SemanticLoweringMapping<'db> {
41    /// Returns the topmost mapped member path containing the given member path, or None no such
42    /// member path exists in the mapping.
43    pub fn topmost_mapped_containing_member_path(
44        &self,
45        mut member_path: MemberPath<'db>,
46    ) -> Option<MemberPath<'db>> {
47        let mut res = None;
48        loop {
49            if self.scattered.contains_key(&member_path) {
50                res = Some(member_path.clone());
51            }
52            let MemberPath::Member { parent, .. } = member_path else {
53                return res;
54            };
55            member_path = *parent;
56        }
57    }
58
59    pub fn destructure_closure(
60        &mut self,
61        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
62        closure_var: VariableId,
63        closure_info: &ClosureInfo<'db>,
64    ) -> Vec<VariableId> {
65        ctx.deconstruct_by_types(
66            closure_var,
67            chain!(closure_info.members.values(), closure_info.snapshots.values()).cloned(),
68        )
69    }
70
71    pub fn get(
72        &mut self,
73        mut ctx: BlockStructRecomposer<'_, '_, 'db>,
74        path: &MemberPath<'db>,
75    ) -> Result<VariableId, AssembleValueError<'db>> {
76        let value = self.break_into_value(&mut ctx, path).ok_or(AssembleValueError::Missing)?;
77        let base_var = path.base_var();
78        let location_stable_ptr = match ctx.ctx.semantic_defs.get(&base_var) {
79            Some(binding) => binding.stable_ptr(ctx.ctx.db),
80            None => base_var.untyped_stable_ptr(ctx.ctx.db),
81        };
82        let location = ctx.ctx.get_location(location_stable_ptr);
83        Self::assemble_value(&mut ctx, value, location).map_err(AssembleValueError::Moved)
84    }
85
86    pub fn introduce(&mut self, path: MemberPath<'db>, var: VariableId) {
87        self.scattered.insert(path, Value::Var(var));
88    }
89
90    pub fn update(
91        &mut self,
92        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
93        path: &MemberPath<'db>,
94        var: VariableId,
95    ) -> Option<()> {
96        // TODO(TomerStarkware): check if path is captured by a closure and invalidate the closure.
97        // Right now this can only happen if we take a snapshot of the variable (as the
98        // snapshot function returns a new var).
99        // we need to ensure the borrow checker invalidates the closure when mutable capture
100        // is supported.
101
102        let value = self.break_into_value(ctx, path)?;
103        *value = Value::Var(var);
104        Some(())
105    }
106
107    /// Marks the variable at the given path as moved.
108    ///
109    /// This function should be called for non-copyable variables.
110    pub fn mark_as_used(
111        &mut self,
112        mut ctx: BlockStructRecomposer<'_, '_, 'db>,
113        path: &MemberPath<'db>,
114        moved: MovedVar<'db>,
115    ) {
116        *self.break_into_value(&mut ctx, path).unwrap() = Value::MovedVar(moved);
117    }
118
119    /// Assembles a [VariableId] from the given [Value] by recursively reconstructing it if it is
120    /// currently deconstructed.
121    ///
122    /// Returns a [MovedVar] if the variable, or any of its members, were moved before.
123    fn assemble_value(
124        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
125        value: &mut Value<'db>,
126        location: LocationId<'db>,
127    ) -> Result<VariableId, MovedVar<'db>> {
128        match value {
129            Value::Var(var) => Ok(*var),
130            Value::MovedVar(moved) => Err(moved.clone()),
131            Value::Scattered(scattered) => {
132                let mut moved_var = None;
133                let members = scattered
134                    .members
135                    .iter_mut()
136                    .map(|(_, value)| match Self::assemble_value(ctx, value, location) {
137                        Ok(var) => var,
138                        Err(moved) => {
139                            let var = moved.var_id;
140                            moved_var.get_or_insert(moved);
141                            var
142                        }
143                    })
144                    .collect_vec();
145                let var = ctx.reconstruct(scattered.concrete_struct_id, members, location);
146                *value = Value::Var(var);
147                if let Some(MovedVar { var_id: _, inference_error, last_use_location }) = moved_var
148                {
149                    Err(MovedVar { var_id: var, inference_error, last_use_location })
150                } else {
151                    Ok(var)
152                }
153            }
154        }
155    }
156
157    fn break_into_value(
158        &mut self,
159        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
160        path: &MemberPath<'db>,
161    ) -> Option<&mut Value<'db>> {
162        if self.scattered.contains_key(path) {
163            return self.scattered.get_mut(path);
164        }
165
166        let &MemberPath::Member { ref parent, member_id, concrete_struct_id, .. } = path else {
167            return None;
168        };
169
170        let parent_value = self.break_into_value(ctx, parent)?;
171        match parent_value {
172            Value::Var(var) => {
173                let members = ctx.deconstruct(concrete_struct_id, *var);
174                let members = OrderedHashMap::from_iter(
175                    members.into_iter().map(|(member_id, var)| (member_id, Value::Var(var))),
176                );
177                let scattered = Scattered { concrete_struct_id, members };
178                *parent_value = Value::Scattered(Box::new(scattered));
179            }
180            &mut Value::MovedVar(MovedVar { var_id, ref inference_error, last_use_location }) => {
181                let member_map = ctx.ctx.db.concrete_struct_members(concrete_struct_id).unwrap();
182                let location = ctx.ctx.variables[var_id].location;
183                let members = OrderedHashMap::from_iter(member_map.values().map(|member| {
184                    (
185                        member.id,
186                        Value::MovedVar(MovedVar {
187                            var_id: ctx.ctx.new_var(VarRequest { ty: member.ty, location }),
188                            inference_error: inference_error.clone(),
189                            last_use_location,
190                        }),
191                    )
192                }));
193                let scattered = Scattered { concrete_struct_id, members };
194                *parent_value = Value::Scattered(Box::new(scattered));
195            }
196            Value::Scattered(..) => {}
197        };
198        extract_matches!(parent_value, Value::Scattered).members.get_mut(&member_id)
199    }
200}
201
202impl<'db> cairo_lang_debug::debug::DebugWithDb<'db> for SemanticLoweringMapping<'db> {
203    type Db = ExprFormatter<'db>;
204
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &ExprFormatter<'db>) -> std::fmt::Result {
206        for (member_path, value) in self.scattered.iter() {
207            writeln!(f, "{:?}: {value}", member_path.debug(db))?;
208        }
209        Ok(())
210    }
211}
212
213/// Merges [SemanticLoweringMapping] from multiple blocks to a single [SemanticLoweringMapping].
214///
215/// The mapping from semantic variables to lowered variables in the new block follows these rules:
216///
217/// * Variables mapped to the same lowered variable across all input blocks are kept as-is.
218/// * Local variables that appear in only a subset of the blocks are removed.
219/// * Variables with different mappings across blocks are remapped to a new lowered variable, by
220///   invoking the `remapped_callback` function.
221pub fn merge_semantics<'db, 'a>(
222    mappings: impl Iterator<Item = &'a SemanticLoweringMapping<'db>>,
223    remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
224) -> SemanticLoweringMapping<'db>
225where
226    'db: 'a,
227{
228    // A map from [MemberPath] to its [Value] in the `mappings` where it appears.
229    // If the number of [Value]s is not the length of `mappings`, it is later dropped.
230    let mut path_to_values: OrderedHashMap<MemberPath<'_>, Vec<Value<'_>>> = Default::default();
231
232    let mut n_mappings = 0;
233    for map in mappings {
234        for (path, var) in map.scattered.iter() {
235            path_to_values.entry(path.clone()).or_default().push(var.clone());
236        }
237        n_mappings += 1;
238    }
239
240    let mut scattered: OrderedHashMap<MemberPath<'_>, Value<'_>> = Default::default();
241    for (path, values) in path_to_values {
242        // The variable is missing in one or more of the maps.
243        // It cannot be used in the merged block.
244        if values.len() != n_mappings {
245            continue;
246        }
247
248        let merged_value = compute_remapped_variables(
249            &values.iter().collect_vec(),
250            false,
251            &path,
252            remapped_callback,
253        );
254        scattered.insert(path, merged_value);
255    }
256
257    SemanticLoweringMapping { scattered }
258}
259
260/// Given a list of [Value]s that correspond to the same semantic [MemberPath] in different blocks,
261/// compute the [Value] in the merge block.
262///
263/// If all values are the same, no remapping is needed.
264/// If some of the values are [Value::Var] and some are [Value::Scattered], then all the values
265/// inside the [Value::Scattered] values need to be remapped.
266/// If all of them are [Value::Scattered], then it is possible that some of the members require
267/// remapping and some don't.
268///
269/// Pass `require_remapping=true` to indicate that during the recursion we encountered a
270/// [Value::Var], and thus we need to remap all the [Value::Scattered] values.
271/// In particular, once we have `require_remapping=true`, all the recursive calls in the subtree
272/// will have `require_remapping=true`.
273///
274/// For example, suppose `values` consists of two trees:
275/// * `A = Scattered(Scattered(v0, v1), v2)` and
276/// * `B = Scattered(Scattered(v0, v3), v4)`.
277///
278/// Then, the result will be:
279/// * `Scattered(Scattered(v0, new_var), new_var)`,
280///
281/// since `v0` is the same in both trees, but the other nodes are not.
282///
283/// If in addition to `A` and `B`, we have another tree
284/// * `C = Scattered(v5, v6)`,
285///
286/// then `v5` will need to be deconstructed, so `C` can be thought of as
287/// * `C = Scattered(Scattered(?, ?), v6)`.
288///
289/// Now, the node of `v0` also requires remapping, so the result will be:
290/// * `Scattered(Scattered(new_var, new_var), new_var)`.
291///
292/// In the recursion, when we encounter `v5`, we change `require_remapping` to `true` and drop `C`
293/// from the list of values (keeping only the scattered values).
294/// This signals that inside this subtree, all values need to be remapped (because of the children
295/// of `v5`, which are marked by `?` above).
296fn compute_remapped_variables<'db>(
297    values: &[&Value<'db>],
298    require_remapping: bool,
299    parent_path: &MemberPath<'db>,
300    remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
301) -> Value<'db> {
302    if let Some(x) = values.iter().find(|value| matches!(value, Value::MovedVar { .. })) {
303        // If any of the values being merged is a [MovedVar], the result will be a [MovedVar].
304        // Return an arbitrary one of them.
305        return (*x).clone();
306    }
307
308    if !require_remapping {
309        // If all values are the same, no remapping is needed.
310        let first_var = values[0];
311        if values.iter().all(|x| *x == first_var) {
312            return first_var.clone();
313        }
314    }
315
316    // Collect all the `Value::Scattered` values.
317    let only_scattered: Vec<&Box<Scattered<'_>>> =
318        values.iter().filter_map(|value| try_extract_matches!(value, Value::Scattered)).collect();
319
320    if only_scattered.is_empty() {
321        let remapped_var = remapped_callback(parent_path);
322        return Value::Var(remapped_var);
323    }
324
325    // If we encountered a [Value::Var], we need to remap all the [Value::Scattered] values.
326    let require_remapping = require_remapping || only_scattered.len() < values.len();
327
328    let concrete_struct_id = only_scattered[0].concrete_struct_id;
329    let members = only_scattered[0]
330        .members
331        .keys()
332        .map(|member_id| {
333            let member_path = MemberPath::Member {
334                parent: parent_path.clone().into(),
335                member_id: *member_id,
336                concrete_struct_id,
337            };
338            // Call `compute_remapped_variables` recursively on the scattered values.
339            // If there is a [Value::Var], `require_remapping` will be set to `true` to account
340            // for it.
341            let member_values =
342                only_scattered.iter().map(|scattered| &scattered.members[member_id]).collect_vec();
343
344            (
345                *member_id,
346                compute_remapped_variables(
347                    &member_values,
348                    require_remapping,
349                    &member_path,
350                    remapped_callback,
351                ),
352            )
353        })
354        .collect();
355
356    Value::Scattered(Box::new(Scattered { concrete_struct_id, members }))
357}
358
359/// Returns an iterator to all the [MemberPath]s that appear in both mappings and have different
360/// values.
361pub fn find_changed_members<'db, 'a>(
362    semantics0: &'a SemanticLoweringMapping<'db>,
363    semantics1: &'a SemanticLoweringMapping<'db>,
364) -> impl Iterator<Item = MemberPath<'db>> + 'a {
365    semantics0.scattered.iter().filter_map(|(path, value0)| {
366        if let Some(value1) = semantics1.scattered.get(path)
367            && value0 != value1
368        {
369            return Some(path.clone());
370        }
371        None
372    })
373}
374
375/// Represents a non-copyable variable that was moved, and can no longer be used.
376#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
377#[debug_db(ExprFormatter<'db>)]
378pub struct MovedVar<'db> {
379    /// The type of the variable.
380    pub var_id: VariableId,
381    /// The reason it is not copyable.
382    pub inference_error: InferenceError<'db>,
383    /// The location of the last use of the moved variable. This is used to report an error.
384    pub last_use_location: LocationId<'db>,
385}
386
387/// An intermediate value for a member path.
388#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
389#[debug_db(ExprFormatter<'db>)]
390enum Value<'db> {
391    /// The value of member path is stored in a lowered variable.
392    Var(VariableId),
393    /// The value of the member path is not stored. If needed, it should be reconstructed from the
394    /// member values.
395    Scattered(Box<Scattered<'db>>),
396    /// Represents a non-copyable variable that was moved, and can no longer be used.
397    MovedVar(MovedVar<'db>),
398}
399
400impl<'db> std::fmt::Display for Value<'db> {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        match self {
403            Value::Var(var) => write!(f, "v{}", var.index()),
404            Value::Scattered(scattered) => {
405                write!(
406                    f,
407                    "Scattered({})",
408                    scattered.members.values().map(|value| value.to_string()).join(", ")
409                )
410            }
411            Value::MovedVar(..) => write!(f, "MovedVar"),
412        }
413    }
414}
415
416/// A value for a non-stored member path. Recursively holds the [Value] for the members.
417#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
418#[debug_db(ExprFormatter<'db>)]
419struct Scattered<'db> {
420    concrete_struct_id: semantic::ConcreteStructId<'db>,
421    members: OrderedHashMap<MemberId<'db>, Value<'db>>,
422}