Skip to main content

cairo_lang_test_runner/
lib.rs

1use std::path::Path;
2use std::sync::Mutex;
3
4use anyhow::{Context, Result, bail};
5use cairo_lang_compiler::db::RootDatabase;
6use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
7use cairo_lang_compiler::project::setup_project;
8use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
9use cairo_lang_filesystem::ids::CrateId;
10use cairo_lang_runner::casm_run::format_for_panic;
11use cairo_lang_runner::profiling::{
12    ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
13};
14use cairo_lang_runner::{
15    ProfilingInfoCollectionConfig, RunResultValue, SierraCasmRunner, StarknetExecutionResources,
16};
17use cairo_lang_sierra::extensions::gas::CostTokenType;
18use cairo_lang_sierra::ids::FunctionId;
19use cairo_lang_sierra::program::{Program, StatementIdx};
20use cairo_lang_sierra_generator::db::SierraGenGroup;
21use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
22use cairo_lang_starknet::contract::ContractInfo;
23use cairo_lang_starknet::starknet_plugin_suite;
24use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
25use cairo_lang_test_plugin::{
26    TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
27    compile_test_prepared_db, test_plugin_suite,
28};
29use cairo_lang_utils::casts::IntoOrPanic;
30use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
31use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
32use colored::Colorize;
33use itertools::Itertools;
34use num_traits::ToPrimitive;
35use rayon::prelude::{IntoParallelIterator, ParallelIterator};
36use starknet_types_core::felt::Felt as Felt252;
37
38#[cfg(test)]
39mod test;
40
41/// Compile and run tests.
42pub struct TestRunner {
43    compiler: TestCompiler,
44    config: TestRunConfig,
45}
46
47impl TestRunner {
48    /// Configure a new test runner
49    ///
50    /// # Arguments
51    ///
52    /// * `path` - The path to compile and run its tests
53    /// * `filter` - Run only tests containing the filter string
54    /// * `include_ignored` - Include ignored tests as well
55    /// * `ignored` - Run ignored tests only
56    /// * `starknet` - Add the starknet plugin to run the tests
57    pub fn new(
58        path: &Path,
59        starknet: bool,
60        allow_warnings: bool,
61        config: TestRunConfig,
62    ) -> Result<Self> {
63        let compiler = TestCompiler::try_new(
64            path,
65            allow_warnings,
66            config.gas_enabled,
67            TestsCompilationConfig {
68                starknet,
69                add_statements_functions: config.profiler_config == Some(ProfilerConfig::Cairo),
70                add_statements_code_locations: false,
71                contract_declarations: None,
72                contract_crate_ids: None,
73                executable_crate_ids: None,
74            },
75        )?;
76        Ok(Self { compiler, config })
77    }
78
79    /// Runs the tests and process the results for a summary.
80    pub fn run(&self) -> Result<Option<TestsSummary>> {
81        let runner = CompiledTestRunner::new(self.compiler.build()?, self.config.clone());
82        runner.run(Some(&self.compiler.db))
83    }
84}
85
86pub struct CompiledTestRunner {
87    pub compiled: TestCompilation,
88    pub config: TestRunConfig,
89}
90
91impl CompiledTestRunner {
92    /// Configure a new compiled test runner
93    ///
94    /// # Arguments
95    ///
96    /// * `compiled` - The compiled tests to run
97    /// * `config` - Test run configuration
98    pub fn new(compiled: TestCompilation, config: TestRunConfig) -> Self {
99        Self { compiled, config }
100    }
101
102    /// Execute preconfigured test execution.
103    pub fn run(self, db: Option<&RootDatabase>) -> Result<Option<TestsSummary>> {
104        let (compiled, filtered_out) = filter_test_cases(
105            self.compiled,
106            self.config.include_ignored,
107            self.config.ignored,
108            &self.config.filter,
109        );
110
111        let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
112            if self.config.profiler_config == Some(ProfilerConfig::Cairo) {
113                let db = db.expect("db must be passed when profiling.");
114                let statements_locations = compiled
115                    .metadata
116                    .statements_locations
117                    .expect("statements locations must be present when profiling.");
118                Some(PorfilingAuxData {
119                    db,
120                    statements_functions: statements_locations
121                        .get_statements_functions_map_for_tests(db),
122                })
123            } else {
124                None
125            },
126            compiled.metadata.named_tests,
127            compiled.sierra_program.program,
128            compiled.metadata.function_set_costs,
129            compiled.metadata.contracts_info,
130            &self.config,
131        )?;
132
133        if failed.is_empty() {
134            println!(
135                "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
136                "ok".bright_green(),
137                passed.len(),
138                failed.len(),
139                ignored.len()
140            );
141            Ok(None)
142        } else {
143            println!("failures:");
144            for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
145                print!("   {failure} - ");
146                match run_result {
147                    RunResultValue::Success(_) => {
148                        println!("expected panic but finished successfully.");
149                    }
150                    RunResultValue::Panic(values) => {
151                        println!("{}", format_for_panic(values.into_iter()));
152                    }
153                }
154            }
155            println!();
156            bail!(
157                "test result: {}. {} passed; {} failed; {} ignored",
158                "FAILED".bright_red(),
159                passed.len(),
160                failed.len(),
161                ignored.len()
162            );
163        }
164    }
165}
166
167/// Configuration of compiled tests runner.
168#[derive(Clone, Debug)]
169pub struct TestRunConfig {
170    pub filter: String,
171    pub include_ignored: bool,
172    pub ignored: bool,
173    /// Whether to run the profiler and how.
174    pub profiler_config: Option<ProfilerConfig>,
175    /// Whether to enable gas calculation.
176    pub gas_enabled: bool,
177    /// Whether to print used resources after each test.
178    pub print_resource_usage: bool,
179}
180
181/// The test cases compiler.
182pub struct TestCompiler {
183    pub db: RootDatabase,
184    pub main_crate_ids: Vec<CrateId>,
185    pub test_crate_ids: Vec<CrateId>,
186    pub allow_warnings: bool,
187    pub config: TestsCompilationConfig,
188}
189
190impl TestCompiler {
191    /// Configure a new test compiler
192    ///
193    /// # Arguments
194    ///
195    /// * `path` - The path to compile and run its tests
196    /// * `starknet` - Add the starknet plugin to run the tests
197    pub fn try_new(
198        path: &Path,
199        allow_warnings: bool,
200        gas_enabled: bool,
201        config: TestsCompilationConfig,
202    ) -> Result<Self> {
203        let db = &mut {
204            let mut b = RootDatabase::builder();
205            let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
206            if !gas_enabled {
207                cfg.insert(Cfg::kv("gas", "disabled"));
208                b.skip_auto_withdraw_gas();
209            }
210            b.detect_corelib();
211            b.with_cfg(cfg);
212            b.with_default_plugin_suite(test_plugin_suite());
213            if config.starknet {
214                b.with_default_plugin_suite(starknet_plugin_suite());
215            }
216            b.build()?
217        };
218
219        let main_crate_ids = setup_project(db, Path::new(&path))?;
220
221        Ok(Self {
222            db: db.snapshot(),
223            test_crate_ids: main_crate_ids.clone(),
224            main_crate_ids,
225            allow_warnings,
226            config,
227        })
228    }
229
230    /// Build the tests and collect metadata.
231    pub fn build(&self) -> Result<TestCompilation> {
232        let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
233        if self.allow_warnings {
234            diag_reporter = diag_reporter.allow_warnings();
235        }
236
237        compile_test_prepared_db(
238            &self.db,
239            self.config.clone(),
240            self.test_crate_ids.clone(),
241            diag_reporter,
242        )
243    }
244}
245
246/// Filter compiled test cases with user provided arguments.
247///
248/// # Arguments
249/// * `compiled` - Compiled test cases with metadata.
250/// * `include_ignored` - Include ignored tests as well.
251/// * `ignored` - Run ignored tests only.l
252/// * `filter` - Include only tests containing the filter string.
253/// # Returns
254/// * (`TestCompilation`, `usize`) - The filtered test cases and the number of filtered out cases.
255pub fn filter_test_cases(
256    compiled: TestCompilation,
257    include_ignored: bool,
258    ignored: bool,
259    filter: &str,
260) -> (TestCompilation, usize) {
261    let total_tests_count = compiled.metadata.named_tests.len();
262    let named_tests = compiled
263        .metadata
264        .named_tests
265        .into_iter()
266        // Filtering unignored tests in `ignored` mode. Keep all tests in `include-ignored` mode.
267        .filter(|(_, test)| !ignored || test.ignored || include_ignored)
268        .map(|(func, mut test)| {
269            // Un-ignoring all the tests in `include-ignored` and `ignored` mode.
270            if include_ignored || ignored {
271                test.ignored = false;
272            }
273            (func, test)
274        })
275        .filter(|(name, _)| name.contains(filter))
276        .collect_vec();
277    let filtered_out = total_tests_count - named_tests.len();
278    let tests = TestCompilation {
279        sierra_program: compiled.sierra_program,
280        metadata: TestCompilationMetadata { named_tests, ..(compiled.metadata) },
281    };
282    (tests, filtered_out)
283}
284
285/// The status of a ran test.
286enum TestStatus {
287    Success,
288    Fail(RunResultValue),
289}
290
291/// The result of a ran test.
292struct TestResult {
293    /// The status of the run.
294    status: TestStatus,
295    /// The gas usage of the run if relevant.
296    gas_usage: Option<i64>,
297    /// The used resources of the run.
298    used_resources: StarknetExecutionResources,
299    /// The profiling info of the run, if requested.
300    profiling_info: Option<ProfilingInfo>,
301}
302
303/// Summary data of the ran tests.
304pub struct TestsSummary {
305    passed: Vec<String>,
306    failed: Vec<String>,
307    ignored: Vec<String>,
308    failed_run_results: Vec<RunResultValue>,
309}
310
311/// Auxiliary data that is required when running tests with profiling.
312pub struct PorfilingAuxData<'a> {
313    pub db: &'a dyn SierraGenGroup,
314    pub statements_functions: UnorderedHashMap<StatementIdx, String>,
315}
316
317/// Runs the tests and process the results for a summary.
318pub fn run_tests(
319    profiler_data: Option<PorfilingAuxData<'_>>,
320    named_tests: Vec<(String, TestConfig)>,
321    sierra_program: Program,
322    function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
323    contracts_info: OrderedHashMap<Felt252, ContractInfo>,
324    config: &TestRunConfig,
325) -> Result<TestsSummary> {
326    let runner = SierraCasmRunner::new(
327        sierra_program.clone(),
328        if config.gas_enabled {
329            Some(MetadataComputationConfig {
330                function_set_costs,
331                linear_gas_solver: true,
332                linear_ap_change_solver: true,
333                skip_non_linear_solver_comparisons: false,
334                compute_runtime_costs: false,
335            })
336        } else {
337            None
338        },
339        contracts_info,
340        match config.profiler_config {
341            None => None,
342            Some(ProfilerConfig::Cairo | ProfilerConfig::Sierra) => {
343                Some(ProfilingInfoCollectionConfig::default())
344            }
345        },
346    )
347    .with_context(|| "Failed setting up runner.")?;
348    let suffix = if named_tests.len() != 1 { "s" } else { "" };
349    println!("running {} test{}", named_tests.len(), suffix);
350    let wrapped_summary = Mutex::new(Ok(TestsSummary {
351        passed: vec![],
352        failed: vec![],
353        ignored: vec![],
354        failed_run_results: vec![],
355    }));
356
357    let profiling_params =
358        config.profiler_config.as_ref().map(ProfilingInfoProcessorParams::from_profiler_config);
359
360    // Run in parallel if possible. If running with db, parallelism is impossible.
361    if config.profiler_config != Some(ProfilerConfig::Cairo) {
362        named_tests
363            .into_par_iter()
364            .map(|(name, test)| run_single_test(test, name, &runner))
365            .for_each(|res| {
366                update_summary(
367                    &wrapped_summary,
368                    res,
369                    &None,
370                    &sierra_program,
371                    &profiling_params,
372                    config.print_resource_usage,
373                );
374            });
375    } else {
376        eprintln!("Note: Tests don't run in parallel when running with profiling.");
377        named_tests
378            .into_iter()
379            .map(move |(name, test)| run_single_test(test, name, &runner))
380            .for_each(|test_result| {
381                update_summary(
382                    &wrapped_summary,
383                    test_result,
384                    &profiler_data,
385                    &sierra_program,
386                    &profiling_params,
387                    config.print_resource_usage,
388                );
389            });
390    }
391
392    wrapped_summary.into_inner().unwrap()
393}
394
395/// Runs a single test and returns a tuple of its name and result.
396fn run_single_test(
397    test: TestConfig,
398    name: String,
399    runner: &SierraCasmRunner,
400) -> anyhow::Result<(String, Option<TestResult>)> {
401    if test.ignored {
402        return Ok((name, None));
403    }
404    let func = runner.find_function(name.as_str())?;
405    let result = runner
406        .run_function_with_starknet_context(func, vec![], test.available_gas, Default::default())
407        .with_context(|| format!("Failed to run the function `{}`.", name.as_str()))?;
408    Ok((
409        name,
410        Some(TestResult {
411            status: match &result.value {
412                RunResultValue::Success(_) => match test.expectation {
413                    TestExpectation::Success => TestStatus::Success,
414                    TestExpectation::Panics(_) => TestStatus::Fail(result.value),
415                },
416                RunResultValue::Panic(value) => match test.expectation {
417                    TestExpectation::Success => TestStatus::Fail(result.value),
418                    TestExpectation::Panics(panic_expectation) => match panic_expectation {
419                        PanicExpectation::Exact(expected) if value != &expected => {
420                            TestStatus::Fail(result.value)
421                        }
422                        _ => TestStatus::Success,
423                    },
424                },
425            },
426            gas_usage: test
427                .available_gas
428                .zip(result.gas_counter)
429                .map(|(before, after)| {
430                    before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
431                })
432                .or_else(|| {
433                    runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())
434                }),
435            used_resources: result.used_resources,
436            profiling_info: result.profiling_info,
437        }),
438    ))
439}
440
441/// Updates the test summary with the given test result.
442fn update_summary(
443    wrapped_summary: &Mutex<std::prelude::v1::Result<TestsSummary, anyhow::Error>>,
444    test_result: std::prelude::v1::Result<(String, Option<TestResult>), anyhow::Error>,
445    profiler_data: &Option<PorfilingAuxData<'_>>,
446    sierra_program: &Program,
447    profiling_params: &Option<ProfilingInfoProcessorParams>,
448    print_resource_usage: bool,
449) {
450    let mut wrapped_summary = wrapped_summary.lock().unwrap();
451    if wrapped_summary.is_err() {
452        return;
453    }
454    let (name, opt_result) = match test_result {
455        Ok((name, opt_result)) => (name, opt_result),
456        Err(err) => {
457            *wrapped_summary = Err(err);
458            return;
459        }
460    };
461    let summary = wrapped_summary.as_mut().unwrap();
462    let (res_type, status_str, gas_usage, used_resources, profiling_info) =
463        if let Some(result) = opt_result {
464            let (res_type, status_str) = match result.status {
465                TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
466                TestStatus::Fail(run_result) => {
467                    summary.failed_run_results.push(run_result);
468                    (&mut summary.failed, "fail".bright_red())
469                }
470            };
471            (
472                res_type,
473                status_str,
474                result.gas_usage,
475                print_resource_usage.then_some(result.used_resources),
476                result.profiling_info,
477            )
478        } else {
479            (&mut summary.ignored, "ignored".bright_yellow(), None, None, None)
480        };
481    if let Some(gas_usage) = gas_usage {
482        println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
483    } else {
484        println!("test {name} ... {status_str}");
485    }
486    if let Some(used_resources) = used_resources {
487        let filtered = used_resources.basic_resources.filter_unused_builtins();
488        // Prints the used resources per test. E.g.:
489        // ```ignore
490        // test cairo_level_tests::interoperability::test_contract_not_deployed ... ok (gas usage est.: 77320)
491        //     steps: 42
492        //     memory holes: 20
493        //     builtins: ("range_check_builtin": 3)
494        //     syscalls: ("CallContract": 1)
495        // test cairo_level_tests::events::test_pop_log ... ok (gas usage est.: 55440)
496        //     steps: 306
497        //     memory holes: 35
498        //     builtins: ("range_check_builtin": 24)
499        //     syscalls: ("EmitEvent": 2)
500        // ```
501        println!("    steps: {}", filtered.n_steps);
502        println!("    memory holes: {}", filtered.n_memory_holes);
503
504        print_resource_map(
505            filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
506            "builtins",
507        );
508        print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
509    }
510    if let Some(profiling_params) = profiling_params {
511        let (opt_db, statements_functions) =
512            if let Some(PorfilingAuxData { db, statements_functions }) = profiler_data {
513                (Some(*db), statements_functions)
514            } else {
515                (None, &UnorderedHashMap::default())
516            };
517
518        let profiling_processor =
519            ProfilingInfoProcessor::new(opt_db, sierra_program, statements_functions);
520        let processed_profiling_info = profiling_processor.process(
521            &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
522            profiling_params,
523        );
524        println!("Profiling info:\n{processed_profiling_info}");
525    }
526    res_type.push(name);
527}
528
529/// Given an iterator of (String, usize) pairs, prints a usage map. E.g.:
530///     syscalls: ("EmitEvent": 2)
531///     syscalls: ("CallContract": 1)
532fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
533    if m.len() != 0 {
534        println!(
535            "    {resource_type}: ({})",
536            m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
537        );
538    }
539}