Skip to main content

cairo_lang_lowering/
specialization.rs

1use cairo_lang_debug::DebugWithDb;
2use cairo_lang_diagnostics::Maybe;
3use cairo_lang_proc_macros::HeapSize;
4use cairo_lang_semantic::helper::ModuleHelper;
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::items::functions::GenericFunctionId;
7use cairo_lang_semantic::items::structure::StructSemantic;
8use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
9use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
10use cairo_lang_utils::extract_matches;
11use itertools::{Itertools, chain, zip_eq};
12use salsa::Database;
13
14use crate::blocks::BlocksBuilder;
15use crate::db::LoweringGroup;
16use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunctionId};
17use crate::lower::context::{VarRequest, VariableAllocator};
18use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
19use crate::{
20    Block, BlockEnd, Lowered, LoweringStage, Statement, StatementCall, StatementConst,
21    StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
22};
23
24// A const argument for a specialized function.
25#[derive(Clone, Debug, Hash, PartialEq, Eq, salsa::Update, HeapSize)]
26pub enum SpecializationArg<'db> {
27    Const {
28        value: ConstValueId<'db>,
29        boxed: bool,
30    },
31    Snapshot(Box<SpecializationArg<'db>>),
32    Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
33    /// Represents struct, tuple, or fixed-size array.
34    Struct(Vec<SpecializationArg<'db>>),
35    Enum {
36        variant: ConcreteVariant<'db>,
37        payload: Box<SpecializationArg<'db>>,
38    },
39    NotSpecialized,
40}
41
42impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
43    type Db = dyn Database;
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
45        match self {
46            SpecializationArg::Const { value, boxed } => {
47                write!(f, "{:?}", value.debug(db))?;
48                if *boxed {
49                    write!(f, ".into_box()")?;
50                }
51                Ok(())
52            }
53            SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
54            SpecializationArg::Struct(args) => {
55                write!(f, "{{")?;
56                let mut inner = args.iter().peekable();
57                while let Some(value) = inner.next() {
58                    write!(f, " ")?;
59                    value.fmt(f, db)?;
60
61                    if inner.peek().is_some() {
62                        write!(f, ",")?;
63                    } else {
64                        write!(f, " ")?;
65                    }
66                }
67                write!(f, "}}")
68            }
69            SpecializationArg::Array(_ty, values) => {
70                write!(f, "array![")?;
71                let mut first = true;
72                for value in values {
73                    if !first {
74                        write!(f, ", ")?;
75                    } else {
76                        first = false;
77                    }
78                    write!(f, "{:?}", value.debug(db))?;
79                }
80                write!(f, "]")
81            }
82            SpecializationArg::Enum { variant, payload } => {
83                write!(f, "{:?}(", variant.debug(db))?;
84                payload.fmt(f, db)?;
85                write!(f, ")")
86            }
87            SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
88        }
89    }
90}
91
92/// The state of the specialization arg building process.
93/// currently only structs require an additional build step.
94enum SpecializationArgBuildingState<'db, 'a> {
95    Initial(&'a SpecializationArg<'db>),
96    TakeSnapshot(VariableId),
97    BuildStruct(Vec<VariableId>),
98    PushBackArray { in_array: VariableId, value: VariableId },
99    BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
100}
101
102/// Returns the lowering of a specialized function.
103pub fn specialized_function_lowered<'db>(
104    db: &'db dyn Database,
105    specialized: SpecializedFunctionId<'db>,
106) -> Maybe<Lowered<'db>> {
107    let specialized = specialized.long(db);
108    let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
109    let base_semantic = specialized.base.base_semantic_function(db);
110
111    let array_module = ModuleHelper::core(db).submodule("array");
112    let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
113    let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
114
115    let mut variables =
116        VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
117    let mut statements = vec![];
118    let mut parameters = vec![];
119    let mut inputs = vec![];
120    let mut stack = vec![];
121
122    let location = LocationId::from_stable_location(
123        db,
124        specialized.base.base_semantic_function(db).stable_location(db),
125    );
126
127    for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
128        let var_id = variables.variables.alloc(base.variables[*param].clone());
129        inputs.push(VarUsage { var_id, location });
130        if SpecializationArg::NotSpecialized == *arg {
131            parameters.push(var_id);
132            continue;
133        }
134        stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
135        while let Some((var_id, state)) = stack.pop() {
136            match state {
137                SpecializationArgBuildingState::Initial(c) => match c {
138                    SpecializationArg::Const { value, boxed } => {
139                        if db.type_size_info(variables[var_id].ty)?
140                            == TypeSizeInformation::ZeroSized
141                        {
142                            assert!(
143                                !boxed,
144                                "Zero sized specialization arguments should only be part of \
145                                 consts and therefore cannot be boxed"
146                            );
147                            statements.push(Statement::StructConstruct(StatementStructConstruct {
148                                inputs: vec![],
149                                output: var_id,
150                            }));
151                        } else {
152                            statements.push(Statement::Const(StatementConst::new(
153                                *value, var_id, *boxed,
154                            )));
155                        }
156                    }
157                    SpecializationArg::Snapshot(inner) => {
158                        let snap_ty = variables.variables[var_id].ty;
159                        let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
160                        let desnapped_var =
161                            variables.new_var(VarRequest { ty: denapped_ty, location });
162                        stack.push((
163                            var_id,
164                            SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
165                        ));
166                        stack.push((
167                            desnapped_var,
168                            SpecializationArgBuildingState::Initial(inner.as_ref()),
169                        ));
170                    }
171                    SpecializationArg::Array(ty, values) => {
172                        let mut arr_var = var_id;
173                        for value in values.iter().rev() {
174                            let in_arr_var =
175                                variables.variables.alloc(variables.variables[var_id].clone());
176                            let value_var = variables.new_var(VarRequest { ty: *ty, location });
177                            stack.push((
178                                arr_var,
179                                SpecializationArgBuildingState::PushBackArray {
180                                    in_array: in_arr_var,
181                                    value: value_var,
182                                },
183                            ));
184                            stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
185                            arr_var = in_arr_var;
186                        }
187                        statements.push(Statement::Call(StatementCall {
188                            function: array_new_fn
189                                .concretize(db, vec![GenericArgumentId::Type(*ty)])
190                                .lowered(db),
191                            inputs: vec![],
192                            with_coupon: false,
193                            outputs: vec![arr_var],
194                            location: variables[var_id].location,
195                            is_specialization_base_call: false,
196                        }));
197                    }
198                    SpecializationArg::Struct(args) => {
199                        let var_ty = variables[var_id].ty;
200                        let location = variables[var_id].location;
201
202                        // Get element types based on the actual type.
203                        let mut var_for_ty = |ty| variables.new_var(VarRequest { ty, location });
204                        let var_ids = match var_ty.long(db) {
205                            TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
206                                let members = db.concrete_struct_members(*concrete_struct)?;
207                                members.values().map(|member| var_for_ty(member.ty)).collect_vec()
208                            }
209                            TypeLongId::Tuple(element_types) => {
210                                element_types.iter().cloned().map(var_for_ty).collect_vec()
211                            }
212                            TypeLongId::FixedSizeArray { type_id, .. } => {
213                                itertools::repeat_n(*type_id, args.len())
214                                    .map(var_for_ty)
215                                    .collect_vec()
216                            }
217                            _ => unreachable!("Expected a struct, tuple, or fixed-size array type"),
218                        };
219
220                        stack.push((
221                            var_id,
222                            SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
223                        ));
224
225                        for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
226                            stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
227                        }
228                    }
229                    SpecializationArg::Enum { variant, payload } => {
230                        let location = variables[var_id].location;
231                        let payload_var =
232                            variables.new_var(VarRequest { ty: variant.ty, location });
233                        stack.push((
234                            var_id,
235                            SpecializationArgBuildingState::BuildEnum {
236                                variant: *variant,
237                                payload: payload_var,
238                            },
239                        ));
240                        stack.push((
241                            payload_var,
242                            SpecializationArgBuildingState::Initial(payload.as_ref()),
243                        ));
244                    }
245                    SpecializationArg::NotSpecialized => {
246                        parameters.push(var_id);
247                    }
248                },
249                SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
250                    let ignored = variables.variables.alloc(variables[desnapped_var].clone());
251                    statements.push(Statement::Snapshot(StatementSnapshot::new(
252                        VarUsage { var_id: desnapped_var, location },
253                        ignored,
254                        var_id,
255                    )));
256                }
257                SpecializationArgBuildingState::PushBackArray { in_array, value } => {
258                    statements.push(Statement::Call(StatementCall {
259                        function: array_append
260                            .concretize(
261                                db,
262                                vec![GenericArgumentId::Type(variables.variables[value].ty)],
263                            )
264                            .lowered(db),
265                        inputs: vec![
266                            VarUsage { var_id: in_array, location },
267                            VarUsage { var_id: value, location },
268                        ],
269                        with_coupon: false,
270                        outputs: vec![var_id],
271                        location,
272                        is_specialization_base_call: false,
273                    }));
274                }
275                SpecializationArgBuildingState::BuildStruct(ids) => {
276                    statements.push(Statement::StructConstruct(StatementStructConstruct {
277                        inputs: ids
278                            .iter()
279                            .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
280                            .collect(),
281                        output: var_id,
282                    }));
283                }
284                SpecializationArgBuildingState::BuildEnum { variant, payload } => {
285                    statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
286                        variant,
287                        input: VarUsage { var_id: payload, location: variables[payload].location },
288                        output: var_id,
289                    }));
290                }
291            }
292        }
293    }
294
295    let outputs: Vec<VariableId> =
296        chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
297            .map(|ty| variables.new_var(VarRequest { ty, location }))
298            .collect_vec();
299    let mut block_builder = BlocksBuilder::new();
300    let ret_usage =
301        outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
302    statements.push(Statement::Call(StatementCall {
303        function: specialized.base.function_id(db)?,
304        with_coupon: false,
305        inputs,
306        outputs,
307        location,
308        is_specialization_base_call: true,
309    }));
310    block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
311    Ok(Lowered {
312        signature: specialized.signature(db)?,
313        variables: variables.variables,
314        blocks: block_builder.build().unwrap(),
315        parameters,
316        diagnostics: Default::default(),
317    })
318}
319
320/// Query implementation of [LoweringGroup::priv_should_specialize].
321#[salsa::tracked]
322pub fn priv_should_specialize<'db>(
323    db: &'db dyn Database,
324    function_id: ids::ConcreteFunctionWithBodyId<'db>,
325) -> Maybe<bool> {
326    let ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) = function_id.long(db) else {
327        panic!("Expected a specialized function");
328    };
329
330    // The heuristic is that the size is 8/10*orig_size > specialized_size of the original size.
331    Ok(db.estimate_size(specialized.long(db).base)?.saturating_mul(8)
332        > db.estimate_size(function_id)?.saturating_mul(10))
333}