cairo_lang_lowering/lower/
refs.rs1use cairo_lang_defs::ids::MemberId;
2use cairo_lang_proc_macros::DebugWithDb;
3use cairo_lang_semantic::expr::fmt::ExprFormatter;
4use cairo_lang_semantic::expr::inference::InferenceError;
5use cairo_lang_semantic::items::structure::StructSemantic;
6use cairo_lang_semantic::usage::MemberPath;
7use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeLongId};
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
10use itertools::{Itertools, chain};
11
12use super::block_builder::BlockStructRecomposer;
13use crate::VariableId;
14use crate::ids::LocationId;
15
16#[derive(Clone, Debug)]
18pub struct ClosureInfo<'db> {
19 pub members: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
22 pub snapshots: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
24}
25
26pub enum AssembleValueError<'db> {
28 Moved(MovedVar<'db>),
30 Missing,
32}
33
34#[derive(Clone, Default, Debug)]
35pub struct SemanticLoweringMapping<'db> {
36 scattered: OrderedHashMap<MemberPath<'db>, Value<'db>>,
38}
39impl<'db> SemanticLoweringMapping<'db> {
40 pub fn topmost_mapped_containing_member_path(
43 &self,
44 mut member_path: MemberPath<'db>,
45 ) -> Option<MemberPath<'db>> {
46 let mut res = None;
47 loop {
48 if self.scattered.contains_key(&member_path) {
49 res = Some(member_path.clone());
50 }
51 let MemberPath::Member { parent, .. } = member_path else {
52 return res;
53 };
54 member_path = *parent;
55 }
56 }
57
58 pub fn destructure_closure(
59 &mut self,
60 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
61 closure_var: VariableId,
62 closure_info: &ClosureInfo<'db>,
63 ) -> Vec<VariableId> {
64 ctx.deconstruct_by_types(
65 closure_var,
66 chain!(closure_info.members.values(), closure_info.snapshots.values()).cloned(),
67 )
68 }
69
70 pub fn get(
71 &mut self,
72 mut ctx: BlockStructRecomposer<'_, '_, 'db>,
73 path: &MemberPath<'db>,
74 ) -> Result<VariableId, AssembleValueError<'db>> {
75 let value = self.break_into_value(&mut ctx, path).ok_or(AssembleValueError::Missing)?;
76 Self::assemble_value(&mut ctx, value).map_err(AssembleValueError::Moved)
77 }
78
79 pub fn introduce(&mut self, path: MemberPath<'db>, var: VariableId) {
80 self.scattered.insert(path, Value::Var(var));
81 }
82
83 pub fn update(
84 &mut self,
85 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
86 path: &MemberPath<'db>,
87 var: VariableId,
88 ) -> Option<()> {
89 let value = self.break_into_value(ctx, path)?;
96 *value = Value::Var(var);
97 Some(())
98 }
99
100 pub fn mark_as_used(
104 &mut self,
105 mut ctx: BlockStructRecomposer<'_, '_, 'db>,
106 path: &MemberPath<'db>,
107 moved: MovedVar<'db>,
108 ) {
109 *self.break_into_value(&mut ctx, path).unwrap() = Value::MovedVar(moved);
110 }
111
112 fn assemble_value(
117 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
118 value: &mut Value<'db>,
119 ) -> Result<VariableId, MovedVar<'db>> {
120 match value {
121 Value::Var(var) => Ok(*var),
122 Value::Scattered(scattered) => {
123 let members_res = scattered
124 .members
125 .iter_mut()
126 .map(|(_, value)| Self::assemble_value(ctx, value))
127 .collect::<Result<_, _>>();
128
129 match members_res {
130 Ok(members) => {
131 let var = ctx.reconstruct(scattered.concrete_struct_id, members);
132 *value = Value::Var(var);
133 Ok(var)
134 }
135 Err(MovedVar { ty: _, inference_error, last_use_location }) => {
136 let y = TypeLongId::<'db>::Concrete(ConcreteTypeId::Struct(
139 scattered.concrete_struct_id,
140 ));
141 let x = y.intern(ctx.ctx.db);
142 Err(MovedVar { ty: x, inference_error, last_use_location })
143 }
144 }
145 }
146 Value::MovedVar(moved) => Err(moved.clone()),
147 }
148 }
149
150 fn break_into_value(
151 &mut self,
152 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
153 path: &MemberPath<'db>,
154 ) -> Option<&mut Value<'db>> {
155 if self.scattered.contains_key(path) {
156 return self.scattered.get_mut(path);
157 }
158
159 let MemberPath::Member { parent, member_id, concrete_struct_id, .. } = path else {
160 return None;
161 };
162
163 let parent_value = self.break_into_value(ctx, parent)?;
164 match parent_value {
165 Value::Var(var) => {
166 let members = ctx.deconstruct(*concrete_struct_id, *var);
167 let members = OrderedHashMap::from_iter(
168 members.into_iter().map(|(member_id, var)| (member_id, Value::Var(var))),
169 );
170 let scattered = Scattered { concrete_struct_id: *concrete_struct_id, members };
171 *parent_value = Value::Scattered(Box::new(scattered));
172 }
173 Value::MovedVar(MovedVar { ty: _, inference_error, last_use_location }) => {
174 let member_map = ctx.ctx.db.concrete_struct_members(*concrete_struct_id).unwrap();
175 let members = OrderedHashMap::from_iter(member_map.values().map(|member| {
176 (
177 member.id,
178 Value::MovedVar(MovedVar {
179 ty: member.ty,
180 inference_error: inference_error.clone(),
181 last_use_location: *last_use_location,
182 }),
183 )
184 }));
185 let scattered = Scattered { concrete_struct_id: *concrete_struct_id, members };
186 *parent_value = Value::Scattered(Box::new(scattered));
187 }
188 Value::Scattered(..) => {}
189 };
190 extract_matches!(parent_value, Value::Scattered).members.get_mut(member_id)
191 }
192}
193
194impl<'db> cairo_lang_debug::debug::DebugWithDb<'db> for SemanticLoweringMapping<'db> {
195 type Db = ExprFormatter<'db>;
196
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &ExprFormatter<'db>) -> std::fmt::Result {
198 for (member_path, value) in self.scattered.iter() {
199 writeln!(f, "{:?}: {value}", member_path.debug(db))?;
200 }
201 Ok(())
202 }
203}
204
205pub fn merge_semantics<'db, 'a>(
214 mappings: impl Iterator<Item = &'a SemanticLoweringMapping<'db>>,
215 remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
216) -> SemanticLoweringMapping<'db>
217where
218 'db: 'a,
219{
220 let mut path_to_values: OrderedHashMap<MemberPath<'_>, Vec<Value<'_>>> = Default::default();
223
224 let mut n_mappings = 0;
225 for map in mappings {
226 for (path, var) in map.scattered.iter() {
227 path_to_values.entry(path.clone()).or_default().push(var.clone());
228 }
229 n_mappings += 1;
230 }
231
232 let mut scattered: OrderedHashMap<MemberPath<'_>, Value<'_>> = Default::default();
233 for (path, values) in path_to_values {
234 if values.len() != n_mappings {
237 continue;
238 }
239
240 let merged_value = compute_remapped_variables(
241 &values.iter().collect_vec(),
242 false,
243 &path,
244 remapped_callback,
245 );
246 scattered.insert(path, merged_value);
247 }
248
249 SemanticLoweringMapping { scattered }
250}
251
252fn compute_remapped_variables<'db>(
289 values: &[&Value<'db>],
290 require_remapping: bool,
291 parent_path: &MemberPath<'db>,
292 remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
293) -> Value<'db> {
294 if let Some(x) = values.iter().find(|value| matches!(value, Value::MovedVar { .. })) {
295 return (*x).clone();
298 }
299
300 if !require_remapping {
301 let first_var = values[0];
303 if values.iter().all(|x| *x == first_var) {
304 return first_var.clone();
305 }
306 }
307
308 let only_scattered: Vec<&Box<Scattered<'_>>> =
310 values.iter().filter_map(|value| try_extract_matches!(value, Value::Scattered)).collect();
311
312 if only_scattered.is_empty() {
313 let remapped_var = remapped_callback(parent_path);
314 return Value::Var(remapped_var);
315 }
316
317 let require_remapping = require_remapping || only_scattered.len() < values.len();
319
320 let concrete_struct_id = only_scattered[0].concrete_struct_id;
321 let members = only_scattered[0]
322 .members
323 .keys()
324 .map(|member_id| {
325 let member_path = MemberPath::Member {
326 parent: parent_path.clone().into(),
327 member_id: *member_id,
328 concrete_struct_id,
329 };
330 let member_values =
334 only_scattered.iter().map(|scattered| &scattered.members[member_id]).collect_vec();
335
336 (
337 *member_id,
338 compute_remapped_variables(
339 &member_values,
340 require_remapping,
341 &member_path,
342 remapped_callback,
343 ),
344 )
345 })
346 .collect();
347
348 Value::Scattered(Box::new(Scattered { concrete_struct_id, members }))
349}
350
351pub fn find_changed_members<'db, 'a>(
354 semantics0: &'a SemanticLoweringMapping<'db>,
355 semantics1: &'a SemanticLoweringMapping<'db>,
356) -> impl Iterator<Item = MemberPath<'db>> + 'a {
357 semantics0.scattered.iter().filter_map(|(path, value0)| {
358 if let Some(value1) = semantics1.scattered.get(path)
359 && value0 != value1
360 {
361 return Some(path.clone());
362 }
363 None
364 })
365}
366
367#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
369#[debug_db(ExprFormatter<'db>)]
370pub struct MovedVar<'db> {
371 pub ty: semantic::TypeId<'db>,
373 pub inference_error: InferenceError<'db>,
375 pub last_use_location: LocationId<'db>,
377}
378
379#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
381#[debug_db(ExprFormatter<'db>)]
382enum Value<'db> {
383 Var(VariableId),
385 Scattered(Box<Scattered<'db>>),
388 MovedVar(MovedVar<'db>),
390}
391
392impl<'db> std::fmt::Display for Value<'db> {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 match self {
395 Value::Var(var) => write!(f, "v{}", var.index()),
396 Value::Scattered(scattered) => {
397 write!(
398 f,
399 "Scattered({})",
400 scattered.members.values().map(|value| value.to_string()).join(", ")
401 )
402 }
403 Value::MovedVar(..) => write!(f, "MovedVar"),
404 }
405 }
406}
407
408#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
410#[debug_db(ExprFormatter<'db>)]
411struct Scattered<'db> {
412 concrete_struct_id: semantic::ConcreteStructId<'db>,
413 members: OrderedHashMap<MemberId<'db>, Value<'db>>,
414}