1use std::vec;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_semantic::helper::ModuleHelper;
6use cairo_lang_semantic::items::constant::ConstValue;
7use cairo_lang_semantic::items::functions::GenericFunctionId;
8use cairo_lang_semantic::{ConcreteTypeId, GenericArgumentId, TypeId, TypeLongId};
9use cairo_lang_utils::LookupIntern;
10use itertools::{Itertools, chain, zip_eq};
11
12use crate::blocks::BlocksBuilder;
13use crate::db::LoweringGroup;
14use crate::ids::{self, LocationId, SemanticFunctionIdEx, SpecializedFunction};
15use crate::lower::context::{VarRequest, VariableAllocator};
16use crate::{
17 Block, BlockEnd, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
18 StatementConst, StatementStructConstruct, VarUsage, VariableId,
19};
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
23pub enum SpecializationArg {
24 Const(ConstValue),
25 EmptyArray(TypeId),
26 Struct(Vec<SpecializationArg>),
27}
28
29impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for SpecializationArg {
30 fn fmt(
31 &self,
32 f: &mut std::fmt::Formatter<'_>,
33 db: &(dyn LoweringGroup + 'a),
34 ) -> std::fmt::Result {
35 match self {
36 SpecializationArg::Const(value) => write!(f, "{:?}", value.debug(db)),
37 SpecializationArg::Struct(inner) => {
38 write!(f, "{{")?;
39 let mut inner = inner.iter().peekable();
40 while let Some(value) = inner.next() {
41 write!(f, " ")?;
42 value.fmt(f, db)?;
43
44 if inner.peek().is_some() {
45 write!(f, ",")?;
46 } else {
47 write!(f, " ")?;
48 }
49 }
50 write!(f, "}}")
51 }
52 SpecializationArg::EmptyArray(_) => write!(f, "array![]"),
53 }
54 }
55}
56
57enum SpecializationArgBuildingState<'a> {
60 Initial(&'a SpecializationArg),
61 BuildStruct(Vec<VariableId>),
62}
63
64pub fn specialized_function_lowered(
66 db: &dyn LoweringGroup,
67 specialized: SpecializedFunction,
68) -> Maybe<Lowered> {
69 let base = db.lowered_body(specialized.base, LoweringStage::Monomorphized)?;
70 let base_semantic = specialized.base.base_semantic_function(db);
71
72 let array_new_fn = GenericFunctionId::Extern(
73 ModuleHelper::core(db).submodule("array").extern_function_id("array_new"),
74 );
75
76 let mut variables =
77 VariableAllocator::new(db, base_semantic.function_with_body_id(db), Default::default())?;
78 let mut statements = vec![];
79 let mut parameters = vec![];
80 let mut inputs = vec![];
81 let mut stack = vec![];
82
83 let location = LocationId::from_stable_location(
84 db,
85 specialized.base.base_semantic_function(db).stable_location(db),
86 );
87
88 for (param, arg) in zip_eq(&base.parameters, specialized.args.iter()) {
89 let var_id = variables.variables.alloc(base.variables[*param].clone());
90 inputs.push(VarUsage { var_id, location });
91 if let Some(c) = arg {
92 stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
93 continue;
94 }
95 parameters.push(var_id);
96 }
97
98 while let Some((var_id, state)) = stack.pop() {
99 match state {
100 SpecializationArgBuildingState::Initial(c) => match c {
101 SpecializationArg::Const(value) => {
102 statements.push(Statement::Const(StatementConst {
103 value: value.clone(),
104 output: var_id,
105 }));
106 }
107 SpecializationArg::EmptyArray(ty) => {
108 statements.push(Statement::Call(StatementCall {
109 function: array_new_fn
110 .concretize(db, vec![GenericArgumentId::Type(*ty)])
111 .lowered(db),
112 inputs: vec![],
113 with_coupon: false,
114 outputs: vec![var_id],
115 location: variables[var_id].location,
116 }));
117 }
118 SpecializationArg::Struct(consts) => {
119 let var = &variables[var_id];
120 let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
121 var.ty.lookup_intern(db)
122 else {
123 unreachable!("Expected a concrete struct type");
124 };
125
126 let members = db.concrete_struct_members(concrete_struct)?;
127
128 let location = var.location;
129 let var_ids = members
130 .values()
131 .map(|member| variables.new_var(VarRequest { ty: member.ty, location }))
132 .collect_vec();
133
134 stack.push((
135 var_id,
136 SpecializationArgBuildingState::BuildStruct(var_ids.clone()),
137 ));
138
139 for (var_id, c) in zip_eq(var_ids, consts) {
140 stack.push((var_id, SpecializationArgBuildingState::Initial(c)));
141 }
142 }
143 },
144 SpecializationArgBuildingState::BuildStruct(ids) => {
145 statements.push(Statement::StructConstruct(StatementStructConstruct {
146 inputs: ids
147 .iter()
148 .map(|id| VarUsage { var_id: *id, location: variables[*id].location })
149 .collect(),
150 output: var_id,
151 }));
152 }
153 }
154 }
155
156 let outputs: Vec<VariableId> =
157 chain!(base.signature.extra_rets.iter().map(|ret| ret.ty()), [base.signature.return_type])
158 .map(|ty| variables.new_var(VarRequest { ty, location }))
159 .collect_vec();
160 let mut block_builder = BlocksBuilder::new();
161 let ret_usage =
162 outputs.iter().map(|var_id| VarUsage { var_id: *var_id, location }).collect_vec();
163 statements.push(Statement::Call(StatementCall {
164 function: specialized.base.function_id(db)?,
165 with_coupon: false,
166 inputs,
167 outputs,
168 location,
169 }));
170 block_builder.alloc(Block { statements, end: BlockEnd::Return(ret_usage, location) });
171 Ok(Lowered {
172 signature: specialized.signature(db)?,
173 variables: variables.variables,
174 blocks: block_builder.build().unwrap(),
175 parameters,
176 diagnostics: Default::default(),
177 })
178}
179
180pub fn priv_should_specialize(
182 db: &dyn LoweringGroup,
183 function_id: ids::ConcreteFunctionWithBodyId,
184) -> Maybe<bool> {
185 let ids::ConcreteFunctionWithBodyLongId::Specialized(SpecializedFunction { base, .. }) =
186 function_id.lookup_intern(db)
187 else {
188 panic!("Expected a specialized function");
189 };
190
191 if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
196 return Ok(false);
197 }
198
199 Ok(db.estimate_size(base)?.saturating_mul(8)
201 > db.estimate_size(function_id)?.saturating_mul(10))
202}