cairo_lang_test_runner/
lib.rs

1use std::path::Path;
2use std::sync::Arc;
3use std::sync::mpsc::channel;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_debug::debug::DebugWithDb;
10use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
11use cairo_lang_filesystem::ids::CrateInput;
12use cairo_lang_runner::casm_run::{StarknetHintProcessor, format_for_panic};
13use cairo_lang_runner::profiling::{
14    ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
15};
16use cairo_lang_runner::{
17    CairoHintProcessor, ProfilingInfoCollectionConfig, RunResultValue, RunnerError,
18    SierraCasmRunner, StarknetExecutionResources,
19};
20use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
21use cairo_lang_starknet::starknet_plugin_suite;
22use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
23use cairo_lang_test_plugin::{
24    TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
25    compile_test_prepared_db, test_plugin_suite,
26};
27use cairo_lang_utils::casts::IntoOrPanic;
28use colored::Colorize;
29use itertools::Itertools;
30use num_traits::ToPrimitive;
31use rayon::prelude::{IntoParallelIterator, ParallelIterator};
32use salsa::Database;
33
34#[cfg(test)]
35mod test;
36
37type ArcCustomHintProcessorFactory = Arc<
38    dyn (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>) + Send + Sync,
39>;
40
41/// Compile and run tests.
42pub struct TestRunner<'db> {
43    compiler: TestCompiler<'db>,
44    config: TestRunConfig,
45    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
46}
47
48impl<'db> TestRunner<'db> {
49    /// Configure a new test runner
50    ///
51    /// # Arguments
52    ///
53    /// * `path` - The path to compile and run its tests
54    /// * `filter` - Run only tests containing the filter string
55    /// * `include_ignored` - Include ignored tests as well
56    /// * `ignored` - Run ignored tests only
57    /// * `starknet` - Add the Starknet plugin to run the tests
58    pub fn new(
59        path: &Path,
60        starknet: bool,
61        allow_warnings: bool,
62        config: TestRunConfig,
63    ) -> Result<Self> {
64        let compiler = TestCompiler::try_new(
65            path,
66            allow_warnings,
67            config.gas_enabled,
68            TestsCompilationConfig {
69                starknet,
70                add_statements_functions: config
71                    .profiler_config
72                    .as_ref()
73                    .is_some_and(|c| c.requires_cairo_debug_info()),
74                add_statements_code_locations: false,
75                contract_declarations: None,
76                contract_crate_ids: None,
77                executable_crate_ids: None,
78                add_functions_debug_info: false,
79            },
80        )?;
81        Ok(Self { compiler, config, custom_hint_processor_factory: None })
82    }
83
84    /// Make this runner run tests using a custom hint processor.
85    pub fn with_custom_hint_processor(
86        &mut self,
87        factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
88        + Send
89        + Sync
90        + 'static,
91    ) -> &mut Self {
92        self.custom_hint_processor_factory = Some(Arc::new(factory));
93        self
94    }
95
96    /// Runs the tests and processes the results for a summary.
97    pub fn run(&self) -> Result<Option<TestsSummary>> {
98        let runner = CompiledTestRunner {
99            compiled: self.compiler.build()?,
100            config: self.config.clone(),
101            custom_hint_processor_factory: self.custom_hint_processor_factory.clone(),
102        };
103        runner.run(Some(&self.compiler.db))
104    }
105}
106
107pub struct CompiledTestRunner<'db> {
108    compiled: TestCompilation<'db>,
109    config: TestRunConfig,
110    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
111}
112
113impl<'db> CompiledTestRunner<'db> {
114    /// Configure a new compiled test runner
115    ///
116    /// # Arguments
117    ///
118    /// * `compiled` - The compiled tests to run
119    /// * `config` - Test run configuration
120    pub fn new(compiled: TestCompilation<'db>, config: TestRunConfig) -> Self {
121        Self { compiled, config, custom_hint_processor_factory: None }
122    }
123
124    /// Make this runner run tests using a custom hint processor.
125    pub fn with_custom_hint_processor(
126        &mut self,
127        factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
128        + Send
129        + Sync
130        + 'static,
131    ) -> &mut Self {
132        self.custom_hint_processor_factory = Some(Arc::new(factory));
133        self
134    }
135
136    /// Execute preconfigured test execution.
137    pub fn run(self, opt_db: Option<&'db RootDatabase>) -> Result<Option<TestsSummary>> {
138        let (compiled, filtered_out) = filter_test_cases(
139            self.compiled,
140            self.config.include_ignored,
141            self.config.ignored,
142            &self.config.filter,
143        );
144
145        let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
146            opt_db.map(|db| db as &dyn Database),
147            compiled,
148            &self.config,
149            self.custom_hint_processor_factory,
150        )?;
151
152        if failed.is_empty() {
153            println!(
154                "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
155                "ok".bright_green(),
156                passed.len(),
157                failed.len(),
158                ignored.len()
159            );
160            Ok(None)
161        } else {
162            println!("failures:");
163            for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
164                print!("   {failure} - ");
165                match run_result {
166                    Ok(RunResultValue::Success(_)) => {
167                        println!("expected panic but finished successfully.");
168                    }
169                    Ok(RunResultValue::Panic(values)) => {
170                        println!("{}", format_for_panic(values.into_iter()));
171                    }
172                    Err(err) => {
173                        println!("{err}");
174                    }
175                }
176            }
177            println!();
178            bail!(
179                "test result: {}. {} passed; {} failed; {} ignored",
180                "FAILED".bright_red(),
181                passed.len(),
182                failed.len(),
183                ignored.len()
184            );
185        }
186    }
187}
188
189/// Configuration of compiled tests runner.
190#[derive(Clone, Debug)]
191pub struct TestRunConfig {
192    pub filter: String,
193    pub include_ignored: bool,
194    pub ignored: bool,
195    /// Whether to run the profiler and how.
196    pub profiler_config: Option<ProfilerConfig>,
197    /// Whether to enable gas calculation.
198    pub gas_enabled: bool,
199    /// Whether to print used resources after each test.
200    pub print_resource_usage: bool,
201}
202
203/// The test cases compiler.
204pub struct TestCompiler<'db> {
205    pub db: RootDatabase,
206    pub main_crate_ids: Vec<CrateInput>,
207    pub test_crate_ids: Vec<CrateInput>,
208    pub allow_warnings: bool,
209    pub config: TestsCompilationConfig<'db>,
210}
211
212impl<'db> TestCompiler<'db> {
213    /// Configure a new test compiler
214    ///
215    /// # Arguments
216    ///
217    /// * `path` - The path to compile and run its tests
218    /// * `starknet` - Add the Starknet plugin to run the tests
219    pub fn try_new(
220        path: &Path,
221        allow_warnings: bool,
222        gas_enabled: bool,
223        config: TestsCompilationConfig<'db>,
224    ) -> Result<Self> {
225        let mut db = {
226            let mut b = RootDatabase::builder();
227            let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
228            if !gas_enabled {
229                cfg.insert(Cfg::kv("gas", "disabled"));
230                b.skip_auto_withdraw_gas();
231            }
232            b.detect_corelib();
233            b.with_cfg(cfg);
234            b.with_default_plugin_suite(test_plugin_suite());
235            if config.starknet {
236                b.with_default_plugin_suite(starknet_plugin_suite());
237            }
238            b.build()?
239        };
240
241        let main_crate_inputs = setup_project(&mut db, Path::new(&path))?;
242
243        Ok(Self {
244            db: db.snapshot(),
245            test_crate_ids: main_crate_inputs.clone(),
246            main_crate_ids: main_crate_inputs,
247            allow_warnings,
248            config,
249        })
250    }
251
252    /// Build the tests and collect metadata.
253    pub fn build(&self) -> Result<TestCompilation<'_>> {
254        let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
255        if self.allow_warnings {
256            diag_reporter = diag_reporter.allow_warnings();
257        }
258
259        compile_test_prepared_db(
260            &self.db,
261            self.config.clone(),
262            self.test_crate_ids.clone(),
263            diag_reporter,
264        )
265    }
266}
267
268/// Filter compiled test cases with user provided arguments.
269///
270/// # Arguments
271/// * `compiled` - Compiled test cases with metadata.
272/// * `include_ignored` - Include ignored tests as well.
273/// * `ignored` - Run ignored tests only
274/// * `filter` - Include only tests containing the filter string.
275/// # Returns
276/// * (`TestCompilation`, `usize`) - The filtered test cases and the number of filtered out cases.
277pub fn filter_test_cases<'db>(
278    compiled: TestCompilation<'db>,
279    include_ignored: bool,
280    ignored: bool,
281    filter: &str,
282) -> (TestCompilation<'db>, usize) {
283    let total_tests_count = compiled.metadata.named_tests.len();
284    let named_tests = compiled
285        .metadata
286        .named_tests
287        .into_iter()
288        // Filtering unignored tests in `ignored` mode. Keep all tests in `include-ignored` mode.
289        .filter(|(_, test)| !ignored || test.ignored || include_ignored)
290        .map(|(func, mut test)| {
291            // Un-ignoring all the tests in `include-ignored` and `ignored` mode.
292            if include_ignored || ignored {
293                test.ignored = false;
294            }
295            (func, test)
296        })
297        .filter(|(name, _)| name.contains(filter))
298        .collect_vec();
299    let filtered_out = total_tests_count - named_tests.len();
300    let tests = TestCompilation {
301        sierra_program: compiled.sierra_program,
302        metadata: TestCompilationMetadata { named_tests, ..compiled.metadata },
303    };
304    (tests, filtered_out)
305}
306
307/// The status of a test run.
308enum TestStatus {
309    Success,
310    Fail(RunResultValue),
311}
312
313/// The result of a test run.
314struct TestResult {
315    /// The status of the run.
316    status: TestStatus,
317    /// The gas usage of the run if relevant.
318    gas_usage: Option<i64>,
319    /// The used resources of the run.
320    used_resources: StarknetExecutionResources,
321    /// The profiling info of the run, if requested.
322    profiling_info: Option<ProfilingInfo>,
323}
324
325/// Summary data of the tests run.
326pub struct TestsSummary {
327    passed: Vec<String>,
328    failed: Vec<String>,
329    ignored: Vec<String>,
330    failed_run_results: Vec<Result<RunResultValue>>,
331}
332
333/// Runs the tests and processes the results for a summary.
334pub fn run_tests(
335    opt_db: Option<&dyn Database>,
336    compiled: TestCompilation<'_>,
337    config: &TestRunConfig,
338    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
339) -> Result<TestsSummary> {
340    let TestCompilation {
341        sierra_program: sierra_program_with_debug_info,
342        metadata:
343            TestCompilationMetadata {
344                named_tests,
345                function_set_costs,
346                contracts_info,
347                statements_locations,
348            },
349    } = compiled;
350    let sierra_program = sierra_program_with_debug_info.program;
351    let runner = SierraCasmRunner::new(
352        sierra_program.clone(),
353        if config.gas_enabled {
354            Some(MetadataComputationConfig {
355                function_set_costs,
356                linear_gas_solver: true,
357                linear_ap_change_solver: true,
358                skip_non_linear_solver_comparisons: false,
359                compute_runtime_costs: false,
360            })
361        } else {
362            None
363        },
364        contracts_info,
365        config.profiler_config.as_ref().map(ProfilingInfoCollectionConfig::from_profiler_config),
366    )
367    .map_err(|err| {
368        let (RunnerError::BuildError(err), Some(db), Some(statements_locations)) =
369            (&err, opt_db, &statements_locations)
370        else {
371            return err.into();
372        };
373        let mut locs = vec![];
374        for stmt_idx in err.stmt_indices() {
375            if let Some(loc) = statements_locations.statement_diagnostic_location(db, stmt_idx) {
376                locs.push(format!("#{stmt_idx} {:?}", loc.debug(db)))
377            }
378        }
379        anyhow::anyhow!("{err}\n{}", locs.join("\n"))
380    })
381    .with_context(|| "Failed setting up runner.")?;
382    let suffix = if named_tests.len() != 1 { "s" } else { "" };
383    println!("running {} test{}", named_tests.len(), suffix);
384
385    let (tx, rx) = channel::<_>();
386    rayon::spawn(move || {
387        named_tests.into_par_iter().for_each(|(name, test)| {
388            let result =
389                run_single_test(test, &name, &runner, custom_hint_processor_factory.clone());
390            tx.send((name, result)).unwrap();
391        })
392    });
393
394    let profiler_data = config.profiler_config.as_ref().and_then(|profiler_config| {
395        if !profiler_config.requires_cairo_debug_info() {
396            return None;
397        }
398
399        let db = opt_db.expect("db must be passed when doing cairo level profiling.");
400        Some((
401            ProfilingInfoProcessor::new(
402                Some(db),
403                &sierra_program,
404                statements_locations
405                    .expect(
406                        "statements locations must be present when doing cairo level profiling.",
407                    )
408                    .get_statements_functions_map_for_tests(db),
409            ),
410            ProfilingInfoProcessorParams::from_profiler_config(profiler_config),
411        ))
412    });
413
414    let mut summary = TestsSummary {
415        passed: vec![],
416        failed: vec![],
417        ignored: vec![],
418        failed_run_results: vec![],
419    };
420    while let Ok((name, result)) = rx.recv() {
421        update_summary(&mut summary, name, result, &profiler_data, config.print_resource_usage);
422    }
423
424    Ok(summary)
425}
426
427/// Runs a single test and returns a tuple of its name and result.
428fn run_single_test(
429    test: TestConfig,
430    name: &str,
431    runner: &SierraCasmRunner,
432    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
433) -> Result<Option<TestResult>> {
434    if test.ignored {
435        return Ok(None);
436    }
437    let func = runner.find_function(name)?;
438
439    let (hint_processor, ctx) =
440        runner.prepare_starknet_context(func, vec![], test.available_gas, Default::default())?;
441
442    let mut hint_processor = match custom_hint_processor_factory {
443        Some(f) => f(hint_processor),
444        None => Box::new(hint_processor),
445    };
446
447    let result =
448        runner.run_function_with_prepared_starknet_context(func, &mut *hint_processor, ctx)?;
449
450    Ok(Some(TestResult {
451        status: match &result.value {
452            RunResultValue::Success(_) => match test.expectation {
453                TestExpectation::Success => TestStatus::Success,
454                TestExpectation::Panics(_) => TestStatus::Fail(result.value),
455            },
456            RunResultValue::Panic(value) => match test.expectation {
457                TestExpectation::Success => TestStatus::Fail(result.value),
458                TestExpectation::Panics(panic_expectation) => match panic_expectation {
459                    PanicExpectation::Exact(expected) if value != &expected => {
460                        TestStatus::Fail(result.value)
461                    }
462                    _ => TestStatus::Success,
463                },
464            },
465        },
466        gas_usage: test
467            .available_gas
468            .zip(result.gas_counter)
469            .map(|(before, after)| {
470                before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
471            })
472            .or_else(|| runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())),
473        used_resources: result.used_resources,
474        profiling_info: result.profiling_info,
475    }))
476}
477
478/// Updates the test summary with the given test result.
479fn update_summary(
480    summary: &mut TestsSummary,
481    name: String,
482    test_result: Result<Option<TestResult>>,
483    profiler_data: &Option<(ProfilingInfoProcessor<'_>, ProfilingInfoProcessorParams)>,
484    print_resource_usage: bool,
485) {
486    let (res_type, status_str, gas_usage, used_resources, profiling_info) = match test_result {
487        Ok(None) => (&mut summary.ignored, "ignored".bright_yellow(), None, None, None),
488        Err(err) => {
489            summary.failed_run_results.push(Err(err));
490            (&mut summary.failed, "failed to run".bright_magenta(), None, None, None)
491        }
492        Ok(Some(result)) => {
493            let (res_type, status_str) = match result.status {
494                TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
495                TestStatus::Fail(run_result) => {
496                    summary.failed_run_results.push(Ok(run_result));
497                    (&mut summary.failed, "fail".bright_red())
498                }
499            };
500            (
501                res_type,
502                status_str,
503                result.gas_usage,
504                print_resource_usage.then_some(result.used_resources),
505                result.profiling_info,
506            )
507        }
508    };
509    if let Some(gas_usage) = gas_usage {
510        println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
511    } else {
512        println!("test {name} ... {status_str}");
513    }
514    if let Some(used_resources) = used_resources {
515        let filtered = used_resources.basic_resources.filter_unused_builtins();
516        // Prints the used resources per test. E.g.:
517        // ```ignore
518        // test cairo_level_tests::interoperability::test_contract_not_deployed ... ok (gas usage est.: 77320)
519        //     steps: 42
520        //     memory holes: 20
521        //     builtins: ("range_check_builtin": 3)
522        //     syscalls: ("CallContract": 1)
523        // test cairo_level_tests::events::test_pop_log ... ok (gas usage est.: 55440)
524        //     steps: 306
525        //     memory holes: 35
526        //     builtins: ("range_check_builtin": 24)
527        //     syscalls: ("EmitEvent": 2)
528        // ```
529        println!("    steps: {}", filtered.n_steps);
530        println!("    memory holes: {}", filtered.n_memory_holes);
531
532        print_resource_map(
533            filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
534            "builtins",
535        );
536        print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
537    }
538    if let Some((profiling_processor, profiling_params)) = profiler_data {
539        let processed_profiling_info = profiling_processor.process(
540            &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
541            profiling_params,
542        );
543        println!("Profiling info:\n{processed_profiling_info}");
544    }
545    res_type.push(name);
546}
547
548/// Given an iterator of (String, usize) pairs, prints a usage map. E.g.:
549///     syscalls: ("EmitEvent": 2)
550///     syscalls: ("CallContract": 1)
551fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
552    if m.len() != 0 {
553        println!(
554            "    {resource_type}: ({})",
555            m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
556        );
557    }
558}