1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_semantic::helper::ModuleHelper;
6use cairo_lang_semantic::items::constant::ConstValueId;
7use cairo_lang_semantic::items::functions::GenericFunctionId;
8use cairo_lang_semantic::items::structure::StructSemantic;
9use cairo_lang_semantic::{ConcreteTypeId, ConcreteVariant, GenericArgumentId, TypeId, TypeLongId};
10use cairo_lang_utils::extract_matches;
11use itertools::{Itertools, chain, zip_eq};
12use salsa::Database;
13
14use crate::blocks::BlocksBuilder;
15use crate::db::LoweringGroup;
16use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
17use crate::lower::context::{VarRequest, VariableAllocator};
18use crate::objects::StatementEnumConstruct as StatementEnumConstructObj;
19use crate::{
20 Block, BlockEnd, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
21 StatementConst, StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
22};
23
24#[derive(Clone, Debug, Hash, PartialEq, Eq)]
26pub enum SpecializationArg<'db> {
27 Const { value: ConstValueId<'db>, boxed: bool },
28 Snapshot(Box<SpecializationArg<'db>>),
29 Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
30 Struct(Vec<SpecializationArg<'db>>),
31 Enum { variant: ConcreteVariant<'db>, payload: Box<SpecializationArg<'db>> },
32 NotSpecialized,
33}
34
35impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
36 type Db = dyn Database;
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
38 match self {
39 SpecializationArg::Const { value, boxed } => {
40 write!(f, "{:?}", value.debug(db))?;
41 if *boxed {
42 write!(f, ".into_box()")?;
43 }
44 Ok(())
45 }
46 SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
47 SpecializationArg::Struct(args) => {
48 write!(f, "{{")?;
49 let mut inner = args.iter().peekable();
50 while let Some(value) = inner.next() {
51 write!(f, " ")?;
52 value.fmt(f, db)?;
53
54 if inner.peek().is_some() {
55 write!(f, ",")?;
56 } else {
57 write!(f, " ")?;
58 }
59 }
60 write!(f, "}}")
61 }
62 SpecializationArg::Array(_ty, values) => {
63 write!(f, "array![")?;
64 let mut first = true;
65 for value in values {
66 if !first {
67 write!(f, ", ")?;
68 } else {
69 first = false;
70 }
71 write!(f, "{:?}", value.debug(db))?;
72 }
73 write!(f, "]")
74 }
75 SpecializationArg::Enum { variant, payload } => {
76 write!(f, "{:?}(", variant.debug(db))?;
77 payload.fmt(f, db)?;
78 write!(f, ")")
79 }
80 SpecializationArg::NotSpecialized => write!(f, "NotSpecialized"),
81 }
82 }
83}
84
85enum SpecializationArgBuildingState<'db, 'a> {
88 Initial(&'a SpecializationArg<'db>),
89 TakeSnapshot(VariableId),
90 BuildStruct(Vec<VariableId>),
91 PushBackArray { in_array: VariableId, value: VariableId },
92 BuildEnum { variant: ConcreteVariant<'db>, payload: VariableId },
93}
94
95pub fn specialized_function_lowered<'db>(
97 db: &'db dyn Database,
98 specialized: SpecializedFunction<'db>,
99) -> Maybe<Lowered<'db>> {
100 let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
101 let base_semantic = specialized.base.base_semantic_function(db);
102
103 let array_module = ModuleHelper::core(db).submodule("array");
104 let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
105 let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
106
107 let mut variables =
108 VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
109 let mut statements = vec![];
110 let mut parameters = vec![];
111 let mut inputs = vec![];
112 let mut stack = vec![];
113
114 let location = LocationId::from_stable_location(
115 db,
116 specialized.base.base_semantic_function(db).stable_location(db),
117 );
118
119 for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
120 let var_id = variables.variables.alloc(base.variables[*param].clone());
121 inputs.push(VarUsage { var_id, location });
122 if SpecializationArg::NotSpecialized == *arg {
123 parameters.push(var_id);
124 continue;
125 }
126 stack.push((var_id, SpecializationArgBuildingState::Initial(arg)));
127 while let Some((var_id, state)) = stack.pop() {
128 match state {
129 SpecializationArgBuildingState::Initial(c) => match c {
130 SpecializationArg::Const { value, boxed } => {
131 statements
132 .push(Statement::Const(StatementConst::new(*value, var_id, *boxed)));
133 }
134 SpecializationArg::Snapshot(inner) => {
135 let snap_ty = variables.variables[var_id].ty;
136 let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
137 let desnapped_var =
138 variables.new_var(VarRequest { ty: denapped_ty, location });
139 stack.push((
140 var_id,
141 SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
142 ));
143 stack.push((
144 desnapped_var,
145 SpecializationArgBuildingState::Initial(inner.as_ref()),
146 ));
147 }
148 SpecializationArg::Array(ty, values) => {
149 let mut arr_var = var_id;
150 for value in values.iter().rev() {
151 let in_arr_var =
152 variables.variables.alloc(variables.variables[var_id].clone());
153 let value_var = variables.new_var(VarRequest { ty: *ty, location });
154 stack.push((
155 arr_var,
156 SpecializationArgBuildingState::PushBackArray {
157 in_array: in_arr_var,
158 value: value_var,
159 },
160 ));
161 stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
162 arr_var = in_arr_var;
163 }
164 statements.push(Statement::Call(StatementCall {
165 function: array_new_fn
166 .concretize(db, vec![GenericArgumentId::Type(*ty)])
167 .lowered(db),
168 inputs: vec![],
169 with_coupon: false,
170 outputs: vec![arr_var],
171 location: variables[var_id].location,
172 }));
173 }
174 SpecializationArg::Struct(args) => {
175 let var = &variables[var_id];
176 let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
177 var.ty.long(db)
178 else {
179 unreachable!("Expected a concrete struct type");
180 };
181
182 let members = db.concrete_struct_members(*concrete_struct)?;
183
184 let location = var.location;
185 let var_ids = members
186 .values()
187 .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
188 .collect_vec();
189
190 stack.push((
191 var_id,
192 SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
193 ));
194
195 for (var_id, arg) in zip_eq(var_ids.iter().rev(), args.iter().rev()) {
196 stack.push((*var_id, SpecializationArgBuildingState::Initial(arg)));
197 }
198 }
199 SpecializationArg::Enum { variant, payload } => {
200 let location = variables[var_id].location;
201 let payload_var =
202 variables.new_var(VarRequest { ty: variant.ty, location });
203 stack.push((
204 var_id,
205 SpecializationArgBuildingState::BuildEnum {
206 variant: *variant,
207 payload: payload_var,
208 },
209 ));
210 stack.push((
211 payload_var,
212 SpecializationArgBuildingState::Initial(payload.as_ref()),
213 ));
214 }
215 SpecializationArg::NotSpecialized => {
216 parameters.push(var_id);
217 }
218 },
219 SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
220 let ignored = variables.variables.alloc(variables[desnapped_var].clone());
221 statements.push(Statement::Snapshot(StatementSnapshot::new(
222 VarUsage { var_id: desnapped_var, location },
223 ignored,
224 var_id,
225 )));
226 }
227 SpecializationArgBuildingState::PushBackArray { in_array, value } => {
228 statements.push(Statement::Call(StatementCall {
229 function: array_append
230 .concretize(
231 db,
232 vec![GenericArgumentId::Type(variables.variables[value].ty)],
233 )
234 .lowered(db),
235 inputs: vec![
236 VarUsage { var_id: in_array, location },
237 VarUsage { var_id: value, location },
238 ],
239 with_coupon: false,
240 outputs: vec![var_id],
241 location,
242 }));
243 }
244 SpecializationArgBuildingState::BuildStruct(ids) => {
245 statements.push(Statement::StructConstruct(StatementStructConstruct {
246 inputs: ids
247 .iter()
248 .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
249 .collect(),
250 output: var_id,
251 }));
252 }
253 SpecializationArgBuildingState::BuildEnum { variant, payload } => {
254 statements.push(Statement::EnumConstruct(StatementEnumConstructObj {
255 variant,
256 input: VarUsage { var_id: payload, location: variables[payload].location },
257 output: var_id,
258 }));
259 }
260 }
261 }
262 }
263
264 let outputs: Vec<VariableId> =
265 chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
266 .map(|ty| variables.new_var(VarRequest { ty, location }))
267 .collect_vec();
268 let mut block_builder = BlocksBuilder::new();
269 let ret_usage =
270 outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
271 statements.push(Statement::Call(StatementCall {
272 function: specialized.base.function_id(db)?,
273 with_coupon: false,
274 inputs,
275 outputs,
276 location,
277 }));
278 block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
279 Ok(Lowered {
280 signature: specialized.signature(db)?,
281 variables: variables.variables,
282 blocks: block_builder.build().unwrap(),
283 parameters,
284 diagnostics: Default::default(),
285 })
286}
287
288#[salsa::tracked]
290pub fn priv_should_specialize<'db>(
291 db: &'db dyn Database,
292 function_id: ids::ConcreteFunctionWithBodyId<'db>,
293) -> Maybe<bool> {
294 let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
295 function_id.long(db)
296 else {
297 panic!("Expected a specialized function");
298 };
299
300 if db.concrete_in_cycle(*base, DependencyType::Call, LoweringStage::Monomorphized)? {
305 return Ok(false);
306 }
307
308 Ok(db.estimate_size(*base)?.saturating_mul(8)
310 > db.estimate_size(function_id)?.saturating_mul(10))
311}