cairo_lang_lowering/implicits/
mod.rs1use cairo_lang_defs::diagnostic_utils::StableLocation;
2use cairo_lang_defs::ids::LanguageElementId;
3use cairo_lang_diagnostics::Maybe;
4use cairo_lang_semantic as semantic;
5use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
6use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
7use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
8use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
9use itertools::{Itertools, chain, zip_eq};
10use salsa::Database;
11use semantic::TypeId;
12
13use crate::blocks::Blocks;
14use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
15use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId};
16use crate::{
17 BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchInfo, Statement,
18 VarUsage, Variable, VariableArena,
19};
20
21struct Context<'db, 'a> {
22 db: &'db dyn Database,
23 lowered: &'a mut Lowered<'db>,
24 implicit_index: UnorderedHashMap<TypeId<'db>, usize>,
25 implicits_tys: Vec<TypeId<'db>>,
26 implicit_vars_for_block: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
27 visited: UnorderedHashSet<BlockId>,
28 location: LocationId<'db>,
29}
30
31pub fn lower_implicits<'db>(
33 db: &'db dyn Database,
34 function_id: ConcreteFunctionWithBodyId<'db>,
35 lowered: &mut Lowered<'db>,
36) {
37 if let Err(diag_added) = inner_lower_implicits(db, function_id, lowered) {
38 lowered.blocks = Blocks::new_errored(diag_added);
39 }
40}
41
42pub fn inner_lower_implicits<'db>(
44 db: &'db dyn Database,
45 function_id: ConcreteFunctionWithBodyId<'db>,
46 lowered: &mut Lowered<'db>,
47) -> Maybe<()> {
48 let semantic_function = function_id.base_semantic_function(db).function_with_body_id(db);
49 let location = LocationId::from_stable_location(
50 db,
51 StableLocation::new(semantic_function.untyped_stable_ptr(db)),
52 );
53 lowered.blocks.has_root()?;
54 let root_block_id = BlockId::root();
55
56 let implicits_tys = db.function_with_body_implicits(function_id)?;
57
58 let implicit_index = implicits_tys.iter().enumerate().map(|(i, ty)| (*ty, i)).collect();
59 let mut ctx = Context {
60 db,
61 lowered,
62 implicit_index,
63 implicits_tys,
64 implicit_vars_for_block: Default::default(),
65 visited: Default::default(),
66 location,
67 };
68
69 lower_function_blocks_implicits(&mut ctx, root_block_id)?;
71
72 let implicit_vars = &ctx.implicit_vars_for_block[&root_block_id];
74 ctx.lowered.parameters.splice(0..0, implicit_vars.iter().map(|var_usage| var_usage.var_id));
75
76 Ok(())
77}
78
79fn alloc_implicits<'db>(
82 db: &'db dyn Database,
83 variables: &mut VariableArena<'db>,
84 implicits_tys: &[TypeId<'db>],
85 location: LocationId<'db>,
86) -> Vec<VarUsage<'db>> {
87 implicits_tys
88 .iter()
89 .copied()
90 .map(|ty| VarUsage {
91 var_id: variables.alloc(Variable::with_default_context(db, ty, location)),
92 location,
93 })
94 .collect_vec()
95}
96
97fn block_body_implicits<'db>(
99 ctx: &mut Context<'db, '_>,
100 block_id: BlockId,
101) -> Result<Vec<VarUsage<'db>>, cairo_lang_diagnostics::DiagnosticAdded> {
102 let mut implicits = ctx
103 .implicit_vars_for_block
104 .entry(block_id)
105 .or_insert_with(|| {
106 alloc_implicits(
107 ctx.db,
108 &mut ctx.lowered.variables,
109 &ctx.implicits_tys,
110 ctx.location.with_auto_generation_note(ctx.db, "implicits"),
111 )
112 })
113 .clone();
114 let require_implicits_libfunc_id = semantic::corelib::internal_require_implicit(ctx.db);
115 let mut remove = vec![];
116 for (i, statement) in ctx.lowered.blocks[block_id].statements.iter_mut().enumerate() {
117 if let Statement::Call(stmt) = statement {
118 if matches!(
119 stmt.function.long(ctx.db),
120 FunctionLongId::Semantic(func_id)
121 if func_id.get_concrete(ctx.db).generic_function == require_implicits_libfunc_id
122 ) {
123 remove.push(i);
124 continue;
125 }
126 let callee_implicits = ctx.db.function_implicits(stmt.function)?;
127 let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
128
129 let indices = callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
130
131 let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
132 stmt.inputs.splice(0..0, implicit_input_vars);
133 let implicit_output_vars = callee_implicits
134 .iter()
135 .copied()
136 .map(|ty| {
137 ctx.lowered
138 .variables
139 .alloc(Variable::with_default_context(ctx.db, ty, location))
140 })
141 .collect_vec();
142 for (i, var) in zip_eq(indices, implicit_output_vars.iter()) {
143 implicits[i] =
144 VarUsage { var_id: *var, location: ctx.lowered.variables[*var].location };
145 }
146 stmt.outputs.splice(0..0, implicit_output_vars);
147 }
148 }
149 for i in remove.into_iter().rev() {
150 ctx.lowered.blocks[block_id].statements.remove(i);
151 }
152 Ok(implicits)
153}
154
155fn lower_function_blocks_implicits<'db>(
157 ctx: &mut Context<'db, '_>,
158 root_block_id: BlockId,
159) -> Maybe<()> {
160 let mut blocks_to_visit = vec![root_block_id];
161 while let Some(block_id) = blocks_to_visit.pop() {
162 if !ctx.visited.insert(block_id) {
163 continue;
164 }
165 let implicits = block_body_implicits(ctx, block_id)?;
166 match &mut ctx.lowered.blocks[block_id].end {
168 BlockEnd::Return(rets, _location) => {
169 rets.splice(0..0, implicits.iter().cloned());
170 }
171 BlockEnd::Panic(_) => {
172 unreachable!("Panics should have been stripped in a previous phase.")
173 }
174 BlockEnd::Goto(block_id, remapping) => {
175 let target_implicits = ctx
176 .implicit_vars_for_block
177 .entry(*block_id)
178 .or_insert_with(|| {
179 alloc_implicits(
180 ctx.db,
181 &mut ctx.lowered.variables,
182 &ctx.implicits_tys,
183 ctx.location,
184 )
185 })
186 .clone();
187 let old_remapping = std::mem::take(&mut remapping.remapping);
188 remapping.remapping = chain!(
189 zip_eq(
190 target_implicits.into_iter().map(|var_usage| var_usage.var_id),
191 implicits
192 ),
193 old_remapping
194 )
195 .collect();
196 blocks_to_visit.push(*block_id);
197 }
198 BlockEnd::Match { info } => {
199 blocks_to_visit.extend(info.arms().iter().rev().map(|a| a.block_id));
200 match info {
201 MatchInfo::Enum(_) | MatchInfo::Value(_) => {
202 for MatchArm { arm_selector: _, block_id, var_ids: _ } in info.arms() {
203 assert!(
204 ctx.implicit_vars_for_block
205 .insert(*block_id, implicits.clone())
206 .is_none(),
207 "Multiple jumps to arm blocks are not allowed."
208 );
209 }
210 }
211 MatchInfo::Extern(stmt) => {
212 let callee_implicits = ctx.db.function_implicits(stmt.function)?;
213
214 let indices =
215 callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
216
217 let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
218 stmt.inputs.splice(0..0, implicit_input_vars);
219 let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
220
221 for MatchArm { arm_selector: _, block_id, var_ids } in stmt.arms.iter_mut()
222 {
223 let mut arm_implicits = implicits.clone();
224 let mut implicit_input_vars = vec![];
225 for ty in callee_implicits.iter().copied() {
226 let var = ctx
227 .lowered
228 .variables
229 .alloc(Variable::with_default_context(ctx.db, ty, location));
230 implicit_input_vars.push(var);
231 let implicit_index = ctx.implicit_index[&ty];
232 arm_implicits[implicit_index] = VarUsage { var_id: var, location };
233 }
234 assert!(
235 ctx.implicit_vars_for_block
236 .insert(*block_id, arm_implicits)
237 .is_none(),
238 "Multiple jumps to arm blocks are not allowed."
239 );
240
241 var_ids.splice(0..0, implicit_input_vars);
242 }
243 }
244 }
245 }
246 BlockEnd::NotSet => unreachable!(),
247 }
248 }
249 Ok(())
250}
251
252#[salsa::tracked]
256pub fn function_implicits<'db>(
257 db: &'db dyn Database,
258 function: FunctionId<'db>,
259) -> Maybe<Vec<TypeId<'db>>> {
260 if let Some(body) = function.body(db)? {
261 return db.function_with_body_implicits(body);
262 }
263 Ok(function.signature(db)?.implicits)
264}
265
266pub trait FunctionImplicitsTrait<'db>: Database {
268 fn function_with_body_implicits(
270 &'db self,
271 function: ConcreteFunctionWithBodyId<'db>,
272 ) -> Maybe<Vec<TypeId<'db>>> {
273 let db: &'db dyn Database = self.as_dyn_database();
274 let scc_representative = db.lowered_scc_representative(
275 function,
276 DependencyType::Call,
277 LoweringStage::PostBaseline,
278 );
279 let mut implicits = scc_implicits(db, scc_representative)?;
280
281 let precedence = db.function_declaration_implicit_precedence(
282 function.base_semantic_function(db).function_with_body_id(db),
283 )?;
284 precedence.apply(&mut implicits, db);
285
286 Ok(implicits)
287 }
288}
289impl<'db, T: Database + ?Sized> FunctionImplicitsTrait<'db> for T {}
290
291fn scc_implicits<'db>(
293 db: &'db dyn Database,
294 scc: ConcreteSCCRepresentative<'db>,
295) -> Maybe<Vec<TypeId<'db>>> {
296 scc_implicits_tracked(db, scc.0)
297}
298
299#[salsa::tracked]
301fn scc_implicits_tracked<'db>(
302 db: &'db dyn Database,
303 rep: ConcreteFunctionWithBodyId<'db>,
304) -> Maybe<Vec<TypeId<'db>>> {
305 let scc_functions = db.lowered_scc(rep, DependencyType::Call, LoweringStage::PostBaseline);
306 let mut all_implicits = OrderedHashSet::<_>::default();
307 for function in scc_functions {
308 all_implicits.extend(function.function_id(db)?.signature(db)?.implicits);
310 let direct_callees =
312 db.lowered_direct_callees(function, DependencyType::Call, LoweringStage::PostBaseline)?;
313 for direct_callee in direct_callees {
314 if let Some(callee_body) = direct_callee.body(db)? {
315 let callee_scc = db.lowered_scc_representative(
316 callee_body,
317 DependencyType::Call,
318 LoweringStage::PostBaseline,
319 );
320 if callee_scc.0 != rep {
321 all_implicits.extend(scc_implicits(db, callee_scc)?);
322 }
323 } else {
324 all_implicits.extend(direct_callee.signature(db)?.implicits);
325 }
326 }
327 }
328 Ok(all_implicits.into_iter().collect())
329}