cairo_lang_lowering/
specialization.rs

1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_semantic::helper::ModuleHelper;
6use cairo_lang_semantic::items::constant::ConstValueId;
7use cairo_lang_semantic::items::functions::GenericFunctionId;
8use cairo_lang_semantic::items::structure::StructSemantic;
9use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
10use cairo_lang_utils::extract_matches;
11use itertools::{Itertools, chain, zip_eq};
12use salsa::Database;
13
14use crate::blocks::BlocksBuilder;
15use crate::db::LoweringGroup;
16use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
17use crate::lower::context::{VarRequest, VariableAllocator};
18use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
19use crate::{
20    Block, BlockEnd, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
21    StatementConst, StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
22};
23
24// A const argument for a specialized function.
25#[derive(Clone, Debug, Hash, PartialEq, Eq)]
26pub enum SpecializationArg<'db> {
27    Const { value: ConstValueId<'db>, boxed: bool },
28    Snapshot(Box<SpecializationArg<'db>>),
29    Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
30    Struct(Vec<SpecializationArg<'db>>),
31    Enum { variant: ConcreteVariant<'db>, payload: Box<SpecializationArg<'db>> },
32    NotSpecialized,
33}
34
35impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
36    type Db = dyn Database;
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
38        match self {
39            SpecializationArg::Const { value, boxed } => {
40                write!(f, "{:?}", value.debug(db))?;
41                if *boxed {
42                    write!(f, ".into_box()")?;
43                }
44                Ok(())
45            }
46            SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
47            SpecializationArg::Struct(args) => {
48                write!(f, "{{")?;
49                let mut inner = args.iter().peekable();
50                while let Some(value) = inner.next() {
51                    write!(f, " ")?;
52                    value.fmt(f, db)?;
53
54                    if inner.peek().is_some() {
55                        write!(f, ",")?;
56                    } else {
57                        write!(f, " ")?;
58                    }
59                }
60                write!(f, "}}")
61            }
62            SpecializationArg::Array(_ty, values) => {
63                write!(f, "array![")?;
64                let mut first = true;
65                for value in values {
66                    if !first {
67                        write!(f, ", ")?;
68                    } else {
69                        first = false;
70                    }
71                    write!(f, "{:?}", value.debug(db))?;
72                }
73                write!(f, "]")
74            }
75            SpecializationArg::Enum { variant, payload } => {
76                write!(f, "{:?}(", variant.debug(db))?;
77                payload.fmt(f, db)?;
78                write!(f, ")")
79            }
80            SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
81        }
82    }
83}
84
85/// The state of the specialization arg building process.
86/// currently only structs require an additional build step.
87enum SpecializationArgBuildingState<'db, 'a> {
88    Initial(&'a SpecializationArg<'db>),
89    TakeSnapshot(VariableId),
90    BuildStruct(Vec<VariableId>),
91    PushBackArray { in_array: VariableId, value: VariableId },
92    BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
93}
94
95/// Returns the lowering of a specialized function.
96pub fn specialized_function_lowered<'db>(
97    db: &'db dyn Database,
98    specialized: SpecializedFunction<'db>,
99) -> Maybe<Lowered<'db>> {
100    let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
101    let base_semantic = specialized.base.base_semantic_function(db);
102
103    let array_module = ModuleHelper::core(db).submodule("array");
104    let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
105    let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
106
107    let mut variables =
108        VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
109    let mut statements = vec![];
110    let mut parameters = vec![];
111    let mut inputs = vec![];
112    let mut stack = vec![];
113
114    let location = LocationId::from_stable_location(
115        db,
116        specialized.base.base_semantic_function(db).stable_location(db),
117    );
118
119    for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
120        let var_id = variables.variables.alloc(base.variables[*param].clone());
121        inputs.push(VarUsage { var_id, location });
122        if SpecializationArg::NotSpecialized == *arg {
123            parameters.push(var_id);
124            continue;
125        }
126        stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
127        while let Some((var_id, state)) = stack.pop() {
128            match state {
129                SpecializationArgBuildingState::Initial(c) => match c {
130                    SpecializationArg::Const { value, boxed } => {
131                        statements
132                            .push(Statement::Const(StatementConst::new(*value, var_id, *boxed)));
133                    }
134                    SpecializationArg::Snapshot(inner) => {
135                        let snap_ty = variables.variables[var_id].ty;
136                        let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
137                        let desnapped_var =
138                            variables.new_var(VarRequest { ty: denapped_ty, location });
139                        stack.push((
140                            var_id,
141                            SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
142                        ));
143                        stack.push((
144                            desnapped_var,
145                            SpecializationArgBuildingState::Initial(inner.as_ref()),
146                        ));
147                    }
148                    SpecializationArg::Array(ty, values) => {
149                        let mut arr_var = var_id;
150                        for value in values.iter().rev() {
151                            let in_arr_var =
152                                variables.variables.alloc(variables.variables[var_id].clone());
153                            let value_var = variables.new_var(VarRequest { ty: *ty, location });
154                            stack.push((
155                                arr_var,
156                                SpecializationArgBuildingState::PushBackArray {
157                                    in_array: in_arr_var,
158                                    value: value_var,
159                                },
160                            ));
161                            stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
162                            arr_var = in_arr_var;
163                        }
164                        statements.push(Statement::Call(StatementCall {
165                            function: array_new_fn
166                                .concretize(db, vec![GenericArgumentId::Type(*ty)])
167                                .lowered(db),
168                            inputs: vec![],
169                            with_coupon: false,
170                            outputs: vec![arr_var],
171                            location: variables[var_id].location,
172                        }));
173                    }
174                    SpecializationArg::Struct(args) => {
175                        let var = &variables[var_id];
176                        let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
177                            var.ty.long(db)
178                        else {
179                            unreachable!("Expected a concrete struct type");
180                        };
181
182                        let members = db.concrete_struct_members(*concrete_struct)?;
183
184                        let location = var.location;
185                        let var_ids = members
186                            .values()
187                            .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
188                            .collect_vec();
189
190                        stack.push((
191                            var_id,
192                            SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
193                        ));
194
195                        for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
196                            stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
197                        }
198                    }
199                    SpecializationArg::Enum { variant, payload } => {
200                        let location = variables[var_id].location;
201                        let payload_var =
202                            variables.new_var(VarRequest { ty: variant.ty, location });
203                        stack.push((
204                            var_id,
205                            SpecializationArgBuildingState::BuildEnum {
206                                variant: *variant,
207                                payload: payload_var,
208                            },
209                        ));
210                        stack.push((
211                            payload_var,
212                            SpecializationArgBuildingState::Initial(payload.as_ref()),
213                        ));
214                    }
215                    SpecializationArg::NotSpecialized => {
216                        parameters.push(var_id);
217                    }
218                },
219                SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
220                    let ignored = variables.variables.alloc(variables[desnapped_var].clone());
221                    statements.push(Statement::Snapshot(StatementSnapshot::new(
222                        VarUsage { var_id: desnapped_var, location },
223                        ignored,
224                        var_id,
225                    )));
226                }
227                SpecializationArgBuildingState::PushBackArray { in_array, value } => {
228                    statements.push(Statement::Call(StatementCall {
229                        function: array_append
230                            .concretize(
231                                db,
232                                vec![GenericArgumentId::Type(variables.variables[value].ty)],
233                            )
234                            .lowered(db),
235                        inputs: vec![
236                            VarUsage { var_id: in_array, location },
237                            VarUsage { var_id: value, location },
238                        ],
239                        with_coupon: false,
240                        outputs: vec![var_id],
241                        location,
242                    }));
243                }
244                SpecializationArgBuildingState::BuildStruct(ids) => {
245                    statements.push(Statement::StructConstruct(StatementStructConstruct {
246                        inputs: ids
247                            .iter()
248                            .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
249                            .collect(),
250                        output: var_id,
251                    }));
252                }
253                SpecializationArgBuildingState::BuildEnum { variant, payload } => {
254                    statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
255                        variant,
256                        input: VarUsage { var_id: payload, location: variables[payload].location },
257                        output: var_id,
258                    }));
259                }
260            }
261        }
262    }
263
264    let outputs: Vec<VariableId> =
265        chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
266            .map(|ty| variables.new_var(VarRequest { ty, location }))
267            .collect_vec();
268    let mut block_builder = BlocksBuilder::new();
269    let ret_usage =
270        outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
271    statements.push(Statement::Call(StatementCall {
272        function: specialized.base.function_id(db)?,
273        with_coupon: false,
274        inputs,
275        outputs,
276        location,
277    }));
278    block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
279    Ok(Lowered {
280        signature: specialized.signature(db)?,
281        variables: variables.variables,
282        blocks: block_builder.build().unwrap(),
283        parameters,
284        diagnostics: Default::default(),
285    })
286}
287
288/// Query implementation of [LoweringGroup::priv_should_specialize].
289#[salsa::tracked]
290pub fn priv_should_specialize<'db>(
291    db: &'db dyn Database,
292    function_id: ids::ConcreteFunctionWithBodyId<'db>,
293) -> Maybe<bool> {
294    let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
295        function_id.long(db)
296    else {
297        panic!("Expected a specialized function");
298    };
299
300    // Breaks cycles.
301    // We cannot estimate the size of functions in a cycle, since the implicits computation requires
302    // the finalized lowering of all the functions in the cycle which requires us to know the
303    // answer of the current function.
304    if db.concrete_in_cycle(*base, DependencyType::Call, LoweringStage::Monomorphized)? {
305        return Ok(false);
306    }
307
308    // The heuristic is that the size is 8/10*orig_size > specialized_size of the original size.
309    Ok(db.estimate_size(*base)?.saturating_mul(8)
310        > db.estimate_size(function_id)?.saturating_mul(10))
311}