1use anyhow::{Context, bail};
2use cairo_lang_defs::db::DefsGroup;
3use cairo_lang_defs::ids::{
4 FreeFunctionId, LanguageElementId, LookupItemId, ModuleId, ModuleItemId,
5 NamedLanguageElementId, SubmoduleId,
6};
7use cairo_lang_diagnostics::ToOption;
8use cairo_lang_filesystem::ids::{CrateId, SmolStrId};
9use cairo_lang_semantic::diagnostic::SemanticDiagnostics;
10use cairo_lang_semantic::expr::inference::InferenceId;
11use cairo_lang_semantic::expr::inference::canonic::ResultNoErrEx;
12use cairo_lang_semantic::items::constant::ConstantSemantic;
13use cairo_lang_semantic::items::functions::{
14 ConcreteFunctionWithBodyId as SemanticConcreteFunctionWithBodyId, GenericFunctionId,
15};
16use cairo_lang_semantic::items::imp::ImplLongId;
17use cairo_lang_semantic::items::impl_alias::ImplAliasSemantic;
18use cairo_lang_semantic::items::module::ModuleSemantic;
19use cairo_lang_semantic::items::us::SemanticUseEx;
20use cairo_lang_semantic::resolve::{ResolvedGenericItem, Resolver};
21use cairo_lang_semantic::substitution::SemanticRewriter;
22use cairo_lang_sierra::ids::FunctionId;
23use cairo_lang_sierra_generator::db::SierraGenGroup;
24use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
25use cairo_lang_starknet_classes::keccak::starknet_keccak;
26use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx, QueryAttrs};
27use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode};
28use cairo_lang_utils::extract_matches;
29use cairo_lang_utils::ordered_hash_map::{
30 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
31};
32use itertools::chain;
33use salsa::Database;
34use serde::{Deserialize, Serialize};
35use starknet_types_core::felt::Felt as Felt252;
36use {cairo_lang_lowering as lowering, cairo_lang_semantic as semantic};
37
38use crate::aliased::Aliased;
39use crate::compile::{SemanticEntryPoints, extract_semantic_entrypoints};
40use crate::plugin::aux_data::StarknetContractAuxData;
41use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};
42
43#[cfg(test)]
44#[path = "contract_test.rs"]
45mod test;
46
47#[derive(Clone)]
49pub struct ContractDeclaration<'db> {
50 pub submodule_id: SubmoduleId<'db>,
52}
53
54impl<'db> ContractDeclaration<'db> {
55 pub fn module_id(&self) -> ModuleId<'db> {
56 ModuleId::Submodule(self.submodule_id)
57 }
58}
59
60pub fn module_contract<'db>(
62 db: &'db dyn Database,
63 module_id: ModuleId<'db>,
64) -> Option<ContractDeclaration<'db>> {
65 let all_aux_data = module_id.module_data(db).ok()?.generated_file_aux_data(db);
66
67 all_aux_data.values().skip(1).find_map(|aux_data| {
81 let StarknetContractAuxData { contract_name } =
82 aux_data.as_ref()?.as_any().downcast_ref()?;
83 if let ModuleId::Submodule(submodule_id) = module_id {
84 Some(ContractDeclaration { submodule_id })
85 } else {
86 unreachable!("Contract `{contract_name}` was not found.");
87 }
88 })
89}
90
91pub fn find_contracts<'db>(
94 db: &'db dyn Database,
95 crate_ids: &[CrateId<'db>],
96) -> Vec<ContractDeclaration<'db>> {
97 let mut contract_declarations = vec![];
98 for crate_id in crate_ids {
99 let modules = db.crate_modules(*crate_id);
100 for module_id in modules.iter() {
101 contract_declarations.extend(module_contract(db, *module_id));
102 }
103 }
104 contract_declarations
105}
106
107pub fn get_contract_abi_functions<'db>(
110 db: &'db dyn Database,
111 contract: &ContractDeclaration<'db>,
112 module_name: &'db str,
113) -> anyhow::Result<Vec<Aliased<semantic::ConcreteFunctionWithBodyId<'db>>>> {
114 let module_name = SmolStrId::from(db, module_name);
115 Ok(chain!(
116 get_contract_internal_module_abi_functions(db, contract, module_name)?,
117 get_impl_aliases_abi_functions(db, contract, module_name)?
118 )
119 .collect())
120}
121
122fn get_contract_internal_module_abi_functions<'db>(
124 db: &'db dyn Database,
125 contract: &ContractDeclaration<'db>,
126 module_name: SmolStrId<'db>,
127) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId<'db>>>> {
128 let generated_module_id = get_generated_contract_module(db, contract)?;
129 let module_id = get_submodule_id(db, generated_module_id, module_name)?;
130 get_module_aliased_functions(db, module_id)?
131 .into_iter()
132 .map(|f| f.try_map(|f| semantic::ConcreteFunctionWithBodyId::from_no_generics_free(db, f)))
133 .collect::<Option<Vec<_>>>()
134 .with_context(|| "Generics are not allowed in wrapper functions")
135}
136
137fn get_module_aliased_functions<'db>(
141 db: &'db dyn Database,
142 module_id: ModuleId<'db>,
143) -> anyhow::Result<Vec<Aliased<FreeFunctionId<'db>>>> {
144 module_id
145 .module_data(db)
146 .map(|data| data.uses(db))
147 .to_option()
148 .with_context(|| "Failed to get external module uses.")?
149 .iter()
150 .map(|(use_id, leaf)| {
151 if let ResolvedGenericItem::GenericFunction(GenericFunctionId::Free(function_id)) = db
152 .use_resolved_item(*use_id)
153 .to_option()
154 .with_context(|| "Failed to fetch used function.")?
155 {
156 Ok(Aliased {
157 value: function_id,
158 alias: leaf.stable_ptr(db).identifier(db).to_string(db),
159 })
160 } else {
161 bail!("Expected a free function.")
162 }
163 })
164 .collect::<Result<Vec<_>, _>>()
165}
166
167fn get_impl_aliases_abi_functions<'db>(
171 db: &'db dyn Database,
172 contract: &ContractDeclaration<'db>,
173 module_prefix: SmolStrId<'db>,
174) -> anyhow::Result<Vec<Aliased<SemanticConcreteFunctionWithBodyId<'db>>>> {
175 let generated_module_id = get_generated_contract_module(db, contract)?;
176 let mut diagnostics = SemanticDiagnostics::default();
177 let mut all_abi_functions = vec![];
178 for (impl_alias_id, impl_alias) in generated_module_id
179 .module_data(db)
180 .to_option()
181 .with_context(|| "Failed to get external module impl aliases.")?
182 .impl_aliases(db)
183 .iter()
184 {
185 if !impl_alias.has_attr_with_arg(db, ABI_ATTR, ABI_ATTR_EMBED_V0_ARG) {
186 continue;
187 }
188 let Ok(resolved_impl) = db.impl_alias_resolved_impl(*impl_alias_id) else {
189 bail!("Internal error: Failed to get impl alias solution.");
190 };
191 let ImplLongId::Concrete(concrete) = resolved_impl.long(db) else {
192 bail!("Internal error: Solved impl alias is not an impl.");
193 };
194 let impl_def_id = concrete.long(db).impl_def_id;
195 let impl_module = impl_def_id.parent_module(db);
196 let impl_name = impl_def_id.name_identifier(db).text(db).long(db);
197 let module_id = get_submodule_id(
198 db,
199 impl_module,
200 SmolStrId::from(db, format!("{}_{impl_name}", module_prefix.long(db))),
201 )?;
202 let mut resolver = Resolver::new(
203 db,
204 impl_alias_id.parent_module(db),
205 InferenceId::LookupItemDeclaration(LookupItemId::ModuleItem(ModuleItemId::ImplAlias(
206 *impl_alias_id,
207 ))),
208 );
209 let Some(last_segment) = impl_alias.impl_path(db).segments(db).elements(db).last() else {
210 unreachable!("impl_path should have at least one segment");
211 };
212 let generic_args = last_segment.generic_args(db).unwrap_or_default();
213 for abi_function in get_module_aliased_functions(db, module_id)? {
214 all_abi_functions.extend(abi_function.try_map(|f| {
215 let concrete_wrapper = resolver
216 .specialize_function(
217 &mut diagnostics,
218 impl_alias.stable_ptr(db).untyped(),
219 GenericFunctionId::Free(f),
220 &generic_args,
221 )
222 .to_option()?
223 .get_concrete(db)
224 .body(db)
225 .to_option()??;
226 let inference = &mut resolver.inference();
227 assert_eq!(
228 inference.finalize_without_reporting(),
229 Ok(()),
230 "All inferences should be solved at this point."
231 );
232 Some(inference.rewrite(concrete_wrapper).no_err())
233 }));
234 }
235 }
236 diagnostics
237 .build()
238 .expect_with_db(db, "Internal error: Inference for wrappers generics failed.");
239 Ok(all_abi_functions)
240}
241
242fn get_generated_contract_module<'db>(
244 db: &'db dyn Database,
245 contract: &ContractDeclaration<'db>,
246) -> anyhow::Result<ModuleId<'db>> {
247 let parent_module_id = contract.submodule_id.parent_module(db);
248 let contract_name = contract.submodule_id.name(db);
249
250 match db
251 .module_item_by_name(parent_module_id, contract_name)
252 .to_option()
253 .with_context(|| "Failed to initiate a lookup in the root module.")?
254 {
255 Some(ModuleItemId::Submodule(generated_module_id)) => {
256 Ok(ModuleId::Submodule(generated_module_id))
257 }
258 _ => anyhow::bail!(format!("Failed to get generated module {}.", contract_name.long(db))),
259 }
260}
261
262fn get_submodule_id<'db>(
264 db: &'db dyn Database,
265 module_id: ModuleId<'db>,
266 submodule_name: SmolStrId<'db>,
267) -> anyhow::Result<ModuleId<'db>> {
268 match db
269 .module_item_by_name(module_id, submodule_name)
270 .to_option()
271 .with_context(|| "Failed to initiate a lookup in the {module_name} module.")?
272 {
273 Some(ModuleItemId::Submodule(submodule_id)) => Ok(ModuleId::Submodule(submodule_id)),
274 _ => anyhow::bail!(
275 "Failed to get the submodule `{}` of `{}`.",
276 submodule_name.long(db),
277 module_id.full_path(db)
278 ),
279 }
280}
281
282#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Eq)]
284pub struct ContractInfo {
285 pub constructor: Option<FunctionId>,
287 #[serde(
289 serialize_with = "serialize_ordered_hashmap_vec",
290 deserialize_with = "deserialize_ordered_hashmap_vec"
291 )]
292 pub externals: OrderedHashMap<Felt252, FunctionId>,
293 #[serde(
295 serialize_with = "serialize_ordered_hashmap_vec",
296 deserialize_with = "deserialize_ordered_hashmap_vec"
297 )]
298 pub l1_handlers: OrderedHashMap<Felt252, FunctionId>,
299}
300
301pub fn get_contracts_info<T: SierraIdReplacer>(
303 db: &dyn Database,
304 contracts: Vec<ContractDeclaration<'_>>,
305 replacer: &T,
306) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
307 let mut contracts_info = OrderedHashMap::default();
308 for contract in contracts {
309 let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
310 contracts_info.insert(class_hash, contract_info);
311 }
312 Ok(contracts_info)
313}
314
315fn analyze_contract<'db, T: SierraIdReplacer>(
317 db: &dyn Database,
318 contract: &ContractDeclaration<'db>,
319 replacer: &T,
320) -> anyhow::Result<(Felt252, ContractInfo)> {
321 let item = db
323 .module_item_by_name(contract.module_id(), SmolStrId::from(db, "TEST_CLASS_HASH"))
324 .unwrap()
325 .unwrap();
326 let constant_id = extract_matches!(item, ModuleItemId::Constant);
327 let class_hash =
328 Felt252::from(db.constant_const_value(constant_id).unwrap().long(db).to_int().unwrap());
329
330 let SemanticEntryPoints { external, l1_handler, constructor } =
332 extract_semantic_entrypoints(db, contract)?;
333 let externals =
334 external.into_iter().map(|f| get_selector_and_sierra_function(db, &f, replacer)).collect();
335 let l1_handlers = l1_handler
336 .into_iter()
337 .map(|f| get_selector_and_sierra_function(db, &f, replacer))
338 .collect();
339 let constructors: Vec<_> = constructor
340 .into_iter()
341 .map(|f| get_selector_and_sierra_function(db, &f, replacer))
342 .collect();
343
344 let contract_info = ContractInfo {
345 externals,
346 l1_handlers,
347 constructor: constructors.into_iter().next().map(|x| x.1),
348 };
349 Ok((class_hash, contract_info))
350}
351
352pub fn get_selector_and_sierra_function<'db, T: SierraIdReplacer>(
355 db: &dyn Database,
356 function_with_body: &Aliased<lowering::ids::ConcreteFunctionWithBodyId<'db>>,
357 replacer: &T,
358) -> (Felt252, FunctionId) {
359 let function_id = function_with_body.value.function_id(db).expect("Function error.");
360 let sierra_id = replacer.replace_function_id(&db.intern_sierra_function(function_id));
361 let selector: Felt252 = starknet_keccak(function_with_body.alias.as_bytes()).into();
362 (selector, sierra_id)
363}