Skip to main content

cairo_lang_lowering/lower/
refs.rs

1use cairo_lang_defs::ids::MemberId;
2use cairo_lang_proc_macros::DebugWithDb;
3use cairo_lang_semantic::expr::fmt::ExprFormatter;
4use cairo_lang_semantic::usage::MemberPath;
5use cairo_lang_semantic::{self as semantic};
6use cairo_lang_utils::extract_matches;
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use itertools::chain;
9
10use crate::VariableId;
11use crate::db::LoweringGroup;
12
13/// Information about members captured by the closure and their types.
14#[derive(Clone, Debug)]
15pub struct ClosureInfo {
16    // TODO(TomerStarkware): unite copiable members and snapshots into a single map.
17    /// The members captured by the closure (not as snapshot).
18    pub members: OrderedHashMap<MemberPath, semantic::TypeId>,
19    /// The types of the captured snapshot variables.
20    pub snapshots: OrderedHashMap<MemberPath, semantic::TypeId>,
21}
22
23#[derive(Clone, Default, Debug)]
24pub struct SemanticLoweringMapping {
25    /// Maps member paths ([MemberPath]) to lowered variable ids or scattered variable ids.
26    scattered: OrderedHashMap<MemberPath, Value>,
27}
28impl SemanticLoweringMapping {
29    /// Returns the topmost mapped member path containing the given member path, or None no such
30    /// member path exists in the mapping.
31    pub fn topmost_mapped_containing_member_path(
32        &mut self,
33        mut member_path: MemberPath,
34    ) -> Option<MemberPath> {
35        let mut res = None;
36        loop {
37            if self.scattered.contains_key(&member_path) {
38                res = Some(member_path.clone());
39            }
40            let MemberPath::Member { parent, .. } = member_path else {
41                return res;
42            };
43            member_path = *parent;
44        }
45    }
46
47    /// Returns the scattered members of the given member path, or None if the member path is not
48    /// scattered.
49    pub fn get_scattered_members(&mut self, member_path: &MemberPath) -> Option<Vec<MemberPath>> {
50        let Some(Value::Scattered(scattered)) = self.scattered.get(member_path) else {
51            return None;
52        };
53        Some(
54            scattered
55                .members
56                .iter()
57                .map(|(member_id, _)| MemberPath::Member {
58                    parent: member_path.clone().into(),
59                    member_id: *member_id,
60                    concrete_struct_id: scattered.concrete_struct_id,
61                })
62                .collect(),
63        )
64    }
65
66    pub fn destructure_closure<TContext: StructRecomposer>(
67        &mut self,
68        ctx: &mut TContext,
69        closure_var: VariableId,
70        closure_info: &ClosureInfo,
71    ) -> Vec<VariableId> {
72        ctx.deconstruct_by_types(
73            closure_var,
74            chain!(closure_info.members.values(), closure_info.snapshots.values()).cloned(),
75        )
76    }
77
78    pub fn get<TContext: StructRecomposer>(
79        &mut self,
80        mut ctx: TContext,
81        path: &MemberPath,
82    ) -> Option<VariableId> {
83        let value = self.break_into_value(&mut ctx, path)?;
84        Self::assemble_value(&mut ctx, value)
85    }
86
87    pub fn introduce(&mut self, path: MemberPath, var: VariableId) {
88        self.scattered.insert(path, Value::Var(var));
89    }
90
91    pub fn update<TContext: StructRecomposer>(
92        &mut self,
93        ctx: &mut TContext,
94        path: &MemberPath,
95        var: VariableId,
96    ) -> Option<()> {
97        // TODO(TomerStarkware): check if path is captured by a closure and invalidate the closure.
98        // Right now this can only happen if we take a snapshot of the variable (as the
99        // snapshot function returns a new var).
100        // we need to ensure the borrow checker invalidates the closure when mutable capture
101        // is supported.
102
103        let value = self.break_into_value(ctx, path)?;
104        *value = Value::Var(var);
105        Some(())
106    }
107
108    fn assemble_value<TContext: StructRecomposer>(
109        ctx: &mut TContext,
110        value: &mut Value,
111    ) -> Option<VariableId> {
112        Some(match value {
113            Value::Var(var) => *var,
114            Value::Scattered(scattered) => {
115                let members = scattered
116                    .members
117                    .iter_mut()
118                    .map(|(_, value)| Self::assemble_value(ctx, value))
119                    .collect::<Option<_>>()?;
120                let var = ctx.reconstruct(scattered.concrete_struct_id, members);
121                *value = Value::Var(var);
122                var
123            }
124        })
125    }
126
127    fn break_into_value<TContext: StructRecomposer>(
128        &mut self,
129        ctx: &mut TContext,
130        path: &MemberPath,
131    ) -> Option<&mut Value> {
132        if self.scattered.contains_key(path) {
133            return self.scattered.get_mut(path);
134        }
135
136        let MemberPath::Member { parent, member_id, concrete_struct_id, .. } = path else {
137            return None;
138        };
139
140        let parent_value = self.break_into_value(ctx, parent)?;
141        match parent_value {
142            Value::Var(var) => {
143                let members = ctx.deconstruct(*concrete_struct_id, *var);
144                let members = OrderedHashMap::from_iter(
145                    members.into_iter().map(|(member_id, var)| (member_id, Value::Var(var))),
146                );
147                let scattered = Scattered { concrete_struct_id: *concrete_struct_id, members };
148                *parent_value = Value::Scattered(Box::new(scattered));
149
150                extract_matches!(parent_value, Value::Scattered).members.get_mut(member_id)
151            }
152            Value::Scattered(scattered) => scattered.members.get_mut(member_id),
153        }
154    }
155}
156
157/// A trait for deconstructing and constructing structs.
158pub trait StructRecomposer {
159    fn deconstruct(
160        &mut self,
161        concrete_struct_id: semantic::ConcreteStructId,
162        value: VariableId,
163    ) -> OrderedHashMap<MemberId, VariableId>;
164
165    fn deconstruct_by_types(
166        &mut self,
167        value: VariableId,
168        types: impl Iterator<Item = semantic::TypeId>,
169    ) -> Vec<VariableId>;
170
171    fn reconstruct(
172        &mut self,
173        concrete_struct_id: semantic::ConcreteStructId,
174        members: Vec<VariableId>,
175    ) -> VariableId;
176    fn var_ty(&self, var: VariableId) -> semantic::TypeId;
177    fn db(&self) -> &dyn LoweringGroup;
178}
179
180/// An intermediate value for a member path.
181#[derive(Clone, Debug, DebugWithDb)]
182#[debug_db(ExprFormatter<'a>)]
183enum Value {
184    /// The value of member path is stored in a lowered variable.
185    Var(VariableId),
186    /// The value of the member path is not stored. It should be reconstructed from the member
187    /// values.
188    Scattered(Box<Scattered>),
189}
190
191/// A value for a non-stored member path. Recursively holds the [Value] for the members.
192#[derive(Clone, Debug, DebugWithDb)]
193#[debug_db(ExprFormatter<'a>)]
194struct Scattered {
195    concrete_struct_id: semantic::ConcreteStructId,
196    members: OrderedHashMap<MemberId, Value>,
197}