1use cairo_lang_defs::ids::{LanguageElementId, MemberId};
2use cairo_lang_proc_macros::DebugWithDb;
3use cairo_lang_semantic::expr::fmt::ExprFormatter;
4use cairo_lang_semantic::expr::inference::InferenceError;
5use cairo_lang_semantic::items::structure::StructSemantic;
6use cairo_lang_semantic::usage::MemberPath;
7use cairo_lang_semantic::{self as semantic};
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9use cairo_lang_utils::{extract_matches, try_extract_matches};
10use itertools::{Itertools, chain};
11
12use super::block_builder::BlockStructRecomposer;
13use super::context::VarRequest;
14use crate::VariableId;
15use crate::ids::LocationId;
16
17#[derive(Clone, Debug)]
19pub struct ClosureInfo<'db> {
20 pub members: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
23 pub snapshots: OrderedHashMap<MemberPath<'db>, semantic::TypeId<'db>>,
25}
26
27pub enum AssembleValueError<'db> {
29 Moved(MovedVar<'db>),
31 Missing,
33}
34
35#[derive(Clone, Default, Debug)]
36pub struct SemanticLoweringMapping<'db> {
37 scattered: OrderedHashMap<MemberPath<'db>, Value<'db>>,
39}
40impl<'db> SemanticLoweringMapping<'db> {
41 pub fn topmost_mapped_containing_member_path(
44 &self,
45 mut member_path: MemberPath<'db>,
46 ) -> Option<MemberPath<'db>> {
47 let mut res = None;
48 loop {
49 if self.scattered.contains_key(&member_path) {
50 res = Some(member_path.clone());
51 }
52 let MemberPath::Member { parent, .. } = member_path else {
53 return res;
54 };
55 member_path = *parent;
56 }
57 }
58
59 pub fn destructure_closure(
60 &mut self,
61 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
62 closure_var: VariableId,
63 closure_info: &ClosureInfo<'db>,
64 ) -> Vec<VariableId> {
65 ctx.deconstruct_by_types(
66 closure_var,
67 chain!(closure_info.members.values(), closure_info.snapshots.values()).cloned(),
68 )
69 }
70
71 pub fn get(
72 &mut self,
73 mut ctx: BlockStructRecomposer<'_, '_, 'db>,
74 path: &MemberPath<'db>,
75 ) -> Result<VariableId, AssembleValueError<'db>> {
76 let value = self.break_into_value(&mut ctx, path).ok_or(AssembleValueError::Missing)?;
77 let base_var = path.base_var();
78 let location_stable_ptr = match ctx.ctx.semantic_defs.get(&base_var) {
79 Some(binding) => binding.stable_ptr(ctx.ctx.db),
80 None => base_var.untyped_stable_ptr(ctx.ctx.db),
81 };
82 let location = ctx.ctx.get_location(location_stable_ptr);
83 Self::assemble_value(&mut ctx, value, location).map_err(AssembleValueError::Moved)
84 }
85
86 pub fn introduce(&mut self, path: MemberPath<'db>, var: VariableId) {
87 self.scattered.insert(path, Value::Var(var));
88 }
89
90 pub fn update(
91 &mut self,
92 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
93 path: &MemberPath<'db>,
94 var: VariableId,
95 ) -> Option<()> {
96 let value = self.break_into_value(ctx, path)?;
103 *value = Value::Var(var);
104 Some(())
105 }
106
107 pub fn mark_as_used(
111 &mut self,
112 mut ctx: BlockStructRecomposer<'_, '_, 'db>,
113 path: &MemberPath<'db>,
114 moved: MovedVar<'db>,
115 ) {
116 *self.break_into_value(&mut ctx, path).unwrap() = Value::MovedVar(moved);
117 }
118
119 fn assemble_value(
124 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
125 value: &mut Value<'db>,
126 location: LocationId<'db>,
127 ) -> Result<VariableId, MovedVar<'db>> {
128 match value {
129 Value::Var(var) => Ok(*var),
130 Value::MovedVar(moved) => Err(moved.clone()),
131 Value::Scattered(scattered) => {
132 let mut moved_var = None;
133 let members = scattered
134 .members
135 .iter_mut()
136 .map(|(_, value)| match Self::assemble_value(ctx, value, location) {
137 Ok(var) => var,
138 Err(moved) => {
139 let var = moved.var_id;
140 moved_var.get_or_insert(moved);
141 var
142 }
143 })
144 .collect_vec();
145 let var = ctx.reconstruct(scattered.concrete_struct_id, members, location);
146 *value = Value::Var(var);
147 if let Some(MovedVar { var_id: _, inference_error, last_use_location }) = moved_var
148 {
149 Err(MovedVar { var_id: var, inference_error, last_use_location })
150 } else {
151 Ok(var)
152 }
153 }
154 }
155 }
156
157 fn break_into_value(
158 &mut self,
159 ctx: &mut BlockStructRecomposer<'_, '_, 'db>,
160 path: &MemberPath<'db>,
161 ) -> Option<&mut Value<'db>> {
162 if self.scattered.contains_key(path) {
163 return self.scattered.get_mut(path);
164 }
165
166 let &MemberPath::Member { ref parent, member_id, concrete_struct_id, .. } = path else {
167 return None;
168 };
169
170 let parent_value = self.break_into_value(ctx, parent)?;
171 match parent_value {
172 Value::Var(var) => {
173 let members = ctx.deconstruct(concrete_struct_id, *var);
174 let members = OrderedHashMap::from_iter(
175 members.into_iter().map(|(member_id, var)| (member_id, Value::Var(var))),
176 );
177 let scattered = Scattered { concrete_struct_id, members };
178 *parent_value = Value::Scattered(Box::new(scattered));
179 }
180 &mut Value::MovedVar(MovedVar { var_id, ref inference_error, last_use_location }) => {
181 let member_map = ctx.ctx.db.concrete_struct_members(concrete_struct_id).unwrap();
182 let location = ctx.ctx.variables[var_id].location;
183 let members = OrderedHashMap::from_iter(member_map.values().map(|member| {
184 (
185 member.id,
186 Value::MovedVar(MovedVar {
187 var_id: ctx.ctx.new_var(VarRequest { ty: member.ty, location }),
188 inference_error: inference_error.clone(),
189 last_use_location,
190 }),
191 )
192 }));
193 let scattered = Scattered { concrete_struct_id, members };
194 *parent_value = Value::Scattered(Box::new(scattered));
195 }
196 Value::Scattered(..) => {}
197 };
198 extract_matches!(parent_value, Value::Scattered).members.get_mut(&member_id)
199 }
200}
201
202impl<'db> cairo_lang_debug::debug::DebugWithDb<'db> for SemanticLoweringMapping<'db> {
203 type Db = ExprFormatter<'db>;
204
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &ExprFormatter<'db>) -> std::fmt::Result {
206 for (member_path, value) in self.scattered.iter() {
207 writeln!(f, "{:?}: {value}", member_path.debug(db))?;
208 }
209 Ok(())
210 }
211}
212
213pub fn merge_semantics<'db, 'a>(
222 mappings: impl Iterator<Item = &'a SemanticLoweringMapping<'db>>,
223 remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
224) -> SemanticLoweringMapping<'db>
225where
226 'db: 'a,
227{
228 let mut path_to_values: OrderedHashMap<MemberPath<'_>, Vec<Value<'_>>> = Default::default();
231
232 let mut n_mappings = 0;
233 for map in mappings {
234 for (path, var) in map.scattered.iter() {
235 path_to_values.entry(path.clone()).or_default().push(var.clone());
236 }
237 n_mappings += 1;
238 }
239
240 let mut scattered: OrderedHashMap<MemberPath<'_>, Value<'_>> = Default::default();
241 for (path, values) in path_to_values {
242 if values.len() != n_mappings {
245 continue;
246 }
247
248 let merged_value = compute_remapped_variables(
249 &values.iter().collect_vec(),
250 false,
251 &path,
252 remapped_callback,
253 );
254 scattered.insert(path, merged_value);
255 }
256
257 SemanticLoweringMapping { scattered }
258}
259
260fn compute_remapped_variables<'db>(
297 values: &[&Value<'db>],
298 require_remapping: bool,
299 parent_path: &MemberPath<'db>,
300 remapped_callback: &mut impl FnMut(&MemberPath<'db>) -> VariableId,
301) -> Value<'db> {
302 if let Some(x) = values.iter().find(|value| matches!(value, Value::MovedVar { .. })) {
303 return (*x).clone();
306 }
307
308 if !require_remapping {
309 let first_var = values[0];
311 if values.iter().all(|x| *x == first_var) {
312 return first_var.clone();
313 }
314 }
315
316 let only_scattered: Vec<&Box<Scattered<'_>>> =
318 values.iter().filter_map(|value| try_extract_matches!(value, Value::Scattered)).collect();
319
320 if only_scattered.is_empty() {
321 let remapped_var = remapped_callback(parent_path);
322 return Value::Var(remapped_var);
323 }
324
325 let require_remapping = require_remapping || only_scattered.len() < values.len();
327
328 let concrete_struct_id = only_scattered[0].concrete_struct_id;
329 let members = only_scattered[0]
330 .members
331 .keys()
332 .map(|member_id| {
333 let member_path = MemberPath::Member {
334 parent: parent_path.clone().into(),
335 member_id: *member_id,
336 concrete_struct_id,
337 };
338 let member_values =
342 only_scattered.iter().map(|scattered| &scattered.members[member_id]).collect_vec();
343
344 (
345 *member_id,
346 compute_remapped_variables(
347 &member_values,
348 require_remapping,
349 &member_path,
350 remapped_callback,
351 ),
352 )
353 })
354 .collect();
355
356 Value::Scattered(Box::new(Scattered { concrete_struct_id, members }))
357}
358
359pub fn find_changed_members<'db, 'a>(
362 semantics0: &'a SemanticLoweringMapping<'db>,
363 semantics1: &'a SemanticLoweringMapping<'db>,
364) -> impl Iterator<Item = MemberPath<'db>> + 'a {
365 semantics0.scattered.iter().filter_map(|(path, value0)| {
366 if let Some(value1) = semantics1.scattered.get(path)
367 && value0 != value1
368 {
369 return Some(path.clone());
370 }
371 None
372 })
373}
374
375#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
377#[debug_db(ExprFormatter<'db>)]
378pub struct MovedVar<'db> {
379 pub var_id: VariableId,
381 pub inference_error: InferenceError<'db>,
383 pub last_use_location: LocationId<'db>,
385}
386
387#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
389#[debug_db(ExprFormatter<'db>)]
390enum Value<'db> {
391 Var(VariableId),
393 Scattered(Box<Scattered<'db>>),
396 MovedVar(MovedVar<'db>),
398}
399
400impl<'db> std::fmt::Display for Value<'db> {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 match self {
403 Value::Var(var) => write!(f, "v{}", var.index()),
404 Value::Scattered(scattered) => {
405 write!(
406 f,
407 "Scattered({})",
408 scattered.members.values().map(|value| value.to_string()).join(", ")
409 )
410 }
411 Value::MovedVar(..) => write!(f, "MovedVar"),
412 }
413 }
414}
415
416#[derive(Clone, Debug, DebugWithDb, Eq, PartialEq)]
418#[debug_db(ExprFormatter<'db>)]
419struct Scattered<'db> {
420 concrete_struct_id: semantic::ConcreteStructId<'db>,
421 members: OrderedHashMap<MemberId<'db>, Value<'db>>,
422}