1use crate::{create_file, json_from_path, source_from_path};
8use oak_core::{Language, Parser, errors::OakError};
9use serde::{Deserialize, Serialize};
10
11use std::{
12 fmt::Debug,
13 path::{Path, PathBuf},
14 time::Duration,
15};
16use walkdir::WalkDir;
17
18pub struct ParserTester {
24 root: PathBuf,
25 extensions: Vec<String>,
26 timeout: Duration,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34pub struct ParserTestExpected {
35 pub success: bool,
37 pub node_count: usize,
39 pub ast_structure: AstNodeData,
41 pub errors: Vec<String>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50pub struct AstNodeData {
51 pub kind: String,
53 pub children: Vec<AstNodeData>,
55 pub text_length: usize,
57 pub is_leaf: bool,
59}
60
61impl ParserTester {
62 pub fn new<P: AsRef<Path>>(root: P) -> Self {
64 Self { root: root.as_ref().to_path_buf(), extensions: vec![], timeout: Duration::from_secs(10) }
65 }
66
67 pub fn with_extension(mut self, extension: impl ToString) -> Self {
69 self.extensions.push(extension.to_string());
70 self
71 }
72
73 pub fn with_timeout(mut self, timeout: Duration) -> Self {
75 self.timeout = timeout;
76 self
77 }
78
79 pub fn run_tests<L, P>(self, parser: &P) -> Result<(), OakError>
81 where
82 P: Parser<L> + Send + Sync,
83 L: Language + Send + Sync,
84 L::ElementType: Serialize + Debug + Sync + Send + Eq + From<L::TokenType>,
85 {
86 let test_files = self.find_test_files()?;
87 let force_regenerated = std::env::var("REGENERATE_TESTS").unwrap_or("0".to_string()) == "1";
88 let mut regenerated_any = false;
89
90 for file_path in test_files {
91 println!("Testing file: {}", file_path.display());
92 regenerated_any |= self.test_single_file::<L, P>(&file_path, parser, force_regenerated)?;
93 }
94
95 if regenerated_any && force_regenerated {
96 println!("Tests regenerated for: {}", self.root.display());
97 Ok(())
98 }
99 else {
100 Ok(())
101 }
102 }
103
104 fn find_test_files(&self) -> Result<Vec<PathBuf>, OakError> {
105 let mut files = Vec::new();
106
107 for entry in WalkDir::new(&self.root) {
108 let entry = entry.unwrap();
109 let path = entry.path();
110
111 if path.is_file() {
112 if let Some(ext) = path.extension() {
113 let ext_str = ext.to_str().unwrap_or("");
114 if self.extensions.iter().any(|e| e == ext_str) {
115 let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
117 let is_output_file = file_name.ends_with(".parsed.json") || file_name.ends_with(".lexed.json") || file_name.ends_with(".built.json") || file_name.ends_with(".expected.json");
118
119 if !is_output_file {
120 files.push(path.to_path_buf());
121 }
122 }
123 }
124 }
125 }
126
127 Ok(files)
128 }
129
130 fn test_single_file<L, P>(&self, file_path: &Path, parser: &P, force_regenerated: bool) -> Result<bool, OakError>
131 where
132 P: Parser<L> + Send + Sync,
133 L: Language + Send + Sync,
134 L::ElementType: Serialize + Debug + Sync + Send + From<L::TokenType>,
135 {
136 let source = source_from_path(file_path)?;
137
138 use std::sync::mpsc;
140 let (tx, rx) = mpsc::channel();
141 let timeout = self.timeout;
142
143 std::thread::scope(|s| {
144 s.spawn(move || {
145 let mut cache = oak_core::parser::ParseSession::<L>::default();
146 let parse_out = parser.parse(&source, &[], &mut cache);
147
148 let (success, ast_structure) = match &parse_out.result {
150 Ok(root) => {
151 let ast = Self::to_ast::<L>(root);
152 (true, ast)
153 }
154 Err(_) => {
155 let ast = AstNodeData { kind: "Error".to_string(), children: vec![], text_length: 0, is_leaf: true };
156 (false, ast)
157 }
158 };
159
160 let mut error_messages: Vec<String> = parse_out.diagnostics.iter().map(|e| e.to_string()).collect();
162 if let Err(e) = &parse_out.result {
163 error_messages.push(e.to_string());
164 }
165
166 let node_count = Self::count_nodes(&ast_structure);
168
169 let test_result = ParserTestExpected { success, node_count, ast_structure, errors: error_messages };
170
171 let _ = tx.send(Ok::<ParserTestExpected, OakError>(test_result));
172 });
173
174 let mut regenerated = false;
175 match rx.recv_timeout(timeout) {
176 Ok(Ok(test_result)) => {
177 let expected_file = file_path.with_extension(format!("{}.parsed.json", file_path.extension().unwrap_or_default().to_str().unwrap_or("")));
178
179 if !expected_file.exists() {
181 let legacy_file = file_path.with_extension("expected.json");
182 if legacy_file.exists() {
183 let _ = std::fs::rename(&legacy_file, &expected_file);
184 }
185 }
186
187 if expected_file.exists() && !force_regenerated {
188 let expected_json = json_from_path(&expected_file)?;
189 let expected: ParserTestExpected = serde_json::from_value(expected_json).map_err(|e| OakError::custom_error(e.to_string()))?;
190 if test_result != expected {
191 return Err(OakError::test_failure(file_path.to_path_buf(), format!("{:#?}", expected), format!("{:#?}", test_result)));
192 }
193 }
194 else {
195 use std::io::Write;
196 let mut file = create_file(&expected_file)?;
197 let json_val = serde_json::to_string_pretty(&test_result).map_err(|e| OakError::custom_error(e.to_string()))?;
198 file.write_all(json_val.as_bytes()).map_err(|e| OakError::custom_error(e.to_string()))?;
199
200 if force_regenerated {
201 regenerated = true;
202 }
203 else {
204 return Err(OakError::test_regenerated(expected_file));
205 }
206 }
207 }
208 Ok(Err(e)) => return Err(e),
209 Err(mpsc::RecvTimeoutError::Timeout) => {
210 return Err(OakError::custom_error(&format!("Parser test timed out after {:?} for file: {}", timeout, file_path.display())));
211 }
212 Err(mpsc::RecvTimeoutError::Disconnected) => {
213 return Err(OakError::custom_error("Parser test thread disconnected unexpectedly"));
214 }
215 }
216 Ok(regenerated)
217 })
218 }
219
220 fn to_ast<'a, L: Language>(root: &'a oak_core::GreenNode<'a, L>) -> AstNodeData {
221 let kind_str = format!("{:?}", root.kind);
222 let mut children = Vec::new();
223 let mut leaf_count: usize = 0;
224 let mut leaf_text_length: usize = 0;
225
226 for c in root.children {
227 match c {
228 oak_core::GreenTree::Node(n) => children.push(Self::to_ast(n)),
229 oak_core::GreenTree::Leaf(l) => {
230 leaf_count += 1;
231 leaf_text_length += l.length as usize;
232 }
233 }
234 }
235
236 if leaf_count > 0 {
237 children.push(AstNodeData { kind: format!("Leaves({})", leaf_count), children: vec![], text_length: leaf_text_length, is_leaf: true });
238 }
239
240 AstNodeData { kind: kind_str, children, text_length: root.byte_length as usize, is_leaf: false }
241 }
242
243 fn count_nodes(node: &AstNodeData) -> usize {
244 1 + node.children.iter().map(Self::count_nodes).sum::<usize>()
245 }
246}