1use std::default::Default;
2use std::sync::Arc;
3
4use anyhow::{Result, ensure};
5use cairo_lang_compiler::db::RootDatabase;
6use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
7use cairo_lang_compiler::get_sierra_program_for_functions;
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
10use cairo_lang_filesystem::db::FilesGroup;
11use cairo_lang_filesystem::ids::CrateId;
12use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
13use cairo_lang_semantic::db::SemanticGroup;
14use cairo_lang_semantic::items::functions::GenericFunctionId;
15use cairo_lang_semantic::plugin::PluginSuite;
16use cairo_lang_semantic::{ConcreteFunction, FunctionLongId};
17use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
18use cairo_lang_sierra::extensions::gas::CostTokenType;
19use cairo_lang_sierra::ids::FunctionId;
20use cairo_lang_sierra::program::ProgramArtifact;
21use cairo_lang_sierra_generator::db::SierraGenGroup;
22use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
23use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
24use cairo_lang_sierra_generator::replace_ids::DebugReplacer;
25use cairo_lang_sierra_generator::statements_locations::StatementsLocations;
26use cairo_lang_starknet::contract::{
27 ContractDeclaration, ContractInfo, find_contracts, get_contract_abi_functions,
28 get_contracts_info,
29};
30use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
31use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
32use cairo_lang_utils::ordered_hash_map::{
33 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
34};
35use itertools::{Itertools, chain};
36pub use plugin::TestPlugin;
37use serde::{Deserialize, Serialize};
38use starknet_types_core::felt::Felt as Felt252;
39pub use test_config::{TestConfig, try_extract_test_config};
40
41mod inline_macros;
42pub mod plugin;
43pub mod test_config;
44
45const TEST_ATTR: &str = "test";
46const SHOULD_PANIC_ATTR: &str = "should_panic";
47const IGNORE_ATTR: &str = "ignore";
48const AVAILABLE_GAS_ATTR: &str = "available_gas";
49const STATIC_GAS_ARG: &str = "static";
50
51#[derive(Clone)]
53pub struct TestsCompilationConfig {
54 pub starknet: bool,
56
57 pub contract_declarations: Option<Vec<ContractDeclaration>>,
61
62 pub contract_crate_ids: Option<Vec<CrateId>>,
65
66 pub executable_crate_ids: Option<Vec<CrateId>>,
69
70 pub add_statements_functions: bool,
73
74 pub add_statements_code_locations: bool,
77}
78
79pub fn compile_test_prepared_db(
91 db: &RootDatabase,
92 tests_compilation_config: TestsCompilationConfig,
93 test_crate_ids: Vec<CrateId>,
94 diagnostics_reporter: DiagnosticsReporter<'_>,
95) -> Result<TestCompilation> {
96 ensure!(
97 tests_compilation_config.starknet
98 || tests_compilation_config.contract_declarations.is_none(),
99 "Contract declarations can be provided only when starknet is enabled."
100 );
101 ensure!(
102 tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
103 "Contract crate ids can be provided only when starknet is enabled."
104 );
105
106 let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
107 find_contracts(
108 db,
109 &tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
110 )
111 });
112 let all_entry_points = if tests_compilation_config.starknet {
113 contracts
114 .iter()
115 .flat_map(|contract| {
116 chain!(
117 get_contract_abi_functions(db, contract, EXTERNAL_MODULE).unwrap_or_default(),
118 get_contract_abi_functions(db, contract, CONSTRUCTOR_MODULE)
119 .unwrap_or_default(),
120 get_contract_abi_functions(db, contract, L1_HANDLER_MODULE).unwrap_or_default(),
121 )
122 })
123 .map(|func| ConcreteFunctionWithBodyId::from_semantic(db, func.value))
124 .collect()
125 } else {
126 vec![]
127 };
128
129 let executable_functions = find_executable_function_ids(
130 db,
131 tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
132 );
133 let all_tests = find_all_tests(db, test_crate_ids.clone());
134
135 let func_ids = chain!(
136 executable_functions.clone().into_keys(),
137 all_entry_points.iter().cloned(),
138 all_tests.iter().flat_map(|(func_id, _cfg)| {
140 ConcreteFunctionWithBodyId::from_no_generics_free(db, *func_id)
141 })
142 )
143 .collect();
144
145 let SierraProgramWithDebug { program: mut sierra_program, debug_info } =
146 Arc::unwrap_or_clone(get_sierra_program_for_functions(db, func_ids, diagnostics_reporter)?);
147
148 let function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>> =
149 all_entry_points
150 .iter()
151 .map(|func_id| {
152 (
153 db.function_with_body_sierra(*func_id).unwrap().id.clone(),
154 [(CostTokenType::Const, ENTRY_POINT_COST)].into(),
155 )
156 })
157 .collect();
158
159 let replacer = DebugReplacer { db };
160 replacer.enrich_function_names(&mut sierra_program);
161
162 let mut annotations = Annotations::default();
163 if tests_compilation_config.add_statements_functions {
164 annotations.extend(Annotations::from(
165 debug_info.statements_locations.extract_statements_functions(db),
166 ))
167 }
168 if tests_compilation_config.add_statements_code_locations {
169 annotations.extend(Annotations::from(
170 debug_info.statements_locations.extract_statements_source_code_locations(db),
171 ))
172 }
173
174 let executables = collect_executables(db, executable_functions, &sierra_program);
175 let named_tests = all_tests
176 .into_iter()
177 .map(|(func_id, test)| {
178 (
179 format!(
180 "{:?}",
181 FunctionLongId {
182 function: ConcreteFunction {
183 generic_function: GenericFunctionId::Free(func_id),
184 generic_args: vec![]
185 }
186 }
187 .debug(db)
188 ),
189 test,
190 )
191 })
192 .collect_vec();
193 let contracts_info = get_contracts_info(db, contracts, &replacer)?;
194 let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
195 executables,
196 annotations,
197 ..DebugInfo::default()
198 });
199
200 Ok(TestCompilation {
201 sierra_program,
202 metadata: TestCompilationMetadata {
203 named_tests,
204 function_set_costs,
205 contracts_info,
206 statements_locations: Some(debug_info.statements_locations),
207 },
208 })
209}
210
211#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
217pub struct TestCompilation {
218 pub sierra_program: ProgramArtifact,
219 #[serde(flatten)]
220 pub metadata: TestCompilationMetadata,
221}
222
223#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
228pub struct TestCompilationMetadata {
229 #[serde(
230 serialize_with = "serialize_ordered_hashmap_vec",
231 deserialize_with = "deserialize_ordered_hashmap_vec"
232 )]
233 pub contracts_info: OrderedHashMap<Felt252, ContractInfo>,
234 #[serde(
235 serialize_with = "serialize_ordered_hashmap_vec",
236 deserialize_with = "deserialize_ordered_hashmap_vec"
237 )]
238 pub function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
239 pub named_tests: Vec<(String, TestConfig)>,
240 #[serde(skip)]
244 pub statements_locations: Option<StatementsLocations>,
245}
246
247fn find_all_tests(
249 db: &dyn SemanticGroup,
250 main_crates: Vec<CrateId>,
251) -> Vec<(FreeFunctionId, TestConfig)> {
252 let mut tests = vec![];
253 for crate_id in main_crates {
254 let modules = db.crate_modules(crate_id);
255 for module_id in modules.iter() {
256 let Ok(module_items) = db.module_items(*module_id) else {
257 continue;
258 };
259 tests.extend(module_items.iter().filter_map(|item| {
260 let ModuleItemId::FreeFunction(func_id) = item else { return None };
261 let Ok(attrs) =
262 db.function_with_body_attributes(FunctionWithBodyId::Free(*func_id))
263 else {
264 return None;
265 };
266 Some((*func_id, try_extract_test_config(db.upcast(), attrs).ok()??))
267 }));
268 }
269 }
270 tests
271}
272
273pub fn test_assert_suite() -> PluginSuite {
275 let mut suite = PluginSuite::default();
276 suite
277 .add_inline_macro_plugin::<inline_macros::assert::AssertEqMacro>()
278 .add_inline_macro_plugin::<inline_macros::assert::AssertNeMacro>()
279 .add_inline_macro_plugin::<inline_macros::assert::AssertLtMacro>()
280 .add_inline_macro_plugin::<inline_macros::assert::AssertLeMacro>()
281 .add_inline_macro_plugin::<inline_macros::assert::AssertGtMacro>()
282 .add_inline_macro_plugin::<inline_macros::assert::AssertGeMacro>();
283 suite
284}
285
286pub fn test_plugin_suite() -> PluginSuite {
288 let mut suite = PluginSuite::default();
289 suite.add_plugin::<TestPlugin>().add(test_assert_suite());
290 suite
291}