Skip to main content

cairo_lang_lowering/
specialization.rs

1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_semantic::helper::ModuleHelper;
6use cairo_lang_semantic::items::constant::ConstValue;
7use cairo_lang_semantic::items::functions::GenericFunctionId;
8use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeId, TypeLongId};
9use cairo_lang_utils::LookupIntern;
10use itertools::{Itertools, chain, zip_eq};
11
12use crate::blocks::BlocksBuilder;
13use crate::db::LoweringGroup;
14use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
15use crate::lower::context::{VarRequest, VariableAllocator};
16use crate::{
17    Block, BlockEnd, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
18    StatementConst, StatementStructConstruct, VarUsage, VariableId,
19};
20
21// A const argument for a specialized function.
22#[derive(Clone, Debug, Hash, PartialEq, Eq)]
23pub enum SpecializationArg {
24    Const(ConstValue),
25    EmptyArray(TypeId),
26    Struct(Vec<SpecializationArg>),
27}
28
29impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for SpecializationArg {
30    fn fmt(
31        &self,
32        f: &mut std::fmt::Formatter<'_>,
33        db: &(dyn LoweringGroup + 'a),
34    ) -> std::fmt::Result {
35        match self {
36            SpecializationArg::Const(value) => write!(f, "{:?}", value.debug(db)),
37            SpecializationArg::Struct(inner) => {
38                write!(f, "{{")?;
39                let mut inner = inner.iter().peekable();
40                while let Some(value) = inner.next() {
41                    write!(f, " ")?;
42                    value.fmt(f, db)?;
43
44                    if inner.peek().is_some() {
45                        write!(f, ",")?;
46                    } else {
47                        write!(f, " ")?;
48                    }
49                }
50                write!(f, "}}")
51            }
52            SpecializationArg::EmptyArray(_) => write!(f, "array![]"),
53        }
54    }
55}
56
57/// The state of the specialization arg building process.
58/// currently only structs require an additional build step.
59enum SpecializationArgBuildingState<'a> {
60    Initial(&'a SpecializationArg),
61    BuildStruct(Vec<VariableId>),
62}
63
64/// Returns the lowering of a specialized function.
65pub fn specialized_function_lowered(
66    db: &dyn LoweringGroup,
67    specialized: SpecializedFunction,
68) -> Maybe<Lowered> {
69    let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
70    let base_semantic = specialized.base.base_semantic_function(db);
71
72    let array_new_fn = GenericFunctionId::Extern(
73        ModuleHelper::core(db).submodule("array").extern_function_id("array_new"),
74    );
75
76    let mut variables =
77        VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
78    let mut statements = vec![];
79    let mut parameters = vec![];
80    let mut inputs = vec![];
81    let mut stack = vec![];
82
83    let location = LocationId::from_stable_location(
84        db,
85        specialized.base.base_semantic_function(db).stable_location(db),
86    );
87
88    for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
89        let var_id = variables.variables.alloc(base.variables[*param].clone());
90        inputs.push(VarUsage { var_id, location });
91        if let Some(c) = arg {
92            stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
93            continue;
94        }
95        parameters.push(var_id);
96    }
97
98    while let Some((var_id, state)) = stack.pop() {
99        match state {
100            SpecializationArgBuildingState::Initial(c) => match c {
101                SpecializationArg::Const(value) => {
102                    statements.push(Statement::Const(StatementConst {
103                        value: value.clone(),
104                        output: var_id,
105                    }));
106                }
107                SpecializationArg::EmptyArray(ty) => {
108                    statements.push(Statement::Call(StatementCall {
109                        function: array_new_fn
110                            .concretize(db, vec![GenericArgumentId::Type(*ty)])
111                            .lowered(db),
112                        inputs: vec![],
113                        with_coupon: false,
114                        outputs: vec![var_id],
115                        location: variables[var_id].location,
116                    }));
117                }
118                SpecializationArg::Struct(consts) => {
119                    let var = &variables[var_id];
120                    let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
121                        var.ty.lookup_intern(db)
122                    else {
123                        unreachable!("Expected a concrete struct type");
124                    };
125
126                    let members = db.concrete_struct_members(concrete_struct)?;
127
128                    let location = var.location;
129                    let var_ids = members
130                        .values()
131                        .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
132                        .collect_vec();
133
134                    stack.push((
135                        var_id,
136                        SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
137                    ));
138
139                    for (var_id, c) in zip_eq(var_ids, consts) {
140                        stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
141                    }
142                }
143            },
144            SpecializationArgBuildingState::BuildStruct(ids) => {
145                statements.push(Statement::StructConstruct(StatementStructConstruct {
146                    inputs: ids
147                        .iter()
148                        .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
149                        .collect(),
150                    output: var_id,
151                }));
152            }
153        }
154    }
155
156    let outputs: Vec<VariableId> =
157        chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
158            .map(|ty| variables.new_var(VarRequest { ty, location }))
159            .collect_vec();
160    let mut block_builder = BlocksBuilder::new();
161    let ret_usage =
162        outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
163    statements.push(Statement::Call(StatementCall {
164        function: specialized.base.function_id(db)?,
165        with_coupon: false,
166        inputs,
167        outputs,
168        location,
169    }));
170    block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
171    Ok(Lowered {
172        signature: specialized.signature(db)?,
173        variables: variables.variables,
174        blocks: block_builder.build().unwrap(),
175        parameters,
176        diagnostics: Default::default(),
177    })
178}
179
180/// Query implementation of [LoweringGroup::priv_should_specialize].
181pub fn priv_should_specialize(
182    db: &dyn LoweringGroup,
183    function_id: ids::ConcreteFunctionWithBodyId,
184) -> Maybe<bool> {
185    let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
186        function_id.lookup_intern(db)
187    else {
188        panic!("Expected a specialized function");
189    };
190
191    // Breaks cycles.
192    // We cannot estimate the size of functions in a cycle, since the implicits computation requires
193    // the finalized lowering of all the functions in the cycle which requires us to know the
194    // answer of the current function.
195    if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
196        return Ok(false);
197    }
198
199    // The heuristic is that the size is 8/10*orig_size > specialized_size of the original size.
200    Ok(db.estimate_size(base)?.saturating_mul(8)
201        > db.estimate_size(function_id)?.saturating_mul(10))
202}