cairo_lang_lowering/
db.rs

1use std::sync::Arc;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs as defs;
5use cairo_lang_defs::ids::{
6    ExternFunctionId, LanguageElementId, ModuleId, ModuleItemId, NamedLanguageElementLongId,
7};
8use cairo_lang_diagnostics::{Diagnostics, DiagnosticsBuilder, Maybe};
9use cairo_lang_filesystem::ids::FileId;
10use cairo_lang_semantic::db::SemanticGroup;
11use cairo_lang_semantic::items::enm::SemanticEnumEx;
12use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId, corelib};
13use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
14use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
15use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
16use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
17use cairo_lang_utils::{Intern, LookupIntern, Upcast};
18use defs::ids::NamedLanguageElementId;
19use itertools::{Itertools, chain};
20use num_traits::ToPrimitive;
21
22use crate::add_withdraw_gas::add_withdraw_gas;
23use crate::blocks::Blocks;
24use crate::borrow_check::{
25    PotentialDestructCalls, borrow_check, borrow_check_possible_withdraw_gas,
26};
27use crate::cache::load_cached_crate_functions;
28use crate::concretize::concretize_lowered;
29use crate::destructs::add_destructs;
30use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind};
31use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas;
32use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, GenericOrSpecialized};
33use crate::inline::get_inline_diagnostics;
34use crate::inline::statements_weights::{ApproxCasmInlineWeight, InlineWeight};
35use crate::lower::{MultiLowering, lower_semantic_function};
36use crate::optimizations::config::OptimizationConfig;
37use crate::optimizations::scrub_units::scrub_units;
38use crate::optimizations::strategy::{OptimizationStrategy, OptimizationStrategyId};
39use crate::panic::lower_panics;
40use crate::specialization::specialized_function_lowered;
41use crate::utils::InliningStrategy;
42use crate::{
43    BlockEnd, BlockId, DependencyType, Location, Lowered, LoweringStage, MatchInfo, Statement, ids,
44};
45
46/// A trait for estimation of the code size of a function.
47pub trait ExternalCodeSizeEstimator {
48    /// Returns estimated size of the function with the given id.
49    fn estimate_size(&self, function_id: ConcreteFunctionWithBodyId) -> Maybe<isize>;
50}
51
52/// Marker trait for using ApproxCasmInlineWeight as the code size estimator.
53pub trait UseApproxCodeSizeEstimator: Upcast<dyn LoweringGroup> {}
54
55impl<T: UseApproxCodeSizeEstimator> ExternalCodeSizeEstimator for T {
56    fn estimate_size(&self, function_id: ConcreteFunctionWithBodyId) -> Maybe<isize> {
57        let db = self.upcast();
58        let lowered = db.lowered_body(function_id, LoweringStage::PostBaseline)?;
59        Ok(ApproxCasmInlineWeight::new(db, &lowered).lowered_weight(&lowered))
60    }
61}
62
63// Salsa database interface.
64#[salsa::query_group(LoweringDatabase)]
65pub trait LoweringGroup:
66    SemanticGroup + Upcast<dyn SemanticGroup> + ExternalCodeSizeEstimator
67{
68    #[salsa::interned]
69    fn intern_lowering_function(&self, id: ids::FunctionLongId) -> ids::FunctionId;
70    #[salsa::interned]
71    fn intern_lowering_concrete_function_with_body(
72        &self,
73        id: ids::ConcreteFunctionWithBodyLongId,
74    ) -> ids::ConcreteFunctionWithBodyId;
75    #[salsa::interned]
76    fn intern_lowering_function_with_body(
77        &self,
78        id: ids::FunctionWithBodyLongId,
79    ) -> ids::FunctionWithBodyId;
80
81    #[salsa::interned]
82    fn intern_location(&self, id: Location) -> ids::LocationId;
83
84    #[salsa::interned]
85    fn intern_strategy(&self, id: OptimizationStrategy) -> OptimizationStrategyId;
86
87    /// Computes the lowered representation of a function with a body, along with all it generated
88    /// functions (e.g. closures, lambdas, loops, ...).
89    fn priv_function_with_body_multi_lowering(
90        &self,
91        function_id: defs::ids::FunctionWithBodyId,
92    ) -> Maybe<Arc<MultiLowering>>;
93
94    /// Returns a mapping from function ids to their multi-lowerings for the given loaded from a
95    /// cache for the given crate.
96    fn cached_multi_lowerings(
97        &self,
98        crate_id: cairo_lang_filesystem::ids::CrateId,
99    ) -> Option<Arc<OrderedHashMap<defs::ids::FunctionWithBodyId, MultiLowering>>>;
100
101    /// Computes the lowered representation of a function with a body before borrow checking.
102    fn priv_function_with_body_lowering(
103        &self,
104        function_id: ids::FunctionWithBodyId,
105    ) -> Maybe<Arc<Lowered>>;
106
107    /// Computes the lowered representation of a function with a body.
108    /// Additionally applies borrow checking testing, and returns the possible calls per block.
109    fn function_with_body_lowering_with_borrow_check(
110        &self,
111        function_id: ids::FunctionWithBodyId,
112    ) -> Maybe<(Arc<Lowered>, Arc<PotentialDestructCalls>)>;
113
114    /// Computes the lowered representation of a function with a body.
115    fn function_with_body_lowering(
116        &self,
117        function_id: ids::FunctionWithBodyId,
118    ) -> Maybe<Arc<Lowered>>;
119
120    /// Computes the lowered representation of a function at the requested lowering stage.
121    fn lowered_body(
122        &self,
123        function_id: ids::ConcreteFunctionWithBodyId,
124        stage: LoweringStage,
125    ) -> Maybe<Arc<Lowered>>;
126
127    /// Returns the set of direct callees which are functions with body of a concrete function with
128    /// a body (i.e. excluding libfunc callees), at the given stage.
129    fn lowered_direct_callees(
130        &self,
131        function_id: ids::ConcreteFunctionWithBodyId,
132        dependency_type: DependencyType,
133        stage: LoweringStage,
134    ) -> Maybe<Vec<ids::FunctionId>>;
135
136    /// Returns the set of direct callees which are functions with body of a concrete function with
137    /// a body (i.e. excluding libfunc callees), at the given stage.
138    fn lowered_direct_callees_with_body(
139        &self,
140        function_id: ids::ConcreteFunctionWithBodyId,
141        dependency_type: DependencyType,
142        stage: LoweringStage,
143    ) -> Maybe<Vec<ids::ConcreteFunctionWithBodyId>>;
144
145    /// Aggregates function level lowering diagnostics.
146    fn function_with_body_lowering_diagnostics(
147        &self,
148        function_id: ids::FunctionWithBodyId,
149    ) -> Maybe<Diagnostics<LoweringDiagnostic>>;
150    /// Aggregates semantic function level lowering diagnostics - along with all its generated
151    /// function.
152    fn semantic_function_with_body_lowering_diagnostics(
153        &self,
154        function_id: defs::ids::FunctionWithBodyId,
155    ) -> Maybe<Diagnostics<LoweringDiagnostic>>;
156    /// Aggregates module level lowering diagnostics.
157    fn module_lowering_diagnostics(
158        &self,
159        module_id: ModuleId,
160    ) -> Maybe<Diagnostics<LoweringDiagnostic>>;
161
162    /// Aggregates file level lowering diagnostics.
163    fn file_lowering_diagnostics(&self, file_id: FileId) -> Maybe<Diagnostics<LoweringDiagnostic>>;
164
165    // ### Queries related to implicits ###
166
167    /// Returns all the implicit parameters that the function requires (according to both its
168    /// signature and the functions it calls). The items in the returned vector are unique and the
169    /// order is consistent, but not necessarily related to the order of the explicit implicits in
170    /// the signature of the function.
171    #[salsa::invoke(crate::implicits::function_implicits)]
172    fn function_implicits(&self, function: ids::FunctionId) -> Maybe<Vec<TypeId>>;
173
174    /// Returns all the implicits used by a strongly connected component of functions.
175    #[salsa::invoke(crate::implicits::scc_implicits)]
176    fn scc_implicits(&self, function: ConcreteSCCRepresentative) -> Maybe<Vec<TypeId>>;
177
178    // ### Queries related to panics ###
179
180    /// Returns whether the function may panic.
181    #[salsa::invoke(crate::panic::function_may_panic)]
182    fn function_may_panic(&self, function: ids::FunctionId) -> Maybe<bool>;
183
184    /// Returns whether any function in the strongly connected component may panic.
185    #[salsa::invoke(crate::panic::scc_may_panic)]
186    fn scc_may_panic(&self, scc: ConcreteSCCRepresentative) -> Maybe<bool>;
187
188    /// Checks if the function has a block that ends with panic.
189    #[salsa::invoke(crate::panic::has_direct_panic)]
190    fn has_direct_panic(&self, function_id: ids::ConcreteFunctionWithBodyId) -> Maybe<bool>;
191
192    // ### cycles ###
193
194    /// Returns the set of direct callees of a function with a body.
195    #[salsa::invoke(crate::graph_algorithms::cycles::function_with_body_direct_callees)]
196    fn function_with_body_direct_callees(
197        &self,
198        function_id: ids::FunctionWithBodyId,
199        dependency_type: DependencyType,
200    ) -> Maybe<OrderedHashSet<ids::FunctionId>>;
201    /// Returns the set of direct callees which are functions with body of a function with a body
202    /// (i.e. excluding libfunc callees).
203    #[salsa::invoke(
204        crate::graph_algorithms::cycles::function_with_body_direct_function_with_body_callees
205    )]
206    fn function_with_body_direct_function_with_body_callees(
207        &self,
208        function_id: ids::FunctionWithBodyId,
209        dependency_type: DependencyType,
210    ) -> Maybe<OrderedHashSet<ids::FunctionWithBodyId>>;
211
212    /// Returns `true` if the function (in its final lowering representation) calls (possibly
213    /// indirectly) itself, or if it calls (possibly indirectly) such a function. For example, if f0
214    /// calls f1, f1 calls f2, f2 calls f3, and f3 calls f2, then [Self::final_contains_call_cycle]
215    /// will return `true` for all of these functions.
216    #[salsa::invoke(crate::graph_algorithms::cycles::final_contains_call_cycle)]
217    #[salsa::cycle(crate::graph_algorithms::cycles::final_contains_call_cycle_handle_cycle)]
218    fn final_contains_call_cycle(
219        &self,
220        function_id: ids::ConcreteFunctionWithBodyId,
221    ) -> Maybe<bool>;
222
223    /// Returns `true` if the function calls (possibly indirectly) itself. For example, if f0 calls
224    /// f1, f1 calls f2, f2 calls f3, and f3 calls f2, then [Self::in_cycle] will return
225    /// `true` for f2 and f3, but false for f0 and f1.
226    #[salsa::invoke(crate::graph_algorithms::cycles::in_cycle)]
227    fn in_cycle(
228        &self,
229        function_id: ids::FunctionWithBodyId,
230        dependency_type: DependencyType,
231    ) -> Maybe<bool>;
232
233    /// A concrete version of `in_cycle`.
234    #[salsa::invoke(crate::graph_algorithms::cycles::concrete_in_cycle)]
235    fn concrete_in_cycle(
236        &self,
237        function_id: ids::ConcreteFunctionWithBodyId,
238        dependency_type: DependencyType,
239        stage: LoweringStage,
240    ) -> Maybe<bool>;
241
242    // ### Strongly connected components ###
243
244    /// Returns the representative of the concrete function's strongly connected component. The
245    /// representative is consistently chosen for all the concrete functions in the same SCC.
246    #[salsa::invoke(
247        crate::graph_algorithms::strongly_connected_components::lowered_scc_representative
248    )]
249    fn lowered_scc_representative(
250        &self,
251        function: ids::ConcreteFunctionWithBodyId,
252        dependency_type: DependencyType,
253        stage: LoweringStage,
254    ) -> ConcreteSCCRepresentative;
255
256    /// Returns all the concrete functions in the same strongly connected component as the given
257    /// concrete function.
258    #[salsa::invoke(crate::graph_algorithms::strongly_connected_components::lowered_scc)]
259    fn lowered_scc(
260        &self,
261        function_id: ids::ConcreteFunctionWithBodyId,
262        dependency_type: DependencyType,
263        stage: LoweringStage,
264    ) -> Vec<ids::ConcreteFunctionWithBodyId>;
265
266    /// Returns all the functions in the same strongly connected component as the given function.
267    #[salsa::invoke(crate::scc::function_with_body_scc)]
268    fn function_with_body_scc(
269        &self,
270        function_id: ids::FunctionWithBodyId,
271        dependency_type: DependencyType,
272    ) -> Vec<ids::FunctionWithBodyId>;
273
274    // ### Feedback set ###
275
276    /// Returns the feedback-vertex-set of the given concrete function. A feedback-vertex-set is the
277    /// set of vertices whose removal leaves a graph without cycles.
278    #[salsa::invoke(crate::graph_algorithms::feedback_set::function_with_body_feedback_set)]
279    fn function_with_body_feedback_set(
280        &self,
281        function: ids::ConcreteFunctionWithBodyId,
282        stage: LoweringStage,
283    ) -> Maybe<OrderedHashSet<ids::ConcreteFunctionWithBodyId>>;
284
285    /// Returns whether the given function needs an additional withdraw_gas call.
286    #[salsa::invoke(crate::graph_algorithms::feedback_set::needs_withdraw_gas)]
287    fn needs_withdraw_gas(&self, function: ids::ConcreteFunctionWithBodyId) -> Maybe<bool>;
288
289    /// Returns the feedback-vertex-set of the given concrete-function SCC-representative. A
290    /// feedback-vertex-set is the set of vertices whose removal leaves a graph without cycles.
291    #[salsa::invoke(crate::graph_algorithms::feedback_set::priv_function_with_body_feedback_set_of_representative)]
292    fn priv_function_with_body_feedback_set_of_representative(
293        &self,
294        function: ConcreteSCCRepresentative,
295        stage: LoweringStage,
296    ) -> Maybe<OrderedHashSet<ids::ConcreteFunctionWithBodyId>>;
297
298    /// Internal query for reorder_statements to cache the function ids that can be moved.
299    #[salsa::invoke(crate::optimizations::config::priv_movable_function_ids)]
300    fn priv_movable_function_ids(&self) -> Arc<UnorderedHashSet<ExternFunctionId>>;
301
302    /// Internal query for the libfuncs information required for const folding.
303    #[salsa::invoke(crate::optimizations::const_folding::priv_const_folding_info)]
304    fn priv_const_folding_info(
305        &self,
306    ) -> Arc<crate::optimizations::const_folding::ConstFoldingLibfuncInfo>;
307
308    // Internal query for a heuristic to decide if a given `function_id` should be inlined.
309    #[salsa::invoke(crate::inline::priv_should_inline)]
310    fn priv_should_inline(&self, function_id: ids::ConcreteFunctionWithBodyId) -> Maybe<bool>;
311
312    // Internal query for if a function is marked as `#[inline(never)]`.
313    #[salsa::invoke(crate::inline::priv_never_inline)]
314    fn priv_never_inline(&self, function_id: ids::ConcreteFunctionWithBodyId) -> Maybe<bool>;
315
316    /// Returns whether a function should be specalized.
317    #[salsa::invoke(crate::specialization::priv_should_specialize)]
318    fn priv_should_specialize(&self, function_id: ids::ConcreteFunctionWithBodyId) -> Maybe<bool>;
319
320    /// Returns the configuration struct that controls the behavior of the optimization passes.
321    #[salsa::input]
322    fn optimization_config(&self) -> Arc<OptimizationConfig>;
323
324    /// Returns the final optimization strategy that is applied on top of
325    /// inlined_function_optimization_strategy.
326    #[salsa::invoke(crate::optimizations::strategy::final_optimization_strategy)]
327    fn final_optimization_strategy(&self) -> OptimizationStrategyId;
328
329    /// Returns the baseline optimization strategy.
330    /// This strategy is used for inlining decision and as a starting point for the final lowering.
331    #[salsa::invoke(crate::optimizations::strategy::baseline_optimization_strategy)]
332    fn baseline_optimization_strategy(&self) -> OptimizationStrategyId;
333
334    /// Returns the expected size of a type.
335    fn type_size(&self, ty: TypeId) -> usize;
336}
337
338pub fn init_lowering_group(
339    db: &mut (dyn LoweringGroup + 'static),
340    inlining_strategy: InliningStrategy,
341) {
342    let mut moveable_functions: Vec<String> = chain!(
343        ["bool_not_impl"],
344        ["felt252_add", "felt252_sub", "felt252_mul", "felt252_div"],
345        ["array::array_new", "array::array_append"],
346        ["box::unbox", "box::box_forward_snapshot", "box::into_box"],
347    )
348    .map(|s| s.to_string())
349    .collect();
350
351    for ty in ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64"] {
352        moveable_functions.push(format!("integer::{ty}_wide_mul"));
353    }
354
355    db.set_optimization_config(Arc::new(
356        OptimizationConfig::default()
357            .with_moveable_functions(moveable_functions)
358            .with_inlining_strategy(inlining_strategy),
359    ));
360}
361
362#[derive(Debug, Eq, PartialEq, Clone, Hash)]
363pub struct GenericSCCRepresentative(pub ids::FunctionWithBodyId);
364
365#[derive(Debug, Eq, PartialEq, Clone, Hash)]
366pub struct ConcreteSCCRepresentative(pub ids::ConcreteFunctionWithBodyId);
367
368// *** Main lowering phases in order.
369
370fn priv_function_with_body_multi_lowering(
371    db: &dyn LoweringGroup,
372    function_id: defs::ids::FunctionWithBodyId,
373) -> Maybe<Arc<MultiLowering>> {
374    let crate_id = function_id.module_file_id(db).0.owning_crate(db);
375    if let Some(map) = db.cached_multi_lowerings(crate_id) {
376        if let Some(multi_lowering) = map.get(&function_id) {
377            return Ok(Arc::new(multi_lowering.clone()));
378        } else {
379            panic!("function not found in cached lowering {:?}", function_id.debug(db));
380        }
381    };
382
383    let multi_lowering = lower_semantic_function(db, function_id)?;
384    Ok(Arc::new(multi_lowering))
385}
386
387fn cached_multi_lowerings(
388    db: &dyn LoweringGroup,
389    crate_id: cairo_lang_filesystem::ids::CrateId,
390) -> Option<Arc<OrderedHashMap<defs::ids::FunctionWithBodyId, MultiLowering>>> {
391    load_cached_crate_functions(db, crate_id)
392}
393
394// * Borrow checking.
395fn priv_function_with_body_lowering(
396    db: &dyn LoweringGroup,
397    function_id: ids::FunctionWithBodyId,
398) -> Maybe<Arc<Lowered>> {
399    let semantic_function_id = function_id.base_semantic_function(db);
400    let multi_lowering = db.priv_function_with_body_multi_lowering(semantic_function_id)?;
401    let lowered = match &function_id.lookup_intern(db) {
402        ids::FunctionWithBodyLongId::Semantic(_) => multi_lowering.main_lowering.clone(),
403        ids::FunctionWithBodyLongId::Generated { key, .. } => {
404            multi_lowering.generated_lowerings[key].clone()
405        }
406    };
407    Ok(Arc::new(lowered))
408}
409
410fn function_with_body_lowering_with_borrow_check(
411    db: &dyn LoweringGroup,
412    function_id: ids::FunctionWithBodyId,
413) -> Maybe<(Arc<Lowered>, Arc<PotentialDestructCalls>)> {
414    let lowered = db.priv_function_with_body_lowering(function_id)?;
415    let borrow_check_result =
416        borrow_check(db, function_id.to_concrete(db)?.is_panic_destruct_fn(db)?, &lowered);
417
418    let lowered = match borrow_check_result.diagnostics.check_error_free() {
419        Ok(_) => lowered,
420        Err(diag_added) => Arc::new(Lowered {
421            diagnostics: borrow_check_result.diagnostics,
422            signature: lowered.signature.clone(),
423            variables: lowered.variables.clone(),
424            blocks: Blocks::new_errored(diag_added),
425            parameters: lowered.parameters.clone(),
426        }),
427    };
428
429    Ok((lowered, Arc::new(borrow_check_result.block_extra_calls)))
430}
431
432fn function_with_body_lowering(
433    db: &dyn LoweringGroup,
434    function_id: ids::FunctionWithBodyId,
435) -> Maybe<Arc<Lowered>> {
436    Ok(db.function_with_body_lowering_with_borrow_check(function_id)?.0)
437}
438
439fn lowered_body(
440    db: &dyn LoweringGroup,
441    function: ids::ConcreteFunctionWithBodyId,
442    stage: LoweringStage,
443) -> Maybe<Arc<Lowered>> {
444    let lowered = match stage {
445        LoweringStage::Monomorphized => match function.generic_or_specialized(db) {
446            GenericOrSpecialized::Generic(generic_function_id) => {
447                db.function_with_body_lowering_diagnostics(generic_function_id)?
448                    .check_error_free()?;
449                let mut lowered = (*db.function_with_body_lowering(generic_function_id)?).clone();
450                concretize_lowered(db, &mut lowered, &function.substitution(db)?)?;
451                lowered
452            }
453            GenericOrSpecialized::Specialized(specialized) => {
454                specialized_function_lowered(db, specialized)?
455            }
456        },
457        LoweringStage::PreOptimizations => {
458            let mut lowered = (*db.lowered_body(function, LoweringStage::Monomorphized)?).clone();
459            add_withdraw_gas(db, function, &mut lowered)?;
460            lower_panics(db, function, &mut lowered)?;
461            add_destructs(db, function, &mut lowered);
462            scrub_units(db, &mut lowered);
463            lowered
464        }
465        LoweringStage::PostBaseline => {
466            let mut lowered =
467                (*db.lowered_body(function, LoweringStage::PreOptimizations)?).clone();
468            db.baseline_optimization_strategy().apply_strategy(db, function, &mut lowered)?;
469            lowered
470        }
471        LoweringStage::Final => {
472            let mut lowered = (*db.lowered_body(function, LoweringStage::PostBaseline)?).clone();
473            db.final_optimization_strategy().apply_strategy(db, function, &mut lowered)?;
474            lowered
475        }
476    };
477    Ok(Arc::new(lowered))
478}
479
480/// Given the lowering of a function, returns the set of direct dependencies of that function,
481/// according to the given [DependencyType]. See [DependencyType] for more information about
482/// what is considered a dependency.
483pub(crate) fn get_direct_callees(
484    db: &dyn LoweringGroup,
485    lowered_function: &Lowered,
486    dependency_type: DependencyType,
487    block_extra_calls: &UnorderedHashMap<BlockId, Vec<FunctionId>>,
488) -> Vec<ids::FunctionId> {
489    let mut direct_callees = Vec::new();
490    if lowered_function.blocks.is_empty() {
491        return direct_callees;
492    }
493    let withdraw_gas_fns =
494        corelib::core_withdraw_gas_fns(db).map(|id| FunctionLongId::Semantic(id).intern(db));
495    let mut visited = vec![false; lowered_function.blocks.len()];
496    let mut stack = vec![BlockId(0)];
497    while let Some(block_id) = stack.pop() {
498        if visited[block_id.0] {
499            continue;
500        }
501        visited[block_id.0] = true;
502        let block = &lowered_function.blocks[block_id];
503        for statement in &block.statements {
504            if let Statement::Call(statement_call) = statement {
505                // If the dependency_type is DependencyType::Cost and this call has a coupon input,
506                // then the call statement has a constant cost and therefore there
507                // is no cost dependency in the called function.
508                if dependency_type != DependencyType::Cost || !statement_call.with_coupon {
509                    direct_callees.push(statement_call.function);
510                }
511            }
512        }
513        if let Some(extra_calls) = block_extra_calls.get(&block_id) {
514            direct_callees.extend(extra_calls.iter().copied());
515        }
516        match &block.end {
517            BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(_) => {}
518            BlockEnd::Goto(next, _) => stack.push(*next),
519            BlockEnd::Match { info } => {
520                let mut arms = info.arms().iter();
521                if let MatchInfo::Extern(s) = info {
522                    direct_callees.push(s.function);
523                    if DependencyType::Cost == dependency_type
524                        && withdraw_gas_fns.contains(&s.function)
525                    {
526                        // Not following the option when successfully fetched gas.
527                        arms.next();
528                    }
529                }
530                stack.extend(arms.map(|arm| arm.block_id));
531            }
532        }
533    }
534    direct_callees
535}
536
537/// Given a vector of FunctionIds returns the vector of FunctionWithBodyIds of the
538/// [ids::ConcreteFunctionWithBodyId]s.
539///
540/// If `dependency_type` is `DependencyType::Cost`, returns the coupon functions when
541/// `coupon_buy` and `coupon_refund` are encountered.
542/// For example, for `coupon_buy::<foo::Coupon>()`, `foo` will be added to the list.
543fn functions_with_body_from_function_ids(
544    db: &dyn LoweringGroup,
545    function_ids: Vec<ids::FunctionId>,
546    dependency_type: DependencyType,
547) -> Maybe<Vec<ids::ConcreteFunctionWithBodyId>> {
548    Ok(function_ids
549        .into_iter()
550        .map(|concrete| {
551            if dependency_type == DependencyType::Cost {
552                if let Some(function_with_body) = extract_coupon_function(db, concrete)? {
553                    return Ok(Some(function_with_body));
554                }
555            }
556            concrete.body(db)
557        })
558        .collect::<Maybe<Vec<_>>>()?
559        .into_iter()
560        .flatten()
561        .collect_vec())
562}
563
564/// Given a [ids::FunctionId] that represents `coupon_buy` or `coupon_refund`, returns the coupon's
565/// function.
566///
567/// For example, `coupon_buy::<foo::Coupon>` will return `foo`.
568fn extract_coupon_function(
569    db: &dyn LoweringGroup,
570    concrete: ids::FunctionId,
571) -> Maybe<Option<ids::ConcreteFunctionWithBodyId>> {
572    // Check that the function is a semantic function.
573    let ids::FunctionLongId::Semantic(function_id) = concrete.lookup_intern(db) else {
574        return Ok(None);
575    };
576
577    // Check that it's an extern function named "coupon_buy" or "coupon_refund".
578    let concrete_function = function_id.get_concrete(db);
579    let generic_function = concrete_function.generic_function;
580    let semantic::items::functions::GenericFunctionId::Extern(extern_function_id) =
581        generic_function
582    else {
583        return Ok(None);
584    };
585    let name = extern_function_id.lookup_intern(db).name(db);
586    if !(name == "coupon_buy" || name == "coupon_refund") {
587        return Ok(None);
588    }
589
590    // Extract the coupon function from the generic argument.
591    let [semantic::GenericArgumentId::Type(type_id)] = concrete_function.generic_args[..] else {
592        panic!("Unexpected generic_args for coupon_buy().");
593    };
594    let semantic::TypeLongId::Coupon(coupon_function) = type_id.lookup_intern(db) else {
595        panic!("Unexpected generic_args for coupon_buy().");
596    };
597
598    // Convert [semantic::FunctionId] to [ids::ConcreteFunctionWithBodyId].
599    let Some(coupon_function_with_body_id) = coupon_function.get_concrete(db).body(db)? else {
600        panic!("Unexpected generic_args for coupon_buy().");
601    };
602
603    Ok(Some(ids::ConcreteFunctionWithBodyId::from_semantic(db, coupon_function_with_body_id)))
604}
605
606fn lowered_direct_callees(
607    db: &dyn LoweringGroup,
608    function_id: ids::ConcreteFunctionWithBodyId,
609    dependency_type: DependencyType,
610    stage: LoweringStage,
611) -> Maybe<Vec<ids::FunctionId>> {
612    let lowered_function = db.lowered_body(function_id, stage)?;
613    Ok(get_direct_callees(db, &lowered_function, dependency_type, &Default::default()))
614}
615
616fn lowered_direct_callees_with_body(
617    db: &dyn LoweringGroup,
618    function_id: ids::ConcreteFunctionWithBodyId,
619    dependency_type: DependencyType,
620    stage: LoweringStage,
621) -> Maybe<Vec<ids::ConcreteFunctionWithBodyId>> {
622    functions_with_body_from_function_ids(
623        db,
624        db.lowered_direct_callees(function_id, dependency_type, stage)?,
625        dependency_type,
626    )
627}
628
629fn function_with_body_lowering_diagnostics(
630    db: &dyn LoweringGroup,
631    function_id: ids::FunctionWithBodyId,
632) -> Maybe<Diagnostics<LoweringDiagnostic>> {
633    let mut diagnostics = DiagnosticsBuilder::default();
634
635    if let Ok(lowered) = db.function_with_body_lowering(function_id) {
636        diagnostics.extend(lowered.diagnostics.clone());
637        if flag_add_withdraw_gas(db) && db.in_cycle(function_id, DependencyType::Cost)? {
638            let location =
639                Location::new(function_id.base_semantic_function(db).stable_location(db));
640            if !lowered.signature.panicable {
641                diagnostics.add(LoweringDiagnostic {
642                    location: location.clone(),
643                    kind: LoweringDiagnosticKind::NoPanicFunctionCycle,
644                });
645            }
646            borrow_check_possible_withdraw_gas(db, location.intern(db), &lowered, &mut diagnostics)
647        }
648    }
649
650    if let Ok(diag) = get_inline_diagnostics(db, function_id) {
651        diagnostics.extend(diag);
652    }
653
654    Ok(diagnostics.build())
655}
656
657fn semantic_function_with_body_lowering_diagnostics(
658    db: &dyn LoweringGroup,
659    semantic_function_id: defs::ids::FunctionWithBodyId,
660) -> Maybe<Diagnostics<LoweringDiagnostic>> {
661    let mut diagnostics = DiagnosticsBuilder::default();
662
663    if let Ok(multi_lowering) = db.priv_function_with_body_multi_lowering(semantic_function_id) {
664        let function_id = ids::FunctionWithBodyLongId::Semantic(semantic_function_id).intern(db);
665        diagnostics
666            .extend(db.function_with_body_lowering_diagnostics(function_id).unwrap_or_default());
667        for (key, _) in multi_lowering.generated_lowerings.iter() {
668            let function_id =
669                ids::FunctionWithBodyLongId::Generated { parent: semantic_function_id, key: *key }
670                    .intern(db);
671            diagnostics.extend(
672                db.function_with_body_lowering_diagnostics(function_id).unwrap_or_default(),
673            );
674        }
675    }
676
677    Ok(diagnostics.build())
678}
679
680fn module_lowering_diagnostics(
681    db: &dyn LoweringGroup,
682    module_id: ModuleId,
683) -> Maybe<Diagnostics<LoweringDiagnostic>> {
684    let mut diagnostics = DiagnosticsBuilder::default();
685    for item in db.module_items(module_id)?.iter() {
686        match item {
687            ModuleItemId::FreeFunction(free_function) => {
688                let function_id = defs::ids::FunctionWithBodyId::Free(*free_function);
689                diagnostics
690                    .extend(db.semantic_function_with_body_lowering_diagnostics(function_id)?);
691            }
692            ModuleItemId::Constant(_) => {}
693            ModuleItemId::Submodule(_) => {}
694            ModuleItemId::Use(_) => {}
695            ModuleItemId::Struct(_) => {}
696            ModuleItemId::Enum(_) => {}
697            ModuleItemId::TypeAlias(_) => {}
698            ModuleItemId::ImplAlias(_) => {}
699            ModuleItemId::Trait(trait_id) => {
700                for trait_func in db.trait_functions(*trait_id)?.values() {
701                    if matches!(db.trait_function_body(*trait_func), Ok(Some(_))) {
702                        let function_id = defs::ids::FunctionWithBodyId::Trait(*trait_func);
703                        diagnostics.extend(
704                            db.semantic_function_with_body_lowering_diagnostics(function_id)?,
705                        );
706                    }
707                }
708            }
709            ModuleItemId::Impl(impl_def_id) => {
710                for impl_func in db.impl_functions(*impl_def_id)?.values() {
711                    let function_id = defs::ids::FunctionWithBodyId::Impl(*impl_func);
712                    diagnostics
713                        .extend(db.semantic_function_with_body_lowering_diagnostics(function_id)?);
714                }
715            }
716            ModuleItemId::ExternType(_) => {}
717            ModuleItemId::ExternFunction(_) => {}
718            ModuleItemId::MacroDeclaration(_) => {}
719        }
720    }
721    Ok(diagnostics.build())
722}
723
724fn file_lowering_diagnostics(
725    db: &dyn LoweringGroup,
726    file_id: FileId,
727) -> Maybe<Diagnostics<LoweringDiagnostic>> {
728    let mut diagnostics = DiagnosticsBuilder::default();
729    for module_id in db.file_modules(file_id)?.iter().copied() {
730        if let Ok(module_diagnostics) = db.module_lowering_diagnostics(module_id) {
731            diagnostics.extend(module_diagnostics)
732        }
733    }
734    Ok(diagnostics.build())
735}
736
737fn type_size(db: &dyn LoweringGroup, ty: TypeId) -> usize {
738    match ty.lookup_intern(db) {
739        TypeLongId::Concrete(concrete_type_id) => match concrete_type_id {
740            ConcreteTypeId::Struct(struct_id) => db
741                .concrete_struct_members(struct_id)
742                .unwrap()
743                .iter()
744                .map(|(_, member)| db.type_size(member.ty))
745                .sum::<usize>(),
746            ConcreteTypeId::Enum(enum_id) => {
747                1 + db
748                    .concrete_enum_variants(enum_id)
749                    .unwrap()
750                    .into_iter()
751                    .map(|variant| db.type_size(variant.ty))
752                    .max()
753                    .unwrap_or_default()
754            }
755            ConcreteTypeId::Extern(extern_id) => {
756                match extern_id.extern_type_id(db).name(db).as_str() {
757                    "Array" | "SquashedFelt252Dict" | "EcPoint" => 2,
758                    "EcState" => 3,
759                    "Uint128MulGuarantee" => 4,
760                    _ => 1,
761                }
762            }
763        },
764        TypeLongId::Tuple(types) => types.into_iter().map(|ty| db.type_size(ty)).sum::<usize>(),
765        TypeLongId::Snapshot(ty) => db.type_size(ty),
766        TypeLongId::FixedSizeArray { type_id, size } => {
767            db.type_size(type_id)
768                * size
769                    .lookup_intern(db)
770                    .into_int()
771                    .expect("Expected ConstValue::Int for size")
772                    .to_usize()
773                    .unwrap()
774        }
775        TypeLongId::Closure(closure_ty) => {
776            closure_ty.captured_types.iter().map(|ty| db.type_size(*ty)).sum()
777        }
778        TypeLongId::Coupon(_) => 0,
779        TypeLongId::GenericParameter(_)
780        | TypeLongId::Var(_)
781        | TypeLongId::ImplType(_)
782        | TypeLongId::Missing(_) => {
783            panic!("Function should only be called with fully concrete types")
784        }
785    }
786}