cairo_lang_lowering/
specialization.rs

1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_semantic::helper::ModuleHelper;
6use cairo_lang_semantic::items::constant::ConstValueId;
7use cairo_lang_semantic::items::functions::GenericFunctionId;
8use cairo_lang_semantic::items::structure::StructSemantic;
9use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeId, TypeLongId};
10use cairo_lang_utils::extract_matches;
11use itertools::{Itertools, chain, zip_eq};
12use salsa::Database;
13
14use crate::blocks::BlocksBuilder;
15use crate::db::LoweringGroup;
16use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
17use crate::lower::context::{VarRequest, VariableAllocator};
18use crate::{
19    Block, BlockEnd, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
20    StatementConst, StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
21};
22
23// A const argument for a specialized function.
24#[derive(Clone, Debug, Hash, PartialEq, Eq)]
25pub enum SpecializationArg<'db> {
26    Const { value: ConstValueId<'db>, boxed: bool },
27    Snapshot(Box<SpecializationArg<'db>>),
28    Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
29    Struct(Vec<SpecializationArg<'db>>),
30}
31
32impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
33    type Db = dyn Database;
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
35        match self {
36            SpecializationArg::Const { value, boxed } => {
37                write!(f, "{:?}", value.debug(db))?;
38                if *boxed {
39                    write!(f, ".into_box()")?;
40                }
41                Ok(())
42            }
43            SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
44            SpecializationArg::Struct(inner) => {
45                write!(f, "{{")?;
46                let mut inner = inner.iter().peekable();
47                while let Some(value) = inner.next() {
48                    write!(f, " ")?;
49                    value.fmt(f, db)?;
50
51                    if inner.peek().is_some() {
52                        write!(f, ",")?;
53                    } else {
54                        write!(f, " ")?;
55                    }
56                }
57                write!(f, "}}")
58            }
59            SpecializationArg::Array(_ty, values) => {
60                write!(f, "array![")?;
61                let mut first = true;
62                for value in values {
63                    if !first {
64                        write!(f, ", ")?;
65                    } else {
66                        first = false;
67                    }
68                    write!(f, "{:?}", value.debug(db))?;
69                }
70                write!(f, "]")
71            }
72        }
73    }
74}
75
76/// The state of the specialization arg building process.
77/// currently only structs require an additional build step.
78enum SpecializationArgBuildingState<'db, 'a> {
79    Initial(&'a SpecializationArg<'db>),
80    TakeSnapshot(VariableId),
81    BuildStruct(Vec<VariableId>),
82    PushBackArray { in_array: VariableId, value: VariableId },
83}
84
85/// Returns the lowering of a specialized function.
86pub fn specialized_function_lowered<'db>(
87    db: &'db dyn Database,
88    specialized: SpecializedFunction<'db>,
89) -> Maybe<Lowered<'db>> {
90    let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
91    let base_semantic = specialized.base.base_semantic_function(db);
92
93    let array_module = ModuleHelper::core(db).submodule("array");
94    let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
95    let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
96
97    let mut variables =
98        VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
99    let mut statements = vec![];
100    let mut parameters = vec![];
101    let mut inputs = vec![];
102    let mut stack = vec![];
103
104    let location = LocationId::from_stable_location(
105        db,
106        specialized.base.base_semantic_function(db).stable_location(db),
107    );
108
109    for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
110        let var_id = variables.variables.alloc(base.variables[*param].clone());
111        inputs.push(VarUsage { var_id, location });
112        if let Some(c) = arg {
113            stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
114            continue;
115        }
116        parameters.push(var_id);
117    }
118
119    while let Some((var_id, state)) = stack.pop() {
120        match state {
121            SpecializationArgBuildingState::Initial(c) => match c {
122                SpecializationArg::Const { value, boxed } => {
123                    statements.push(Statement::Const(StatementConst::new(*value, var_id, *boxed)));
124                }
125                SpecializationArg::Snapshot(inner) => {
126                    let snap_ty = variables.variables[var_id].ty;
127                    let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
128                    let desnapped_var = variables.new_var(VarRequest { ty: denapped_ty, location });
129                    stack.push((
130                        var_id,
131                        SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
132                    ));
133                    stack.push((
134                        desnapped_var,
135                        SpecializationArgBuildingState::Initial(inner.as_ref()),
136                    ));
137                }
138                SpecializationArg::Array(ty, values) => {
139                    let mut arr_var = var_id;
140                    for value in values.iter().rev() {
141                        let in_arr_var =
142                            variables.variables.alloc(variables.variables[var_id].clone());
143                        let value_var = variables.new_var(VarRequest { ty: *ty, location });
144                        stack.push((
145                            arr_var,
146                            SpecializationArgBuildingState::PushBackArray {
147                                in_array: in_arr_var,
148                                value: value_var,
149                            },
150                        ));
151                        stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
152                        arr_var = in_arr_var;
153                    }
154                    statements.push(Statement::Call(StatementCall {
155                        function: array_new_fn
156                            .concretize(db, vec![GenericArgumentId::Type(*ty)])
157                            .lowered(db),
158                        inputs: vec![],
159                        with_coupon: false,
160                        outputs: vec![arr_var],
161                        location: variables[var_id].location,
162                    }));
163                }
164                SpecializationArg::Struct(consts) => {
165                    let var = &variables[var_id];
166                    let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
167                        var.ty.long(db)
168                    else {
169                        unreachable!("Expected a concrete struct type");
170                    };
171
172                    let members = db.concrete_struct_members(*concrete_struct)?;
173
174                    let location = var.location;
175                    let var_ids = members
176                        .values()
177                        .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
178                        .collect_vec();
179
180                    stack.push((
181                        var_id,
182                        SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
183                    ));
184
185                    for (var_id, c) in zip_eq(var_ids, consts) {
186                        stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
187                    }
188                }
189            },
190            SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
191                let ignored = variables.variables.alloc(variables[desnapped_var].clone());
192                statements.push(Statement::Snapshot(StatementSnapshot::new(
193                    VarUsage { var_id: desnapped_var, location },
194                    ignored,
195                    var_id,
196                )));
197            }
198            SpecializationArgBuildingState::PushBackArray { in_array, value } => {
199                statements.push(Statement::Call(StatementCall {
200                    function: array_append
201                        .concretize(
202                            db,
203                            vec![GenericArgumentId::Type(variables.variables[value].ty)],
204                        )
205                        .lowered(db),
206                    inputs: vec![
207                        VarUsage { var_id: in_array, location },
208                        VarUsage { var_id: value, location },
209                    ],
210                    with_coupon: false,
211                    outputs: vec![var_id],
212                    location,
213                }));
214            }
215            SpecializationArgBuildingState::BuildStruct(ids) => {
216                statements.push(Statement::StructConstruct(StatementStructConstruct {
217                    inputs: ids
218                        .iter()
219                        .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
220                        .collect(),
221                    output: var_id,
222                }));
223            }
224        }
225    }
226
227    let outputs: Vec<VariableId> =
228        chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
229            .map(|ty| variables.new_var(VarRequest { ty, location }))
230            .collect_vec();
231    let mut block_builder = BlocksBuilder::new();
232    let ret_usage =
233        outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
234    statements.push(Statement::Call(StatementCall {
235        function: specialized.base.function_id(db)?,
236        with_coupon: false,
237        inputs,
238        outputs,
239        location,
240    }));
241    block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
242    Ok(Lowered {
243        signature: specialized.signature(db)?,
244        variables: variables.variables,
245        blocks: block_builder.build().unwrap(),
246        parameters,
247        diagnostics: Default::default(),
248    })
249}
250
251/// Query implementation of [LoweringGroup::priv_should_specialize].
252#[salsa::tracked]
253pub fn priv_should_specialize<'db>(
254    db: &'db dyn Database,
255    function_id: ids::ConcreteFunctionWithBodyId<'db>,
256) -> Maybe<bool> {
257    let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
258        function_id.long(db)
259    else {
260        panic!("Expected a specialized function");
261    };
262
263    // Breaks cycles.
264    // We cannot estimate the size of functions in a cycle, since the implicits computation requires
265    // the finalized lowering of all the functions in the cycle which requires us to know the
266    // answer of the current function.
267    if db.concrete_in_cycle(*base, DependencyType::Call, LoweringStage::Monomorphized)? {
268        return Ok(false);
269    }
270
271    // The heuristic is that the size is 8/10*orig_size > specialized_size of the original size.
272    Ok(db.estimate_size(*base)?.saturating_mul(8)
273        > db.estimate_size(function_id)?.saturating_mul(10))
274}