1use std::path::Path;
2use std::sync::Mutex;
3use std::vec::IntoIter;
4
5use anyhow::{Context, Result, bail};
6use cairo_lang_compiler::db::RootDatabase;
7use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
8use cairo_lang_compiler::project::setup_project;
9use cairo_lang_filesystem::cfg::{Cfg, CfgSet};
10use cairo_lang_filesystem::ids::CrateId;
11use cairo_lang_runner::casm_run::format_next_item;
12use cairo_lang_runner::profiling::{
13 ProfilingInfo, ProfilingInfoProcessor, ProfilingInfoProcessorParams,
14};
15use cairo_lang_runner::{
16 ProfilingInfoCollectionConfig, RunResultValue, SierraCasmRunner, StarknetExecutionResources,
17};
18use cairo_lang_sierra::extensions::gas::CostTokenType;
19use cairo_lang_sierra::ids::FunctionId;
20use cairo_lang_sierra::program::{Program, StatementIdx};
21use cairo_lang_sierra_generator::db::SierraGenGroup;
22use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig;
23use cairo_lang_starknet::contract::ContractInfo;
24use cairo_lang_starknet::starknet_plugin_suite;
25use cairo_lang_test_plugin::test_config::{PanicExpectation, TestExpectation};
26use cairo_lang_test_plugin::{
27 TestCompilation, TestCompilationMetadata, TestConfig, TestsCompilationConfig,
28 compile_test_prepared_db, test_plugin_suite,
29};
30use cairo_lang_utils::casts::IntoOrPanic;
31use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
32use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
33use colored::Colorize;
34use itertools::Itertools;
35use num_traits::ToPrimitive;
36use rayon::prelude::{IntoParallelIterator, ParallelIterator};
37use starknet_types_core::felt::Felt as Felt252;
38
39#[cfg(test)]
40mod test;
41
42pub struct TestRunner {
44 compiler: TestCompiler,
45 config: TestRunConfig,
46}
47
48impl TestRunner {
49 pub fn new(
59 path: &Path,
60 starknet: bool,
61 allow_warnings: bool,
62 config: TestRunConfig,
63 ) -> Result<Self> {
64 let compiler = TestCompiler::try_new(
65 path,
66 allow_warnings,
67 config.gas_enabled,
68 TestsCompilationConfig {
69 starknet,
70 add_statements_functions: config.run_profiler == RunProfilerConfig::Cairo,
71 add_statements_code_locations: false,
72 contract_declarations: None,
73 contract_crate_ids: None,
74 executable_crate_ids: None,
75 },
76 )?;
77 Ok(Self { compiler, config })
78 }
79
80 pub fn run(&self) -> Result<Option<TestsSummary>> {
82 let runner = CompiledTestRunner::new(self.compiler.build()?, self.config.clone());
83 runner.run(Some(&self.compiler.db))
84 }
85}
86
87pub struct CompiledTestRunner {
88 pub compiled: TestCompilation,
89 pub config: TestRunConfig,
90}
91
92impl CompiledTestRunner {
93 pub fn new(compiled: TestCompilation, config: TestRunConfig) -> Self {
100 Self { compiled, config }
101 }
102
103 pub fn run(self, db: Option<&RootDatabase>) -> Result<Option<TestsSummary>> {
105 let (compiled, filtered_out) = filter_test_cases(
106 self.compiled,
107 self.config.include_ignored,
108 self.config.ignored,
109 &self.config.filter,
110 );
111
112 let TestsSummary { passed, failed, ignored, failed_run_results } = run_tests(
113 if self.config.run_profiler == RunProfilerConfig::Cairo {
114 let db = db.expect("db must be passed when profiling.");
115 let statements_locations = compiled
116 .metadata
117 .statements_locations
118 .expect("statements locations must be present when profiling.");
119 Some(PorfilingAuxData {
120 db,
121 statements_functions: statements_locations
122 .get_statements_functions_map_for_tests(db),
123 })
124 } else {
125 None
126 },
127 compiled.metadata.named_tests,
128 compiled.sierra_program.program,
129 compiled.metadata.function_set_costs,
130 compiled.metadata.contracts_info,
131 &self.config,
132 )?;
133
134 if failed.is_empty() {
135 println!(
136 "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;",
137 "ok".bright_green(),
138 passed.len(),
139 failed.len(),
140 ignored.len()
141 );
142 Ok(None)
143 } else {
144 println!("failures:");
145 for (failure, run_result) in failed.iter().zip_eq(failed_run_results) {
146 print!(" {failure} - ");
147 match run_result {
148 RunResultValue::Success(_) => {
149 println!("expected panic but finished successfully.");
150 }
151 RunResultValue::Panic(values) => {
152 println!("{}", format_for_panic(values.into_iter()));
153 }
154 }
155 }
156 println!();
157 bail!(
158 "test result: {}. {} passed; {} failed; {} ignored",
159 "FAILED".bright_red(),
160 passed.len(),
161 failed.len(),
162 ignored.len()
163 );
164 }
165 }
166}
167
168fn format_for_panic(mut felts: IntoIter<Felt252>) -> String {
170 let mut items = Vec::new();
171 while let Some(item) = format_next_item(&mut felts) {
172 items.push(item.quote_if_string());
173 }
174 let panic_values_string =
175 if let [item] = &items[..] { item.clone() } else { format!("({})", items.join(", ")) };
176 format!("Panicked with {panic_values_string}.")
177}
178
179#[derive(Clone, Debug, PartialEq, Eq)]
186pub enum RunProfilerConfig {
187 None,
188 Cairo,
189 Sierra,
190}
191
192#[derive(Clone, Debug)]
194pub struct TestRunConfig {
195 pub filter: String,
196 pub include_ignored: bool,
197 pub ignored: bool,
198 pub run_profiler: RunProfilerConfig,
200 pub gas_enabled: bool,
202 pub print_resource_usage: bool,
204}
205
206pub struct TestCompiler {
208 pub db: RootDatabase,
209 pub main_crate_ids: Vec<CrateId>,
210 pub test_crate_ids: Vec<CrateId>,
211 pub allow_warnings: bool,
212 pub config: TestsCompilationConfig,
213}
214
215impl TestCompiler {
216 pub fn try_new(
223 path: &Path,
224 allow_warnings: bool,
225 gas_enabled: bool,
226 config: TestsCompilationConfig,
227 ) -> Result<Self> {
228 let db = &mut {
229 let mut b = RootDatabase::builder();
230 if !gas_enabled {
231 b.skip_auto_withdraw_gas();
232 } else {
233 b.with_add_redeposit_gas();
234 }
235 b.detect_corelib();
236 b.with_cfg(CfgSet::from_iter([Cfg::name("test"), Cfg::kv("target", "test")]));
237 b.with_plugin_suite(test_plugin_suite());
238 if config.starknet {
239 b.with_plugin_suite(starknet_plugin_suite());
240 }
241 b.build()?
242 };
243
244 let main_crate_ids = setup_project(db, Path::new(&path))?;
245
246 Ok(Self {
247 db: db.snapshot(),
248 test_crate_ids: main_crate_ids.clone(),
249 main_crate_ids,
250 allow_warnings,
251 config,
252 })
253 }
254
255 pub fn build(&self) -> Result<TestCompilation> {
257 let mut diag_reporter =
258 DiagnosticsReporter::stderr().with_crates(&self.main_crate_ids.clone());
259 if self.allow_warnings {
260 diag_reporter = diag_reporter.allow_warnings();
261 }
262
263 compile_test_prepared_db(
264 &self.db,
265 self.config.clone(),
266 self.test_crate_ids.clone(),
267 diag_reporter,
268 )
269 }
270}
271
272pub fn filter_test_cases(
282 compiled: TestCompilation,
283 include_ignored: bool,
284 ignored: bool,
285 filter: &str,
286) -> (TestCompilation, usize) {
287 let total_tests_count = compiled.metadata.named_tests.len();
288 let named_tests = compiled
289 .metadata
290 .named_tests
291 .into_iter()
292 .filter(|(_, test)| !ignored || test.ignored || include_ignored)
294 .map(|(func, mut test)| {
295 if include_ignored || ignored {
297 test.ignored = false;
298 }
299 (func, test)
300 })
301 .filter(|(name, _)| name.contains(filter))
302 .collect_vec();
303 let filtered_out = total_tests_count - named_tests.len();
304 let tests = TestCompilation {
305 sierra_program: compiled.sierra_program,
306 metadata: TestCompilationMetadata { named_tests, ..(compiled.metadata) },
307 };
308 (tests, filtered_out)
309}
310
311enum TestStatus {
313 Success,
314 Fail(RunResultValue),
315}
316
317struct TestResult {
319 status: TestStatus,
321 gas_usage: Option<i64>,
323 used_resources: StarknetExecutionResources,
325 profiling_info: Option<ProfilingInfo>,
327}
328
329pub struct TestsSummary {
331 passed: Vec<String>,
332 failed: Vec<String>,
333 ignored: Vec<String>,
334 failed_run_results: Vec<RunResultValue>,
335}
336
337pub struct PorfilingAuxData<'a> {
339 pub db: &'a dyn SierraGenGroup,
340 pub statements_functions: UnorderedHashMap<StatementIdx, String>,
341}
342
343pub fn run_tests(
345 profiler_data: Option<PorfilingAuxData<'_>>,
346 named_tests: Vec<(String, TestConfig)>,
347 sierra_program: Program,
348 function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
349 contracts_info: OrderedHashMap<Felt252, ContractInfo>,
350 config: &TestRunConfig,
351) -> Result<TestsSummary> {
352 let runner = SierraCasmRunner::new(
353 sierra_program.clone(),
354 if config.gas_enabled {
355 Some(MetadataComputationConfig {
356 function_set_costs,
357 linear_gas_solver: true,
358 linear_ap_change_solver: true,
359 skip_non_linear_solver_comparisons: false,
360 compute_runtime_costs: false,
361 })
362 } else {
363 None
364 },
365 contracts_info,
366 match config.run_profiler {
367 RunProfilerConfig::None => None,
368 RunProfilerConfig::Cairo | RunProfilerConfig::Sierra => {
369 Some(ProfilingInfoCollectionConfig::default())
370 }
371 },
372 )
373 .with_context(|| "Failed setting up runner.")?;
374 let suffix = if named_tests.len() != 1 { "s" } else { "" };
375 println!("running {} test{}", named_tests.len(), suffix);
376 let wrapped_summary = Mutex::new(Ok(TestsSummary {
377 passed: vec![],
378 failed: vec![],
379 ignored: vec![],
380 failed_run_results: vec![],
381 }));
382
383 if profiler_data.is_none() {
385 named_tests
386 .into_par_iter()
387 .map(|(name, test)| run_single_test(test, name, &runner))
388 .for_each(|res| {
389 update_summary(
390 &wrapped_summary,
391 res,
392 &None,
393 &sierra_program,
394 &ProfilingInfoProcessorParams {
395 process_by_original_user_function: false,
396 process_by_cairo_function: false,
397 ..ProfilingInfoProcessorParams::default()
398 },
399 config.print_resource_usage,
400 );
401 });
402 } else {
403 eprintln!("Note: Tests don't run in parallel when running with profiling.");
404 named_tests
405 .into_iter()
406 .map(move |(name, test)| run_single_test(test, name, &runner))
407 .for_each(|test_result| {
408 update_summary(
409 &wrapped_summary,
410 test_result,
411 &profiler_data,
412 &sierra_program,
413 &ProfilingInfoProcessorParams::default(),
414 config.print_resource_usage,
415 );
416 });
417 }
418
419 wrapped_summary.into_inner().unwrap()
420}
421
422fn run_single_test(
424 test: TestConfig,
425 name: String,
426 runner: &SierraCasmRunner,
427) -> anyhow::Result<(String, Option<TestResult>)> {
428 if test.ignored {
429 return Ok((name, None));
430 }
431 let func = runner.find_function(name.as_str())?;
432 let result = runner
433 .run_function_with_starknet_context(func, vec![], test.available_gas, Default::default())
434 .with_context(|| format!("Failed to run the function `{}`.", name.as_str()))?;
435 Ok((
436 name,
437 Some(TestResult {
438 status: match &result.value {
439 RunResultValue::Success(_) => match test.expectation {
440 TestExpectation::Success => TestStatus::Success,
441 TestExpectation::Panics(_) => TestStatus::Fail(result.value),
442 },
443 RunResultValue::Panic(value) => match test.expectation {
444 TestExpectation::Success => TestStatus::Fail(result.value),
445 TestExpectation::Panics(panic_expectation) => match panic_expectation {
446 PanicExpectation::Exact(expected) if value != &expected => {
447 TestStatus::Fail(result.value)
448 }
449 _ => TestStatus::Success,
450 },
451 },
452 },
453 gas_usage: test
454 .available_gas
455 .zip(result.gas_counter)
456 .map(|(before, after)| {
457 before.into_or_panic::<i64>() - after.to_bigint().to_i64().unwrap()
458 })
459 .or_else(|| {
460 runner.initial_required_gas(func).map(|gas| gas.into_or_panic::<i64>())
461 }),
462 used_resources: result.used_resources,
463 profiling_info: result.profiling_info,
464 }),
465 ))
466}
467
468fn update_summary(
470 wrapped_summary: &Mutex<std::prelude::v1::Result<TestsSummary, anyhow::Error>>,
471 test_result: std::prelude::v1::Result<(String, Option<TestResult>), anyhow::Error>,
472 profiler_data: &Option<PorfilingAuxData<'_>>,
473 sierra_program: &Program,
474 profiling_params: &ProfilingInfoProcessorParams,
475 print_resource_usage: bool,
476) {
477 let mut wrapped_summary = wrapped_summary.lock().unwrap();
478 if wrapped_summary.is_err() {
479 return;
480 }
481 let (name, opt_result) = match test_result {
482 Ok((name, opt_result)) => (name, opt_result),
483 Err(err) => {
484 *wrapped_summary = Err(err);
485 return;
486 }
487 };
488 let summary = wrapped_summary.as_mut().unwrap();
489 let (res_type, status_str, gas_usage, used_resources, profiling_info) =
490 if let Some(result) = opt_result {
491 let (res_type, status_str) = match result.status {
492 TestStatus::Success => (&mut summary.passed, "ok".bright_green()),
493 TestStatus::Fail(run_result) => {
494 summary.failed_run_results.push(run_result);
495 (&mut summary.failed, "fail".bright_red())
496 }
497 };
498 (
499 res_type,
500 status_str,
501 result.gas_usage,
502 print_resource_usage.then_some(result.used_resources),
503 result.profiling_info,
504 )
505 } else {
506 (&mut summary.ignored, "ignored".bright_yellow(), None, None, None)
507 };
508 if let Some(gas_usage) = gas_usage {
509 println!("test {name} ... {status_str} (gas usage est.: {gas_usage})");
510 } else {
511 println!("test {name} ... {status_str}");
512 }
513 if let Some(used_resources) = used_resources {
514 let filtered = used_resources.basic_resources.filter_unused_builtins();
515 println!(" steps: {}", filtered.n_steps);
529 println!(" memory holes: {}", filtered.n_memory_holes);
530
531 print_resource_map(
532 filtered.builtin_instance_counter.into_iter().map(|(k, v)| (k.to_string(), v)),
533 "builtins",
534 );
535 print_resource_map(used_resources.syscalls.into_iter(), "syscalls");
536 }
537 if let Some(profiling_info) = profiling_info {
538 let Some(PorfilingAuxData { db, statements_functions }) = profiler_data else {
539 panic!("profiler_data is None");
540 };
541 let profiling_processor = ProfilingInfoProcessor::new(
542 Some(*db),
543 sierra_program.clone(),
544 statements_functions.clone(),
545 Default::default(),
546 );
547 let processed_profiling_info =
548 profiling_processor.process_ex(&profiling_info, profiling_params);
549 println!("Profiling info:\n{processed_profiling_info}");
550 }
551 res_type.push(name);
552}
553
554fn print_resource_map(m: impl ExactSizeIterator<Item = (String, usize)>, resource_type: &str) {
558 if m.len() != 0 {
559 println!(
560 " {resource_type}: ({})",
561 m.into_iter().sorted().map(|(k, v)| format!(r#""{k}": {v}"#)).join(", ")
562 );
563 }
564}