Skip to main content

leo_test_framework/
lib.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This is a simple test framework for the Leo compiler.
18
19#[cfg(not(feature = "no_parallel"))]
20use rayon::prelude::*;
21
22use similar::{ChangeTag, TextDiff};
23use std::{
24    fs,
25    path::{Path, PathBuf},
26};
27use walkdir::WalkDir;
28
29enum TestFailure {
30    Panicked(String),
31    Mismatch { got: String, expected: String },
32}
33
34/// Print a unified diff between expected and actual content.
35fn print_diff(expected: &str, actual: &str) {
36    let diff = TextDiff::from_lines(expected, actual);
37    let has_changes = diff.iter_all_changes().any(|c| c.tag() != ChangeTag::Equal);
38    if !has_changes {
39        return;
40    }
41    for change in diff.iter_all_changes() {
42        let sign = match change.tag() {
43            ChangeTag::Delete => "-",
44            ChangeTag::Insert => "+",
45            ChangeTag::Equal => " ",
46        };
47        eprint!("{sign}{change}");
48    }
49    eprintln!();
50}
51
52/// Pulls tests from `category`, running them through the `runner` and
53/// comparing them against expectations in previous runs.
54///
55/// The tests are `.leo` files in `tests/{category}`, and the
56/// runner receives the contents of each of them as a `&str`,
57/// returning a `String` result. A test is considered to have failed
58/// if it panics or if results differ from the previous run.
59///
60///
61/// If no corresponding `.out` file is found in `expecations/{category}`,
62/// or if the environment variable `UPDATE_EXPECT` is set, no
63/// comparison to a previous result is done and the result of the current
64/// run is written to the file.
65pub fn run_tests(category: &str, runner: fn(&str) -> String) {
66    // This ensures error output doesn't try to display colors.
67    unsafe {
68        // SAFETY: Safety issues around `set_var` are surprisingly complicated.
69        // For now, I think marking tests as `serial` may be sufficient to
70        // address this, and in the future we'll try to think of an alternative for
71        // error output.
72        std::env::set_var("NOCOLOR", "x");
73    }
74
75    let base_tests_dir: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "..", "tests"].iter().collect();
76
77    let base_tests_dir = base_tests_dir.canonicalize().unwrap();
78    let tests_dir = base_tests_dir.join("tests").join(category);
79    let expectations_dir = base_tests_dir.join("expectations").join(category);
80
81    let filter_string = std::env::var("TEST_FILTER").unwrap_or_default();
82    let rewrite_expectations = std::env::var("UPDATE_EXPECT").is_ok();
83
84    struct TestResult {
85        failure: Option<TestFailure>,
86        name: PathBuf,
87        wrote: bool,
88    }
89
90    let paths: Vec<PathBuf> = WalkDir::new(&tests_dir)
91        .into_iter()
92        .flatten()
93        .filter_map(|entry| {
94            let path = entry.path();
95
96            if path.to_str().is_none() {
97                panic!("Path not unicode: {}.", path.display());
98            };
99
100            let path_str = path.to_str().unwrap();
101
102            if !path_str.contains(&filter_string) || !path_str.ends_with(".leo") {
103                return None;
104            }
105
106            Some(path.into())
107        })
108        .collect();
109
110    let run_test = |path: &PathBuf| -> TestResult {
111        let contents =
112            fs::read_to_string(path).unwrap_or_else(|e| panic!("Failed to read file {}: {e}.", path.display()));
113        let result_output = std::panic::catch_unwind(|| runner(&contents));
114        if let Err(payload) = result_output {
115            let s1 = payload.downcast_ref::<&str>().map(|s| s.to_string());
116            let s2 = payload.downcast_ref::<String>().cloned();
117            let s = s1.or(s2).unwrap_or_else(|| "Unknown panic payload".to_string());
118
119            return TestResult { failure: Some(TestFailure::Panicked(s)), name: path.clone(), wrote: false };
120        }
121        let output = result_output.unwrap();
122
123        let mut expectation_path: PathBuf = expectations_dir.join(path.strip_prefix(&tests_dir).unwrap());
124        expectation_path.set_extension("out");
125
126        // It may not be ideal to the the IO below in parallel, but I'm thinking it likely won't matter.
127        if rewrite_expectations || !expectation_path.exists() {
128            fs::write(&expectation_path, &output)
129                .unwrap_or_else(|e| panic!("Failed to write file {}: {e}.", expectation_path.display()));
130            TestResult { failure: None, name: path.clone(), wrote: true }
131        } else {
132            let expected = fs::read_to_string(&expectation_path)
133                .unwrap_or_else(|e| panic!("Failed to read file {}: {e}.", expectation_path.display()));
134            if output == expected {
135                TestResult { failure: None, name: path.clone(), wrote: false }
136            } else {
137                TestResult {
138                    failure: Some(TestFailure::Mismatch { got: output, expected }),
139                    name: path.clone(),
140                    wrote: false,
141                }
142            }
143        }
144    };
145
146    #[cfg(feature = "no_parallel")]
147    let results: Vec<TestResult> = paths.iter().map(run_test).collect();
148
149    #[cfg(not(feature = "no_parallel"))]
150    let results: Vec<TestResult> = paths.par_iter().map(run_test).collect();
151
152    println!("Ran {} tests.", results.len());
153
154    let failure_count = results.iter().filter(|test_result| test_result.failure.is_some()).count();
155
156    if failure_count != 0 {
157        eprintln!("{failure_count}/{} tests failed.", results.len());
158    }
159
160    let writes = results.iter().filter(|test_result| test_result.wrote).count();
161
162    for test_result in results.iter() {
163        if let Some(test_failure) = &test_result.failure {
164            eprintln!("FAILURE: {}:", test_result.name.display());
165            match test_failure {
166                TestFailure::Panicked(s) => eprintln!("Rust panic:\n{s}"),
167                TestFailure::Mismatch { got, expected } => {
168                    eprintln!("Diff (expected -> got):");
169                    print_diff(expected, got);
170                }
171            }
172        }
173    }
174
175    if writes != 0 {
176        println!("Wrote {}/{} expectation files for tests:", writes, results.len());
177    }
178
179    for test_result in results.iter() {
180        if test_result.wrote {
181            println!("{}", test_result.name.display());
182        }
183    }
184
185    assert!(failure_count == 0);
186}
187
188pub fn run_single_test(category: &str, path: &Path, runner: fn(&str) -> String) {
189    use std::fs;
190
191    // Canonicalize the test file path to avoid strip_prefix issues
192    let path = path.canonicalize().unwrap();
193
194    unsafe {
195        // Disable colored output in test failures
196        std::env::set_var("NOCOLOR", "x");
197    }
198
199    // Base directories
200    let base_tests_dir: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "..", "tests"].iter().collect();
201    let base_tests_dir = base_tests_dir.canonicalize().unwrap();
202
203    let tests_dir = base_tests_dir.join("tests").join(category);
204    let expectations_dir = base_tests_dir.join("expectations").join(category);
205
206    let rewrite_expectations = std::env::var("UPDATE_EXPECT").is_ok();
207
208    // Read the test file
209    println!("Running: {}", path.display());
210    let contents = fs::read_to_string(&path).unwrap_or_else(|e| panic!("Failed to read file {}: {e}.", path.display()));
211
212    // Run the test and catch panics
213    let result_output = std::panic::catch_unwind(|| runner(&contents));
214
215    let mut wrote = false;
216
217    match result_output {
218        Err(payload) => {
219            let s1 = payload.downcast_ref::<&str>().map(|s| s.to_string());
220            let s2 = payload.downcast_ref::<String>().cloned();
221            let s = s1.or(s2).unwrap_or_else(|| "Unknown panic payload".to_string());
222
223            eprintln!("FAILURE: {}:", path.display());
224            eprintln!("Rust panic:\n{s}");
225            panic!("Test failed: {}", path.display());
226        }
227        Ok(output) => {
228            // Expectation file
229            let mut expectation_path = expectations_dir.join(path.strip_prefix(&tests_dir).unwrap());
230            expectation_path.set_extension("out");
231
232            if rewrite_expectations || !expectation_path.exists() {
233                fs::write(&expectation_path, &output)
234                    .unwrap_or_else(|e| panic!("Failed to write file {}: {e}.", expectation_path.display()));
235                wrote = true;
236            } else {
237                let expected = fs::read_to_string(&expectation_path)
238                    .unwrap_or_else(|e| panic!("Failed to read file {}: {e}.", expectation_path.display()));
239
240                if output != expected {
241                    eprintln!("FAILURE: {}:", path.display());
242                    eprintln!("Diff (expected -> got):");
243                    print_diff(&expected, &output);
244                    panic!("Test failed: {}", path.display());
245                }
246            }
247        }
248    }
249
250    if wrote {
251        println!("Wrote expectation file for test: {}", path.display());
252    }
253}