cairo_lang_lowering/lower/
refs.rs

1use cairo_lang_defs::ids::MemberId;
2use cairo_lang_proc_macros::DebugWithDb;
3use cairo_lang_semantic::expr::fmt::ExprFormatter;
4use cairo_lang_semantic::expr::inference::InferenceError;
5use cairo_lang_semantic::items::structure::StructSemantic;
6use cairo_lang_semantic::usage::MemberPath;
7use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeLongId};
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
10use itertools::{Itertools, chain};
11
12use super::block_builder::BlockStructRecomposer;
13use crate::VariableId;
14use crate::ids::LocationId;
15
16/// Information about members captured by the closure and their types.
17#[derive(Clone, Debug)]
18pub struct ClosureInfo<'db> {
19    // TODO(TomerStarkware): unite copiable members and snapshots into a single map.
20    /// The members captured by the closure (not as snapshot).
21    pub members: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
22    /// The types of the captured snapshot variables.
23    pub snapshots: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
24}
25
26/// The result of `SemanticLoweringMapping::assemble_value`.
27pub enum AssembleValueError<'db> {
28    /// The variable was moved before.
29    Moved(MovedVar<'db>),
30    /// The variable is missing from `SemanticLoweringMapping::scattered`.
31    Missing,
32}
33
34#[derive(Clone, Default, Debug)]
35pub struct SemanticLoweringMapping<'db> {
36    /// Maps member paths ([MemberPath]) to lowered variable ids or scattered variable ids.
37    scattered: OrderedHashMap<MemberPath<'db>, Value<'db>>,
38}
39impl<'db> SemanticLoweringMapping<'db> {
40    /// Returns the topmost mapped member path containing the given member path, or None no such
41    /// member path exists in the mapping.
42    pub fn topmost_mapped_containing_member_path(
43        &self,
44        mut member_path: MemberPath<'db>,
45    ) -> Option<MemberPath<'db>> {
46        let mut res = None;
47        loop {
48            if self.scattered.contains_key(&member_path) {
49                res = Some(member_path.clone());
50            }
51            let MemberPath::Member { parent, .. } = member_path else {
52                return res;
53            };
54            member_path = *parent;
55        }
56    }
57
58    pub fn destructure_closure(
59        &mut self,
60        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
61        closure_var: VariableId,
62        closure_info: &ClosureInfo<'db>,
63    ) -> Vec<VariableId> {
64        ctx.deconstruct_by_types(
65            closure_var,
66            chain!(closure_info.members.values(), closure_info.snapshots.values()).cloned(),
67        )
68    }
69
70    pub fn get(
71        &mut self,
72        mut ctx: BlockStructRecomposer<'_, '_, 'db>,
73        path: &MemberPath<'db>,
74    ) -> Result<VariableId, AssembleValueError<'db>> {
75        let value = self.break_into_value(&mut ctx, path).ok_or(AssembleValueError::Missing)?;
76        Self::assemble_value(&mut ctx, value).map_err(AssembleValueError::Moved)
77    }
78
79    pub fn introduce(&mut self, path: MemberPath<'db>, var: VariableId) {
80        self.scattered.insert(path, Value::Var(var));
81    }
82
83    pub fn update(
84        &mut self,
85        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
86        path: &MemberPath<'db>,
87        var: VariableId,
88    ) -> Option<()> {
89        // TODO(TomerStarkware): check if path is captured by a closure and invalidate the closure.
90        // Right now this can only happen if we take a snapshot of the variable (as the
91        // snapshot function returns a new var).
92        // we need to ensure the borrow checker invalidates the closure when mutable capture
93        // is supported.
94
95        let value = self.break_into_value(ctx, path)?;
96        *value = Value::Var(var);
97        Some(())
98    }
99
100    /// Marks the variable at the given path as moved.
101    ///
102    /// This function should be called for non-copyable variables.
103    pub fn mark_as_used(
104        &mut self,
105        mut ctx: BlockStructRecomposer<'_, '_, 'db>,
106        path: &MemberPath<'db>,
107        moved: MovedVar<'db>,
108    ) {
109        *self.break_into_value(&mut ctx, path).unwrap() = Value::MovedVar(moved);
110    }
111
112    /// Assembles a [VariableId] from the given [Value] by recursively reconstructing it if it is
113    /// currently deconstructed.
114    ///
115    /// Returns a [MovedVar] if the variable, or any of its members, were moved before.
116    fn assemble_value(
117        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
118        value: &mut Value<'db>,
119    ) -> Result<VariableId, MovedVar<'db>> {
120        match value {
121            Value::Var(var) => Ok(*var),
122            Value::Scattered(scattered) => {
123                let members_res = scattered
124                    .members
125                    .iter_mut()
126                    .map(|(_, value)| Self::assemble_value(ctx, value))
127                    .collect::<Result<_, _>>();
128
129                match members_res {
130                    Ok(members) => {
131                        let var = ctx.reconstruct(scattered.concrete_struct_id, members);
132                        *value = Value::Var(var);
133                        Ok(var)
134                    }
135                    Err(MovedVar { ty: _, inference_error, last_use_location }) => {
136                        // Don't use the type of the moved member. Replace it with the type of the
137                        // variable.
138                        let y = TypeLongId::<'db>::Concrete(ConcreteTypeId::Struct(
139                            scattered.concrete_struct_id,
140                        ));
141                        let x = y.intern(ctx.ctx.db);
142                        Err(MovedVar { ty: x, inference_error, last_use_location })
143                    }
144                }
145            }
146            Value::MovedVar(moved) => Err(moved.clone()),
147        }
148    }
149
150    fn break_into_value(
151        &mut self,
152        ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
153        path: &MemberPath<'db>,
154    ) -> Option<&mut Value<'db>> {
155        if self.scattered.contains_key(path) {
156            return self.scattered.get_mut(path);
157        }
158
159        let MemberPath::Member { parent, member_id, concrete_struct_id, .. } = path else {
160            return None;
161        };
162
163        let parent_value = self.break_into_value(ctx, parent)?;
164        match parent_value {
165            Value::Var(var) => {
166                let members = ctx.deconstruct(*concrete_struct_id, *var);
167                let members = OrderedHashMap::from_iter(
168                    members.into_iter().map(|(member_id, var)| (member_id, Value::Var(var))),
169                );
170                let scattered = Scattered { concrete_struct_id: *concrete_struct_id, members };
171                *parent_value = Value::Scattered(Box::new(scattered));
172            }
173            Value::MovedVar(MovedVar { ty: _, inference_error, last_use_location }) => {
174                let member_map = ctx.ctx.db.concrete_struct_members(*concrete_struct_id).unwrap();
175                let members = OrderedHashMap::from_iter(member_map.values().map(|member| {
176                    (
177                        member.id,
178                        Value::MovedVar(MovedVar {
179                            ty: member.ty,
180                            inference_error: inference_error.clone(),
181                            last_use_location: *last_use_location,
182                        }),
183                    )
184                }));
185                let scattered = Scattered { concrete_struct_id: *concrete_struct_id, members };
186                *parent_value = Value::Scattered(Box::new(scattered));
187            }
188            Value::Scattered(..) => {}
189        };
190        extract_matches!(parent_value, Value::Scattered).members.get_mut(member_id)
191    }
192}
193
194impl<'db> cairo_lang_debug::debug::DebugWithDb<'db> for SemanticLoweringMapping<'db> {
195    type Db = ExprFormatter<'db>;
196
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &ExprFormatter<'db>) -> std::fmt::Result {
198        for (member_path, value) in self.scattered.iter() {
199            writeln!(f, "{:?}: {value}", member_path.debug(db))?;
200        }
201        Ok(())
202    }
203}
204
205/// Merges [SemanticLoweringMapping] from multiple blocks to a single [SemanticLoweringMapping].
206///
207/// The mapping from semantic variables to lowered variables in the new block follows these rules:
208///
209/// * Variables mapped to the same lowered variable across all input blocks are kept as-is.
210/// * Local variables that appear in only a subset of the blocks are removed.
211/// * Variables with different mappings across blocks are remapped to a new lowered variable, by
212///   invoking the `remapped_callback` function.
213pub fn merge_semantics<'db, 'a>(
214    mappings: impl Iterator<Item = &'a SemanticLoweringMapping<'db>>,
215    remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
216) -> SemanticLoweringMapping<'db>
217where
218    'db: 'a,
219{
220    // A map from [MemberPath] to its [Value] in the `mappings` where it appears.
221    // If the number of [Value]s is not the length of `mappings`, it is later dropped.
222    let mut path_to_values: OrderedHashMap<MemberPath<'_>, Vec<Value<'_>>> = Default::default();
223
224    let mut n_mappings = 0;
225    for map in mappings {
226        for (path, var) in map.scattered.iter() {
227            path_to_values.entry(path.clone()).or_default().push(var.clone());
228        }
229        n_mappings += 1;
230    }
231
232    let mut scattered: OrderedHashMap<MemberPath<'_>, Value<'_>> = Default::default();
233    for (path, values) in path_to_values {
234        // The variable is missing in one or more of the maps.
235        // It cannot be used in the merged block.
236        if values.len() != n_mappings {
237            continue;
238        }
239
240        let merged_value = compute_remapped_variables(
241            &values.iter().collect_vec(),
242            false,
243            &path,
244            remapped_callback,
245        );
246        scattered.insert(path, merged_value);
247    }
248
249    SemanticLoweringMapping { scattered }
250}
251
252/// Given a list of [Value]s that correspond to the same semantic [MemberPath] in different blocks,
253/// compute the [Value] in the merge block.
254///
255/// If all values are the same, no remapping is needed.
256/// If some of the values are [Value::Var] and some are [Value::Scattered], then all the values
257/// inside the [Value::Scattered] values need to be remapped.
258/// If all of them are [Value::Scattered], then it is possible that some of the members require
259/// remapping and some don't.
260///
261/// Pass `require_remapping=true` to indicate that during the recursion we encountered a
262/// [Value::Var], and thus we need to remap all the [Value::Scattered] values.
263/// In particular, once we have `require_remapping=true`, all the recursive calls in the subtree
264/// will have `require_remapping=true`.
265///
266/// For example, suppose `values` consists of two trees:
267/// * `A = Scattered(Scattered(v0, v1), v2)` and
268/// * `B = Scattered(Scattered(v0, v3), v4)`.
269///
270/// Then, the result will be:
271/// * `Scattered(Scattered(v0, new_var), new_var)`,
272///
273/// since `v0` is the same in both trees, but the other nodes are not.
274///
275/// If in addition to `A` and `B`, we have another tree
276/// * `C = Scattered(v5, v6)`,
277///
278/// then `v5` will need to be deconstructed, so `C` can be thought of as
279/// * `C = Scattered(Scattered(?, ?), v6)`.
280///
281/// Now, the node of `v0` also requires remapping, so the result will be:
282/// * `Scattered(Scattered(new_var, new_var), new_var)`.
283///
284/// In the recursion, when we encounter `v5`, we change `require_remapping` to `true` and drop `C`
285/// from the list of values (keeping only the scattered values).
286/// This signals that inside this subtree, all values need to be remapped (because of the children
287/// of `v5`, which are marked by `?` above).
288fn compute_remapped_variables<'db>(
289    values: &[&Value<'db>],
290    require_remapping: bool,
291    parent_path: &MemberPath<'db>,
292    remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
293) -> Value<'db> {
294    if let Some(x) = values.iter().find(|value| matches!(value, Value::MovedVar { .. })) {
295        // If any of the values being merged is a [MovedVar], the result will be a [MovedVar].
296        // Return an arbitrary one of them.
297        return (*x).clone();
298    }
299
300    if !require_remapping {
301        // If all values are the same, no remapping is needed.
302        let first_var = values[0];
303        if values.iter().all(|x| *x == first_var) {
304            return first_var.clone();
305        }
306    }
307
308    // Collect all the `Value::Scattered` values.
309    let only_scattered: Vec<&Box<Scattered<'_>>> =
310        values.iter().filter_map(|value| try_extract_matches!(value, Value::Scattered)).collect();
311
312    if only_scattered.is_empty() {
313        let remapped_var = remapped_callback(parent_path);
314        return Value::Var(remapped_var);
315    }
316
317    // If we encountered a [Value::Var], we need to remap all the [Value::Scattered] values.
318    let require_remapping = require_remapping || only_scattered.len() < values.len();
319
320    let concrete_struct_id = only_scattered[0].concrete_struct_id;
321    let members = only_scattered[0]
322        .members
323        .keys()
324        .map(|member_id| {
325            let member_path = MemberPath::Member {
326                parent: parent_path.clone().into(),
327                member_id: *member_id,
328                concrete_struct_id,
329            };
330            // Call `compute_remapped_variables` recursively on the scattered values.
331            // If there is a [Value::Var], `require_remapping` will be set to `true` to account
332            // for it.
333            let member_values =
334                only_scattered.iter().map(|scattered| &scattered.members[member_id]).collect_vec();
335
336            (
337                *member_id,
338                compute_remapped_variables(
339                    &member_values,
340                    require_remapping,
341                    &member_path,
342                    remapped_callback,
343                ),
344            )
345        })
346        .collect();
347
348    Value::Scattered(Box::new(Scattered { concrete_struct_id, members }))
349}
350
351/// Returns an iterator to all the [MemberPath]s that appear in both mappings and have different
352/// values.
353pub fn find_changed_members<'db, 'a>(
354    semantics0: &'a SemanticLoweringMapping<'db>,
355    semantics1: &'a SemanticLoweringMapping<'db>,
356) -> impl Iterator<Item = MemberPath<'db>> + 'a {
357    semantics0.scattered.iter().filter_map(|(path, value0)| {
358        if let Some(value1) = semantics1.scattered.get(path)
359            && value0 != value1
360        {
361            return Some(path.clone());
362        }
363        None
364    })
365}
366
367/// Represents a non-copyable variable that was moved, and can no longer be used.
368#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
369#[debug_db(ExprFormatter<'db>)]
370pub struct MovedVar<'db> {
371    /// The type of the variable.
372    pub ty: semantic::TypeId<'db>,
373    /// The reason it is not copyable.
374    pub inference_error: InferenceError<'db>,
375    /// The location of the last use of the moved variable. This is used to report an error.
376    pub last_use_location: LocationId<'db>,
377}
378
379/// An intermediate value for a member path.
380#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
381#[debug_db(ExprFormatter<'db>)]
382enum Value<'db> {
383    /// The value of member path is stored in a lowered variable.
384    Var(VariableId),
385    /// The value of the member path is not stored. If needed, it should be reconstructed from the
386    /// member values.
387    Scattered(Box<Scattered<'db>>),
388    /// Represents a non-copyable variable that was moved, and can no longer be used.
389    MovedVar(MovedVar<'db>),
390}
391
392impl<'db> std::fmt::Display for Value<'db> {
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        match self {
395            Value::Var(var) => write!(f, "v{}", var.index()),
396            Value::Scattered(scattered) => {
397                write!(
398                    f,
399                    "Scattered({})",
400                    scattered.members.values().map(|value| value.to_string()).join(", ")
401                )
402            }
403            Value::MovedVar(..) => write!(f, "MovedVar"),
404        }
405    }
406}
407
408/// A value for a non-stored member path. Recursively holds the [Value] for the members.
409#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
410#[debug_db(ExprFormatter<'db>)]
411struct Scattered<'db> {
412    concrete_struct_id: semantic::ConcreteStructId<'db>,
413    members: OrderedHashMap<MemberId<'db>, Value<'db>>,
414}