cairo_lang_lowering/
specialization.rs

1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_proc_macros::HeapSize;
6use cairo_lang_semantic::helper::ModuleHelper;
7use cairo_lang_semantic::items::constant::ConstValueId;
8use cairo_lang_semantic::items::functions::GenericFunctionId;
9use cairo_lang_semantic::items::structure::StructSemantic;
10use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
11use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
12use cairo_lang_utils::extract_matches;
13use itertools::{Itertools, chain, zip_eq};
14use salsa::Database;
15
16use crate::blocks::BlocksBuilder;
17use crate::db::LoweringGroup;
18use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
19use crate::lower::context::{VarRequest, VariableAllocator};
20use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
21use crate::{
22    Block, BlockEnd, Lowered, LoweringStage, Statement, StatementCall, StatementConst,
23    StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
24};
25
26// A const argument for a specialized function.
27#[derive(Clone, Debug, Hash, PartialEq, Eq, HeapSize)]
28pub enum SpecializationArg<'db> {
29    Const {
30        value: ConstValueId<'db>,
31        boxed: bool,
32    },
33    Snapshot(Box<SpecializationArg<'db>>),
34    Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
35    /// Represents struct, tuple, or fixed-size array.
36    Struct(Vec<SpecializationArg<'db>>),
37    Enum {
38        variant: ConcreteVariant<'db>,
39        payload: Box<SpecializationArg<'db>>,
40    },
41    NotSpecialized,
42}
43
44impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
45    type Db = dyn Database;
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
47        match self {
48            SpecializationArg::Const { value, boxed } => {
49                write!(f, "{:?}", value.debug(db))?;
50                if *boxed {
51                    write!(f, ".into_box()")?;
52                }
53                Ok(())
54            }
55            SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
56            SpecializationArg::Struct(args) => {
57                write!(f, "{{")?;
58                let mut inner = args.iter().peekable();
59                while let Some(value) = inner.next() {
60                    write!(f, " ")?;
61                    value.fmt(f, db)?;
62
63                    if inner.peek().is_some() {
64                        write!(f, ",")?;
65                    } else {
66                        write!(f, " ")?;
67                    }
68                }
69                write!(f, "}}")
70            }
71            SpecializationArg::Array(_ty, values) => {
72                write!(f, "array![")?;
73                let mut first = true;
74                for value in values {
75                    if !first {
76                        write!(f, ", ")?;
77                    } else {
78                        first = false;
79                    }
80                    write!(f, "{:?}", value.debug(db))?;
81                }
82                write!(f, "]")
83            }
84            SpecializationArg::Enum { variant, payload } => {
85                write!(f, "{:?}(", variant.debug(db))?;
86                payload.fmt(f, db)?;
87                write!(f, ")")
88            }
89            SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
90        }
91    }
92}
93
94/// The state of the specialization arg building process.
95/// currently only structs require an additional build step.
96enum SpecializationArgBuildingState<'db, 'a> {
97    Initial(&'a SpecializationArg<'db>),
98    TakeSnapshot(VariableId),
99    BuildStruct(Vec<VariableId>),
100    PushBackArray { in_array: VariableId, value: VariableId },
101    BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
102}
103
104/// Returns the lowering of a specialized function.
105pub fn specialized_function_lowered<'db>(
106    db: &'db dyn Database,
107    specialized: SpecializedFunction<'db>,
108) -> Maybe<Lowered<'db>> {
109    let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
110    let base_semantic = specialized.base.base_semantic_function(db);
111
112    let array_module = ModuleHelper::core(db).submodule("array");
113    let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
114    let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
115
116    let mut variables =
117        VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
118    let mut statements = vec![];
119    let mut parameters = vec![];
120    let mut inputs = vec![];
121    let mut stack = vec![];
122
123    let location = LocationId::from_stable_location(
124        db,
125        specialized.base.base_semantic_function(db).stable_location(db),
126    );
127
128    for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
129        let var_id = variables.variables.alloc(base.variables[*param].clone());
130        inputs.push(VarUsage { var_id, location });
131        if SpecializationArg::NotSpecialized == *arg {
132            parameters.push(var_id);
133            continue;
134        }
135        stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
136        while let Some((var_id, state)) = stack.pop() {
137            match state {
138                SpecializationArgBuildingState::Initial(c) => match c {
139                    SpecializationArg::Const { value, boxed } => {
140                        if db.type_size_info(variables[var_id].ty)?
141                            == TypeSizeInformation::ZeroSized
142                        {
143                            assert!(
144                                !boxed,
145                                "Zero sized specialization arguments should only be part of \
146                                 consts and therefore cannot be boxed"
147                            );
148                            statements.push(Statement::StructConstruct(StatementStructConstruct {
149                                inputs: vec![],
150                                output: var_id,
151                            }));
152                        } else {
153                            statements.push(Statement::Const(StatementConst::new(
154                                *value, var_id, *boxed,
155                            )));
156                        }
157                    }
158                    SpecializationArg::Snapshot(inner) => {
159                        let snap_ty = variables.variables[var_id].ty;
160                        let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
161                        let desnapped_var =
162                            variables.new_var(VarRequest { ty: denapped_ty, location });
163                        stack.push((
164                            var_id,
165                            SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
166                        ));
167                        stack.push((
168                            desnapped_var,
169                            SpecializationArgBuildingState::Initial(inner.as_ref()),
170                        ));
171                    }
172                    SpecializationArg::Array(ty, values) => {
173                        let mut arr_var = var_id;
174                        for value in values.iter().rev() {
175                            let in_arr_var =
176                                variables.variables.alloc(variables.variables[var_id].clone());
177                            let value_var = variables.new_var(VarRequest { ty: *ty, location });
178                            stack.push((
179                                arr_var,
180                                SpecializationArgBuildingState::PushBackArray {
181                                    in_array: in_arr_var,
182                                    value: value_var,
183                                },
184                            ));
185                            stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
186                            arr_var = in_arr_var;
187                        }
188                        statements.push(Statement::Call(StatementCall {
189                            function: array_new_fn
190                                .concretize(db, vec![GenericArgumentId::Type(*ty)])
191                                .lowered(db),
192                            inputs: vec![],
193                            with_coupon: false,
194                            outputs: vec![arr_var],
195                            location: variables[var_id].location,
196                            is_specialization_base_call: false,
197                        }));
198                    }
199                    SpecializationArg::Struct(args) => {
200                        let var_ty = variables[var_id].ty;
201                        let location = variables[var_id].location;
202
203                        // Get element types based on the actual type.
204                        let mut var_for_ty = |ty| variables.new_var(VarRequest { ty, location });
205                        let var_ids = match var_ty.long(db) {
206                            TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
207                                let members = db.concrete_struct_members(*concrete_struct)?;
208                                members.values().map(|member| var_for_ty(member.ty)).collect_vec()
209                            }
210                            TypeLongId::Tuple(element_types) => {
211                                element_types.iter().cloned().map(var_for_ty).collect_vec()
212                            }
213                            TypeLongId::FixedSizeArray { type_id, .. } => {
214                                itertools::repeat_n(*type_id, args.len())
215                                    .map(var_for_ty)
216                                    .collect_vec()
217                            }
218                            _ => unreachable!("Expected a struct, tuple, or fixed-size array type"),
219                        };
220
221                        stack.push((
222                            var_id,
223                            SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
224                        ));
225
226                        for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
227                            stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
228                        }
229                    }
230                    SpecializationArg::Enum { variant, payload } => {
231                        let location = variables[var_id].location;
232                        let payload_var =
233                            variables.new_var(VarRequest { ty: variant.ty, location });
234                        stack.push((
235                            var_id,
236                            SpecializationArgBuildingState::BuildEnum {
237                                variant: *variant,
238                                payload: payload_var,
239                            },
240                        ));
241                        stack.push((
242                            payload_var,
243                            SpecializationArgBuildingState::Initial(payload.as_ref()),
244                        ));
245                    }
246                    SpecializationArg::NotSpecialized => {
247                        parameters.push(var_id);
248                    }
249                },
250                SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
251                    let ignored = variables.variables.alloc(variables[desnapped_var].clone());
252                    statements.push(Statement::Snapshot(StatementSnapshot::new(
253                        VarUsage { var_id: desnapped_var, location },
254                        ignored,
255                        var_id,
256                    )));
257                }
258                SpecializationArgBuildingState::PushBackArray { in_array, value } => {
259                    statements.push(Statement::Call(StatementCall {
260                        function: array_append
261                            .concretize(
262                                db,
263                                vec![GenericArgumentId::Type(variables.variables[value].ty)],
264                            )
265                            .lowered(db),
266                        inputs: vec![
267                            VarUsage { var_id: in_array, location },
268                            VarUsage { var_id: value, location },
269                        ],
270                        with_coupon: false,
271                        outputs: vec![var_id],
272                        location,
273                        is_specialization_base_call: false,
274                    }));
275                }
276                SpecializationArgBuildingState::BuildStruct(ids) => {
277                    statements.push(Statement::StructConstruct(StatementStructConstruct {
278                        inputs: ids
279                            .iter()
280                            .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
281                            .collect(),
282                        output: var_id,
283                    }));
284                }
285                SpecializationArgBuildingState::BuildEnum { variant, payload } => {
286                    statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
287                        variant,
288                        input: VarUsage { var_id: payload, location: variables[payload].location },
289                        output: var_id,
290                    }));
291                }
292            }
293        }
294    }
295
296    let outputs: Vec<VariableId> =
297        chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
298            .map(|ty| variables.new_var(VarRequest { ty, location }))
299            .collect_vec();
300    let mut block_builder = BlocksBuilder::new();
301    let ret_usage =
302        outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
303    statements.push(Statement::Call(StatementCall {
304        function: specialized.base.function_id(db)?,
305        with_coupon: false,
306        inputs,
307        outputs,
308        location,
309        is_specialization_base_call: true,
310    }));
311    block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
312    Ok(Lowered {
313        signature: specialized.signature(db)?,
314        variables: variables.variables,
315        blocks: block_builder.build().unwrap(),
316        parameters,
317        diagnostics: Default::default(),
318    })
319}
320
321/// Query implementation of [LoweringGroup::priv_should_specialize].
322#[salsa::tracked]
323pub fn priv_should_specialize<'db>(
324    db: &'db dyn Database,
325    function_id: ids::ConcreteFunctionWithBodyId<'db>,
326) -> Maybe<bool> {
327    let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
328        function_id.long(db)
329    else {
330        panic!("Expected a specialized function");
331    };
332
333    // The heuristic is that the size is 8/10*orig_size > specialized_size of the original size.
334    Ok(db.estimate_size(*base)?.saturating_mul(8)
335        > db.estimate_size(function_id)?.saturating_mul(10))
336}