1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_semantic::helper::ModuleHelper;
6use cairo_lang_semantic::items::constant::ConstValueId;
7use cairo_lang_semantic::items::functions::GenericFunctionId;
8use cairo_lang_semantic::items::structure::StructSemantic;
9use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeId, TypeLongId};
10use cairo_lang_utils::extract_matches;
11use itertools::{Itertools, chain, zip_eq};
12use salsa::Database;
13
14use crate::blocks::BlocksBuilder;
15use crate::db::LoweringGroup;
16use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
17use crate::lower::context::{VarRequest, VariableAllocator};
18use crate::{
19 Block, BlockEnd, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
20 StatementConst, StatementSnapshot, StatementStructConstruct, VarUsage, VariableId,
21};
22
23#[derive(Clone, Debug, Hash, PartialEq, Eq)]
25pub enum SpecializationArg<'db> {
26 Const { value: ConstValueId<'db>, boxed: bool },
27 Snapshot(Box<SpecializationArg<'db>>),
28 Array(TypeId<'db>, Vec<SpecializationArg<'db>>),
29 Struct(Vec<SpecializationArg<'db>>),
30}
31
32impl<'a> DebugWithDb<'a> for SpecializationArg<'a> {
33 type Db = dyn Database;
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &'a dyn Database) -> std::fmt::Result {
35 match self {
36 SpecializationArg::Const { value, boxed } => {
37 write!(f, "{:?}", value.debug(db))?;
38 if *boxed {
39 write!(f, ".into_box()")?;
40 }
41 Ok(())
42 }
43 SpecializationArg::Snapshot(inner) => write!(f, "@{:?}", inner.debug(db)),
44 SpecializationArg::Struct(inner) => {
45 write!(f, "{{")?;
46 let mut inner = inner.iter().peekable();
47 while let Some(value) = inner.next() {
48 write!(f, " ")?;
49 value.fmt(f, db)?;
50
51 if inner.peek().is_some() {
52 write!(f, ",")?;
53 } else {
54 write!(f, " ")?;
55 }
56 }
57 write!(f, "}}")
58 }
59 SpecializationArg::Array(_ty, values) => {
60 write!(f, "array![")?;
61 let mut first = true;
62 for value in values {
63 if !first {
64 write!(f, ", ")?;
65 } else {
66 first = false;
67 }
68 write!(f, "{:?}", value.debug(db))?;
69 }
70 write!(f, "]")
71 }
72 }
73 }
74}
75
76enum SpecializationArgBuildingState<'db, 'a> {
79 Initial(&'a SpecializationArg<'db>),
80 TakeSnapshot(VariableId),
81 BuildStruct(Vec<VariableId>),
82 PushBackArray { in_array: VariableId, value: VariableId },
83}
84
85pub fn specialized_function_lowered<'db>(
87 db: &'db dyn Database,
88 specialized: SpecializedFunction<'db>,
89) -> Maybe<Lowered<'db>> {
90 let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
91 let base_semantic = specialized.base.base_semantic_function(db);
92
93 let array_module = ModuleHelper::core(db).submodule("array");
94 let array_new_fn = GenericFunctionId::Extern(array_module.extern_function_id("array_new"));
95 let array_append = GenericFunctionId::Extern(array_module.extern_function_id("array_append"));
96
97 let mut variables =
98 VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
99 let mut statements = vec![];
100 let mut parameters = vec![];
101 let mut inputs = vec![];
102 let mut stack = vec![];
103
104 let location = LocationId::from_stable_location(
105 db,
106 specialized.base.base_semantic_function(db).stable_location(db),
107 );
108
109 for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
110 let var_id = variables.variables.alloc(base.variables[*param].clone());
111 inputs.push(VarUsage { var_id, location });
112 if let Some(c) = arg {
113 stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
114 continue;
115 }
116 parameters.push(var_id);
117 }
118
119 while let Some((var_id, state)) = stack.pop() {
120 match state {
121 SpecializationArgBuildingState::Initial(c) => match c {
122 SpecializationArg::Const { value, boxed } => {
123 statements.push(Statement::Const(StatementConst::new(*value, var_id, *boxed)));
124 }
125 SpecializationArg::Snapshot(inner) => {
126 let snap_ty = variables.variables[var_id].ty;
127 let denapped_ty = *extract_matches!(snap_ty.long(db), TypeLongId::Snapshot);
128 let desnapped_var = variables.new_var(VarRequest { ty: denapped_ty, location });
129 stack.push((
130 var_id,
131 SpecializationArgBuildingState::TakeSnapshot(desnapped_var),
132 ));
133 stack.push((
134 desnapped_var,
135 SpecializationArgBuildingState::Initial(inner.as_ref()),
136 ));
137 }
138 SpecializationArg::Array(ty, values) => {
139 let mut arr_var = var_id;
140 for value in values.iter().rev() {
141 let in_arr_var =
142 variables.variables.alloc(variables.variables[var_id].clone());
143 let value_var = variables.new_var(VarRequest { ty: *ty, location });
144 stack.push((
145 arr_var,
146 SpecializationArgBuildingState::PushBackArray {
147 in_array: in_arr_var,
148 value: value_var,
149 },
150 ));
151 stack.push((value_var, SpecializationArgBuildingState::Initial(value)));
152 arr_var = in_arr_var;
153 }
154 statements.push(Statement::Call(StatementCall {
155 function: array_new_fn
156 .concretize(db, vec![GenericArgumentId::Type(*ty)])
157 .lowered(db),
158 inputs: vec![],
159 with_coupon: false,
160 outputs: vec![arr_var],
161 location: variables[var_id].location,
162 }));
163 }
164 SpecializationArg::Struct(consts) => {
165 let var = &variables[var_id];
166 let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
167 var.ty.long(db)
168 else {
169 unreachable!("Expected a concrete struct type");
170 };
171
172 let members = db.concrete_struct_members(*concrete_struct)?;
173
174 let location = var.location;
175 let var_ids = members
176 .values()
177 .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
178 .collect_vec();
179
180 stack.push((
181 var_id,
182 SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
183 ));
184
185 for (var_id, c) in zip_eq(var_ids, consts) {
186 stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
187 }
188 }
189 },
190 SpecializationArgBuildingState::TakeSnapshot(desnapped_var) => {
191 let ignored = variables.variables.alloc(variables[desnapped_var].clone());
192 statements.push(Statement::Snapshot(StatementSnapshot::new(
193 VarUsage { var_id: desnapped_var, location },
194 ignored,
195 var_id,
196 )));
197 }
198 SpecializationArgBuildingState::PushBackArray { in_array, value } => {
199 statements.push(Statement::Call(StatementCall {
200 function: array_append
201 .concretize(
202 db,
203 vec![GenericArgumentId::Type(variables.variables[value].ty)],
204 )
205 .lowered(db),
206 inputs: vec![
207 VarUsage { var_id: in_array, location },
208 VarUsage { var_id: value, location },
209 ],
210 with_coupon: false,
211 outputs: vec![var_id],
212 location,
213 }));
214 }
215 SpecializationArgBuildingState::BuildStruct(ids) => {
216 statements.push(Statement::StructConstruct(StatementStructConstruct {
217 inputs: ids
218 .iter()
219 .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
220 .collect(),
221 output: var_id,
222 }));
223 }
224 }
225 }
226
227 let outputs: Vec<VariableId> =
228 chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
229 .map(|ty| variables.new_var(VarRequest { ty, location }))
230 .collect_vec();
231 let mut block_builder = BlocksBuilder::new();
232 let ret_usage =
233 outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
234 statements.push(Statement::Call(StatementCall {
235 function: specialized.base.function_id(db)?,
236 with_coupon: false,
237 inputs,
238 outputs,
239 location,
240 }));
241 block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
242 Ok(Lowered {
243 signature: specialized.signature(db)?,
244 variables: variables.variables,
245 blocks: block_builder.build().unwrap(),
246 parameters,
247 diagnostics: Default::default(),
248 })
249}
250
251#[salsa::tracked]
253pub fn priv_should_specialize<'db>(
254 db: &'db dyn Database,
255 function_id: ids::ConcreteFunctionWithBodyId<'db>,
256) -> Maybe<bool> {
257 let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
258 function_id.long(db)
259 else {
260 panic!("Expected a specialized function");
261 };
262
263 if db.concrete_in_cycle(*base, DependencyType::Call, LoweringStage::Monomorphized)? {
268 return Ok(false);
269 }
270
271 Ok(db.estimate_size(*base)?.saturating_mul(8)
273 > db.estimate_size(function_id)?.saturating_mul(10))
274}