1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_proc_macros::HeapSize;
6use cairo_lang_semantic::helper::ModuleHelper;
7use cairo_lang_semantic::items::constant::ConstValueId;
8use cairo_lang_semantic::items::functions::GenericFunctionId;
9use cairo_lang_semantic::items::structure::StructSemantic;
10use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
11use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
12use cairo_lang_utils::extract_matches;
13use itertools::{Itertools, chain, zip_eq};
14use salsa::Database;
15
16use crate::blocks::BlocksBuilder;
17use crate::db::LoweringGroup;
18use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
19use crate::lower::context::{VarRequest, VariableAllocator};
20use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
21use crate::{
22 Block, BlockEnd, Lowered, LoweringStage, Statement, StatementCall, StatementConst,
23 StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
24};
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq, HeapSize)]
28pub enum SpecializationArg<'db> {
29 Const { value: ConstValueId<'db>, boxed: bool },
30 Snapshot(Box<SpecializationArg<'db>>),
31 Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
32 Struct(Vec<SpecializationArg<'db>>),
33 Enum { variant: ConcreteVariant<'db>, payload: Box<SpecializationArg<'db>> },
34 NotSpecialized,
35}
36
37impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
38 type Db = dyn Database;
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
40 match self {
41 SpecializationArg::Const { value, boxed } => {
42 write!(f, "{:?}", value.debug(db))?;
43 if *boxed {
44 write!(f, ".into_box()")?;
45 }
46 Ok(())
47 }
48 SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
49 SpecializationArg::Struct(args) => {
50 write!(f, "{{")?;
51 let mut inner = args.iter().peekable();
52 while let Some(value) = inner.next() {
53 write!(f, " ")?;
54 value.fmt(f, db)?;
55
56 if inner.peek().is_some() {
57 write!(f, ",")?;
58 } else {
59 write!(f, " ")?;
60 }
61 }
62 write!(f, "}}")
63 }
64 SpecializationArg::Array(_ty, values) => {
65 write!(f, "array![")?;
66 let mut first = true;
67 for value in values {
68 if !first {
69 write!(f, ", ")?;
70 } else {
71 first = false;
72 }
73 write!(f, "{:?}", value.debug(db))?;
74 }
75 write!(f, "]")
76 }
77 SpecializationArg::Enum { variant, payload } => {
78 write!(f, "{:?}(", variant.debug(db))?;
79 payload.fmt(f, db)?;
80 write!(f, ")")
81 }
82 SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
83 }
84 }
85}
86
87enum SpecializationArgBuildingState<'db, 'a> {
90 Initial(&'a SpecializationArg<'db>),
91 TakeSnapshot(VariableId),
92 BuildStruct(Vec<VariableId>),
93 PushBackArray { in_array: VariableId, value: VariableId },
94 BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
95}
96
97pub fn specialized_function_lowered<'db>(
99 db: &'db dyn Database,
100 specialized: SpecializedFunction<'db>,
101) -> Maybe<Lowered<'db>> {
102 let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
103 let base_semantic = specialized.base.base_semantic_function(db);
104
105 let array_module = ModuleHelper::core(db).submodule("array");
106 let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
107 let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
108
109 let mut variables =
110 VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
111 let mut statements = vec![];
112 let mut parameters = vec![];
113 let mut inputs = vec![];
114 let mut stack = vec![];
115
116 let location = LocationId::from_stable_location(
117 db,
118 specialized.base.base_semantic_function(db).stable_location(db),
119 );
120
121 for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
122 let var_id = variables.variables.alloc(base.variables[*param].clone());
123 inputs.push(VarUsage { var_id, location });
124 if SpecializationArg::NotSpecialized == *arg {
125 parameters.push(var_id);
126 continue;
127 }
128 stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
129 while let Some((var_id, state)) = stack.pop() {
130 match state {
131 SpecializationArgBuildingState::Initial(c) => match c {
132 SpecializationArg::Const { value, boxed } => {
133 if db.type_size_info(variables[var_id].ty)?
134 == TypeSizeInformation::ZeroSized
135 {
136 assert!(
137 !boxed,
138 "Zero sized specialization arguments should only be part of \
139 consts and therefore cannot be boxed"
140 );
141 statements.push(Statement::StructConstruct(StatementStructConstruct {
142 inputs: vec![],
143 output: var_id,
144 }));
145 } else {
146 statements.push(Statement::Const(StatementConst::new(
147 *value, var_id, *boxed,
148 )));
149 }
150 }
151 SpecializationArg::Snapshot(inner) => {
152 let snap_ty = variables.variables[var_id].ty;
153 let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
154 let desnapped_var =
155 variables.new_var(VarRequest { ty: denapped_ty, location });
156 stack.push((
157 var_id,
158 SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
159 ));
160 stack.push((
161 desnapped_var,
162 SpecializationArgBuildingState::Initial(inner.as_ref()),
163 ));
164 }
165 SpecializationArg::Array(ty, values) => {
166 let mut arr_var = var_id;
167 for value in values.iter().rev() {
168 let in_arr_var =
169 variables.variables.alloc(variables.variables[var_id].clone());
170 let value_var = variables.new_var(VarRequest { ty: *ty, location });
171 stack.push((
172 arr_var,
173 SpecializationArgBuildingState::PushBackArray {
174 in_array: in_arr_var,
175 value: value_var,
176 },
177 ));
178 stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
179 arr_var = in_arr_var;
180 }
181 statements.push(Statement::Call(StatementCall {
182 function: array_new_fn
183 .concretize(db, vec![GenericArgumentId::Type(*ty)])
184 .lowered(db),
185 inputs: vec![],
186 with_coupon: false,
187 outputs: vec![arr_var],
188 location: variables[var_id].location,
189 is_specialization_base_call: false,
190 }));
191 }
192 SpecializationArg::Struct(args) => {
193 let var = &variables[var_id];
194 let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
195 var.ty.long(db)
196 else {
197 unreachable!("Expected a concrete struct type");
198 };
199
200 let members = db.concrete_struct_members(*concrete_struct)?;
201
202 let location = var.location;
203 let var_ids = members
204 .values()
205 .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
206 .collect_vec();
207
208 stack.push((
209 var_id,
210 SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
211 ));
212
213 for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
214 stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
215 }
216 }
217 SpecializationArg::Enum { variant, payload } => {
218 let location = variables[var_id].location;
219 let payload_var =
220 variables.new_var(VarRequest { ty: variant.ty, location });
221 stack.push((
222 var_id,
223 SpecializationArgBuildingState::BuildEnum {
224 variant: *variant,
225 payload: payload_var,
226 },
227 ));
228 stack.push((
229 payload_var,
230 SpecializationArgBuildingState::Initial(payload.as_ref()),
231 ));
232 }
233 SpecializationArg::NotSpecialized => {
234 parameters.push(var_id);
235 }
236 },
237 SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
238 let ignored = variables.variables.alloc(variables[desnapped_var].clone());
239 statements.push(Statement::Snapshot(StatementSnapshot::new(
240 VarUsage { var_id: desnapped_var, location },
241 ignored,
242 var_id,
243 )));
244 }
245 SpecializationArgBuildingState::PushBackArray { in_array, value } => {
246 statements.push(Statement::Call(StatementCall {
247 function: array_append
248 .concretize(
249 db,
250 vec![GenericArgumentId::Type(variables.variables[value].ty)],
251 )
252 .lowered(db),
253 inputs: vec![
254 VarUsage { var_id: in_array, location },
255 VarUsage { var_id: value, location },
256 ],
257 with_coupon: false,
258 outputs: vec![var_id],
259 location,
260 is_specialization_base_call: false,
261 }));
262 }
263 SpecializationArgBuildingState::BuildStruct(ids) => {
264 statements.push(Statement::StructConstruct(StatementStructConstruct {
265 inputs: ids
266 .iter()
267 .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
268 .collect(),
269 output: var_id,
270 }));
271 }
272 SpecializationArgBuildingState::BuildEnum { variant, payload } => {
273 statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
274 variant,
275 input: VarUsage { var_id: payload, location: variables[payload].location },
276 output: var_id,
277 }));
278 }
279 }
280 }
281 }
282
283 let outputs: Vec<VariableId> =
284 chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
285 .map(|ty| variables.new_var(VarRequest { ty, location }))
286 .collect_vec();
287 let mut block_builder = BlocksBuilder::new();
288 let ret_usage =
289 outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
290 statements.push(Statement::Call(StatementCall {
291 function: specialized.base.function_id(db)?,
292 with_coupon: false,
293 inputs,
294 outputs,
295 location,
296 is_specialization_base_call: true,
297 }));
298 block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
299 Ok(Lowered {
300 signature: specialized.signature(db)?,
301 variables: variables.variables,
302 blocks: block_builder.build().unwrap(),
303 parameters,
304 diagnostics: Default::default(),
305 })
306}
307
308#[salsa::tracked]
310pub fn priv_should_specialize<'db>(
311 db: &'db dyn Database,
312 function_id: ids::ConcreteFunctionWithBodyId<'db>,
313) -> Maybe<bool> {
314 let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
315 function_id.long(db)
316 else {
317 panic!("Expected a specialized function");
318 };
319
320 Ok(db.estimate_size(*base)?.saturating_mul(8)
322 > db.estimate_size(function_id)?.saturating_mul(10))
323}