1use std::path::Path;
2use std::sync::Arc;
3use std::sync::mpsc::channel;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_debug::debug::DebugWithDb;
10use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
11use cairo_lang_filesystem::ids::CrateInput;
12use cairo_lang_runner::casm_run::{StarknetHintProcessor, format_for_panic};
13use cairo_lang_runner::profiling::{
14 ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
15};
16use cairo_lang_runner::{
17 CairoHintProcessor, ProfilingInfoCollectionConfig, RunResultValue, RunnerError,
18 SierraCasmRunner, StarknetExecutionResources,
19};
20use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
21use cairo_lang_starknet::starknet_plugin_suite;
22use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
23use cairo_lang_test_plugin::{
24 TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
25 compile_test_prepared_db, test_plugin_suite,
26};
27use cairo_lang_utils::casts::IntoOrPanic;
28use colored::Colorize;
29use itertools::Itertools;
30use num_traits::ToPrimitive;
31use rayon::prelude::{IntoParallelIterator, ParallelIterator};
32use salsa::Database;
33
34#[cfg(test)]
35mod test;
36
37type ArcCustomHintProcessorFactory = Arc<
38 dyn (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>) + Send + Sync,
39>;
40
41pub struct TestRunner<'db> {
43 compiler: TestCompiler<'db>,
44 config: TestRunConfig,
45 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
46}
47
48impl<'db> TestRunner<'db> {
49 pub fn new(
59 path: &Path,
60 starknet: bool,
61 allow_warnings: bool,
62 config: TestRunConfig,
63 ) -> Result<Self> {
64 let compiler = TestCompiler::try_new(
65 path,
66 allow_warnings,
67 config.gas_enabled,
68 TestsCompilationConfig {
69 starknet,
70 add_statements_functions: config
71 .profiler_config
72 .as_ref()
73 .is_some_and(|c| c.requires_cairo_debug_info()),
74 add_statements_code_locations: false,
75 contract_declarations: None,
76 contract_crate_ids: None,
77 executable_crate_ids: None,
78 add_functions_debug_info: false,
79 },
80 )?;
81 Ok(Self { compiler, config, custom_hint_processor_factory: None })
82 }
83
84 pub fn with_custom_hint_processor(
86 &mut self,
87 factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
88 + Send
89 + Sync
90 + 'static,
91 ) -> &mut Self {
92 self.custom_hint_processor_factory = Some(Arc::new(factory));
93 self
94 }
95
96 pub fn run(&self) -> Result<Option<TestsSummary>> {
98 let runner = CompiledTestRunner {
99 compiled: self.compiler.build()?,
100 config: self.config.clone(),
101 custom_hint_processor_factory: self.custom_hint_processor_factory.clone(),
102 };
103 runner.run(Some(&self.compiler.db))
104 }
105}
106
107pub struct CompiledTestRunner<'db> {
108 compiled: TestCompilation<'db>,
109 config: TestRunConfig,
110 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
111}
112
113impl<'db> CompiledTestRunner<'db> {
114 pub fn new(compiled: TestCompilation<'db>, config: TestRunConfig) -> Self {
121 Self { compiled, config, custom_hint_processor_factory: None }
122 }
123
124 pub fn with_custom_hint_processor(
126 &mut self,
127 factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
128 + Send
129 + Sync
130 + 'static,
131 ) -> &mut Self {
132 self.custom_hint_processor_factory = Some(Arc::new(factory));
133 self
134 }
135
136 pub fn run(self, opt_db: Option<&'db RootDatabase>) -> Result<Option<TestsSummary>> {
138 let (compiled, filtered_out) = filter_test_cases(
139 self.compiled,
140 self.config.include_ignored,
141 self.config.ignored,
142 &self.config.filter,
143 );
144
145 let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
146 opt_db.map(|db| db as &dyn Database),
147 compiled,
148 &self.config,
149 self.custom_hint_processor_factory,
150 )?;
151
152 if failed.is_empty() {
153 println!(
154 "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
155 "ok".bright_green(),
156 passed.len(),
157 failed.len(),
158 ignored.len()
159 );
160 Ok(None)
161 } else {
162 println!("failures:");
163 for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
164 print!(" {failure} - ");
165 match run_result {
166 Ok(RunResultValue::Success(_)) => {
167 println!("expected panic but finished successfully.");
168 }
169 Ok(RunResultValue::Panic(values)) => {
170 println!("{}", format_for_panic(values.into_iter()));
171 }
172 Err(err) => {
173 println!("{err}");
174 }
175 }
176 }
177 println!();
178 bail!(
179 "test result: {}. {} passed; {} failed; {} ignored",
180 "FAILED".bright_red(),
181 passed.len(),
182 failed.len(),
183 ignored.len()
184 );
185 }
186 }
187}
188
189#[derive(Clone, Debug)]
191pub struct TestRunConfig {
192 pub filter: String,
193 pub include_ignored: bool,
194 pub ignored: bool,
195 pub profiler_config: Option<ProfilerConfig>,
197 pub gas_enabled: bool,
199 pub print_resource_usage: bool,
201}
202
203pub struct TestCompiler<'db> {
205 pub db: RootDatabase,
206 pub main_crate_ids: Vec<CrateInput>,
207 pub test_crate_ids: Vec<CrateInput>,
208 pub allow_warnings: bool,
209 pub config: TestsCompilationConfig<'db>,
210}
211
212impl<'db> TestCompiler<'db> {
213 pub fn try_new(
220 path: &Path,
221 allow_warnings: bool,
222 gas_enabled: bool,
223 config: TestsCompilationConfig<'db>,
224 ) -> Result<Self> {
225 let mut db = {
226 let mut b = RootDatabase::builder();
227 let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
228 if !gas_enabled {
229 cfg.insert(Cfg::kv("gas", "disabled"));
230 b.skip_auto_withdraw_gas();
231 }
232 b.detect_corelib();
233 b.with_cfg(cfg);
234 b.with_default_plugin_suite(test_plugin_suite());
235 if config.starknet {
236 b.with_default_plugin_suite(starknet_plugin_suite());
237 }
238 b.build()?
239 };
240
241 let main_crate_inputs = setup_project(&mut db, Path::new(&path))?;
242
243 Ok(Self {
244 db: db.snapshot(),
245 test_crate_ids: main_crate_inputs.clone(),
246 main_crate_ids: main_crate_inputs,
247 allow_warnings,
248 config,
249 })
250 }
251
252 pub fn build(&self) -> Result<TestCompilation<'_>> {
254 let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
255 if self.allow_warnings {
256 diag_reporter = diag_reporter.allow_warnings();
257 }
258
259 compile_test_prepared_db(
260 &self.db,
261 self.config.clone(),
262 self.test_crate_ids.clone(),
263 diag_reporter,
264 )
265 }
266}
267
268pub fn filter_test_cases<'db>(
278 compiled: TestCompilation<'db>,
279 include_ignored: bool,
280 ignored: bool,
281 filter: &str,
282) -> (TestCompilation<'db>, usize) {
283 let total_tests_count = compiled.metadata.named_tests.len();
284 let named_tests = compiled
285 .metadata
286 .named_tests
287 .into_iter()
288 .filter(|(_, test)| !ignored || test.ignored || include_ignored)
290 .map(|(func, mut test)| {
291 if include_ignored || ignored {
293 test.ignored = false;
294 }
295 (func, test)
296 })
297 .filter(|(name, _)| name.contains(filter))
298 .collect_vec();
299 let filtered_out = total_tests_count - named_tests.len();
300 let tests = TestCompilation {
301 sierra_program: compiled.sierra_program,
302 metadata: TestCompilationMetadata { named_tests, ..compiled.metadata },
303 };
304 (tests, filtered_out)
305}
306
307enum TestStatus {
309 Success,
310 Fail(RunResultValue),
311}
312
313struct TestResult {
315 status: TestStatus,
317 gas_usage: Option<i64>,
319 used_resources: StarknetExecutionResources,
321 profiling_info: Option<ProfilingInfo>,
323}
324
325pub struct TestsSummary {
327 passed: Vec<String>,
328 failed: Vec<String>,
329 ignored: Vec<String>,
330 failed_run_results: Vec<Result<RunResultValue>>,
331}
332
333pub fn run_tests(
335 opt_db: Option<&dyn Database>,
336 compiled: TestCompilation<'_>,
337 config: &TestRunConfig,
338 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
339) -> Result<TestsSummary> {
340 let TestCompilation {
341 sierra_program: sierra_program_with_debug_info,
342 metadata:
343 TestCompilationMetadata {
344 named_tests,
345 function_set_costs,
346 contracts_info,
347 statements_locations,
348 },
349 } = compiled;
350 let sierra_program = sierra_program_with_debug_info.program;
351 let runner = SierraCasmRunner::new(
352 sierra_program.clone(),
353 if config.gas_enabled {
354 Some(MetadataComputationConfig {
355 function_set_costs,
356 linear_gas_solver: true,
357 linear_ap_change_solver: true,
358 skip_non_linear_solver_comparisons: false,
359 compute_runtime_costs: false,
360 })
361 } else {
362 None
363 },
364 contracts_info,
365 config.profiler_config.as_ref().map(ProfilingInfoCollectionConfig::from_profiler_config),
366 )
367 .map_err(|err| {
368 let (RunnerError::BuildError(err), Some(db), Some(statements_locations)) =
369 (&err, opt_db, &statements_locations)
370 else {
371 return err.into();
372 };
373 let mut locs = vec![];
374 for stmt_idx in err.stmt_indices() {
375 if let Some(loc) = statements_locations.statement_diagnostic_location(db, stmt_idx) {
376 locs.push(format!("#{stmt_idx} {:?}", loc.debug(db)))
377 }
378 }
379 anyhow::anyhow!("{err}\n{}", locs.join("\n"))
380 })
381 .with_context(|| "Failed setting up runner.")?;
382 let suffix = if named_tests.len() != 1 { "s" } else { "" };
383 println!("running {} test{}", named_tests.len(), suffix);
384
385 let (tx, rx) = channel::<_>();
386 rayon::spawn(move || {
387 named_tests.into_par_iter().for_each(|(name, test)| {
388 let result =
389 run_single_test(test, &name, &runner, custom_hint_processor_factory.clone());
390 tx.send((name, result)).unwrap();
391 })
392 });
393
394 let profiler_data = config.profiler_config.as_ref().and_then(|profiler_config| {
395 if !profiler_config.requires_cairo_debug_info() {
396 return None;
397 }
398
399 let db = opt_db.expect("db must be passed when doing cairo level profiling.");
400 Some((
401 ProfilingInfoProcessor::new(
402 Some(db),
403 &sierra_program,
404 statements_locations
405 .expect(
406 "statements locations must be present when doing cairo level profiling.",
407 )
408 .get_statements_functions_map_for_tests(db),
409 ),
410 ProfilingInfoProcessorParams::from_profiler_config(profiler_config),
411 ))
412 });
413
414 let mut summary = TestsSummary {
415 passed: vec![],
416 failed: vec![],
417 ignored: vec![],
418 failed_run_results: vec![],
419 };
420 while let Ok((name, result)) = rx.recv() {
421 update_summary(&mut summary, name, result, &profiler_data, config.print_resource_usage);
422 }
423
424 Ok(summary)
425}
426
427fn run_single_test(
429 test: TestConfig,
430 name: &str,
431 runner: &SierraCasmRunner,
432 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
433) -> Result<Option<TestResult>> {
434 if test.ignored {
435 return Ok(None);
436 }
437 let func = runner.find_function(name)?;
438
439 let (hint_processor, ctx) =
440 runner.prepare_starknet_context(func, vec![], test.available_gas, Default::default())?;
441
442 let mut hint_processor = match custom_hint_processor_factory {
443 Some(f) => f(hint_processor),
444 None => Box::new(hint_processor),
445 };
446
447 let result =
448 runner.run_function_with_prepared_starknet_context(func, &mut *hint_processor, ctx)?;
449
450 Ok(Some(TestResult {
451 status: match &result.value {
452 RunResultValue::Success(_) => match test.expectation {
453 TestExpectation::Success => TestStatus::Success,
454 TestExpectation::Panics(_) => TestStatus::Fail(result.value),
455 },
456 RunResultValue::Panic(value) => match test.expectation {
457 TestExpectation::Success => TestStatus::Fail(result.value),
458 TestExpectation::Panics(panic_expectation) => match panic_expectation {
459 PanicExpectation::Exact(expected) if value != &expected => {
460 TestStatus::Fail(result.value)
461 }
462 _ => TestStatus::Success,
463 },
464 },
465 },
466 gas_usage: test
467 .available_gas
468 .zip(result.gas_counter)
469 .map(|(before, after)| {
470 before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
471 })
472 .or_else(|| runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())),
473 used_resources: result.used_resources,
474 profiling_info: result.profiling_info,
475 }))
476}
477
478fn update_summary(
480 summary: &mut TestsSummary,
481 name: String,
482 test_result: Result<Option<TestResult>>,
483 profiler_data: &Option<(ProfilingInfoProcessor<'_>, ProfilingInfoProcessorParams)>,
484 print_resource_usage: bool,
485) {
486 let (res_type, status_str, gas_usage, used_resources, profiling_info) = match test_result {
487 Ok(None) => (&mut summary.ignored, "ignored".bright_yellow(), None, None, None),
488 Err(err) => {
489 summary.failed_run_results.push(Err(err));
490 (&mut summary.failed, "failed to run".bright_magenta(), None, None, None)
491 }
492 Ok(Some(result)) => {
493 let (res_type, status_str) = match result.status {
494 TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
495 TestStatus::Fail(run_result) => {
496 summary.failed_run_results.push(Ok(run_result));
497 (&mut summary.failed, "fail".bright_red())
498 }
499 };
500 (
501 res_type,
502 status_str,
503 result.gas_usage,
504 print_resource_usage.then_some(result.used_resources),
505 result.profiling_info,
506 )
507 }
508 };
509 if let Some(gas_usage) = gas_usage {
510 println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
511 } else {
512 println!("test {name} ... {status_str}");
513 }
514 if let Some(used_resources) = used_resources {
515 let filtered = used_resources.basic_resources.filter_unused_builtins();
516 println!(" steps: {}", filtered.n_steps);
530 println!(" memory holes: {}", filtered.n_memory_holes);
531
532 print_resource_map(
533 filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
534 "builtins",
535 );
536 print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
537 }
538 if let Some((profiling_processor, profiling_params)) = profiler_data {
539 let processed_profiling_info = profiling_processor.process(
540 &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
541 profiling_params,
542 );
543 println!("Profiling info:\n{processed_profiling_info}");
544 }
545 res_type.push(name);
546}
547
548fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
552 if m.len() != 0 {
553 println!(
554 " {resource_type}: ({})",
555 m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
556 );
557 }
558}