1use std::path::Path;
2use std::sync::Arc;
3use std::sync::mpsc::channel;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_debug::debug::DebugWithDb;
10use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
11use cairo_lang_filesystem::ids::CrateInput;
12use cairo_lang_runner::casm_run::{StarknetHintProcessor, format_for_panic};
13use cairo_lang_runner::profiling::{
14 ProfilerConfig, ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
15};
16use cairo_lang_runner::{
17 CairoHintProcessor, ProfilingInfoCollectionConfig, RunResultValue, RunnerError,
18 SierraCasmRunner, StarknetExecutionResources,
19};
20use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
21use cairo_lang_starknet::starknet_plugin_suite;
22use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
23use cairo_lang_test_plugin::{
24 TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
25 compile_test_prepared_db, test_plugin_suite,
26};
27use cairo_lang_utils::casts::IntoOrPanic;
28use colored::Colorize;
29use itertools::Itertools;
30use num_traits::ToPrimitive;
31use rayon::prelude::{IntoParallelIterator, ParallelIterator};
32use salsa::Database;
33
34#[cfg(test)]
35mod test;
36
37type ArcCustomHintProcessorFactory = Arc<
38 dyn (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>) + Send + Sync,
39>;
40
41pub struct TestRunner<'db> {
43 compiler: TestCompiler<'db>,
44 config: TestRunConfig,
45 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
46}
47
48impl<'db> TestRunner<'db> {
49 pub fn new(
59 path: &Path,
60 starknet: bool,
61 allow_warnings: bool,
62 config: TestRunConfig,
63 ) -> Result<Self> {
64 let compiler = TestCompiler::try_new(
65 path,
66 allow_warnings,
67 config.gas_enabled,
68 TestsCompilationConfig {
69 starknet,
70 add_statements_functions: config
71 .profiler_config
72 .as_ref()
73 .is_some_and(|c| c.requires_cairo_debug_info()),
74 add_statements_code_locations: false,
75 contract_declarations: None,
76 contract_crate_ids: None,
77 executable_crate_ids: None,
78 add_functions_debug_info: false,
79 replace_ids: false,
80 },
81 )?;
82 Ok(Self { compiler, config, custom_hint_processor_factory: None })
83 }
84
85 pub fn with_custom_hint_processor(
87 &mut self,
88 factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
89 + Send
90 + Sync
91 + 'static,
92 ) -> &mut Self {
93 self.custom_hint_processor_factory = Some(Arc::new(factory));
94 self
95 }
96
97 pub fn run(&self) -> Result<Option<TestsSummary>> {
99 let runner = CompiledTestRunner {
100 compiled: self.compiler.build()?,
101 config: self.config.clone(),
102 custom_hint_processor_factory: self.custom_hint_processor_factory.clone(),
103 };
104 runner.run(Some(&self.compiler.db))
105 }
106}
107
108pub struct CompiledTestRunner<'db> {
109 compiled: TestCompilation<'db>,
110 config: TestRunConfig,
111 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
112}
113
114impl<'db> CompiledTestRunner<'db> {
115 pub fn new(compiled: TestCompilation<'db>, config: TestRunConfig) -> Self {
122 Self { compiled, config, custom_hint_processor_factory: None }
123 }
124
125 pub fn with_custom_hint_processor(
127 &mut self,
128 factory: impl (for<'a> Fn(CairoHintProcessor<'a>) -> Box<dyn StarknetHintProcessor + 'a>)
129 + Send
130 + Sync
131 + 'static,
132 ) -> &mut Self {
133 self.custom_hint_processor_factory = Some(Arc::new(factory));
134 self
135 }
136
137 pub fn run(self, opt_db: Option<&'db RootDatabase>) -> Result<Option<TestsSummary>> {
139 let (compiled, filtered_out) = filter_test_cases(
140 self.compiled,
141 self.config.include_ignored,
142 self.config.ignored,
143 &self.config.filter,
144 );
145
146 let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
147 opt_db.map(|db| db as &dyn Database),
148 compiled,
149 &self.config,
150 self.custom_hint_processor_factory,
151 )?;
152
153 if failed.is_empty() {
154 println!(
155 "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
156 "ok".bright_green(),
157 passed.len(),
158 failed.len(),
159 ignored.len()
160 );
161 Ok(None)
162 } else {
163 println!("failures:");
164 for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
165 print!(" {failure} - ");
166 match run_result {
167 Ok(RunResultValue::Success(_)) => {
168 println!("expected panic but finished successfully.");
169 }
170 Ok(RunResultValue::Panic(values)) => {
171 println!("{}", format_for_panic(values.into_iter()));
172 }
173 Err(err) => {
174 println!("{err}");
175 }
176 }
177 }
178 println!();
179 bail!(
180 "test result: {}. {} passed; {} failed; {} ignored",
181 "FAILED".bright_red(),
182 passed.len(),
183 failed.len(),
184 ignored.len()
185 );
186 }
187 }
188}
189
190#[derive(Clone, Debug)]
192pub struct TestRunConfig {
193 pub filter: String,
194 pub include_ignored: bool,
195 pub ignored: bool,
196 pub profiler_config: Option<ProfilerConfig>,
198 pub gas_enabled: bool,
200 pub print_resource_usage: bool,
202}
203
204pub struct TestCompiler<'db> {
206 pub db: RootDatabase,
207 pub main_crate_ids: Vec<CrateInput>,
208 pub test_crate_ids: Vec<CrateInput>,
209 pub allow_warnings: bool,
210 pub config: TestsCompilationConfig<'db>,
211}
212
213impl<'db> TestCompiler<'db> {
214 pub fn try_new(
221 path: &Path,
222 allow_warnings: bool,
223 gas_enabled: bool,
224 config: TestsCompilationConfig<'db>,
225 ) -> Result<Self> {
226 let mut db = {
227 let mut b = RootDatabase::builder();
228 let mut cfg = CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]);
229 if !gas_enabled {
230 cfg.insert(Cfg::kv("gas", "disabled"));
231 b.skip_auto_withdraw_gas();
232 }
233 b.detect_corelib();
234 b.with_cfg(cfg);
235 b.with_default_plugin_suite(test_plugin_suite());
236 if config.starknet {
237 b.with_default_plugin_suite(starknet_plugin_suite());
238 }
239 b.build()?
240 };
241
242 let main_crate_inputs = setup_project(&mut db, path)?;
243
244 Ok(Self {
245 db: db.snapshot(),
246 test_crate_ids: main_crate_inputs.clone(),
247 main_crate_ids: main_crate_inputs,
248 allow_warnings,
249 config,
250 })
251 }
252
253 pub fn build(&self) -> Result<TestCompilation<'_>> {
255 let mut diag_reporter = DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids);
256 if self.allow_warnings {
257 diag_reporter = diag_reporter.allow_warnings();
258 }
259
260 compile_test_prepared_db(
261 &self.db,
262 self.config.clone(),
263 self.test_crate_ids.clone(),
264 diag_reporter,
265 )
266 }
267}
268
269pub fn filter_test_cases<'db>(
279 compiled: TestCompilation<'db>,
280 include_ignored: bool,
281 ignored: bool,
282 filter: &str,
283) -> (TestCompilation<'db>, usize) {
284 let total_tests_count = compiled.metadata.named_tests.len();
285 let named_tests = compiled
286 .metadata
287 .named_tests
288 .into_iter()
289 .filter(|(_, test)| !ignored || test.ignored || include_ignored)
291 .map(|(func, mut test)| {
292 if include_ignored || ignored {
294 test.ignored = false;
295 }
296 (func, test)
297 })
298 .filter(|(name, _)| name.contains(filter))
299 .collect_vec();
300 let filtered_out = total_tests_count - named_tests.len();
301 let tests = TestCompilation {
302 sierra_program: compiled.sierra_program,
303 metadata: TestCompilationMetadata { named_tests, ..compiled.metadata },
304 };
305 (tests, filtered_out)
306}
307
308enum TestStatus {
310 Success,
311 Fail(RunResultValue),
312}
313
314struct TestResult {
316 status: TestStatus,
318 gas_usage: Option<i64>,
320 used_resources: StarknetExecutionResources,
322 profiling_info: Option<ProfilingInfo>,
324}
325
326pub struct TestsSummary {
328 passed: Vec<String>,
329 failed: Vec<String>,
330 ignored: Vec<String>,
331 failed_run_results: Vec<Result<RunResultValue>>,
332}
333
334pub fn run_tests(
336 opt_db: Option<&dyn Database>,
337 compiled: TestCompilation<'_>,
338 config: &TestRunConfig,
339 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
340) -> Result<TestsSummary> {
341 let TestCompilation {
342 sierra_program: sierra_program_with_debug_info,
343 metadata:
344 TestCompilationMetadata {
345 named_tests,
346 function_set_costs,
347 contracts_info,
348 statements_locations,
349 },
350 } = compiled;
351 let sierra_program = sierra_program_with_debug_info.program;
352 let runner = SierraCasmRunner::new(
353 sierra_program.clone(),
354 if config.gas_enabled {
355 Some(MetadataComputationConfig {
356 function_set_costs,
357 linear_gas_solver: true,
358 linear_ap_change_solver: true,
359 skip_non_linear_solver_comparisons: false,
360 compute_runtime_costs: false,
361 })
362 } else {
363 None
364 },
365 contracts_info,
366 config.profiler_config.as_ref().map(ProfilingInfoCollectionConfig::from_profiler_config),
367 )
368 .map_err(|err| {
369 let (RunnerError::BuildError(err), Some(db), Some(statements_locations)) =
370 (&err, opt_db, &statements_locations)
371 else {
372 return err.into();
373 };
374 let mut locs = vec![];
375 for stmt_idx in err.stmt_indices() {
376 if let Some(loc) = statements_locations.statement_diagnostic_location(db, stmt_idx) {
377 locs.push(format!("#{stmt_idx} {:?}", loc.debug(db)))
378 }
379 }
380 anyhow::anyhow!("{err}\n{}", locs.join("\n"))
381 })
382 .with_context(|| "Failed setting up runner.")?;
383 let suffix = if named_tests.len() != 1 { "s" } else { "" };
384 println!("running {} test{}", named_tests.len(), suffix);
385
386 let (tx, rx) = channel::<_>();
387 rayon::spawn(move || {
388 named_tests.into_par_iter().for_each(|(name, test)| {
389 let result =
390 run_single_test(test, &name, &runner, custom_hint_processor_factory.clone());
391 tx.send((name, result)).unwrap();
392 })
393 });
394
395 let profiler_data = config.profiler_config.as_ref().and_then(|profiler_config| {
396 if !profiler_config.requires_cairo_debug_info() {
397 return None;
398 }
399
400 let db = opt_db.expect("db must be passed when doing cairo level profiling.");
401 Some((
402 ProfilingInfoProcessor::new(
403 Some(db),
404 &sierra_program,
405 statements_locations
406 .expect(
407 "statements locations must be present when doing cairo level profiling.",
408 )
409 .get_statements_functions_map_for_tests(db),
410 ),
411 ProfilingInfoProcessorParams::from_profiler_config(profiler_config),
412 ))
413 });
414
415 let mut summary = TestsSummary {
416 passed: vec![],
417 failed: vec![],
418 ignored: vec![],
419 failed_run_results: vec![],
420 };
421 while let Ok((name, result)) = rx.recv() {
422 update_summary(&mut summary, name, result, &profiler_data, config.print_resource_usage);
423 }
424
425 Ok(summary)
426}
427
428fn run_single_test(
430 test: TestConfig,
431 name: &str,
432 runner: &SierraCasmRunner,
433 custom_hint_processor_factory: Option<ArcCustomHintProcessorFactory>,
434) -> Result<Option<TestResult>> {
435 if test.ignored {
436 return Ok(None);
437 }
438 let func = runner.find_function(name)?;
439
440 let (hint_processor, ctx) =
441 runner.prepare_starknet_context(func, vec![], test.available_gas, Default::default())?;
442
443 let mut hint_processor = match custom_hint_processor_factory {
444 Some(f) => f(hint_processor),
445 None => Box::new(hint_processor),
446 };
447
448 let result =
449 runner.run_function_with_prepared_starknet_context(func, &mut *hint_processor, ctx)?;
450
451 Ok(Some(TestResult {
452 status: match &result.value {
453 RunResultValue::Success(_) => match test.expectation {
454 TestExpectation::Success => TestStatus::Success,
455 TestExpectation::Panics(_) => TestStatus::Fail(result.value),
456 },
457 RunResultValue::Panic(value) => match test.expectation {
458 TestExpectation::Success => TestStatus::Fail(result.value),
459 TestExpectation::Panics(panic_expectation) => match panic_expectation {
460 PanicExpectation::Exact(expected) if value != &expected => {
461 TestStatus::Fail(result.value)
462 }
463 _ => TestStatus::Success,
464 },
465 },
466 },
467 gas_usage: test
468 .available_gas
469 .zip(result.gas_counter)
470 .map(|(before, after)| {
471 before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
472 })
473 .or_else(|| runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())),
474 used_resources: result.used_resources,
475 profiling_info: result.profiling_info,
476 }))
477}
478
479fn update_summary(
481 summary: &mut TestsSummary,
482 name: String,
483 test_result: Result<Option<TestResult>>,
484 profiler_data: &Option<(ProfilingInfoProcessor<'_>, ProfilingInfoProcessorParams)>,
485 print_resource_usage: bool,
486) {
487 let (res_type, status_str, gas_usage, used_resources, profiling_info) = match test_result {
488 Ok(None) => (&mut summary.ignored, "ignored".bright_yellow(), None, None, None),
489 Err(err) => {
490 summary.failed_run_results.push(Err(err));
491 (&mut summary.failed, "failed to run".bright_magenta(), None, None, None)
492 }
493 Ok(Some(result)) => {
494 let (res_type, status_str) = match result.status {
495 TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
496 TestStatus::Fail(run_result) => {
497 summary.failed_run_results.push(Ok(run_result));
498 (&mut summary.failed, "fail".bright_red())
499 }
500 };
501 (
502 res_type,
503 status_str,
504 result.gas_usage,
505 print_resource_usage.then_some(result.used_resources),
506 result.profiling_info,
507 )
508 }
509 };
510 if let Some(gas_usage) = gas_usage {
511 println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
512 } else {
513 println!("test {name} ... {status_str}");
514 }
515 if let Some(used_resources) = used_resources {
516 let filtered = used_resources.basic_resources.filter_unused_builtins();
517 println!(" steps: {}", filtered.n_steps);
531 println!(" memory holes: {}", filtered.n_memory_holes);
532
533 print_resource_map(
534 filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
535 "builtins",
536 );
537 print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
538 }
539 if let Some((profiling_processor, profiling_params)) = profiler_data {
540 let processed_profiling_info = profiling_processor.process(
541 &profiling_info.expect("profiling_info must be Some when profiler_config is Some"),
542 profiling_params,
543 );
544 println!("Profiling info:\n{processed_profiling_info}");
545 }
546 res_type.push(name);
547}
548
549fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
553 if m.len() != 0 {
554 println!(
555 " {resource_type}: ({})",
556 m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
557 );
558 }
559}