1use std::path::Path;
2use std::sync::Arc;
3use std::sync::mpsc::channel;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_debug::debug::DebugWithDb;
10use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
11use cairo_lang_filesystem::ids::CrateInput;
12use cairo_lang_runner::casm_run::{StarknetHintProcessor, format_for_panic};
13use cairo_lang_runner::profiling::{
14 ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
15};
16use cairo_lang_runner::{
17 CairoHintProcessor, ProfilingInfoCollectionConfig, RunResultValue, RunnerError,
18 SierraCasmRunner, StarknetExecutionResources,
19};
20use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
21use cairo_lang_starknet::starknet_plugin_suite;
22use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
23use cairo_lang_test_plugin::{
24 TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
25 compile_test_prepared_db, test_plugin_suite,
26};
27use cairo_lang_utils::casts::IntoOrPanic;
28use colored::Colorize;
29use itertools::Itertools;
30use num_traits::ToPrimitive;
31use rayon::prelude::{IntoParallelIterator, ParallelIterator};
32use salsa::Database;
33
34#[cfg(test)]
35mod test;
36
37type ArcCustomHintProcessorFactory = Arc<
38 dyn (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>) + Send + Sync,
39>;
40
41pub struct TestRunner<'db> {
43 compiler: TestCompiler<'db>,
44 config: TestRunConfig,
45 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
46}
47
48impl<'db> TestRunner<'db> {
49 pub fn new(
59 path: &Path,
60 starknet: bool,
61 allow_warnings: bool,
62 config: TestRunConfig,
63 ) -> Result<Self> {
64 let compiler = TestCompiler::try_new(
65 path,
66 allow_warnings,
67 config.gas_enabled,
68 TestsCompilationConfig {
69 starknet,
70 add_statements_functions: config
71 .profiler_config
72 .as_ref()
73 .is_some_and(|c| c.requires_cairo_debug_info()),
74 add_statements_code_locations: false,
75 contract_declarations: None,
76 contract_crate_ids: None,
77 executable_crate_ids: None,
78 },
79 )?;
80 Ok(Self { compiler, config, custom_hint_processor_factory: None })
81 }
82
83 pub fn with_custom_hint_processor(
85 &mut self,
86 factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
87 + Send
88 + Sync
89 + 'static,
90 ) -> &mut Self {
91 self.custom_hint_processor_factory = Some(Arc::new(factory));
92 self
93 }
94
95 pub fn run(&self) -> Result<Option<TestsSummary>> {
97 let runner = CompiledTestRunner {
98 compiled: self.compiler.build()?,
99 config: self.config.clone(),
100 custom_hint_processor_factory: self.custom_hint_processor_factory.clone(),
101 };
102 runner.run(Some(&self.compiler.db))
103 }
104}
105
106pub struct CompiledTestRunner<'db> {
107 compiled: TestCompilation<'db>,
108 config: TestRunConfig,
109 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
110}
111
112impl<'db> CompiledTestRunner<'db> {
113 pub fn new(compiled: TestCompilation<'db>, config: TestRunConfig) -> Self {
120 Self { compiled, config, custom_hint_processor_factory: None }
121 }
122
123 pub fn with_custom_hint_processor(
125 &mut self,
126 factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
127 + Send
128 + Sync
129 + 'static,
130 ) -> &mut Self {
131 self.custom_hint_processor_factory = Some(Arc::new(factory));
132 self
133 }
134
135 pub fn run(self, opt_db: Option<&'db RootDatabase>) -> Result<Option<TestsSummary>> {
137 let (compiled, filtered_out) = filter_test_cases(
138 self.compiled,
139 self.config.include_ignored,
140 self.config.ignored,
141 &self.config.filter,
142 );
143
144 let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
145 opt_db.map(|db| db as &dyn Database),
146 compiled,
147 &self.config,
148 self.custom_hint_processor_factory,
149 )?;
150
151 if failed.is_empty() {
152 println!(
153 "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
154 "ok".bright_green(),
155 passed.len(),
156 failed.len(),
157 ignored.len()
158 );
159 Ok(None)
160 } else {
161 println!("failures:");
162 for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
163 print!(" {failure} - ");
164 match run_result {
165 Ok(RunResultValue::Success(_)) => {
166 println!("expected panic but finished successfully.");
167 }
168 Ok(RunResultValue::Panic(values)) => {
169 println!("{}", format_for_panic(values.into_iter()));
170 }
171 Err(err) => {
172 println!("{err}");
173 }
174 }
175 }
176 println!();
177 bail!(
178 "test result: {}. {} passed; {} failed; {} ignored",
179 "FAILED".bright_red(),
180 passed.len(),
181 failed.len(),
182 ignored.len()
183 );
184 }
185 }
186}
187
188#[derive(Clone, Debug)]
190pub struct TestRunConfig {
191 pub filter: String,
192 pub include_ignored: bool,
193 pub ignored: bool,
194 pub profiler_config: Option<ProfilerConfig>,
196 pub gas_enabled: bool,
198 pub print_resource_usage: bool,
200}
201
202pub struct TestCompiler<'db> {
204 pub db: RootDatabase,
205 pub main_crate_ids: Vec<CrateInput>,
206 pub test_crate_ids: Vec<CrateInput>,
207 pub allow_warnings: bool,
208 pub config: TestsCompilationConfig<'db>,
209}
210
211impl<'db> TestCompiler<'db> {
212 pub fn try_new(
219 path: &Path,
220 allow_warnings: bool,
221 gas_enabled: bool,
222 config: TestsCompilationConfig<'db>,
223 ) -> Result<Self> {
224 let mut db = {
225 let mut b = RootDatabase::builder();
226 let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
227 if !gas_enabled {
228 cfg.insert(Cfg::kv("gas", "disabled"));
229 b.skip_auto_withdraw_gas();
230 }
231 b.detect_corelib();
232 b.with_cfg(cfg);
233 b.with_default_plugin_suite(test_plugin_suite());
234 if config.starknet {
235 b.with_default_plugin_suite(starknet_plugin_suite());
236 }
237 b.build()?
238 };
239
240 let main_crate_inputs = setup_project(&mut db, Path::new(&path))?;
241
242 Ok(Self {
243 db: db.snapshot(),
244 test_crate_ids: main_crate_inputs.clone(),
245 main_crate_ids: main_crate_inputs,
246 allow_warnings,
247 config,
248 })
249 }
250
251 pub fn build(&self) -> Result<TestCompilation<'_>> {
253 let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
254 if self.allow_warnings {
255 diag_reporter = diag_reporter.allow_warnings();
256 }
257
258 compile_test_prepared_db(
259 &self.db,
260 self.config.clone(),
261 self.test_crate_ids.clone(),
262 diag_reporter,
263 )
264 }
265}
266
267pub fn filter_test_cases<'db>(
277 compiled: TestCompilation<'db>,
278 include_ignored: bool,
279 ignored: bool,
280 filter: &str,
281) -> (TestCompilation<'db>, usize) {
282 let total_tests_count = compiled.metadata.named_tests.len();
283 let named_tests = compiled
284 .metadata
285 .named_tests
286 .into_iter()
287 .filter(|(_, test)| !ignored || test.ignored || include_ignored)
289 .map(|(func, mut test)| {
290 if include_ignored || ignored {
292 test.ignored = false;
293 }
294 (func, test)
295 })
296 .filter(|(name, _)| name.contains(filter))
297 .collect_vec();
298 let filtered_out = total_tests_count - named_tests.len();
299 let tests = TestCompilation {
300 sierra_program: compiled.sierra_program,
301 metadata: TestCompilationMetadata { named_tests, ..compiled.metadata },
302 };
303 (tests, filtered_out)
304}
305
306enum TestStatus {
308 Success,
309 Fail(RunResultValue),
310}
311
312struct TestResult {
314 status: TestStatus,
316 gas_usage: Option<i64>,
318 used_resources: StarknetExecutionResources,
320 profiling_info: Option<ProfilingInfo>,
322}
323
324pub struct TestsSummary {
326 passed: Vec<String>,
327 failed: Vec<String>,
328 ignored: Vec<String>,
329 failed_run_results: Vec<Result<RunResultValue>>,
330}
331
332pub fn run_tests(
334 opt_db: Option<&dyn Database>,
335 compiled: TestCompilation<'_>,
336 config: &TestRunConfig,
337 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
338) -> Result<TestsSummary> {
339 let TestCompilation {
340 sierra_program: sierra_program_with_debug_info,
341 metadata:
342 TestCompilationMetadata {
343 named_tests,
344 function_set_costs,
345 contracts_info,
346 statements_locations,
347 },
348 } = compiled;
349 let sierra_program = sierra_program_with_debug_info.program;
350 let runner = SierraCasmRunner::new(
351 sierra_program.clone(),
352 if config.gas_enabled {
353 Some(MetadataComputationConfig {
354 function_set_costs,
355 linear_gas_solver: true,
356 linear_ap_change_solver: true,
357 skip_non_linear_solver_comparisons: false,
358 compute_runtime_costs: false,
359 })
360 } else {
361 None
362 },
363 contracts_info,
364 config.profiler_config.as_ref().map(ProfilingInfoCollectionConfig::from_profiler_config),
365 )
366 .map_err(|err| {
367 let (RunnerError::BuildError(err), Some(db), Some(statements_locations)) =
368 (&err, opt_db, &statements_locations)
369 else {
370 return err.into();
371 };
372 let mut locs = vec![];
373 for stmt_idx in err.stmt_indices() {
374 if let Some(loc) = statements_locations.statement_diagnostic_location(db, stmt_idx) {
375 locs.push(format!("#{stmt_idx} {:?}", loc.debug(db)))
376 }
377 }
378 anyhow::anyhow!("{err}\n{}", locs.join("\n"))
379 })
380 .with_context(|| "Failed setting up runner.")?;
381 let suffix = if named_tests.len() != 1 { "s" } else { "" };
382 println!("running {} test{}", named_tests.len(), suffix);
383
384 let (tx, rx) = channel::<_>();
385 rayon::spawn(move || {
386 named_tests.into_par_iter().for_each(|(name, test)| {
387 let result =
388 run_single_test(test, &name, &runner, custom_hint_processor_factory.clone());
389 tx.send((name, result)).unwrap();
390 })
391 });
392
393 let profiler_data = config.profiler_config.as_ref().and_then(|profiler_config| {
394 if !profiler_config.requires_cairo_debug_info() {
395 return None;
396 }
397
398 let db = opt_db.expect("db must be passed when doing cairo level profiling.");
399 Some((
400 ProfilingInfoProcessor::new(
401 Some(db),
402 &sierra_program,
403 statements_locations
404 .expect(
405 "statements locations must be present when doing cairo level profiling.",
406 )
407 .get_statements_functions_map_for_tests(db),
408 ),
409 ProfilingInfoProcessorParams::from_profiler_config(profiler_config),
410 ))
411 });
412
413 let mut summary = TestsSummary {
414 passed: vec![],
415 failed: vec![],
416 ignored: vec![],
417 failed_run_results: vec![],
418 };
419 while let Ok((name, result)) = rx.recv() {
420 update_summary(&mut summary, name, result, &profiler_data, config.print_resource_usage);
421 }
422
423 Ok(summary)
424}
425
426fn run_single_test(
428 test: TestConfig,
429 name: &str,
430 runner: &SierraCasmRunner,
431 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
432) -> Result<Option<TestResult>> {
433 if test.ignored {
434 return Ok(None);
435 }
436 let func = runner.find_function(name)?;
437
438 let (hint_processor, ctx) =
439 runner.prepare_starknet_context(func, vec![], test.available_gas, Default::default())?;
440
441 let mut hint_processor = match custom_hint_processor_factory {
442 Some(f) => f(hint_processor),
443 None => Box::new(hint_processor),
444 };
445
446 let result =
447 runner.run_function_with_prepared_starknet_context(func, &mut *hint_processor, ctx)?;
448
449 Ok(Some(TestResult {
450 status: match &result.value {
451 RunResultValue::Success(_) => match test.expectation {
452 TestExpectation::Success => TestStatus::Success,
453 TestExpectation::Panics(_) => TestStatus::Fail(result.value),
454 },
455 RunResultValue::Panic(value) => match test.expectation {
456 TestExpectation::Success => TestStatus::Fail(result.value),
457 TestExpectation::Panics(panic_expectation) => match panic_expectation {
458 PanicExpectation::Exact(expected) if value != &expected => {
459 TestStatus::Fail(result.value)
460 }
461 _ => TestStatus::Success,
462 },
463 },
464 },
465 gas_usage: test
466 .available_gas
467 .zip(result.gas_counter)
468 .map(|(before, after)| {
469 before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
470 })
471 .or_else(|| runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())),
472 used_resources: result.used_resources,
473 profiling_info: result.profiling_info,
474 }))
475}
476
477fn update_summary(
479 summary: &mut TestsSummary,
480 name: String,
481 test_result: Result<Option<TestResult>>,
482 profiler_data: &Option<(ProfilingInfoProcessor<'_>, ProfilingInfoProcessorParams)>,
483 print_resource_usage: bool,
484) {
485 let (res_type, status_str, gas_usage, used_resources, profiling_info) = match test_result {
486 Ok(None) => (&mut summary.ignored, "ignored".bright_yellow(), None, None, None),
487 Err(err) => {
488 summary.failed_run_results.push(Err(err));
489 (&mut summary.failed, "failed to run".bright_magenta(), None, None, None)
490 }
491 Ok(Some(result)) => {
492 let (res_type, status_str) = match result.status {
493 TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
494 TestStatus::Fail(run_result) => {
495 summary.failed_run_results.push(Ok(run_result));
496 (&mut summary.failed, "fail".bright_red())
497 }
498 };
499 (
500 res_type,
501 status_str,
502 result.gas_usage,
503 print_resource_usage.then_some(result.used_resources),
504 result.profiling_info,
505 )
506 }
507 };
508 if let Some(gas_usage) = gas_usage {
509 println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
510 } else {
511 println!("test {name} ... {status_str}");
512 }
513 if let Some(used_resources) = used_resources {
514 let filtered = used_resources.basic_resources.filter_unused_builtins();
515 println!(" steps: {}", filtered.n_steps);
529 println!(" memory holes: {}", filtered.n_memory_holes);
530
531 print_resource_map(
532 filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
533 "builtins",
534 );
535 print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
536 }
537 if let Some((profiling_processor, profiling_params)) = profiler_data {
538 let processed_profiling_info = profiling_processor.process(
539 &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
540 profiling_params,
541 );
542 println!("Profiling info:\n{processed_profiling_info}");
543 }
544 res_type.push(name);
545}
546
547fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
551 if m.len() != 0 {
552 println!(
553 " {resource_type}: ({})",
554 m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
555 );
556 }
557}