cairo_lang_lowering/implicits/
mod.rs

1use cairo_lang_defs::diagnostic_utils::StableLocation;
2use cairo_lang_defs::ids::LanguageElementId;
3use cairo_lang_diagnostics::Maybe;
4use cairo_lang_semantic as semantic;
5use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
6use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
7use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
8use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
9use itertools::{Itertools, chain, zip_eq};
10use salsa::Database;
11use semantic::TypeId;
12
13use crate::blocks::Blocks;
14use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
15use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId};
16use crate::{
17    BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchInfo, Statement,
18    VarUsage, Variable, VariableArena,
19};
20
21struct Context<'db, 'a> {
22    db: &'db dyn Database,
23    lowered: &'a mut Lowered<'db>,
24    implicit_index: UnorderedHashMap<TypeId<'db>, usize>,
25    implicits_tys: Vec<TypeId<'db>>,
26    implicit_vars_for_block: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
27    visited: UnorderedHashSet<BlockId>,
28    location: LocationId<'db>,
29}
30
31/// Lowering phase that adds implicits.
32pub fn lower_implicits<'db>(
33    db: &'db dyn Database,
34    function_id: ConcreteFunctionWithBodyId<'db>,
35    lowered: &mut Lowered<'db>,
36) {
37    if let Err(diag_added) = inner_lower_implicits(db, function_id, lowered) {
38        lowered.blocks = Blocks::new_errored(diag_added);
39    }
40}
41
42/// Similar to lower_implicits, but uses Maybe<> for convenience.
43pub fn inner_lower_implicits<'db>(
44    db: &'db dyn Database,
45    function_id: ConcreteFunctionWithBodyId<'db>,
46    lowered: &mut Lowered<'db>,
47) -> Maybe<()> {
48    let semantic_function = function_id.base_semantic_function(db).function_with_body_id(db);
49    let location = LocationId::from_stable_location(
50        db,
51        StableLocation::new(semantic_function.untyped_stable_ptr(db)),
52    );
53    lowered.blocks.has_root()?;
54    let root_block_id = BlockId::root();
55
56    let implicits_tys = db.function_with_body_implicits(function_id)?;
57
58    let implicit_index = implicits_tys.iter().enumerate().map(|(i, ty)| (*ty, i)).collect();
59    let mut ctx = Context {
60        db,
61        lowered,
62        implicit_index,
63        implicits_tys,
64        implicit_vars_for_block: Default::default(),
65        visited: Default::default(),
66        location,
67    };
68
69    // Start from root block.
70    lower_function_blocks_implicits(&mut ctx, root_block_id)?;
71
72    // Introduce new input variables in the root block.
73    let implicit_vars = &ctx.implicit_vars_for_block[&root_block_id];
74    ctx.lowered.parameters.splice(0..0, implicit_vars.iter().map(|var_usage| var_usage.var_id));
75
76    Ok(())
77}
78
79/// Allocates and returns new variables with usage location for each of the current function's
80/// implicits.
81fn alloc_implicits<'db>(
82    db: &'db dyn Database,
83    variables: &mut VariableArena<'db>,
84    implicits_tys: &[TypeId<'db>],
85    location: LocationId<'db>,
86) -> Vec<VarUsage<'db>> {
87    implicits_tys
88        .iter()
89        .copied()
90        .map(|ty| VarUsage {
91            var_id: variables.alloc(Variable::with_default_context(db, ty, location)),
92            location,
93        })
94        .collect_vec()
95}
96
97/// Returns the implicits that are used in the statements of a block.
98fn block_body_implicits<'db>(
99    ctx: &mut Context<'db, '_>,
100    block_id: BlockId,
101) -> Result<Vec<VarUsage<'db>>, cairo_lang_diagnostics::DiagnosticAdded> {
102    let mut implicits = ctx
103        .implicit_vars_for_block
104        .entry(block_id)
105        .or_insert_with(|| {
106            alloc_implicits(
107                ctx.db,
108                &mut ctx.lowered.variables,
109                &ctx.implicits_tys,
110                ctx.location.with_auto_generation_note(ctx.db, "implicits"),
111            )
112        })
113        .clone();
114    let require_implicits_libfunc_id = semantic::corelib::internal_require_implicit(ctx.db);
115    let mut remove = vec![];
116    for (i, statement) in ctx.lowered.blocks[block_id].statements.iter_mut().enumerate() {
117        if let Statement::Call(stmt) = statement {
118            if matches!(
119                stmt.function.long(ctx.db),
120                FunctionLongId::Semantic(func_id)
121                    if func_id.get_concrete(ctx.db).generic_function == require_implicits_libfunc_id
122            ) {
123                remove.push(i);
124                continue;
125            }
126            let callee_implicits = ctx.db.function_implicits(stmt.function)?;
127            let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
128
129            let indices = callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
130
131            let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
132            stmt.inputs.splice(0..0, implicit_input_vars);
133            let implicit_output_vars = callee_implicits
134                .iter()
135                .copied()
136                .map(|ty| {
137                    ctx.lowered
138                        .variables
139                        .alloc(Variable::with_default_context(ctx.db, ty, location))
140                })
141                .collect_vec();
142            for (i, var) in zip_eq(indices, implicit_output_vars.iter()) {
143                implicits[i] =
144                    VarUsage { var_id: *var, location: ctx.lowered.variables[*var].location };
145            }
146            stmt.outputs.splice(0..0, implicit_output_vars);
147        }
148    }
149    for i in remove.into_iter().rev() {
150        ctx.lowered.blocks[block_id].statements.remove(i);
151    }
152    Ok(implicits)
153}
154
155/// Finds the implicits for a function's blocks starting from the root.
156fn lower_function_blocks_implicits<'db>(
157    ctx: &mut Context<'db, '_>,
158    root_block_id: BlockId,
159) -> Maybe<()> {
160    let mut blocks_to_visit = vec![root_block_id];
161    while let Some(block_id) = blocks_to_visit.pop() {
162        if !ctx.visited.insert(block_id) {
163            continue;
164        }
165        let implicits = block_body_implicits(ctx, block_id)?;
166        // End.
167        match &mut ctx.lowered.blocks[block_id].end {
168            BlockEnd::Return(rets, _location) => {
169                rets.splice(0..0, implicits.iter().cloned());
170            }
171            BlockEnd::Panic(_) => {
172                unreachable!("Panics should have been stripped in a previous phase.")
173            }
174            BlockEnd::Goto(block_id, remapping) => {
175                let target_implicits = ctx
176                    .implicit_vars_for_block
177                    .entry(*block_id)
178                    .or_insert_with(|| {
179                        alloc_implicits(
180                            ctx.db,
181                            &mut ctx.lowered.variables,
182                            &ctx.implicits_tys,
183                            ctx.location,
184                        )
185                    })
186                    .clone();
187                let old_remapping = std::mem::take(&mut remapping.remapping);
188                remapping.remapping = chain!(
189                    zip_eq(
190                        target_implicits.into_iter().map(|var_usage| var_usage.var_id),
191                        implicits
192                    ),
193                    old_remapping
194                )
195                .collect();
196                blocks_to_visit.push(*block_id);
197            }
198            BlockEnd::Match { info } => {
199                blocks_to_visit.extend(info.arms().iter().rev().map(|a| a.block_id));
200                match info {
201                    MatchInfo::Enum(_) | MatchInfo::Value(_) => {
202                        for MatchArm { arm_selector: _, block_id, var_ids: _ } in info.arms() {
203                            assert!(
204                                ctx.implicit_vars_for_block
205                                    .insert(*block_id, implicits.clone())
206                                    .is_none(),
207                                "Multiple jumps to arm blocks are not allowed."
208                            );
209                        }
210                    }
211                    MatchInfo::Extern(stmt) => {
212                        let callee_implicits = ctx.db.function_implicits(stmt.function)?;
213
214                        let indices =
215                            callee_implicits.iter().map(|ty| ctx.implicit_index[ty]).collect_vec();
216
217                        let implicit_input_vars = indices.iter().map(|i| implicits[*i]);
218                        stmt.inputs.splice(0..0, implicit_input_vars);
219                        let location = stmt.location.with_auto_generation_note(ctx.db, "implicits");
220
221                        for MatchArm { arm_selector: _, block_id, var_ids } in stmt.arms.iter_mut()
222                        {
223                            let mut arm_implicits = implicits.clone();
224                            let mut implicit_input_vars = vec![];
225                            for ty in callee_implicits.iter().copied() {
226                                let var = ctx
227                                    .lowered
228                                    .variables
229                                    .alloc(Variable::with_default_context(ctx.db, ty, location));
230                                implicit_input_vars.push(var);
231                                let implicit_index = ctx.implicit_index[&ty];
232                                arm_implicits[implicit_index] = VarUsage { var_id: var, location };
233                            }
234                            assert!(
235                                ctx.implicit_vars_for_block
236                                    .insert(*block_id, arm_implicits)
237                                    .is_none(),
238                                "Multiple jumps to arm blocks are not allowed."
239                            );
240
241                            var_ids.splice(0..0, implicit_input_vars);
242                        }
243                    }
244                }
245            }
246            BlockEnd::NotSet => unreachable!(),
247        }
248    }
249    Ok(())
250}
251
252// =========== Query implementations ===========
253
254/// Query implementation of [crate::db::LoweringGroup::function_implicits].
255#[salsa::tracked]
256pub fn function_implicits<'db>(
257    db: &'db dyn Database,
258    function: FunctionId<'db>,
259) -> Maybe<Vec<TypeId<'db>>> {
260    if let Some(body) = function.body(db)? {
261        return db.function_with_body_implicits(body);
262    }
263    Ok(function.signature(db)?.implicits)
264}
265
266/// A trait to add helper methods in [LoweringGroup].
267pub trait FunctionImplicitsTrait<'db>: Database {
268    /// Returns all the implicits used by a [ConcreteFunctionWithBodyId].
269    fn function_with_body_implicits(
270        &'db self,
271        function: ConcreteFunctionWithBodyId<'db>,
272    ) -> Maybe<Vec<TypeId<'db>>> {
273        let db: &'db dyn Database = self.as_dyn_database();
274        let scc_representative = db.lowered_scc_representative(
275            function,
276            DependencyType::Call,
277            LoweringStage::PostBaseline,
278        );
279        let mut implicits = scc_implicits(db, scc_representative)?;
280
281        let precedence = db.function_declaration_implicit_precedence(
282            function.base_semantic_function(db).function_with_body_id(db),
283        )?;
284        precedence.apply(&mut implicits, db);
285
286        Ok(implicits)
287    }
288}
289impl<'db, T: Database + ?Sized> FunctionImplicitsTrait<'db> for T {}
290
291/// Returns all the implicits used by a strongly connected component of functions.
292fn scc_implicits<'db>(
293    db: &'db dyn Database,
294    scc: ConcreteSCCRepresentative<'db>,
295) -> Maybe<Vec<TypeId<'db>>> {
296    scc_implicits_tracked(db, scc.0)
297}
298
299/// Tracked implementation of [scc_implicits].
300#[salsa::tracked]
301fn scc_implicits_tracked<'db>(
302    db: &'db dyn Database,
303    rep: ConcreteFunctionWithBodyId<'db>,
304) -> Maybe<Vec<TypeId<'db>>> {
305    let scc_functions = db.lowered_scc(rep, DependencyType::Call, LoweringStage::PostBaseline);
306    let mut all_implicits = OrderedHashSet::<_>::default();
307    for function in scc_functions {
308        // Add the function's explicit implicits.
309        all_implicits.extend(function.function_id(db)?.signature(db)?.implicits);
310        // For each direct callee, add its implicits.
311        let direct_callees =
312            db.lowered_direct_callees(function, DependencyType::Call, LoweringStage::PostBaseline)?;
313        for direct_callee in direct_callees {
314            if let Some(callee_body) = direct_callee.body(db)? {
315                let callee_scc = db.lowered_scc_representative(
316                    callee_body,
317                    DependencyType::Call,
318                    LoweringStage::PostBaseline,
319                );
320                if callee_scc.0 != rep {
321                    all_implicits.extend(scc_implicits(db, callee_scc)?);
322                }
323            } else {
324                all_implicits.extend(direct_callee.signature(db)?.implicits);
325            }
326        }
327    }
328    Ok(all_implicits.into_iter().collect())
329}