cairo_lang_lowering/
specialization.rs

1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_proc_macros::HeapSize;
6use cairo_lang_semantic::helper::ModuleHelper;
7use cairo_lang_semantic::items::constant::ConstValueId;
8use cairo_lang_semantic::items::functions::GenericFunctionId;
9use cairo_lang_semantic::items::structure::StructSemantic;
10use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
11use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
12use cairo_lang_utils::extract_matches;
13use itertools::{Itertools, chain, zip_eq};
14use salsa::Database;
15
16use crate::blocks::BlocksBuilder;
17use crate::db::LoweringGroup;
18use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
19use crate::lower::context::{VarRequest, VariableAllocator};
20use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
21use crate::{
22    Block, BlockEnd, Lowered, LoweringStage, Statement, StatementCall, StatementConst,
23    StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
24};
25
26// A const argument for a specialized function.
27#[derive(Clone, Debug, Hash, PartialEq, Eq, HeapSize)]
28pub enum SpecializationArg<'db> {
29    Const { value: ConstValueId<'db>, boxed: bool },
30    Snapshot(Box<SpecializationArg<'db>>),
31    Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
32    Struct(Vec<SpecializationArg<'db>>),
33    Enum { variant: ConcreteVariant<'db>, payload: Box<SpecializationArg<'db>> },
34    NotSpecialized,
35}
36
37impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
38    type Db = dyn Database;
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
40        match self {
41            SpecializationArg::Const { value, boxed } => {
42                write!(f, "{:?}", value.debug(db))?;
43                if *boxed {
44                    write!(f, ".into_box()")?;
45                }
46                Ok(())
47            }
48            SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
49            SpecializationArg::Struct(args) => {
50                write!(f, "{{")?;
51                let mut inner = args.iter().peekable();
52                while let Some(value) = inner.next() {
53                    write!(f, " ")?;
54                    value.fmt(f, db)?;
55
56                    if inner.peek().is_some() {
57                        write!(f, ",")?;
58                    } else {
59                        write!(f, " ")?;
60                    }
61                }
62                write!(f, "}}")
63            }
64            SpecializationArg::Array(_ty, values) => {
65                write!(f, "array![")?;
66                let mut first = true;
67                for value in values {
68                    if !first {
69                        write!(f, ", ")?;
70                    } else {
71                        first = false;
72                    }
73                    write!(f, "{:?}", value.debug(db))?;
74                }
75                write!(f, "]")
76            }
77            SpecializationArg::Enum { variant, payload } => {
78                write!(f, "{:?}(", variant.debug(db))?;
79                payload.fmt(f, db)?;
80                write!(f, ")")
81            }
82            SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
83        }
84    }
85}
86
87/// The state of the specialization arg building process.
88/// currently only structs require an additional build step.
89enum SpecializationArgBuildingState<'db, 'a> {
90    Initial(&'a SpecializationArg<'db>),
91    TakeSnapshot(VariableId),
92    BuildStruct(Vec<VariableId>),
93    PushBackArray { in_array: VariableId, value: VariableId },
94    BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
95}
96
97/// Returns the lowering of a specialized function.
98pub fn specialized_function_lowered<'db>(
99    db: &'db dyn Database,
100    specialized: SpecializedFunction<'db>,
101) -> Maybe<Lowered<'db>> {
102    let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
103    let base_semantic = specialized.base.base_semantic_function(db);
104
105    let array_module = ModuleHelper::core(db).submodule("array");
106    let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
107    let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
108
109    let mut variables =
110        VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
111    let mut statements = vec![];
112    let mut parameters = vec![];
113    let mut inputs = vec![];
114    let mut stack = vec![];
115
116    let location = LocationId::from_stable_location(
117        db,
118        specialized.base.base_semantic_function(db).stable_location(db),
119    );
120
121    for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
122        let var_id = variables.variables.alloc(base.variables[*param].clone());
123        inputs.push(VarUsage { var_id, location });
124        if SpecializationArg::NotSpecialized == *arg {
125            parameters.push(var_id);
126            continue;
127        }
128        stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
129        while let Some((var_id, state)) = stack.pop() {
130            match state {
131                SpecializationArgBuildingState::Initial(c) => match c {
132                    SpecializationArg::Const { value, boxed } => {
133                        if db.type_size_info(variables[var_id].ty)?
134                            == TypeSizeInformation::ZeroSized
135                        {
136                            assert!(
137                                !boxed,
138                                "Zero sized specialization arguments should only be part of \
139                                 consts and therefore cannot be boxed"
140                            );
141                            statements.push(Statement::StructConstruct(StatementStructConstruct {
142                                inputs: vec![],
143                                output: var_id,
144                            }));
145                        } else {
146                            statements.push(Statement::Const(StatementConst::new(
147                                *value, var_id, *boxed,
148                            )));
149                        }
150                    }
151                    SpecializationArg::Snapshot(inner) => {
152                        let snap_ty = variables.variables[var_id].ty;
153                        let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
154                        let desnapped_var =
155                            variables.new_var(VarRequest { ty: denapped_ty, location });
156                        stack.push((
157                            var_id,
158                            SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
159                        ));
160                        stack.push((
161                            desnapped_var,
162                            SpecializationArgBuildingState::Initial(inner.as_ref()),
163                        ));
164                    }
165                    SpecializationArg::Array(ty, values) => {
166                        let mut arr_var = var_id;
167                        for value in values.iter().rev() {
168                            let in_arr_var =
169                                variables.variables.alloc(variables.variables[var_id].clone());
170                            let value_var = variables.new_var(VarRequest { ty: *ty, location });
171                            stack.push((
172                                arr_var,
173                                SpecializationArgBuildingState::PushBackArray {
174                                    in_array: in_arr_var,
175                                    value: value_var,
176                                },
177                            ));
178                            stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
179                            arr_var = in_arr_var;
180                        }
181                        statements.push(Statement::Call(StatementCall {
182                            function: array_new_fn
183                                .concretize(db, vec![GenericArgumentId::Type(*ty)])
184                                .lowered(db),
185                            inputs: vec![],
186                            with_coupon: false,
187                            outputs: vec![arr_var],
188                            location: variables[var_id].location,
189                            is_specialization_base_call: false,
190                        }));
191                    }
192                    SpecializationArg::Struct(args) => {
193                        let var = &variables[var_id];
194                        let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
195                            var.ty.long(db)
196                        else {
197                            unreachable!("Expected a concrete struct type");
198                        };
199
200                        let members = db.concrete_struct_members(*concrete_struct)?;
201
202                        let location = var.location;
203                        let var_ids = members
204                            .values()
205                            .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
206                            .collect_vec();
207
208                        stack.push((
209                            var_id,
210                            SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
211                        ));
212
213                        for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
214                            stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
215                        }
216                    }
217                    SpecializationArg::Enum { variant, payload } => {
218                        let location = variables[var_id].location;
219                        let payload_var =
220                            variables.new_var(VarRequest { ty: variant.ty, location });
221                        stack.push((
222                            var_id,
223                            SpecializationArgBuildingState::BuildEnum {
224                                variant: *variant,
225                                payload: payload_var,
226                            },
227                        ));
228                        stack.push((
229                            payload_var,
230                            SpecializationArgBuildingState::Initial(payload.as_ref()),
231                        ));
232                    }
233                    SpecializationArg::NotSpecialized => {
234                        parameters.push(var_id);
235                    }
236                },
237                SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
238                    let ignored = variables.variables.alloc(variables[desnapped_var].clone());
239                    statements.push(Statement::Snapshot(StatementSnapshot::new(
240                        VarUsage { var_id: desnapped_var, location },
241                        ignored,
242                        var_id,
243                    )));
244                }
245                SpecializationArgBuildingState::PushBackArray { in_array, value } => {
246                    statements.push(Statement::Call(StatementCall {
247                        function: array_append
248                            .concretize(
249                                db,
250                                vec![GenericArgumentId::Type(variables.variables[value].ty)],
251                            )
252                            .lowered(db),
253                        inputs: vec![
254                            VarUsage { var_id: in_array, location },
255                            VarUsage { var_id: value, location },
256                        ],
257                        with_coupon: false,
258                        outputs: vec![var_id],
259                        location,
260                        is_specialization_base_call: false,
261                    }));
262                }
263                SpecializationArgBuildingState::BuildStruct(ids) => {
264                    statements.push(Statement::StructConstruct(StatementStructConstruct {
265                        inputs: ids
266                            .iter()
267                            .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
268                            .collect(),
269                        output: var_id,
270                    }));
271                }
272                SpecializationArgBuildingState::BuildEnum { variant, payload } => {
273                    statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
274                        variant,
275                        input: VarUsage { var_id: payload, location: variables[payload].location },
276                        output: var_id,
277                    }));
278                }
279            }
280        }
281    }
282
283    let outputs: Vec<VariableId> =
284        chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
285            .map(|ty| variables.new_var(VarRequest { ty, location }))
286            .collect_vec();
287    let mut block_builder = BlocksBuilder::new();
288    let ret_usage =
289        outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
290    statements.push(Statement::Call(StatementCall {
291        function: specialized.base.function_id(db)?,
292        with_coupon: false,
293        inputs,
294        outputs,
295        location,
296        is_specialization_base_call: true,
297    }));
298    block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
299    Ok(Lowered {
300        signature: specialized.signature(db)?,
301        variables: variables.variables,
302        blocks: block_builder.build().unwrap(),
303        parameters,
304        diagnostics: Default::default(),
305    })
306}
307
308/// Query implementation of [LoweringGroup::priv_should_specialize].
309#[salsa::tracked]
310pub fn priv_should_specialize<'db>(
311    db: &'db dyn Database,
312    function_id: ids::ConcreteFunctionWithBodyId<'db>,
313) -> Maybe<bool> {
314    let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
315        function_id.long(db)
316    else {
317        panic!("Expected a specialized function");
318    };
319
320    // The heuristic is that the size is 8/10*orig_size > specialized_size of the original size.
321    Ok(db.estimate_size(*base)?.saturating_mul(8)
322        > db.estimate_size(function_id)?.saturating_mul(10))
323}