cairo_lang_test_plugin/
lib.rs

1use anyhow::{Result, ensure};
2use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
3use cairo_lang_compiler::{ensure_diagnostics, get_sierra_program_for_functions};
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::db::DefsGroup;
6use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
7use cairo_lang_filesystem::db::FilesGroup;
8use cairo_lang_filesystem::ids::{CrateId, CrateInput};
9use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
10use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
11use cairo_lang_semantic::items::functions::GenericFunctionId;
12use cairo_lang_semantic::plugin::PluginSuite;
13use cairo_lang_semantic::{ConcreteFunction, FunctionLongId};
14use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
15use cairo_lang_sierra::extensions::gas::{CostTokenMap, CostTokenType};
16use cairo_lang_sierra::ids::FunctionId;
17use cairo_lang_sierra::program::ProgramArtifact;
18use cairo_lang_sierra_generator::db::SierraGenGroup;
19use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
20use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
21use cairo_lang_sierra_generator::replace_ids::DebugReplacer;
22use cairo_lang_sierra_generator::statements_locations::StatementsLocations;
23use cairo_lang_starknet::contract::{
24    ContractDeclaration, ContractInfo, find_contracts, get_contract_abi_functions,
25    get_contracts_info,
26};
27use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
28use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
29use cairo_lang_utils::ordered_hash_map::{
30    OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
31};
32use itertools::{Itertools, chain};
33pub use plugin::TestPlugin;
34use salsa::Database;
35use serde::{Deserialize, Serialize};
36use starknet_types_core::felt::Felt as Felt252;
37pub use test_config::{TestConfig, try_extract_test_config};
38
39mod inline_macros;
40pub mod plugin;
41pub mod test_config;
42
43const TEST_ATTR: &str = "test";
44const SHOULD_PANIC_ATTR: &str = "should_panic";
45const IGNORE_ATTR: &str = "ignore";
46const AVAILABLE_GAS_ATTR: &str = "available_gas";
47const STATIC_GAS_ARG: &str = "static";
48
49/// Configuration for test compilation.
50#[derive(Clone)]
51pub struct TestsCompilationConfig<'db> {
52    /// Adds the Starknet contracts to the compiled tests.
53    pub starknet: bool,
54
55    /// Contracts to compile.
56    /// If defined, only these contracts will be available in tests.
57    /// If not, all contracts from `contract_crate_ids` will be compiled.
58    pub contract_declarations: Option<Vec<ContractDeclaration<'db>>>,
59
60    /// Crates to be searched for contracts.
61    /// If not defined, all crates will be searched.
62    pub contract_crate_ids: Option<&'db [CrateId<'db>]>,
63
64    /// Crates to be searched for executable attributes.
65    /// If not defined, test crates will be searched.
66    pub executable_crate_ids: Option<Vec<CrateId<'db>>>,
67
68    /// Adds a mapping used by [cairo-profiler](https://github.com/software-mansion/cairo-profiler) to
69    /// [Annotations] in [DebugInfo] in the compiled tests.
70    pub add_statements_functions: bool,
71
72    /// Adds a mapping used by [cairo-coverage](https://github.com/software-mansion/cairo-coverage) to
73    /// [Annotations] in [DebugInfo] in the compiled tests.
74    pub add_statements_code_locations: bool,
75}
76
77/// Runs Cairo compiler.
78///
79/// # Arguments
80/// * `db` - Preloaded compilation database.
81/// * `tests_compilation_config` - The compiler configuration for tests compilation.
82/// * `main_crate_ids` - [`CrateId`]s to compile. Use `CrateLongId::Real(name).intern(db)` in order
83///   to obtain [`CrateId`] from its name.
84/// * `test_crate_ids` - [`CrateId`]s to find tests cases in. Must be a subset of `main_crate_ids`.
85/// # Returns
86/// * `Ok(TestCompilation)` - The compiled test cases with metadata.
87/// * `Err(anyhow::Error)` - Compilation failed.
88pub fn compile_test_prepared_db<'db>(
89    db: &'db dyn Database,
90    tests_compilation_config: TestsCompilationConfig<'db>,
91    test_crate_ids: Vec<CrateInput>,
92    mut diagnostics_reporter: DiagnosticsReporter<'_>,
93) -> Result<TestCompilation<'db>> {
94    ensure!(
95        tests_compilation_config.starknet
96            || tests_compilation_config.contract_declarations.is_none(),
97        "Contract declarations can be provided only when starknet is enabled."
98    );
99    ensure!(
100        tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
101        "Contract crate ids can be provided only when starknet is enabled."
102    );
103
104    ensure_diagnostics(db, &mut diagnostics_reporter)?;
105
106    let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
107        find_contracts(
108            db,
109            tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
110        )
111    });
112    let all_entry_points = if tests_compilation_config.starknet {
113        contracts
114            .iter()
115            .flat_map(|contract| {
116                chain!(
117                    get_contract_abi_functions(db, contract, EXTERNAL_MODULE).unwrap_or_default(),
118                    get_contract_abi_functions(db, contract, CONSTRUCTOR_MODULE)
119                        .unwrap_or_default(),
120                    get_contract_abi_functions(db, contract, L1_HANDLER_MODULE).unwrap_or_default(),
121                )
122            })
123            .map(|func| ConcreteFunctionWithBodyId::from_semantic(db, func.value))
124            .collect()
125    } else {
126        vec![]
127    };
128
129    let test_crate_ids = CrateInput::into_crate_ids(db, test_crate_ids);
130    let executable_functions = find_executable_function_ids(
131        db,
132        tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
133    );
134    let all_tests = find_all_tests(db, test_crate_ids);
135
136    let func_ids = chain!(
137        executable_functions.keys().cloned(),
138        all_entry_points.iter().cloned(),
139        // TODO(maciektr): Remove test entrypoints after migration to executable attr.
140        all_tests.iter().flat_map(|(func_id, _cfg)| {
141            ConcreteFunctionWithBodyId::from_no_generics_free(db, *func_id)
142        })
143    )
144    .collect();
145
146    let SierraProgramWithDebug { program: sierra_program, debug_info } =
147        get_sierra_program_for_functions(db, func_ids)?;
148
149    let function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>> = all_entry_points
150        .iter()
151        .map(|func_id| {
152            (
153                db.function_with_body_sierra(*func_id).unwrap().id.clone(),
154                CostTokenMap::from_iter([(CostTokenType::Const, ENTRY_POINT_COST)]),
155            )
156        })
157        .collect();
158
159    let replacer = DebugReplacer { db };
160    let mut sierra_program = sierra_program.clone();
161    replacer.enrich_function_names(&mut sierra_program);
162
163    let mut annotations = Annotations::default();
164    if tests_compilation_config.add_statements_functions {
165        annotations.extend(Annotations::from(
166            debug_info.statements_locations.extract_statements_functions(db),
167        ))
168    }
169    if tests_compilation_config.add_statements_code_locations {
170        annotations.extend(Annotations::from(
171            debug_info.statements_locations.extract_statements_source_code_locations(db),
172        ))
173    }
174
175    let executables = collect_executables(db, executable_functions, &sierra_program);
176    let named_tests = all_tests
177        .into_iter()
178        .map(|(func_id, test)| {
179            (
180                format!(
181                    "{:?}",
182                    FunctionLongId {
183                        function: ConcreteFunction {
184                            generic_function: GenericFunctionId::Free(func_id),
185                            generic_args: vec![]
186                        }
187                    }
188                    .debug(db)
189                ),
190                test,
191            )
192        })
193        .collect_vec();
194    let contracts_info = get_contracts_info(db, contracts, &replacer)?;
195    let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
196        executables,
197        annotations,
198        ..DebugInfo::default()
199    });
200
201    Ok(TestCompilation {
202        sierra_program,
203        metadata: TestCompilationMetadata {
204            named_tests,
205            function_set_costs,
206            contracts_info,
207            statements_locations: Some(debug_info.statements_locations.clone()),
208        },
209    })
210}
211
212/// Encapsulation of all data required to execute tests.
213///
214/// This includes the source code compiled to a Sierra program and all cairo-test-specific
215/// data extracted from it.
216/// This can be stored on the filesystem and shared externally.
217#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
218pub struct TestCompilation<'db> {
219    pub sierra_program: ProgramArtifact,
220    #[serde(flatten)]
221    pub metadata: TestCompilationMetadata<'db>,
222}
223
224/// Encapsulation of all data required to execute tests, except for the Sierra program itself.
225///
226/// This includes all cairo-test-specific data extracted from the program.
227/// This can be stored on the filesystem and shared externally.
228#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
229pub struct TestCompilationMetadata<'db> {
230    #[serde(
231        serialize_with = "serialize_ordered_hashmap_vec",
232        deserialize_with = "deserialize_ordered_hashmap_vec"
233    )]
234    pub contracts_info: OrderedHashMap<Felt252, ContractInfo>,
235    #[serde(
236        serialize_with = "serialize_ordered_hashmap_vec",
237        deserialize_with = "deserialize_ordered_hashmap_vec"
238    )]
239    pub function_set_costs: OrderedHashMap<FunctionId, CostTokenMap<i32>>,
240    pub named_tests: Vec<(String, TestConfig)>,
241    /// Optional `StatementsLocations` for the compiled tests.
242    /// See [StatementsLocations] for more information.
243    // TODO(Gil): consider serializing this field once it is stable.
244    #[serde(skip)]
245    pub statements_locations: Option<StatementsLocations<'db>>,
246}
247
248/// Finds the tests in the requested crates.
249fn find_all_tests<'db>(
250    db: &'db dyn Database,
251    main_crates: Vec<CrateId<'db>>,
252) -> Vec<(FreeFunctionId<'db>, TestConfig)> {
253    let mut tests = vec![];
254    for crate_id in main_crates {
255        let modules = db.crate_modules(crate_id);
256        for module_id in modules.iter() {
257            let Ok(module_data) = module_id.module_data(db) else {
258                continue;
259            };
260            tests.extend(module_data.items(db).iter().filter_map(|item| {
261                let ModuleItemId::FreeFunction(func_id) = item else { return None };
262                let Ok(attrs) =
263                    db.function_with_body_attributes(FunctionWithBodyId::Free(*func_id))
264                else {
265                    return None;
266                };
267                Some((*func_id, try_extract_test_config(db, attrs).ok()??))
268            }));
269        }
270    }
271    tests
272}
273
274/// The suite of plugins that implements assert macros for tests.
275pub fn test_assert_suite() -> PluginSuite {
276    let mut suite = PluginSuite::default();
277    suite
278        .add_inline_macro_plugin::<inline_macros::assert::AssertEqMacro>()
279        .add_inline_macro_plugin::<inline_macros::assert::AssertNeMacro>()
280        .add_inline_macro_plugin::<inline_macros::assert::AssertLtMacro>()
281        .add_inline_macro_plugin::<inline_macros::assert::AssertLeMacro>()
282        .add_inline_macro_plugin::<inline_macros::assert::AssertGtMacro>()
283        .add_inline_macro_plugin::<inline_macros::assert::AssertGeMacro>();
284    suite
285}
286
287/// The suite of plugins for compilation for testing.
288pub fn test_plugin_suite() -> PluginSuite {
289    let mut suite = PluginSuite::default();
290    suite.add_plugin::<TestPlugin>().add(test_assert_suite());
291    suite
292}