cairo_lang_test_runner/
lib.rs

1use std::path::Path;
2use std::sync::Arc;
3use std::sync::mpsc::channel;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_debug::debug::DebugWithDb;
10use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
11use cairo_lang_filesystem::ids::CrateInput;
12use cairo_lang_runner::casm_run::{StarknetHintProcessor, format_for_panic};
13use cairo_lang_runner::profiling::{
14    ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
15};
16use cairo_lang_runner::{
17    CairoHintProcessor, ProfilingInfoCollectionConfig, RunResultValue, RunnerError,
18    SierraCasmRunner, StarknetExecutionResources,
19};
20use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
21use cairo_lang_starknet::starknet_plugin_suite;
22use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
23use cairo_lang_test_plugin::{
24    TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
25    compile_test_prepared_db, test_plugin_suite,
26};
27use cairo_lang_utils::casts::IntoOrPanic;
28use colored::Colorize;
29use itertools::Itertools;
30use num_traits::ToPrimitive;
31use rayon::prelude::{IntoParallelIterator, ParallelIterator};
32use salsa::Database;
33
34#[cfg(test)]
35mod test;
36
37type ArcCustomHintProcessorFactory = Arc<
38    dyn (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>) + Send + Sync,
39>;
40
41/// Compile and run tests.
42pub struct TestRunner<'db> {
43    compiler: TestCompiler<'db>,
44    config: TestRunConfig,
45    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
46}
47
48impl<'db> TestRunner<'db> {
49    /// Configure a new test runner
50    ///
51    /// # Arguments
52    ///
53    /// * `path` - The path to compile and run its tests
54    /// * `filter` - Run only tests containing the filter string
55    /// * `include_ignored` - Include ignored tests as well
56    /// * `ignored` - Run ignored tests only
57    /// * `starknet` - Add the starknet plugin to run the tests
58    pub fn new(
59        path: &Path,
60        starknet: bool,
61        allow_warnings: bool,
62        config: TestRunConfig,
63    ) -> Result<Self> {
64        let compiler = TestCompiler::try_new(
65            path,
66            allow_warnings,
67            config.gas_enabled,
68            TestsCompilationConfig {
69                starknet,
70                add_statements_functions: config
71                    .profiler_config
72                    .as_ref()
73                    .is_some_and(|c| c.requires_cairo_debug_info()),
74                add_statements_code_locations: false,
75                contract_declarations: None,
76                contract_crate_ids: None,
77                executable_crate_ids: None,
78            },
79        )?;
80        Ok(Self { compiler, config, custom_hint_processor_factory: None })
81    }
82
83    /// Make this runner run tests using a custom hint processor.
84    pub fn with_custom_hint_processor(
85        &mut self,
86        factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
87        + Send
88        + Sync
89        + 'static,
90    ) -> &mut Self {
91        self.custom_hint_processor_factory = Some(Arc::new(factory));
92        self
93    }
94
95    /// Runs the tests and process the results for a summary.
96    pub fn run(&self) -> Result<Option<TestsSummary>> {
97        let runner = CompiledTestRunner {
98            compiled: self.compiler.build()?,
99            config: self.config.clone(),
100            custom_hint_processor_factory: self.custom_hint_processor_factory.clone(),
101        };
102        runner.run(Some(&self.compiler.db))
103    }
104}
105
106pub struct CompiledTestRunner<'db> {
107    compiled: TestCompilation<'db>,
108    config: TestRunConfig,
109    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
110}
111
112impl<'db> CompiledTestRunner<'db> {
113    /// Configure a new compiled test runner
114    ///
115    /// # Arguments
116    ///
117    /// * `compiled` - The compiled tests to run
118    /// * `config` - Test run configuration
119    pub fn new(compiled: TestCompilation<'db>, config: TestRunConfig) -> Self {
120        Self { compiled, config, custom_hint_processor_factory: None }
121    }
122
123    /// Make this runner run tests using a custom hint processor.
124    pub fn with_custom_hint_processor(
125        &mut self,
126        factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
127        + Send
128        + Sync
129        + 'static,
130    ) -> &mut Self {
131        self.custom_hint_processor_factory = Some(Arc::new(factory));
132        self
133    }
134
135    /// Execute preconfigured test execution.
136    pub fn run(self, opt_db: Option<&'db RootDatabase>) -> Result<Option<TestsSummary>> {
137        let (compiled, filtered_out) = filter_test_cases(
138            self.compiled,
139            self.config.include_ignored,
140            self.config.ignored,
141            &self.config.filter,
142        );
143
144        let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
145            opt_db.map(|db| db as &dyn Database),
146            compiled,
147            &self.config,
148            self.custom_hint_processor_factory,
149        )?;
150
151        if failed.is_empty() {
152            println!(
153                "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
154                "ok".bright_green(),
155                passed.len(),
156                failed.len(),
157                ignored.len()
158            );
159            Ok(None)
160        } else {
161            println!("failures:");
162            for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
163                print!("   {failure} - ");
164                match run_result {
165                    Ok(RunResultValue::Success(_)) => {
166                        println!("expected panic but finished successfully.");
167                    }
168                    Ok(RunResultValue::Panic(values)) => {
169                        println!("{}", format_for_panic(values.into_iter()));
170                    }
171                    Err(err) => {
172                        println!("{err}");
173                    }
174                }
175            }
176            println!();
177            bail!(
178                "test result: {}. {} passed; {} failed; {} ignored",
179                "FAILED".bright_red(),
180                passed.len(),
181                failed.len(),
182                ignored.len()
183            );
184        }
185    }
186}
187
188/// Configuration of compiled tests runner.
189#[derive(Clone, Debug)]
190pub struct TestRunConfig {
191    pub filter: String,
192    pub include_ignored: bool,
193    pub ignored: bool,
194    /// Whether to run the profiler and how.
195    pub profiler_config: Option<ProfilerConfig>,
196    /// Whether to enable gas calculation.
197    pub gas_enabled: bool,
198    /// Whether to print used resources after each test.
199    pub print_resource_usage: bool,
200}
201
202/// The test cases compiler.
203pub struct TestCompiler<'db> {
204    pub db: RootDatabase,
205    pub main_crate_ids: Vec<CrateInput>,
206    pub test_crate_ids: Vec<CrateInput>,
207    pub allow_warnings: bool,
208    pub config: TestsCompilationConfig<'db>,
209}
210
211impl<'db> TestCompiler<'db> {
212    /// Configure a new test compiler
213    ///
214    /// # Arguments
215    ///
216    /// * `path` - The path to compile and run its tests
217    /// * `starknet` - Add the starknet plugin to run the tests
218    pub fn try_new(
219        path: &Path,
220        allow_warnings: bool,
221        gas_enabled: bool,
222        config: TestsCompilationConfig<'db>,
223    ) -> Result<Self> {
224        let mut db = {
225            let mut b = RootDatabase::builder();
226            let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
227            if !gas_enabled {
228                cfg.insert(Cfg::kv("gas", "disabled"));
229                b.skip_auto_withdraw_gas();
230            }
231            b.detect_corelib();
232            b.with_cfg(cfg);
233            b.with_default_plugin_suite(test_plugin_suite());
234            if config.starknet {
235                b.with_default_plugin_suite(starknet_plugin_suite());
236            }
237            b.build()?
238        };
239
240        let main_crate_inputs = setup_project(&mut db, Path::new(&path))?;
241
242        Ok(Self {
243            db: db.snapshot(),
244            test_crate_ids: main_crate_inputs.clone(),
245            main_crate_ids: main_crate_inputs,
246            allow_warnings,
247            config,
248        })
249    }
250
251    /// Build the tests and collect metadata.
252    pub fn build(&self) -> Result<TestCompilation<'_>> {
253        let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
254        if self.allow_warnings {
255            diag_reporter = diag_reporter.allow_warnings();
256        }
257
258        compile_test_prepared_db(
259            &self.db,
260            self.config.clone(),
261            self.test_crate_ids.clone(),
262            diag_reporter,
263        )
264    }
265}
266
267/// Filter compiled test cases with user provided arguments.
268///
269/// # Arguments
270/// * `compiled` - Compiled test cases with metadata.
271/// * `include_ignored` - Include ignored tests as well.
272/// * `ignored` - Run ignored tests only.l
273/// * `filter` - Include only tests containing the filter string.
274/// # Returns
275/// * (`TestCompilation`, `usize`) - The filtered test cases and the number of filtered out cases.
276pub fn filter_test_cases<'db>(
277    compiled: TestCompilation<'db>,
278    include_ignored: bool,
279    ignored: bool,
280    filter: &str,
281) -> (TestCompilation<'db>, usize) {
282    let total_tests_count = compiled.metadata.named_tests.len();
283    let named_tests = compiled
284        .metadata
285        .named_tests
286        .into_iter()
287        // Filtering unignored tests in `ignored` mode. Keep all tests in `include-ignored` mode.
288        .filter(|(_, test)| !ignored || test.ignored || include_ignored)
289        .map(|(func, mut test)| {
290            // Un-ignoring all the tests in `include-ignored` and `ignored` mode.
291            if include_ignored || ignored {
292                test.ignored = false;
293            }
294            (func, test)
295        })
296        .filter(|(name, _)| name.contains(filter))
297        .collect_vec();
298    let filtered_out = total_tests_count - named_tests.len();
299    let tests = TestCompilation {
300        sierra_program: compiled.sierra_program,
301        metadata: TestCompilationMetadata { named_tests, ..compiled.metadata },
302    };
303    (tests, filtered_out)
304}
305
306/// The status of a ran test.
307enum TestStatus {
308    Success,
309    Fail(RunResultValue),
310}
311
312/// The result of a ran test.
313struct TestResult {
314    /// The status of the run.
315    status: TestStatus,
316    /// The gas usage of the run if relevant.
317    gas_usage: Option<i64>,
318    /// The used resources of the run.
319    used_resources: StarknetExecutionResources,
320    /// The profiling info of the run, if requested.
321    profiling_info: Option<ProfilingInfo>,
322}
323
324/// Summary data of the ran tests.
325pub struct TestsSummary {
326    passed: Vec<String>,
327    failed: Vec<String>,
328    ignored: Vec<String>,
329    failed_run_results: Vec<Result<RunResultValue>>,
330}
331
332/// Runs the tests and process the results for a summary.
333pub fn run_tests(
334    opt_db: Option<&dyn Database>,
335    compiled: TestCompilation<'_>,
336    config: &TestRunConfig,
337    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
338) -> Result<TestsSummary> {
339    let TestCompilation {
340        sierra_program: sierra_program_with_debug_info,
341        metadata:
342            TestCompilationMetadata {
343                named_tests,
344                function_set_costs,
345                contracts_info,
346                statements_locations,
347            },
348    } = compiled;
349    let sierra_program = sierra_program_with_debug_info.program;
350    let runner = SierraCasmRunner::new(
351        sierra_program.clone(),
352        if config.gas_enabled {
353            Some(MetadataComputationConfig {
354                function_set_costs,
355                linear_gas_solver: true,
356                linear_ap_change_solver: true,
357                skip_non_linear_solver_comparisons: false,
358                compute_runtime_costs: false,
359            })
360        } else {
361            None
362        },
363        contracts_info,
364        config.profiler_config.as_ref().map(ProfilingInfoCollectionConfig::from_profiler_config),
365    )
366    .map_err(|err| {
367        let (RunnerError::BuildError(err), Some(db), Some(statements_locations)) =
368            (&err, opt_db, &statements_locations)
369        else {
370            return err.into();
371        };
372        let mut locs = vec![];
373        for stmt_idx in err.stmt_indices() {
374            if let Some(loc) = statements_locations.statement_diagnostic_location(db, stmt_idx) {
375                locs.push(format!("#{stmt_idx} {:?}", loc.debug(db)))
376            }
377        }
378        anyhow::anyhow!("{err}\n{}", locs.join("\n"))
379    })
380    .with_context(|| "Failed setting up runner.")?;
381    let suffix = if named_tests.len() != 1 { "s" } else { "" };
382    println!("running {} test{}", named_tests.len(), suffix);
383
384    let (tx, rx) = channel::<_>();
385    rayon::spawn(move || {
386        named_tests.into_par_iter().for_each(|(name, test)| {
387            let result =
388                run_single_test(test, &name, &runner, custom_hint_processor_factory.clone());
389            tx.send((name, result)).unwrap();
390        })
391    });
392
393    let profiler_data = config.profiler_config.as_ref().and_then(|profiler_config| {
394        if !profiler_config.requires_cairo_debug_info() {
395            return None;
396        }
397
398        let db = opt_db.expect("db must be passed when doing cairo level profiling.");
399        Some((
400            ProfilingInfoProcessor::new(
401                Some(db),
402                &sierra_program,
403                statements_locations
404                    .expect(
405                        "statements locations must be present when doing cairo level profiling.",
406                    )
407                    .get_statements_functions_map_for_tests(db),
408            ),
409            ProfilingInfoProcessorParams::from_profiler_config(profiler_config),
410        ))
411    });
412
413    let mut summary = TestsSummary {
414        passed: vec![],
415        failed: vec![],
416        ignored: vec![],
417        failed_run_results: vec![],
418    };
419    while let Ok((name, result)) = rx.recv() {
420        update_summary(&mut summary, name, result, &profiler_data, config.print_resource_usage);
421    }
422
423    Ok(summary)
424}
425
426/// Runs a single test and returns a tuple of its name and result.
427fn run_single_test(
428    test: TestConfig,
429    name: &str,
430    runner: &SierraCasmRunner,
431    custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
432) -> Result<Option<TestResult>> {
433    if test.ignored {
434        return Ok(None);
435    }
436    let func = runner.find_function(name)?;
437
438    let (hint_processor, ctx) =
439        runner.prepare_starknet_context(func, vec![], test.available_gas, Default::default())?;
440
441    let mut hint_processor = match custom_hint_processor_factory {
442        Some(f) => f(hint_processor),
443        None => Box::new(hint_processor),
444    };
445
446    let result =
447        runner.run_function_with_prepared_starknet_context(func, &mut *hint_processor, ctx)?;
448
449    Ok(Some(TestResult {
450        status: match &result.value {
451            RunResultValue::Success(_) => match test.expectation {
452                TestExpectation::Success => TestStatus::Success,
453                TestExpectation::Panics(_) => TestStatus::Fail(result.value),
454            },
455            RunResultValue::Panic(value) => match test.expectation {
456                TestExpectation::Success => TestStatus::Fail(result.value),
457                TestExpectation::Panics(panic_expectation) => match panic_expectation {
458                    PanicExpectation::Exact(expected) if value != &expected => {
459                        TestStatus::Fail(result.value)
460                    }
461                    _ => TestStatus::Success,
462                },
463            },
464        },
465        gas_usage: test
466            .available_gas
467            .zip(result.gas_counter)
468            .map(|(before, after)| {
469                before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
470            })
471            .or_else(|| runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())),
472        used_resources: result.used_resources,
473        profiling_info: result.profiling_info,
474    }))
475}
476
477/// Updates the test summary with the given test result.
478fn update_summary(
479    summary: &mut TestsSummary,
480    name: String,
481    test_result: Result<Option<TestResult>>,
482    profiler_data: &Option<(ProfilingInfoProcessor<'_>, ProfilingInfoProcessorParams)>,
483    print_resource_usage: bool,
484) {
485    let (res_type, status_str, gas_usage, used_resources, profiling_info) = match test_result {
486        Ok(None) => (&mut summary.ignored, "ignored".bright_yellow(), None, None, None),
487        Err(err) => {
488            summary.failed_run_results.push(Err(err));
489            (&mut summary.failed, "failed to run".bright_magenta(), None, None, None)
490        }
491        Ok(Some(result)) => {
492            let (res_type, status_str) = match result.status {
493                TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
494                TestStatus::Fail(run_result) => {
495                    summary.failed_run_results.push(Ok(run_result));
496                    (&mut summary.failed, "fail".bright_red())
497                }
498            };
499            (
500                res_type,
501                status_str,
502                result.gas_usage,
503                print_resource_usage.then_some(result.used_resources),
504                result.profiling_info,
505            )
506        }
507    };
508    if let Some(gas_usage) = gas_usage {
509        println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
510    } else {
511        println!("test {name} ... {status_str}");
512    }
513    if let Some(used_resources) = used_resources {
514        let filtered = used_resources.basic_resources.filter_unused_builtins();
515        // Prints the used resources per test. E.g.:
516        // ```ignore
517        // test cairo_level_tests::interoperability::test_contract_not_deployed ... ok (gas usage est.: 77320)
518        //     steps: 42
519        //     memory holes: 20
520        //     builtins: ("range_check_builtin": 3)
521        //     syscalls: ("CallContract": 1)
522        // test cairo_level_tests::events::test_pop_log ... ok (gas usage est.: 55440)
523        //     steps: 306
524        //     memory holes: 35
525        //     builtins: ("range_check_builtin": 24)
526        //     syscalls: ("EmitEvent": 2)
527        // ```
528        println!("    steps: {}", filtered.n_steps);
529        println!("    memory holes: {}", filtered.n_memory_holes);
530
531        print_resource_map(
532            filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
533            "builtins",
534        );
535        print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
536    }
537    if let Some((profiling_processor, profiling_params)) = profiler_data {
538        let processed_profiling_info = profiling_processor.process(
539            &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
540            profiling_params,
541        );
542        println!("Profiling info:\n{processed_profiling_info}");
543    }
544    res_type.push(name);
545}
546
547/// Given an iterator of (String, usize) pairs, prints a usage map. E.g.:
548///     syscalls: ("EmitEvent": 2)
549///     syscalls: ("CallContract": 1)
550fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
551    if m.len() != 0 {
552        println!(
553            "    {resource_type}: ({})",
554            m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
555        );
556    }
557}