1use anyhow::{Result, ensure};
2use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
3use cairo_lang_compiler::{ensure_diagnostics, get_sierra_program_for_functions};
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::db::DefsGroup;
6use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
7use cairo_lang_filesystem::db::FilesGroup;
8use cairo_lang_filesystem::ids::{CrateId, CrateInput};
9use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
10use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
11use cairo_lang_semantic::items::functions::GenericFunctionId;
12use cairo_lang_semantic::plugin::PluginSuite;
13use cairo_lang_semantic::{ConcreteFunction, FunctionLongId};
14use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
15use cairo_lang_sierra::extensions::gas::{CostTokenMap, CostTokenType};
16use cairo_lang_sierra::ids::FunctionId;
17use cairo_lang_sierra::program::ProgramArtifact;
18use cairo_lang_sierra_generator::db::SierraGenGroup;
19use cairo_lang_sierra_generator::debug_info::StatementsLocations;
20use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
21use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
22use cairo_lang_sierra_generator::replace_ids::DebugReplacer;
23use cairo_lang_starknet::contract::{
24 ContractDeclaration, ContractInfo, find_contracts, get_contract_abi_functions,
25 get_contracts_info,
26};
27use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
28use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
29use cairo_lang_utils::ordered_hash_map::{
30 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
31};
32use itertools::{Itertools, chain};
33pub use plugin::TestPlugin;
34use salsa::Database;
35use serde::{Deserialize, Serialize};
36use starknet_types_core::felt::Felt as Felt252;
37pub use test_config::{TestConfig, try_extract_test_config};
38
39mod inline_macros;
40pub mod plugin;
41pub mod test_config;
42
43const TEST_ATTR: &str = "test";
44const SHOULD_PANIC_ATTR: &str = "should_panic";
45const IGNORE_ATTR: &str = "ignore";
46const AVAILABLE_GAS_ATTR: &str = "available_gas";
47const STATIC_GAS_ARG: &str = "static";
48
49#[derive(Clone)]
51pub struct TestsCompilationConfig<'db> {
52 pub starknet: bool,
54
55 pub contract_declarations: Option<Vec<ContractDeclaration<'db>>>,
59
60 pub contract_crate_ids: Option<&'db [CrateId<'db>]>,
63
64 pub executable_crate_ids: Option<Vec<CrateId<'db>>>,
67
68 pub add_statements_functions: bool,
71
72 pub add_statements_code_locations: bool,
75
76 pub add_functions_debug_info: bool,
79}
80
81pub fn compile_test_prepared_db<'db>(
93 db: &'db dyn Database,
94 tests_compilation_config: TestsCompilationConfig<'db>,
95 test_crate_ids: Vec<CrateInput>,
96 mut diagnostics_reporter: DiagnosticsReporter<'_>,
97) -> Result<TestCompilation<'db>> {
98 ensure!(
99 tests_compilation_config.starknet
100 || tests_compilation_config.contract_declarations.is_none(),
101 "Contract declarations can be provided only when starknet is enabled."
102 );
103 ensure!(
104 tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
105 "Contract crate ids can be provided only when starknet is enabled."
106 );
107
108 ensure_diagnostics(db, &mut diagnostics_reporter)?;
109
110 let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
111 find_contracts(
112 db,
113 tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
114 )
115 });
116 let all_entry_points = if tests_compilation_config.starknet {
117 contracts
118 .iter()
119 .flat_map(|contract| {
120 chain!(
121 get_contract_abi_functions(db, contract, EXTERNAL_MODULE).unwrap_or_default(),
122 get_contract_abi_functions(db, contract, CONSTRUCTOR_MODULE)
123 .unwrap_or_default(),
124 get_contract_abi_functions(db, contract, L1_HANDLER_MODULE).unwrap_or_default(),
125 )
126 })
127 .map(|func| ConcreteFunctionWithBodyId::from_semantic(db, func.value))
128 .collect()
129 } else {
130 vec![]
131 };
132
133 let test_crate_ids = CrateInput::into_crate_ids(db, test_crate_ids);
134 let executable_functions = find_executable_function_ids(
135 db,
136 tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
137 );
138 let all_tests = find_all_tests(db, test_crate_ids);
139
140 let func_ids = chain!(
141 executable_functions.keys().cloned(),
142 all_entry_points.iter().cloned(),
143 all_tests.iter().flat_map(|(func_id, _cfg)| {
145 ConcreteFunctionWithBodyId::from_no_generics_free(db, *func_id)
146 })
147 )
148 .collect();
149
150 let SierraProgramWithDebug { program: sierra_program, debug_info } =
151 get_sierra_program_for_functions(db, func_ids)?;
152
153 let function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>> = all_entry_points
154 .iter()
155 .map(|func_id| {
156 (
157 db.function_with_body_sierra(*func_id).unwrap().id.clone(),
158 CostTokenMap::from_iter([(CostTokenType::Const, ENTRY_POINT_COST)]),
159 )
160 })
161 .collect();
162
163 let replacer = DebugReplacer { db };
164 let mut sierra_program = sierra_program.clone();
165 replacer.enrich_function_names(&mut sierra_program);
166
167 let mut annotations = Annotations::default();
168 if tests_compilation_config.add_statements_functions {
169 annotations.extend(Annotations::from(
170 debug_info.statements_locations.extract_statements_functions(db),
171 ))
172 }
173 if tests_compilation_config.add_statements_code_locations {
174 annotations.extend(Annotations::from(
175 debug_info.statements_locations.extract_statements_source_code_locations(db),
176 ))
177 }
178
179 if tests_compilation_config.add_functions_debug_info {
180 annotations.extend(Annotations::from(
181 debug_info.functions_info.extract_serializable_debug_info(db),
182 ))
183 }
184
185 let executables = collect_executables(db, executable_functions, &sierra_program);
186 let named_tests = all_tests
187 .into_iter()
188 .map(|(func_id, test)| {
189 (
190 format!(
191 "{:?}",
192 FunctionLongId {
193 function: ConcreteFunction {
194 generic_function: GenericFunctionId::Free(func_id),
195 generic_args: vec![]
196 }
197 }
198 .debug(db)
199 ),
200 test,
201 )
202 })
203 .collect_vec();
204 let contracts_info = get_contracts_info(db, contracts, &replacer)?;
205 let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
206 executables,
207 annotations,
208 ..DebugInfo::default()
209 });
210
211 Ok(TestCompilation {
212 sierra_program,
213 metadata: TestCompilationMetadata {
214 named_tests,
215 function_set_costs,
216 contracts_info,
217 statements_locations: Some(debug_info.statements_locations.clone()),
218 },
219 })
220}
221
222#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
228pub struct TestCompilation<'db> {
229 pub sierra_program: ProgramArtifact,
230 #[serde(flatten)]
231 pub metadata: TestCompilationMetadata<'db>,
232}
233
234#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
239pub struct TestCompilationMetadata<'db> {
240 #[serde(
241 serialize_with = "serialize_ordered_hashmap_vec",
242 deserialize_with = "deserialize_ordered_hashmap_vec"
243 )]
244 pub contracts_info: OrderedHashMap<Felt252, ContractInfo>,
245 #[serde(
246 serialize_with = "serialize_ordered_hashmap_vec",
247 deserialize_with = "deserialize_ordered_hashmap_vec"
248 )]
249 pub function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>>,
250 pub named_tests: Vec<(String, TestConfig)>,
251 #[serde(skip)]
255 pub statements_locations: Option<StatementsLocations<'db>>,
256}
257
258fn find_all_tests<'db>(
260 db: &'db dyn Database,
261 main_crates: Vec<CrateId<'db>>,
262) -> Vec<(FreeFunctionId<'db>, TestConfig)> {
263 let mut tests = vec![];
264 for crate_id in main_crates {
265 let modules = db.crate_modules(crate_id);
266 for module_id in modules.iter() {
267 let Ok(module_data) = module_id.module_data(db) else {
268 continue;
269 };
270 tests.extend(module_data.items(db).iter().filter_map(|item| {
271 let ModuleItemId::FreeFunction(func_id) = item else { return None };
272 let Ok(attrs) =
273 db.function_with_body_attributes(FunctionWithBodyId::Free(*func_id))
274 else {
275 return None;
276 };
277 Some((*func_id, try_extract_test_config(db, attrs).ok()??))
278 }));
279 }
280 }
281 tests
282}
283
284pub fn test_assert_suite() -> PluginSuite {
286 let mut suite = PluginSuite::default();
287 suite
288 .add_inline_macro_plugin::<inline_macros::assert::AssertEqMacro>()
289 .add_inline_macro_plugin::<inline_macros::assert::AssertNeMacro>()
290 .add_inline_macro_plugin::<inline_macros::assert::AssertLtMacro>()
291 .add_inline_macro_plugin::<inline_macros::assert::AssertLeMacro>()
292 .add_inline_macro_plugin::<inline_macros::assert::AssertGtMacro>()
293 .add_inline_macro_plugin::<inline_macros::assert::AssertGeMacro>();
294 suite
295}
296
297pub fn test_plugin_suite() -> PluginSuite {
299 let mut suite = PluginSuite::default();
300 suite.add_plugin::<TestPlugin>().add(test_assert_suite());
301 suite
302}