1use crate::{create_file, source_from_path};
8use oak_core::{Language, Parser, errors::OakError};
9
10#[cfg(feature = "serde")]
11use crate::json_from_path;
12#[cfg(feature = "serde")]
13use serde::Serialize;
14
15use std::{
16 fmt::Debug,
17 path::{Path, PathBuf},
18 time::Duration,
19};
20use walkdir::WalkDir;
21
22pub struct ParserTester {
28 root: PathBuf,
29 extensions: Vec<String>,
30 timeout: Duration,
31}
32
33#[derive(Debug, Clone, PartialEq)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub struct ParserTestExpected {
40 pub success: bool,
42 pub node_count: usize,
44 pub ast_structure: AstNodeData,
46 pub errors: Vec<String>,
48}
49
50#[derive(Debug, Clone, PartialEq)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct AstNodeData {
57 pub kind: String,
59 pub children: Vec<AstNodeData>,
61 pub text_length: usize,
63 pub is_leaf: bool,
65}
66
67impl ParserTester {
68 pub fn new<P: AsRef<Path>>(root: P) -> Self {
70 Self { root: root.as_ref().to_path_buf(), extensions: vec![], timeout: Duration::from_secs(10) }
71 }
72
73 pub fn with_extension(mut self, extension: impl ToString) -> Self {
75 self.extensions.push(extension.to_string());
76 self
77 }
78
79 pub fn with_timeout(mut self, timeout: Duration) -> Self {
81 self.timeout = timeout;
82 self
83 }
84
85 #[cfg(feature = "serde")]
87 pub fn run_tests<L, P>(self, parser: &P) -> Result<(), OakError>
88 where
89 P: Parser<L> + Send + Sync,
90 L: Language + Send + Sync,
91 L::ElementType: serde::Serialize + Debug + Sync + Send + Eq + From<L::TokenType>,
92 {
93 let test_files = self.find_test_files()?;
94 let force_regenerated = std::env::var("REGENERATE_TESTS").unwrap_or("0".to_string()) == "1";
95 let mut regenerated_any = false;
96
97 for file_path in test_files {
98 println!("Testing file: {}", file_path.display());
99 regenerated_any |= self.test_single_file::<L, P>(&file_path, parser, force_regenerated)?;
100 }
101
102 if regenerated_any && force_regenerated {
103 println!("Tests regenerated for: {}", self.root.display());
104 Ok(())
105 }
106 else {
107 Ok(())
108 }
109 }
110
111 #[cfg(not(feature = "serde"))]
113 pub fn run_tests<L, P>(self, _parser: &P) -> Result<(), OakError>
114 where
115 P: Parser<L> + Send + Sync,
116 L: Language + Send + Sync,
117 L::ElementType: Debug + Sync + Send + Eq + From<L::TokenType>,
118 {
119 Ok(())
120 }
121
122 fn find_test_files(&self) -> Result<Vec<PathBuf>, OakError> {
123 let mut files = Vec::new();
124
125 for entry in WalkDir::new(&self.root) {
126 let entry = entry.unwrap();
127 let path = entry.path();
128
129 if path.is_file() {
130 if let Some(ext) = path.extension() {
131 let ext_str = ext.to_str().unwrap_or("");
132 if self.extensions.iter().any(|e| e == ext_str) {
133 let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
135 let is_output_file = file_name.ends_with(".parsed.json") || file_name.ends_with(".lexed.json") || file_name.ends_with(".built.json") || file_name.ends_with(".expected.json");
136
137 if !is_output_file {
138 files.push(path.to_path_buf());
139 }
140 }
141 }
142 }
143 }
144
145 Ok(files)
146 }
147
148 #[cfg(feature = "serde")]
149 fn test_single_file<L, P>(&self, file_path: &Path, parser: &P, force_regenerated: bool) -> Result<bool, OakError>
150 where
151 P: Parser<L> + Send + Sync,
152 L: Language + Send + Sync,
153 L::ElementType: serde::Serialize + Debug + Sync + Send + Eq + From<L::TokenType>,
154 {
155 let source = source_from_path(file_path)?;
156
157 use std::sync::{Arc, Mutex};
159 let result: Arc<Mutex<Option<Result<ParserTestExpected, OakError>>>> = Arc::new(Mutex::new(None));
160 let result_clone = Arc::clone(&result);
161 let timeout = self.timeout;
162
163 std::thread::scope(|s| {
164 let handle = s.spawn(move || {
165 let mut cache = oak_core::parser::ParseSession::<L>::default();
166 let parse_out = parser.parse(&source, &[], &mut cache);
167
168 let (success, ast_structure) = match &parse_out.result {
170 Ok(root) => {
171 let ast = Self::to_ast::<L>(root);
172 (true, ast)
173 }
174 Err(_) => {
175 let ast = AstNodeData { kind: "Error".to_string(), children: vec![], text_length: 0, is_leaf: true };
176 (false, ast)
177 }
178 };
179
180 let mut error_messages: Vec<String> = parse_out.diagnostics.iter().map(|e| e.to_string()).collect();
182 if let Err(e) = &parse_out.result {
183 error_messages.push(e.to_string());
184 }
185
186 let node_count = Self::count_nodes(&ast_structure);
188
189 let test_result = ParserTestExpected { success, node_count, ast_structure, errors: error_messages };
190
191 let mut result = result_clone.lock().unwrap();
192 *result = Some(Ok(test_result));
193 });
194
195 let start_time = std::time::Instant::now();
197 let timeout_occurred = loop {
198 if handle.is_finished() {
200 break false;
201 }
202
203 if start_time.elapsed() > timeout {
205 break true;
206 }
207
208 std::thread::sleep(std::time::Duration::from_millis(10));
210 };
211
212 if timeout_occurred {
214 return Err(OakError::custom_error(&format!("Parser test timed out after {:?} for file: {}", timeout, file_path.display())));
215 }
216
217 let test_result = {
219 let result_guard = result.lock().unwrap();
220 match result_guard.as_ref() {
221 Some(Ok(test_result)) => test_result.clone(),
222 Some(Err(e)) => return Err(e.clone()),
223 None => return Err(OakError::custom_error("Parser test thread disconnected unexpectedly")),
224 }
225 };
226
227 let mut regenerated = false;
228 let expected_file = file_path.with_extension(format!("{}.parsed.json", file_path.extension().unwrap_or_default().to_str().unwrap_or("")));
229
230 if !expected_file.exists() {
232 let legacy_file = file_path.with_extension("expected.json");
233 if legacy_file.exists() {
234 let _ = std::fs::rename(&legacy_file, &expected_file);
235 }
236 }
237
238 if expected_file.exists() && !force_regenerated {
239 let expected_json = json_from_path(&expected_file)?;
240 let expected: ParserTestExpected = serde_json::from_value(expected_json).map_err(|e| OakError::custom_error(e.to_string()))?;
241 if test_result != expected {
242 return Err(OakError::test_failure(file_path.to_path_buf(), format!("{:#?}", expected), format!("{:#?}", test_result)));
243 }
244 }
245 else {
246 use std::io::Write;
247 let mut file = create_file(&expected_file)?;
248 let mut buf = Vec::new();
249 let formatter = serde_json::ser::PrettyFormatter::with_indent(b" "); let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
251 test_result.serialize(&mut ser).map_err(|e| OakError::custom_error(e.to_string()))?;
252 file.write_all(&buf).map_err(|e| OakError::custom_error(e.to_string()))?;
253
254 if force_regenerated {
255 regenerated = true;
256 }
257 else {
258 return Err(OakError::test_regenerated(expected_file));
259 }
260 }
261
262 Ok(regenerated)
263 })
264 }
265
266 fn to_ast<'a, L: Language>(root: &'a oak_core::GreenNode<'a, L>) -> AstNodeData {
267 let kind_str = format!("{:?}", root.kind);
268 let mut children = Vec::new();
269 let mut leaf_count: usize = 0;
270 let mut leaf_text_length: usize = 0;
271
272 for c in root.children {
273 match c {
274 oak_core::GreenTree::Node(n) => children.push(Self::to_ast(n)),
275 oak_core::GreenTree::Leaf(l) => {
276 leaf_count += 1;
277 leaf_text_length += l.length as usize;
278 }
279 }
280 }
281
282 if leaf_count > 0 {
283 children.push(AstNodeData { kind: format!("Leaves({})", leaf_count), children: vec![], text_length: leaf_text_length, is_leaf: true });
284 }
285
286 AstNodeData { kind: kind_str, children, text_length: root.byte_length as usize, is_leaf: false }
287 }
288
289 fn count_nodes(node: &AstNodeData) -> usize {
290 1 + node.children.iter().map(Self::count_nodes).sum::<usize>()
291 }
292}