1use cairo_lang_debug::DebugWithDb;
2use cairo_lang_diagnostics::Maybe;
3use cairo_lang_proc_macros::HeapSize;
4use cairo_lang_semantic::helper::ModuleHelper;
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::items::functions::GenericFunctionId;
7use cairo_lang_semantic::items::structure::StructSemantic;
8use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
9use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
10use cairo_lang_utils::extract_matches;
11use itertools::{Itertools, chain, zip_eq};
12use salsa::Database;
13
14use crate::blocks::BlocksBuilder;
15use crate::db::LoweringGroup;
16use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunctionId};
17use crate::lower::context::{VarRequest, VariableAllocator};
18use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
19use crate::{
20 Block, BlockEnd, Lowered, LoweringStage, Statement, StatementCall, StatementConst,
21 StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
22};
23
24#[derive(Clone, Debug, Hash, PartialEq, Eq, salsa::Update, HeapSize)]
26pub enum SpecializationArg<'db> {
27 Const {
28 value: ConstValueId<'db>,
29 boxed: bool,
30 },
31 Snapshot(Box<SpecializationArg<'db>>),
32 Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
33 Struct(Vec<SpecializationArg<'db>>),
35 Enum {
36 variant: ConcreteVariant<'db>,
37 payload: Box<SpecializationArg<'db>>,
38 },
39 NotSpecialized,
40}
41
42impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
43 type Db = dyn Database;
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
45 match self {
46 SpecializationArg::Const { value, boxed } => {
47 write!(f, "{:?}", value.debug(db))?;
48 if *boxed {
49 write!(f, ".into_box()")?;
50 }
51 Ok(())
52 }
53 SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
54 SpecializationArg::Struct(args) => {
55 write!(f, "{{")?;
56 let mut inner = args.iter().peekable();
57 while let Some(value) = inner.next() {
58 write!(f, " ")?;
59 value.fmt(f, db)?;
60
61 if inner.peek().is_some() {
62 write!(f, ",")?;
63 } else {
64 write!(f, " ")?;
65 }
66 }
67 write!(f, "}}")
68 }
69 SpecializationArg::Array(_ty, values) => {
70 write!(f, "array![")?;
71 let mut first = true;
72 for value in values {
73 if !first {
74 write!(f, ", ")?;
75 } else {
76 first = false;
77 }
78 write!(f, "{:?}", value.debug(db))?;
79 }
80 write!(f, "]")
81 }
82 SpecializationArg::Enum { variant, payload } => {
83 write!(f, "{:?}(", variant.debug(db))?;
84 payload.fmt(f, db)?;
85 write!(f, ")")
86 }
87 SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
88 }
89 }
90}
91
92enum SpecializationArgBuildingState<'db, 'a> {
95 Initial(&'a SpecializationArg<'db>),
96 TakeSnapshot(VariableId),
97 BuildStruct(Vec<VariableId>),
98 PushBackArray { in_array: VariableId, value: VariableId },
99 BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
100}
101
102pub fn specialized_function_lowered<'db>(
104 db: &'db dyn Database,
105 specialized: SpecializedFunctionId<'db>,
106) -> Maybe<Lowered<'db>> {
107 let specialized = specialized.long(db);
108 let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
109 let base_semantic = specialized.base.base_semantic_function(db);
110
111 let array_module = ModuleHelper::core(db).submodule("array");
112 let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
113 let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
114
115 let mut variables =
116 VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
117 let mut statements = vec![];
118 let mut parameters = vec![];
119 let mut inputs = vec![];
120 let mut stack = vec![];
121
122 let location = LocationId::from_stable_location(
123 db,
124 specialized.base.base_semantic_function(db).stable_location(db),
125 );
126
127 for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
128 let var_id = variables.variables.alloc(base.variables[*param].clone());
129 inputs.push(VarUsage { var_id, location });
130 if SpecializationArg::NotSpecialized == *arg {
131 parameters.push(var_id);
132 continue;
133 }
134 stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
135 while let Some((var_id, state)) = stack.pop() {
136 match state {
137 SpecializationArgBuildingState::Initial(c) => match c {
138 SpecializationArg::Const { value, boxed } => {
139 if db.type_size_info(variables[var_id].ty)?
140 == TypeSizeInformation::ZeroSized
141 {
142 assert!(
143 !boxed,
144 "Zero sized specialization arguments should only be part of \
145 consts and therefore cannot be boxed"
146 );
147 statements.push(Statement::StructConstruct(StatementStructConstruct {
148 inputs: vec![],
149 output: var_id,
150 }));
151 } else {
152 statements.push(Statement::Const(StatementConst::new(
153 *value, var_id, *boxed,
154 )));
155 }
156 }
157 SpecializationArg::Snapshot(inner) => {
158 let snap_ty = variables.variables[var_id].ty;
159 let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
160 let desnapped_var =
161 variables.new_var(VarRequest { ty: denapped_ty, location });
162 stack.push((
163 var_id,
164 SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
165 ));
166 stack.push((
167 desnapped_var,
168 SpecializationArgBuildingState::Initial(inner.as_ref()),
169 ));
170 }
171 SpecializationArg::Array(ty, values) => {
172 let mut arr_var = var_id;
173 for value in values.iter().rev() {
174 let in_arr_var =
175 variables.variables.alloc(variables.variables[var_id].clone());
176 let value_var = variables.new_var(VarRequest { ty: *ty, location });
177 stack.push((
178 arr_var,
179 SpecializationArgBuildingState::PushBackArray {
180 in_array: in_arr_var,
181 value: value_var,
182 },
183 ));
184 stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
185 arr_var = in_arr_var;
186 }
187 statements.push(Statement::Call(StatementCall {
188 function: array_new_fn
189 .concretize(db, vec![GenericArgumentId::Type(*ty)])
190 .lowered(db),
191 inputs: vec![],
192 with_coupon: false,
193 outputs: vec![arr_var],
194 location: variables[var_id].location,
195 is_specialization_base_call: false,
196 }));
197 }
198 SpecializationArg::Struct(args) => {
199 let var_ty = variables[var_id].ty;
200 let location = variables[var_id].location;
201
202 let mut var_for_ty = |ty| variables.new_var(VarRequest { ty, location });
204 let var_ids = match var_ty.long(db) {
205 TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
206 let members = db.concrete_struct_members(*concrete_struct)?;
207 members.values().map(|member| var_for_ty(member.ty)).collect_vec()
208 }
209 TypeLongId::Tuple(element_types) => {
210 element_types.iter().cloned().map(var_for_ty).collect_vec()
211 }
212 TypeLongId::FixedSizeArray { type_id, .. } => {
213 itertools::repeat_n(*type_id, args.len())
214 .map(var_for_ty)
215 .collect_vec()
216 }
217 _ => unreachable!("Expected a struct, tuple, or fixed-size array type"),
218 };
219
220 stack.push((
221 var_id,
222 SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
223 ));
224
225 for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
226 stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
227 }
228 }
229 SpecializationArg::Enum { variant, payload } => {
230 let location = variables[var_id].location;
231 let payload_var =
232 variables.new_var(VarRequest { ty: variant.ty, location });
233 stack.push((
234 var_id,
235 SpecializationArgBuildingState::BuildEnum {
236 variant: *variant,
237 payload: payload_var,
238 },
239 ));
240 stack.push((
241 payload_var,
242 SpecializationArgBuildingState::Initial(payload.as_ref()),
243 ));
244 }
245 SpecializationArg::NotSpecialized => {
246 parameters.push(var_id);
247 }
248 },
249 SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
250 let ignored = variables.variables.alloc(variables[desnapped_var].clone());
251 statements.push(Statement::Snapshot(StatementSnapshot::new(
252 VarUsage { var_id: desnapped_var, location },
253 ignored,
254 var_id,
255 )));
256 }
257 SpecializationArgBuildingState::PushBackArray { in_array, value } => {
258 statements.push(Statement::Call(StatementCall {
259 function: array_append
260 .concretize(
261 db,
262 vec![GenericArgumentId::Type(variables.variables[value].ty)],
263 )
264 .lowered(db),
265 inputs: vec![
266 VarUsage { var_id: in_array, location },
267 VarUsage { var_id: value, location },
268 ],
269 with_coupon: false,
270 outputs: vec![var_id],
271 location,
272 is_specialization_base_call: false,
273 }));
274 }
275 SpecializationArgBuildingState::BuildStruct(ids) => {
276 statements.push(Statement::StructConstruct(StatementStructConstruct {
277 inputs: ids
278 .iter()
279 .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
280 .collect(),
281 output: var_id,
282 }));
283 }
284 SpecializationArgBuildingState::BuildEnum { variant, payload } => {
285 statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
286 variant,
287 input: VarUsage { var_id: payload, location: variables[payload].location },
288 output: var_id,
289 }));
290 }
291 }
292 }
293 }
294
295 let outputs: Vec<VariableId> =
296 chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
297 .map(|ty| variables.new_var(VarRequest { ty, location }))
298 .collect_vec();
299 let mut block_builder = BlocksBuilder::new();
300 let ret_usage =
301 outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
302 statements.push(Statement::Call(StatementCall {
303 function: specialized.base.function_id(db)?,
304 with_coupon: false,
305 inputs,
306 outputs,
307 location,
308 is_specialization_base_call: true,
309 }));
310 block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
311 Ok(Lowered {
312 signature: specialized.signature(db)?,
313 variables: variables.variables,
314 blocks: block_builder.build().unwrap(),
315 parameters,
316 diagnostics: Default::default(),
317 })
318}
319
320#[salsa::tracked]
322pub fn priv_should_specialize<'db>(
323 db: &'db dyn Database,
324 function_id: ids::ConcreteFunctionWithBodyId<'db>,
325) -> Maybe<bool> {
326 let ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) = function_id.long(db) else {
327 panic!("Expected a specialized function");
328 };
329
330 Ok(db.estimate_size(specialized.long(db).base)?.saturating_mul(8)
332 > db.estimate_size(function_id)?.saturating_mul(10))
333}