1use std::default::Default;
2
3use anyhow::{Result, ensure};
4use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
5use cairo_lang_compiler::{ensure_diagnostics, get_sierra_program_for_functions};
6use cairo_lang_debug::DebugWithDb;
7use cairo_lang_defs::db::DefsGroup;
8use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
9use cairo_lang_filesystem::db::FilesGroup;
10use cairo_lang_filesystem::ids::{CrateId, CrateInput};
11use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
12use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
13use cairo_lang_semantic::items::functions::GenericFunctionId;
14use cairo_lang_semantic::plugin::PluginSuite;
15use cairo_lang_semantic::{ConcreteFunction, FunctionLongId};
16use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
17use cairo_lang_sierra::extensions::gas::{CostTokenMap, CostTokenType};
18use cairo_lang_sierra::ids::FunctionId;
19use cairo_lang_sierra::program::ProgramArtifact;
20use cairo_lang_sierra_generator::db::SierraGenGroup;
21use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
22use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
23use cairo_lang_sierra_generator::replace_ids::DebugReplacer;
24use cairo_lang_sierra_generator::statements_locations::StatementsLocations;
25use cairo_lang_starknet::contract::{
26 ContractDeclaration, ContractInfo, find_contracts, get_contract_abi_functions,
27 get_contracts_info,
28};
29use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
30use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
31use cairo_lang_utils::ordered_hash_map::{
32 OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
33};
34use itertools::{Itertools, chain};
35pub use plugin::TestPlugin;
36use salsa::Database;
37use serde::{Deserialize, Serialize};
38use starknet_types_core::felt::Felt as Felt252;
39pub use test_config::{TestConfig, try_extract_test_config};
40
41mod inline_macros;
42pub mod plugin;
43pub mod test_config;
44
45const TEST_ATTR: &str = "test";
46const SHOULD_PANIC_ATTR: &str = "should_panic";
47const IGNORE_ATTR: &str = "ignore";
48const AVAILABLE_GAS_ATTR: &str = "available_gas";
49const STATIC_GAS_ARG: &str = "static";
50
51#[derive(Clone)]
53pub struct TestsCompilationConfig<'db> {
54 pub starknet: bool,
56
57 pub contract_declarations: Option<Vec<ContractDeclaration<'db>>>,
61
62 pub contract_crate_ids: Option<&'db [CrateId<'db>]>,
65
66 pub executable_crate_ids: Option<Vec<CrateId<'db>>>,
69
70 pub add_statements_functions: bool,
73
74 pub add_statements_code_locations: bool,
77}
78
79pub fn compile_test_prepared_db<'db>(
91 db: &'db dyn Database,
92 tests_compilation_config: TestsCompilationConfig<'db>,
93 test_crate_ids: Vec<CrateInput>,
94 mut diagnostics_reporter: DiagnosticsReporter<'_>,
95) -> Result<TestCompilation<'db>> {
96 ensure!(
97 tests_compilation_config.starknet
98 || tests_compilation_config.contract_declarations.is_none(),
99 "Contract declarations can be provided only when starknet is enabled."
100 );
101 ensure!(
102 tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
103 "Contract crate ids can be provided only when starknet is enabled."
104 );
105
106 ensure_diagnostics(db, &mut diagnostics_reporter)?;
107
108 let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
109 find_contracts(
110 db,
111 tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
112 )
113 });
114 let all_entry_points = if tests_compilation_config.starknet {
115 contracts
116 .iter()
117 .flat_map(|contract| {
118 chain!(
119 get_contract_abi_functions(db, contract, EXTERNAL_MODULE).unwrap_or_default(),
120 get_contract_abi_functions(db, contract, CONSTRUCTOR_MODULE)
121 .unwrap_or_default(),
122 get_contract_abi_functions(db, contract, L1_HANDLER_MODULE).unwrap_or_default(),
123 )
124 })
125 .map(|func| ConcreteFunctionWithBodyId::from_semantic(db, func.value))
126 .collect()
127 } else {
128 vec![]
129 };
130
131 let test_crate_ids = CrateInput::into_crate_ids(db, test_crate_ids);
132 let executable_functions = find_executable_function_ids(
133 db,
134 tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
135 );
136 let all_tests = find_all_tests(db, test_crate_ids);
137
138 let func_ids = chain!(
139 executable_functions.clone().into_keys(),
140 all_entry_points.iter().cloned(),
141 all_tests.iter().flat_map(|(func_id, _cfg)| {
143 ConcreteFunctionWithBodyId::from_no_generics_free(db, *func_id)
144 })
145 )
146 .collect();
147
148 let SierraProgramWithDebug { program: sierra_program, debug_info } =
149 get_sierra_program_for_functions(db, func_ids)?;
150
151 let function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>> = all_entry_points
152 .iter()
153 .map(|func_id| {
154 (
155 db.function_with_body_sierra(*func_id).unwrap().id.clone(),
156 CostTokenMap::from_iter([(CostTokenType::Const, ENTRY_POINT_COST)]),
157 )
158 })
159 .collect();
160
161 let replacer = DebugReplacer { db };
162 let mut sierra_program = sierra_program.clone();
163 replacer.enrich_function_names(&mut sierra_program);
164
165 let mut annotations = Annotations::default();
166 if tests_compilation_config.add_statements_functions {
167 annotations.extend(Annotations::from(
168 debug_info.statements_locations.extract_statements_functions(db),
169 ))
170 }
171 if tests_compilation_config.add_statements_code_locations {
172 annotations.extend(Annotations::from(
173 debug_info.statements_locations.extract_statements_source_code_locations(db),
174 ))
175 }
176
177 let executables = collect_executables(db, executable_functions, &sierra_program);
178 let named_tests = all_tests
179 .into_iter()
180 .map(|(func_id, test)| {
181 (
182 format!(
183 "{:?}",
184 FunctionLongId {
185 function: ConcreteFunction {
186 generic_function: GenericFunctionId::Free(func_id),
187 generic_args: vec![]
188 }
189 }
190 .debug(db)
191 ),
192 test,
193 )
194 })
195 .collect_vec();
196 let contracts_info = get_contracts_info(db, contracts, &replacer)?;
197 let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
198 executables,
199 annotations,
200 ..DebugInfo::default()
201 });
202
203 Ok(TestCompilation {
204 sierra_program,
205 metadata: TestCompilationMetadata {
206 named_tests,
207 function_set_costs,
208 contracts_info,
209 statements_locations: Some(debug_info.statements_locations.clone()),
210 },
211 })
212}
213
214#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
220pub struct TestCompilation<'db> {
221 pub sierra_program: ProgramArtifact,
222 #[serde(flatten)]
223 pub metadata: TestCompilationMetadata<'db>,
224}
225
226#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
231pub struct TestCompilationMetadata<'db> {
232 #[serde(
233 serialize_with = "serialize_ordered_hashmap_vec",
234 deserialize_with = "deserialize_ordered_hashmap_vec"
235 )]
236 pub contracts_info: OrderedHashMap<Felt252, ContractInfo>,
237 #[serde(
238 serialize_with = "serialize_ordered_hashmap_vec",
239 deserialize_with = "deserialize_ordered_hashmap_vec"
240 )]
241 pub function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>>,
242 pub named_tests: Vec<(String, TestConfig)>,
243 #[serde(skip)]
247 pub statements_locations: Option<StatementsLocations<'db>>,
248}
249
250fn find_all_tests<'db>(
252 db: &'db dyn Database,
253 main_crates: Vec<CrateId<'db>>,
254) -> Vec<(FreeFunctionId<'db>, TestConfig)> {
255 let mut tests = vec![];
256 for crate_id in main_crates {
257 let modules = db.crate_modules(crate_id);
258 for module_id in modules.iter() {
259 let Ok(module_data) = module_id.module_data(db) else {
260 continue;
261 };
262 tests.extend(module_data.items(db).iter().filter_map(|item| {
263 let ModuleItemId::FreeFunction(func_id) = item else { return None };
264 let Ok(attrs) =
265 db.function_with_body_attributes(FunctionWithBodyId::Free(*func_id))
266 else {
267 return None;
268 };
269 Some((*func_id, try_extract_test_config(db, attrs).ok()??))
270 }));
271 }
272 }
273 tests
274}
275
276pub fn test_assert_suite() -> PluginSuite {
278 let mut suite = PluginSuite::default();
279 suite
280 .add_inline_macro_plugin::<inline_macros::assert::AssertEqMacro>()
281 .add_inline_macro_plugin::<inline_macros::assert::AssertNeMacro>()
282 .add_inline_macro_plugin::<inline_macros::assert::AssertLtMacro>()
283 .add_inline_macro_plugin::<inline_macros::assert::AssertLeMacro>()
284 .add_inline_macro_plugin::<inline_macros::assert::AssertGtMacro>()
285 .add_inline_macro_plugin::<inline_macros::assert::AssertGeMacro>();
286 suite
287}
288
289pub fn test_plugin_suite() -> PluginSuite {
291 let mut suite = PluginSuite::default();
292 suite.add_plugin::<TestPlugin>().add(test_assert_suite());
293 suite
294}