1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_proc_macros::HeapSize;
6use cairo_lang_semantic::helper::ModuleHelper;
7use cairo_lang_semantic::items::constant::ConstValueId;
8use cairo_lang_semantic::items::functions::GenericFunctionId;
9use cairo_lang_semantic::items::structure::StructSemantic;
10use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
11use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
12use cairo_lang_utils::extract_matches;
13use itertools::{Itertools, chain, zip_eq};
14use salsa::Database;
15
16use crate::blocks::BlocksBuilder;
17use crate::db::LoweringGroup;
18use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
19use crate::lower::context::{VarRequest, VariableAllocator};
20use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
21use crate::{
22 Block, BlockEnd, Lowered, LoweringStage, Statement, StatementCall, StatementConst,
23 StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
24};
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq, HeapSize)]
28pub enum SpecializationArg<'db> {
29 Const {
30 value: ConstValueId<'db>,
31 boxed: bool,
32 },
33 Snapshot(Box<SpecializationArg<'db>>),
34 Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
35 Struct(Vec<SpecializationArg<'db>>),
37 Enum {
38 variant: ConcreteVariant<'db>,
39 payload: Box<SpecializationArg<'db>>,
40 },
41 NotSpecialized,
42}
43
44impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
45 type Db = dyn Database;
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
47 match self {
48 SpecializationArg::Const { value, boxed } => {
49 write!(f, "{:?}", value.debug(db))?;
50 if *boxed {
51 write!(f, ".into_box()")?;
52 }
53 Ok(())
54 }
55 SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
56 SpecializationArg::Struct(args) => {
57 write!(f, "{{")?;
58 let mut inner = args.iter().peekable();
59 while let Some(value) = inner.next() {
60 write!(f, " ")?;
61 value.fmt(f, db)?;
62
63 if inner.peek().is_some() {
64 write!(f, ",")?;
65 } else {
66 write!(f, " ")?;
67 }
68 }
69 write!(f, "}}")
70 }
71 SpecializationArg::Array(_ty, values) => {
72 write!(f, "array![")?;
73 let mut first = true;
74 for value in values {
75 if !first {
76 write!(f, ", ")?;
77 } else {
78 first = false;
79 }
80 write!(f, "{:?}", value.debug(db))?;
81 }
82 write!(f, "]")
83 }
84 SpecializationArg::Enum { variant, payload } => {
85 write!(f, "{:?}(", variant.debug(db))?;
86 payload.fmt(f, db)?;
87 write!(f, ")")
88 }
89 SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
90 }
91 }
92}
93
94enum SpecializationArgBuildingState<'db, 'a> {
97 Initial(&'a SpecializationArg<'db>),
98 TakeSnapshot(VariableId),
99 BuildStruct(Vec<VariableId>),
100 PushBackArray { in_array: VariableId, value: VariableId },
101 BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
102}
103
104pub fn specialized_function_lowered<'db>(
106 db: &'db dyn Database,
107 specialized: SpecializedFunction<'db>,
108) -> Maybe<Lowered<'db>> {
109 let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
110 let base_semantic = specialized.base.base_semantic_function(db);
111
112 let array_module = ModuleHelper::core(db).submodule("array");
113 let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
114 let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
115
116 let mut variables =
117 VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
118 let mut statements = vec![];
119 let mut parameters = vec![];
120 let mut inputs = vec![];
121 let mut stack = vec![];
122
123 let location = LocationId::from_stable_location(
124 db,
125 specialized.base.base_semantic_function(db).stable_location(db),
126 );
127
128 for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
129 let var_id = variables.variables.alloc(base.variables[*param].clone());
130 inputs.push(VarUsage { var_id, location });
131 if SpecializationArg::NotSpecialized == *arg {
132 parameters.push(var_id);
133 continue;
134 }
135 stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
136 while let Some((var_id, state)) = stack.pop() {
137 match state {
138 SpecializationArgBuildingState::Initial(c) => match c {
139 SpecializationArg::Const { value, boxed } => {
140 if db.type_size_info(variables[var_id].ty)?
141 == TypeSizeInformation::ZeroSized
142 {
143 assert!(
144 !boxed,
145 "Zero sized specialization arguments should only be part of \
146 consts and therefore cannot be boxed"
147 );
148 statements.push(Statement::StructConstruct(StatementStructConstruct {
149 inputs: vec![],
150 output: var_id,
151 }));
152 } else {
153 statements.push(Statement::Const(StatementConst::new(
154 *value, var_id, *boxed,
155 )));
156 }
157 }
158 SpecializationArg::Snapshot(inner) => {
159 let snap_ty = variables.variables[var_id].ty;
160 let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
161 let desnapped_var =
162 variables.new_var(VarRequest { ty: denapped_ty, location });
163 stack.push((
164 var_id,
165 SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
166 ));
167 stack.push((
168 desnapped_var,
169 SpecializationArgBuildingState::Initial(inner.as_ref()),
170 ));
171 }
172 SpecializationArg::Array(ty, values) => {
173 let mut arr_var = var_id;
174 for value in values.iter().rev() {
175 let in_arr_var =
176 variables.variables.alloc(variables.variables[var_id].clone());
177 let value_var = variables.new_var(VarRequest { ty: *ty, location });
178 stack.push((
179 arr_var,
180 SpecializationArgBuildingState::PushBackArray {
181 in_array: in_arr_var,
182 value: value_var,
183 },
184 ));
185 stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
186 arr_var = in_arr_var;
187 }
188 statements.push(Statement::Call(StatementCall {
189 function: array_new_fn
190 .concretize(db, vec![GenericArgumentId::Type(*ty)])
191 .lowered(db),
192 inputs: vec![],
193 with_coupon: false,
194 outputs: vec![arr_var],
195 location: variables[var_id].location,
196 is_specialization_base_call: false,
197 }));
198 }
199 SpecializationArg::Struct(args) => {
200 let var_ty = variables[var_id].ty;
201 let location = variables[var_id].location;
202
203 let mut var_for_ty = |ty| variables.new_var(VarRequest { ty, location });
205 let var_ids = match var_ty.long(db) {
206 TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
207 let members = db.concrete_struct_members(*concrete_struct)?;
208 members.values().map(|member| var_for_ty(member.ty)).collect_vec()
209 }
210 TypeLongId::Tuple(element_types) => {
211 element_types.iter().cloned().map(var_for_ty).collect_vec()
212 }
213 TypeLongId::FixedSizeArray { type_id, .. } => {
214 itertools::repeat_n(*type_id, args.len())
215 .map(var_for_ty)
216 .collect_vec()
217 }
218 _ => unreachable!("Expected a struct, tuple, or fixed-size array type"),
219 };
220
221 stack.push((
222 var_id,
223 SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
224 ));
225
226 for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
227 stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
228 }
229 }
230 SpecializationArg::Enum { variant, payload } => {
231 let location = variables[var_id].location;
232 let payload_var =
233 variables.new_var(VarRequest { ty: variant.ty, location });
234 stack.push((
235 var_id,
236 SpecializationArgBuildingState::BuildEnum {
237 variant: *variant,
238 payload: payload_var,
239 },
240 ));
241 stack.push((
242 payload_var,
243 SpecializationArgBuildingState::Initial(payload.as_ref()),
244 ));
245 }
246 SpecializationArg::NotSpecialized => {
247 parameters.push(var_id);
248 }
249 },
250 SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
251 let ignored = variables.variables.alloc(variables[desnapped_var].clone());
252 statements.push(Statement::Snapshot(StatementSnapshot::new(
253 VarUsage { var_id: desnapped_var, location },
254 ignored,
255 var_id,
256 )));
257 }
258 SpecializationArgBuildingState::PushBackArray { in_array, value } => {
259 statements.push(Statement::Call(StatementCall {
260 function: array_append
261 .concretize(
262 db,
263 vec![GenericArgumentId::Type(variables.variables[value].ty)],
264 )
265 .lowered(db),
266 inputs: vec![
267 VarUsage { var_id: in_array, location },
268 VarUsage { var_id: value, location },
269 ],
270 with_coupon: false,
271 outputs: vec![var_id],
272 location,
273 is_specialization_base_call: false,
274 }));
275 }
276 SpecializationArgBuildingState::BuildStruct(ids) => {
277 statements.push(Statement::StructConstruct(StatementStructConstruct {
278 inputs: ids
279 .iter()
280 .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
281 .collect(),
282 output: var_id,
283 }));
284 }
285 SpecializationArgBuildingState::BuildEnum { variant, payload } => {
286 statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
287 variant,
288 input: VarUsage { var_id: payload, location: variables[payload].location },
289 output: var_id,
290 }));
291 }
292 }
293 }
294 }
295
296 let outputs: Vec<VariableId> =
297 chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
298 .map(|ty| variables.new_var(VarRequest { ty, location }))
299 .collect_vec();
300 let mut block_builder = BlocksBuilder::new();
301 let ret_usage =
302 outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
303 statements.push(Statement::Call(StatementCall {
304 function: specialized.base.function_id(db)?,
305 with_coupon: false,
306 inputs,
307 outputs,
308 location,
309 is_specialization_base_call: true,
310 }));
311 block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
312 Ok(Lowered {
313 signature: specialized.signature(db)?,
314 variables: variables.variables,
315 blocks: block_builder.build().unwrap(),
316 parameters,
317 diagnostics: Default::default(),
318 })
319}
320
321#[salsa::tracked]
323pub fn priv_should_specialize<'db>(
324 db: &'db dyn Database,
325 function_id: ids::ConcreteFunctionWithBodyId<'db>,
326) -> Maybe<bool> {
327 let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
328 function_id.long(db)
329 else {
330 panic!("Expected a specialized function");
331 };
332
333 Ok(db.estimate_size(*base)?.saturating_mul(8)
335 > db.estimate_size(function_id)?.saturating_mul(10))
336}