Skip to main content

cairo_lang_test_runner/
lib.rs

1use std::path::Path;
2use std::sync::Arc;
3use std::sync::mpsc::channel;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_debug::debug::DebugWithDb;
10use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
11use cairo_lang_filesystem::ids::CrateInput;
12use cairo_lang_runner::casm_run::{StarknetHintProcessor, format_for_panic};
13use cairo_lang_runner::profiling::{
14    ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
15};
16use cairo_lang_runner::{
17    CairoHintProcessor, ProfilingInfoCollectionConfig, RunResultValue, RunnerError,
18    SierraCasmRunner, StarknetExecutionResources,
19};
20use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
21use cairo_lang_starknet::starknet_plugin_suite;
22use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
23use cairo_lang_test_plugin::{
24    TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
25    compile_test_prepared_db, test_plugin_suite,
26};
27use cairo_lang_utils::casts::IntoOrPanic;
28use colored::Colorize;
29use itertools::Itertools;
30use num_traits::ToPrimitive;
31use rayon::prelude::{IntoParallelIterator, ParallelIterator};
32use salsa::Database;
33
34#[cfg(test)]
35mod test;
36
37type ArcCustomHintProcessorFactory = Arc<
38    dyn (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>) + Send + Sync,
39>;
40
41/// Compile and run tests.
42pub struct TestRunner<'db> {
43    compiler: TestCompiler<'db>,
44    config: TestRunConfig,
45    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
46}
47
48impl<'db> TestRunner<'db> {
49    /// Configure a new test runner
50    ///
51    /// # Arguments
52    ///
53    /// * `path` - The path to compile and run its tests
54    /// * `filter` - Run only tests containing the filter string
55    /// * `include_ignored` - Include ignored tests as well
56    /// * `ignored` - Run ignored tests only
57    /// * `starknet` - Add the Starknet plugin to run the tests
58    pub fn new(
59        path: &Path,
60        starknet: bool,
61        allow_warnings: bool,
62        config: TestRunConfig,
63    ) -> Result<Self> {
64        let compiler = TestCompiler::try_new(
65            path,
66            allow_warnings,
67            config.gas_enabled,
68            TestsCompilationConfig {
69                starknet,
70                add_statements_functions: config
71                    .profiler_config
72                    .as_ref()
73                    .is_some_and(|c| c.requires_cairo_debug_info()),
74                add_statements_code_locations: false,
75                contract_declarations: None,
76                contract_crate_ids: None,
77                executable_crate_ids: None,
78                add_functions_debug_info: false,
79                replace_ids: false,
80            },
81        )?;
82        Ok(Self { compiler, config, custom_hint_processor_factory: None })
83    }
84
85    /// Make this runner run tests using a custom hint processor.
86    pub fn with_custom_hint_processor(
87        &mut self,
88        factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
89        + Send
90        + Sync
91        + 'static,
92    ) -> &mut Self {
93        self.custom_hint_processor_factory = Some(Arc::new(factory));
94        self
95    }
96
97    /// Runs the tests and processes the results for a summary.
98    pub fn run(&self) -> Result<Option<TestsSummary>> {
99        let runner = CompiledTestRunner {
100            compiled: self.compiler.build()?,
101            config: self.config.clone(),
102            custom_hint_processor_factory: self.custom_hint_processor_factory.clone(),
103        };
104        runner.run(Some(&self.compiler.db))
105    }
106}
107
108pub struct CompiledTestRunner<'db> {
109    compiled: TestCompilation<'db>,
110    config: TestRunConfig,
111    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
112}
113
114impl<'db> CompiledTestRunner<'db> {
115    /// Configure a new compiled test runner
116    ///
117    /// # Arguments
118    ///
119    /// * `compiled` - The compiled tests to run
120    /// * `config` - Test run configuration
121    pub fn new(compiled: TestCompilation<'db>, config: TestRunConfig) -> Self {
122        Self { compiled, config, custom_hint_processor_factory: None }
123    }
124
125    /// Make this runner run tests using a custom hint processor.
126    pub fn with_custom_hint_processor(
127        &mut self,
128        factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
129        + Send
130        + Sync
131        + 'static,
132    ) -> &mut Self {
133        self.custom_hint_processor_factory = Some(Arc::new(factory));
134        self
135    }
136
137    /// Execute preconfigured test execution.
138    pub fn run(self, opt_db: Option<&'db RootDatabase>) -> Result<Option<TestsSummary>> {
139        let (compiled, filtered_out) = filter_test_cases(
140            self.compiled,
141            self.config.include_ignored,
142            self.config.ignored,
143            &self.config.filter,
144        );
145
146        let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
147            opt_db.map(|db| db as &dyn Database),
148            compiled,
149            &self.config,
150            self.custom_hint_processor_factory,
151        )?;
152
153        if failed.is_empty() {
154            println!(
155                "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
156                "ok".bright_green(),
157                passed.len(),
158                failed.len(),
159                ignored.len()
160            );
161            Ok(None)
162        } else {
163            println!("failures:");
164            for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
165                print!("   {failure} - ");
166                match run_result {
167                    Ok(RunResultValue::Success(_)) => {
168                        println!("expected panic but finished successfully.");
169                    }
170                    Ok(RunResultValue::Panic(values)) => {
171                        println!("{}", format_for_panic(values.into_iter()));
172                    }
173                    Err(err) => {
174                        println!("{err}");
175                    }
176                }
177            }
178            println!();
179            bail!(
180                "test result: {}. {} passed; {} failed; {} ignored",
181                "FAILED".bright_red(),
182                passed.len(),
183                failed.len(),
184                ignored.len()
185            );
186        }
187    }
188}
189
190/// Configuration of compiled tests runner.
191#[derive(Clone, Debug)]
192pub struct TestRunConfig {
193    pub filter: String,
194    pub include_ignored: bool,
195    pub ignored: bool,
196    /// Whether to run the profiler and how.
197    pub profiler_config: Option<ProfilerConfig>,
198    /// Whether to enable gas calculation.
199    pub gas_enabled: bool,
200    /// Whether to print used resources after each test.
201    pub print_resource_usage: bool,
202}
203
204/// The test cases compiler.
205pub struct TestCompiler<'db> {
206    pub db: RootDatabase,
207    pub main_crate_ids: Vec<CrateInput>,
208    pub test_crate_ids: Vec<CrateInput>,
209    pub allow_warnings: bool,
210    pub config: TestsCompilationConfig<'db>,
211}
212
213impl<'db> TestCompiler<'db> {
214    /// Configure a new test compiler
215    ///
216    /// # Arguments
217    ///
218    /// * `path` - The path to compile and run its tests
219    /// * `starknet` - Add the Starknet plugin to run the tests
220    pub fn try_new(
221        path: &Path,
222        allow_warnings: bool,
223        gas_enabled: bool,
224        config: TestsCompilationConfig<'db>,
225    ) -> Result<Self> {
226        let mut db = {
227            let mut b = RootDatabase::builder();
228            let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
229            if !gas_enabled {
230                cfg.insert(Cfg::kv("gas", "disabled"));
231                b.skip_auto_withdraw_gas();
232            }
233            b.detect_corelib();
234            b.with_cfg(cfg);
235            b.with_default_plugin_suite(test_plugin_suite());
236            if config.starknet {
237                b.with_default_plugin_suite(starknet_plugin_suite());
238            }
239            b.build()?
240        };
241
242        let main_crate_inputs = setup_project(&mut db, path)?;
243
244        Ok(Self {
245            db: db.snapshot(),
246            test_crate_ids: main_crate_inputs.clone(),
247            main_crate_ids: main_crate_inputs,
248            allow_warnings,
249            config,
250        })
251    }
252
253    /// Build the tests and collect metadata.
254    pub fn build(&self) -> Result<TestCompilation<'_>> {
255        let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
256        if self.allow_warnings {
257            diag_reporter = diag_reporter.allow_warnings();
258        }
259
260        compile_test_prepared_db(
261            &self.db,
262            self.config.clone(),
263            self.test_crate_ids.clone(),
264            diag_reporter,
265        )
266    }
267}
268
269/// Filter compiled test cases with user provided arguments.
270///
271/// # Arguments
272/// * `compiled` - Compiled test cases with metadata.
273/// * `include_ignored` - Include ignored tests as well.
274/// * `ignored` - Run ignored tests only
275/// * `filter` - Include only tests containing the filter string.
276/// # Returns
277/// * (`TestCompilation`, `usize`) - The filtered test cases and the number of filtered out cases.
278pub fn filter_test_cases<'db>(
279    compiled: TestCompilation<'db>,
280    include_ignored: bool,
281    ignored: bool,
282    filter: &str,
283) -> (TestCompilation<'db>, usize) {
284    let total_tests_count = compiled.metadata.named_tests.len();
285    let named_tests = compiled
286        .metadata
287        .named_tests
288        .into_iter()
289        // Filtering unignored tests in `ignored` mode. Keep all tests in `include-ignored` mode.
290        .filter(|(_, test)| !ignored || test.ignored || include_ignored)
291        .map(|(func, mut test)| {
292            // Un-ignoring all the tests in `include-ignored` and `ignored` mode.
293            if include_ignored || ignored {
294                test.ignored = false;
295            }
296            (func, test)
297        })
298        .filter(|(name, _)| name.contains(filter))
299        .collect_vec();
300    let filtered_out = total_tests_count - named_tests.len();
301    let tests = TestCompilation {
302        sierra_program: compiled.sierra_program,
303        metadata: TestCompilationMetadata { named_tests, ..compiled.metadata },
304    };
305    (tests, filtered_out)
306}
307
308/// The status of a test run.
309enum TestStatus {
310    Success,
311    Fail(RunResultValue),
312}
313
314/// The result of a test run.
315struct TestResult {
316    /// The status of the run.
317    status: TestStatus,
318    /// The gas usage of the run if relevant.
319    gas_usage: Option<i64>,
320    /// The used resources of the run.
321    used_resources: StarknetExecutionResources,
322    /// The profiling info of the run, if requested.
323    profiling_info: Option<ProfilingInfo>,
324}
325
326/// Summary data of the tests run.
327pub struct TestsSummary {
328    passed: Vec<String>,
329    failed: Vec<String>,
330    ignored: Vec<String>,
331    failed_run_results: Vec<Result<RunResultValue>>,
332}
333
334/// Runs the tests and processes the results for a summary.
335pub fn run_tests(
336    opt_db: Option<&dyn Database>,
337    compiled: TestCompilation<'_>,
338    config: &TestRunConfig,
339    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
340) -> Result<TestsSummary> {
341    let TestCompilation {
342        sierra_program: sierra_program_with_debug_info,
343        metadata:
344            TestCompilationMetadata {
345                named_tests,
346                function_set_costs,
347                contracts_info,
348                statements_locations,
349            },
350    } = compiled;
351    let sierra_program = sierra_program_with_debug_info.program;
352    let runner = SierraCasmRunner::new(
353        sierra_program.clone(),
354        if config.gas_enabled {
355            Some(MetadataComputationConfig {
356                function_set_costs,
357                linear_gas_solver: true,
358                linear_ap_change_solver: true,
359                skip_non_linear_solver_comparisons: false,
360                compute_runtime_costs: false,
361            })
362        } else {
363            None
364        },
365        contracts_info,
366        config.profiler_config.as_ref().map(ProfilingInfoCollectionConfig::from_profiler_config),
367    )
368    .map_err(|err| {
369        let (RunnerError::BuildError(err), Some(db), Some(statements_locations)) =
370            (&err, opt_db, &statements_locations)
371        else {
372            return err.into();
373        };
374        let mut locs = vec![];
375        for stmt_idx in err.stmt_indices() {
376            if let Some(loc) = statements_locations.statement_diagnostic_location(db, stmt_idx) {
377                locs.push(format!("#{stmt_idx} {:?}", loc.debug(db)))
378            }
379        }
380        anyhow::anyhow!("{err}\n{}", locs.join("\n"))
381    })
382    .with_context(|| "Failed setting up runner.")?;
383    let suffix = if named_tests.len() != 1 { "s" } else { "" };
384    println!("running {} test{}", named_tests.len(), suffix);
385
386    let (tx, rx) = channel::<_>();
387    rayon::spawn(move || {
388        named_tests.into_par_iter().for_each(|(name, test)| {
389            let result =
390                run_single_test(test, &name, &runner, custom_hint_processor_factory.clone());
391            tx.send((name, result)).unwrap();
392        })
393    });
394
395    let profiler_data = config.profiler_config.as_ref().and_then(|profiler_config| {
396        if !profiler_config.requires_cairo_debug_info() {
397            return None;
398        }
399
400        let db = opt_db.expect("db must be passed when doing cairo level profiling.");
401        Some((
402            ProfilingInfoProcessor::new(
403                Some(db),
404                &sierra_program,
405                statements_locations
406                    .expect(
407                        "statements locations must be present when doing cairo level profiling.",
408                    )
409                    .get_statements_functions_map_for_tests(db),
410            ),
411            ProfilingInfoProcessorParams::from_profiler_config(profiler_config),
412        ))
413    });
414
415    let mut summary = TestsSummary {
416        passed: vec![],
417        failed: vec![],
418        ignored: vec![],
419        failed_run_results: vec![],
420    };
421    while let Ok((name, result)) = rx.recv() {
422        update_summary(&mut summary, name, result, &profiler_data, config.print_resource_usage);
423    }
424
425    Ok(summary)
426}
427
428/// Runs a single test and returns a tuple of its name and result.
429fn run_single_test(
430    test: TestConfig,
431    name: &str,
432    runner: &SierraCasmRunner,
433    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
434) -> Result<Option<TestResult>> {
435    if test.ignored {
436        return Ok(None);
437    }
438    let func = runner.find_function(name)?;
439
440    let (hint_processor, ctx) =
441        runner.prepare_starknet_context(func, vec![], test.available_gas, Default::default())?;
442
443    let mut hint_processor = match custom_hint_processor_factory {
444        Some(f) => f(hint_processor),
445        None => Box::new(hint_processor),
446    };
447
448    let result =
449        runner.run_function_with_prepared_starknet_context(func, &mut *hint_processor, ctx)?;
450
451    Ok(Some(TestResult {
452        status: match &result.value {
453            RunResultValue::Success(_) => match test.expectation {
454                TestExpectation::Success => TestStatus::Success,
455                TestExpectation::Panics(_) => TestStatus::Fail(result.value),
456            },
457            RunResultValue::Panic(value) => match test.expectation {
458                TestExpectation::Success => TestStatus::Fail(result.value),
459                TestExpectation::Panics(panic_expectation) => match panic_expectation {
460                    PanicExpectation::Exact(expected) if value != &expected => {
461                        TestStatus::Fail(result.value)
462                    }
463                    _ => TestStatus::Success,
464                },
465            },
466        },
467        gas_usage: test
468            .available_gas
469            .zip(result.gas_counter)
470            .map(|(before, after)| {
471                before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
472            })
473            .or_else(|| runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())),
474        used_resources: result.used_resources,
475        profiling_info: result.profiling_info,
476    }))
477}
478
479/// Updates the test summary with the given test result.
480fn update_summary(
481    summary: &mut TestsSummary,
482    name: String,
483    test_result: Result<Option<TestResult>>,
484    profiler_data: &Option<(ProfilingInfoProcessor<'_>, ProfilingInfoProcessorParams)>,
485    print_resource_usage: bool,
486) {
487    let (res_type, status_str, gas_usage, used_resources, profiling_info) = match test_result {
488        Ok(None) => (&mut summary.ignored, "ignored".bright_yellow(), None, None, None),
489        Err(err) => {
490            summary.failed_run_results.push(Err(err));
491            (&mut summary.failed, "failed to run".bright_magenta(), None, None, None)
492        }
493        Ok(Some(result)) => {
494            let (res_type, status_str) = match result.status {
495                TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
496                TestStatus::Fail(run_result) => {
497                    summary.failed_run_results.push(Ok(run_result));
498                    (&mut summary.failed, "fail".bright_red())
499                }
500            };
501            (
502                res_type,
503                status_str,
504                result.gas_usage,
505                print_resource_usage.then_some(result.used_resources),
506                result.profiling_info,
507            )
508        }
509    };
510    if let Some(gas_usage) = gas_usage {
511        println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
512    } else {
513        println!("test {name} ... {status_str}");
514    }
515    if let Some(used_resources) = used_resources {
516        let filtered = used_resources.basic_resources.filter_unused_builtins();
517        // Prints the used resources per test. E.g.:
518        // ```ignore
519        // test cairo_level_tests::interoperability::test_contract_not_deployed ... ok (gas usage est.: 77320)
520        //     steps: 42
521        //     memory holes: 20
522        //     builtins: ("range_check_builtin": 3)
523        //     syscalls: ("CallContract": 1)
524        // test cairo_level_tests::events::test_pop_log ... ok (gas usage est.: 55440)
525        //     steps: 306
526        //     memory holes: 35
527        //     builtins: ("range_check_builtin": 24)
528        //     syscalls: ("EmitEvent": 2)
529        // ```
530        println!("    steps: {}", filtered.n_steps);
531        println!("    memory holes: {}", filtered.n_memory_holes);
532
533        print_resource_map(
534            filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
535            "builtins",
536        );
537        print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
538    }
539    if let Some((profiling_processor, profiling_params)) = profiler_data {
540        let processed_profiling_info = profiling_processor.process(
541            &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
542            profiling_params,
543        );
544        println!("Profiling info:\n{processed_profiling_info}");
545    }
546    res_type.push(name);
547}
548
549/// Given an iterator of (String, usize) pairs, prints a usage map. E.g.:
550///     syscalls: ("EmitEvent": 2)
551///     syscalls: ("CallContract": 1)
552fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
553    if m.len() != 0 {
554        println!(
555            "    {resource_type}: ({})",
556            m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
557        );
558    }
559}