1use anyhow::{Result, ensure};
2use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
3use cairo_lang_compiler::{ensure_diagnostics, get_sierra_program_for_functions};
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::db::DefsGroup;
6use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
7use cairo_lang_filesystem::db::FilesGroup;
8use cairo_lang_filesystem::ids::{CrateId, CrateInput};
9use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
10use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
11use cairo_lang_semantic::items::functions::GenericFunctionId;
12use cairo_lang_semantic::plugin::PluginSuite;
13use cairo_lang_semantic::{ConcreteFunction, FunctionLongId};
14use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
15use cairo_lang_sierra::extensions::gas::{CostTokenMap, CostTokenType};
16use cairo_lang_sierra::ids::FunctionId;
17use cairo_lang_sierra::program::ProgramArtifact;
18use cairo_lang_sierra_generator::db::SierraGenGroup;
19use cairo_lang_sierra_generator::debug_info::StatementsLocations;
20use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
21use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
22use cairo_lang_sierra_generator::replace_ids::{DebugReplacer, SierraIdReplacer};
23use cairo_lang_starknet::contract::{
24 ContractDeclaration, ContractInfo, find_contracts, get_contract_abi_functions,
25 get_contracts_info,
26};
27use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
28use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
29use cairo_lang_utils::CloneableDatabase;
30use cairo_lang_utils::ordered_hash_map::{
31 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
32};
33use itertools::{Itertools, chain};
34pub use plugin::TestPlugin;
35use salsa::Database;
36use serde::{Deserialize, Serialize};
37use starknet_types_core::felt::Felt as Felt252;
38pub use test_config::{TestConfig, try_extract_test_config};
39
40mod inline_macros;
41pub mod plugin;
42pub mod test_config;
43
44const TEST_ATTR: &str = "test";
45const SHOULD_PANIC_ATTR: &str = "should_panic";
46const IGNORE_ATTR: &str = "ignore";
47const AVAILABLE_GAS_ATTR: &str = "available_gas";
48const STATIC_GAS_ARG: &str = "static";
49
50#[derive(Clone)]
52pub struct TestsCompilationConfig<'db> {
53 pub starknet: bool,
55
56 pub contract_declarations: Option<Vec<ContractDeclaration<'db>>>,
60
61 pub contract_crate_ids: Option<&'db [CrateId<'db>]>,
64
65 pub executable_crate_ids: Option<Vec<CrateId<'db>>>,
68
69 pub add_statements_functions: bool,
72
73 pub add_statements_code_locations: bool,
76
77 pub add_functions_debug_info: bool,
80
81 pub replace_ids: bool,
83}
84
85pub fn compile_test_prepared_db<'db>(
97 db: &'db dyn CloneableDatabase,
98 tests_compilation_config: TestsCompilationConfig<'db>,
99 test_crate_ids: Vec<CrateInput>,
100 mut diagnostics_reporter: DiagnosticsReporter<'_>,
101) -> Result<TestCompilation<'db>> {
102 ensure!(
103 tests_compilation_config.starknet
104 || tests_compilation_config.contract_declarations.is_none(),
105 "Contract declarations can be provided only when starknet is enabled."
106 );
107 ensure!(
108 tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
109 "Contract crate ids can be provided only when starknet is enabled."
110 );
111
112 ensure_diagnostics(db, &mut diagnostics_reporter)?;
113
114 let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
115 find_contracts(
116 db,
117 tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
118 )
119 });
120 let all_entry_points = if tests_compilation_config.starknet {
121 contracts
122 .iter()
123 .flat_map(|contract| {
124 chain!(
125 get_contract_abi_functions(db, contract, EXTERNAL_MODULE).unwrap_or_default(),
126 get_contract_abi_functions(db, contract, CONSTRUCTOR_MODULE)
127 .unwrap_or_default(),
128 get_contract_abi_functions(db, contract, L1_HANDLER_MODULE).unwrap_or_default(),
129 )
130 })
131 .map(|func| ConcreteFunctionWithBodyId::from_semantic(db, func.value))
132 .collect()
133 } else {
134 vec![]
135 };
136
137 let test_crate_ids = CrateInput::into_crate_ids(db, test_crate_ids);
138 let executable_functions = find_executable_function_ids(
139 db,
140 tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
141 );
142 let all_tests = find_all_tests(db, test_crate_ids);
143
144 let func_ids = chain!(
145 executable_functions.keys().cloned(),
146 all_entry_points.iter().cloned(),
147 all_tests.iter().flat_map(|(func_id, _cfg)| {
149 ConcreteFunctionWithBodyId::from_no_generics_free(db, *func_id)
150 })
151 )
152 .collect();
153
154 let SierraProgramWithDebug { program: sierra_program, debug_info } =
155 get_sierra_program_for_functions(db, func_ids)?;
156
157 let function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>> = all_entry_points
158 .iter()
159 .map(|func_id| {
160 (
161 db.function_with_body_sierra(*func_id).unwrap().id.clone(),
162 CostTokenMap::from_iter([(CostTokenType::Const, ENTRY_POINT_COST)]),
163 )
164 })
165 .collect();
166
167 let replacer = DebugReplacer { db };
168
169 let sierra_program = if tests_compilation_config.replace_ids {
170 replacer.apply(sierra_program)
171 } else {
172 let mut sierra_program = sierra_program.clone();
173 replacer.enrich_function_names(&mut sierra_program);
174 sierra_program
175 };
176
177 let mut annotations = Annotations::default();
178 if tests_compilation_config.add_statements_functions {
179 annotations.extend(Annotations::from(
180 debug_info.statements_locations.extract_statements_functions(db),
181 ))
182 }
183 if tests_compilation_config.add_statements_code_locations {
184 annotations.extend(Annotations::from(
185 debug_info.statements_locations.extract_statements_source_code_locations(db),
186 ))
187 }
188
189 if tests_compilation_config.add_functions_debug_info {
190 annotations.extend(Annotations::from(
191 debug_info.functions_info.extract_serializable_debug_info(db),
192 ))
193 }
194
195 let executables = collect_executables(db, executable_functions, &sierra_program);
196 let named_tests = all_tests
197 .into_iter()
198 .map(|(func_id, test)| {
199 (
200 format!(
201 "{:?}",
202 FunctionLongId {
203 function: ConcreteFunction {
204 generic_function: GenericFunctionId::Free(func_id),
205 generic_args: vec![]
206 }
207 }
208 .debug(db)
209 ),
210 test,
211 )
212 })
213 .collect_vec();
214 let contracts_info = get_contracts_info(db, contracts, &replacer)?;
215 let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
216 executables,
217 annotations,
218 ..DebugInfo::default()
219 });
220
221 Ok(TestCompilation {
222 sierra_program,
223 metadata: TestCompilationMetadata {
224 named_tests,
225 function_set_costs,
226 contracts_info,
227 statements_locations: Some(debug_info.statements_locations.clone()),
228 },
229 })
230}
231
232#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
238pub struct TestCompilation<'db> {
239 pub sierra_program: ProgramArtifact,
240 #[serde(flatten)]
241 pub metadata: TestCompilationMetadata<'db>,
242}
243
244#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
249pub struct TestCompilationMetadata<'db> {
250 #[serde(
251 serialize_with = "serialize_ordered_hashmap_vec",
252 deserialize_with = "deserialize_ordered_hashmap_vec"
253 )]
254 pub contracts_info: OrderedHashMap<Felt252, ContractInfo>,
255 #[serde(
256 serialize_with = "serialize_ordered_hashmap_vec",
257 deserialize_with = "deserialize_ordered_hashmap_vec"
258 )]
259 pub function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>>,
260 pub named_tests: Vec<(String, TestConfig)>,
261 #[serde(skip)]
265 pub statements_locations: Option<StatementsLocations<'db>>,
266}
267
268fn find_all_tests<'db>(
270 db: &'db dyn Database,
271 main_crates: Vec<CrateId<'db>>,
272) -> Vec<(FreeFunctionId<'db>, TestConfig)> {
273 let mut tests = vec![];
274 for crate_id in main_crates {
275 let modules = db.crate_modules(crate_id);
276 for module_id in modules.iter() {
277 let Ok(module_data) = module_id.module_data(db) else {
278 continue;
279 };
280 tests.extend(module_data.items(db).iter().filter_map(|item| {
281 let ModuleItemId::FreeFunction(func_id) = item else { return None };
282 let Ok(attrs) =
283 db.function_with_body_attributes(FunctionWithBodyId::Free(*func_id))
284 else {
285 return None;
286 };
287 Some((*func_id, try_extract_test_config(db, attrs).ok()??))
288 }));
289 }
290 }
291 tests
292}
293
294pub fn test_assert_suite() -> PluginSuite {
296 let mut suite = PluginSuite::default();
297 suite
298 .add_inline_macro_plugin::<inline_macros::assert::AssertEqMacro>()
299 .add_inline_macro_plugin::<inline_macros::assert::AssertNeMacro>()
300 .add_inline_macro_plugin::<inline_macros::assert::AssertLtMacro>()
301 .add_inline_macro_plugin::<inline_macros::assert::AssertLeMacro>()
302 .add_inline_macro_plugin::<inline_macros::assert::AssertGtMacro>()
303 .add_inline_macro_plugin::<inline_macros::assert::AssertGeMacro>();
304 suite
305}
306
307pub fn test_plugin_suite() -> PluginSuite {
309 let mut suite = PluginSuite::default();
310 suite.add_plugin::<TestPlugin>().add(test_assert_suite());
311 suite
312}