cairo_lang_test_plugin/
lib.rs

1use std::default::Default;
2use std::sync::Arc;
3
4use anyhow::{Result, ensure};
5use cairo_lang_compiler::db::RootDatabase;
6use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
7use cairo_lang_compiler::{DbWarmupContext, get_sierra_program_for_functions};
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
10use cairo_lang_filesystem::db::FilesGroup;
11use cairo_lang_filesystem::ids::CrateId;
12use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
13use cairo_lang_semantic::db::SemanticGroup;
14use cairo_lang_semantic::items::functions::GenericFunctionId;
15use cairo_lang_semantic::plugin::PluginSuite;
16use cairo_lang_semantic::{ConcreteFunction, FunctionLongId};
17use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
18use cairo_lang_sierra::extensions::gas::CostTokenType;
19use cairo_lang_sierra::ids::FunctionId;
20use cairo_lang_sierra::program::ProgramArtifact;
21use cairo_lang_sierra_generator::db::SierraGenGroup;
22use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
23use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
24use cairo_lang_sierra_generator::replace_ids::DebugReplacer;
25use cairo_lang_sierra_generator::statements_locations::StatementsLocations;
26use cairo_lang_starknet::contract::{
27    ContractDeclaration, ContractInfo, find_contracts, get_contract_abi_functions,
28    get_contracts_info,
29};
30use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
31use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
32use cairo_lang_utils::ordered_hash_map::{
33    OrderedHashMap, deserialize_ordered_hashmap_vec, serialize_ordered_hashmap_vec,
34};
35use itertools::{Itertools, chain};
36pub use plugin::TestPlugin;
37use serde::{Deserialize, Serialize};
38use starknet_types_core::felt::Felt as Felt252;
39pub use test_config::{TestConfig, try_extract_test_config};
40
41mod inline_macros;
42pub mod plugin;
43pub mod test_config;
44
45const TEST_ATTR: &str = "test";
46const SHOULD_PANIC_ATTR: &str = "should_panic";
47const IGNORE_ATTR: &str = "ignore";
48const AVAILABLE_GAS_ATTR: &str = "available_gas";
49const STATIC_GAS_ARG: &str = "static";
50
51/// Configuration for test compilation.
52#[derive(Clone)]
53pub struct TestsCompilationConfig {
54    /// Adds the starknet contracts to the compiled tests.
55    pub starknet: bool,
56
57    /// Contracts to compile.
58    /// If defined, only this contacts will be available in tests.
59    /// If not, all contracts from `contract_crate_ids` will be compiled.
60    pub contract_declarations: Option<Vec<ContractDeclaration>>,
61
62    /// Crates to be searched for contracts.
63    /// If not defined, all crates will be searched.
64    pub contract_crate_ids: Option<Vec<CrateId>>,
65
66    /// Crates to be searched for executable attributes.
67    /// If not defined, test crates will be searched.
68    pub executable_crate_ids: Option<Vec<CrateId>>,
69
70    /// Adds mapping used by [cairo-profiler](https://github.com/software-mansion/cairo-profiler) to
71    /// [Annotations] in [DebugInfo] in the compiled tests.
72    pub add_statements_functions: bool,
73
74    /// Adds mapping used by [cairo-coverage](https://github.com/software-mansion/cairo-coverage) to
75    /// [Annotations] in [DebugInfo] in the compiled tests.
76    pub add_statements_code_locations: bool,
77}
78
79/// Runs Cairo compiler.
80///
81/// # Arguments
82/// * `db` - Preloaded compilation database.
83/// * `tests_compilation_config` - The compiler configuration for tests compilation.
84/// * `main_crate_ids` - [`CrateId`]s to compile. Use `CrateLongId::Real(name).intern(db)` in order
85///   to obtain [`CrateId`] from its name.
86/// * `test_crate_ids` - [`CrateId`]s to find tests cases in. Must be a subset of `main_crate_ids`.
87/// # Returns
88/// * `Ok(TestCompilation)` - The compiled test cases with metadata.
89/// * `Err(anyhow::Error)` - Compilation failed.
90pub fn compile_test_prepared_db(
91    db: &RootDatabase,
92    tests_compilation_config: TestsCompilationConfig,
93    test_crate_ids: Vec<CrateId>,
94    mut diagnostics_reporter: DiagnosticsReporter<'_>,
95) -> Result<TestCompilation> {
96    ensure!(
97        tests_compilation_config.starknet
98            || tests_compilation_config.contract_declarations.is_none(),
99        "Contract declarations can be provided only when starknet is enabled."
100    );
101    ensure!(
102        tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
103        "Contract crate ids can be provided only when starknet is enabled."
104    );
105
106    let context = DbWarmupContext::new();
107    context.ensure_diagnostics(db, &mut diagnostics_reporter)?;
108
109    let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
110        find_contracts(
111            db,
112            &tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
113        )
114    });
115    let all_entry_points = if tests_compilation_config.starknet {
116        contracts
117            .iter()
118            .flat_map(|contract| {
119                chain!(
120                    get_contract_abi_functions(db, contract, EXTERNAL_MODULE).unwrap_or_default(),
121                    get_contract_abi_functions(db, contract, CONSTRUCTOR_MODULE)
122                        .unwrap_or_default(),
123                    get_contract_abi_functions(db, contract, L1_HANDLER_MODULE).unwrap_or_default(),
124                )
125            })
126            .map(|func| ConcreteFunctionWithBodyId::from_semantic(db, func.value))
127            .collect()
128    } else {
129        vec![]
130    };
131
132    let executable_functions = find_executable_function_ids(
133        db,
134        tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
135    );
136    let all_tests = find_all_tests(db, test_crate_ids);
137
138    let func_ids = chain!(
139        executable_functions.clone().into_keys(),
140        all_entry_points.iter().cloned(),
141        // TODO(maciektr): Remove test entrypoints after migration to executable attr.
142        all_tests.iter().flat_map(|(func_id, _cfg)| {
143            ConcreteFunctionWithBodyId::from_no_generics_free(db, *func_id)
144        })
145    )
146    .collect();
147
148    let SierraProgramWithDebug { program: mut sierra_program, debug_info } =
149        Arc::unwrap_or_clone(get_sierra_program_for_functions(db, func_ids, context)?);
150
151    let function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>> =
152        all_entry_points
153            .iter()
154            .map(|func_id| {
155                (
156                    db.function_with_body_sierra(*func_id).unwrap().id.clone(),
157                    [(CostTokenType::Const, ENTRY_POINT_COST)].into(),
158                )
159            })
160            .collect();
161
162    let replacer = DebugReplacer { db };
163    replacer.enrich_function_names(&mut sierra_program);
164
165    let mut annotations = Annotations::default();
166    if tests_compilation_config.add_statements_functions {
167        annotations.extend(Annotations::from(
168            debug_info.statements_locations.extract_statements_functions(db),
169        ))
170    }
171    if tests_compilation_config.add_statements_code_locations {
172        annotations.extend(Annotations::from(
173            debug_info.statements_locations.extract_statements_source_code_locations(db),
174        ))
175    }
176
177    let executables = collect_executables(db, executable_functions, &sierra_program);
178    let named_tests = all_tests
179        .into_iter()
180        .map(|(func_id, test)| {
181            (
182                format!(
183                    "{:?}",
184                    FunctionLongId {
185                        function: ConcreteFunction {
186                            generic_function: GenericFunctionId::Free(func_id),
187                            generic_args: vec![]
188                        }
189                    }
190                    .debug(db)
191                ),
192                test,
193            )
194        })
195        .collect_vec();
196    let contracts_info = get_contracts_info(db, contracts, &replacer)?;
197    let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
198        executables,
199        annotations,
200        ..DebugInfo::default()
201    });
202
203    Ok(TestCompilation {
204        sierra_program,
205        metadata: TestCompilationMetadata {
206            named_tests,
207            function_set_costs,
208            contracts_info,
209            statements_locations: Some(debug_info.statements_locations),
210        },
211    })
212}
213
214/// Encapsulation of all data required to execute tests.
215///
216/// This includes the source code compiled to a Sierra program and all cairo-test specific
217/// data extracted from it.
218/// This can be stored on the filesystem and shared externally.
219#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
220pub struct TestCompilation {
221    pub sierra_program: ProgramArtifact,
222    #[serde(flatten)]
223    pub metadata: TestCompilationMetadata,
224}
225
226/// Encapsulation of all data required to execute tests, except for the Sierra program itself.
227///
228/// This includes all cairo-test specific data extracted from the program.
229/// This can be stored on the filesystem and shared externally.
230#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
231pub struct TestCompilationMetadata {
232    #[serde(
233        serialize_with = "serialize_ordered_hashmap_vec",
234        deserialize_with = "deserialize_ordered_hashmap_vec"
235    )]
236    pub contracts_info: OrderedHashMap<Felt252, ContractInfo>,
237    #[serde(
238        serialize_with = "serialize_ordered_hashmap_vec",
239        deserialize_with = "deserialize_ordered_hashmap_vec"
240    )]
241    pub function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
242    pub named_tests: Vec<(String, TestConfig)>,
243    /// Optional `StatementsLocations` for the compiled tests.
244    /// See [StatementsLocations] for more information.
245    // TODO(Gil): consider serializing this field once it is stable.
246    #[serde(skip)]
247    pub statements_locations: Option<StatementsLocations>,
248}
249
250/// Finds the tests in the requested crates.
251fn find_all_tests(
252    db: &dyn SemanticGroup,
253    main_crates: Vec<CrateId>,
254) -> Vec<(FreeFunctionId, TestConfig)> {
255    let mut tests = vec![];
256    for crate_id in main_crates {
257        let modules = db.crate_modules(crate_id);
258        for module_id in modules.iter() {
259            let Ok(module_items) = db.module_items(*module_id) else {
260                continue;
261            };
262            tests.extend(module_items.iter().filter_map(|item| {
263                let ModuleItemId::FreeFunction(func_id) = item else { return None };
264                let Ok(attrs) =
265                    db.function_with_body_attributes(FunctionWithBodyId::Free(*func_id))
266                else {
267                    return None;
268                };
269                Some((*func_id, try_extract_test_config(db, attrs).ok()??))
270            }));
271        }
272    }
273    tests
274}
275
276/// The suite of plugins that implements assert macros for tests.
277pub fn test_assert_suite() -> PluginSuite {
278    let mut suite = PluginSuite::default();
279    suite
280        .add_inline_macro_plugin::<inline_macros::assert::AssertEqMacro>()
281        .add_inline_macro_plugin::<inline_macros::assert::AssertNeMacro>()
282        .add_inline_macro_plugin::<inline_macros::assert::AssertLtMacro>()
283        .add_inline_macro_plugin::<inline_macros::assert::AssertLeMacro>()
284        .add_inline_macro_plugin::<inline_macros::assert::AssertGtMacro>()
285        .add_inline_macro_plugin::<inline_macros::assert::AssertGeMacro>();
286    suite
287}
288
289/// The suite of plugins for compilation for testing.
290pub fn test_plugin_suite() -> PluginSuite {
291    let mut suite = PluginSuite::default();
292    suite.add_plugin::<TestPlugin>().add(test_assert_suite());
293    suite
294}