1use crate::{create_file, source_from_path};
8use oak_core::{
9 Language, Lexer, Source, TokenType,
10 errors::{OakDiagnostics, OakError},
11};
12
13#[cfg(feature = "serde")]
14use crate::json_from_path;
15#[cfg(feature = "serde")]
16use serde::Serialize;
17
18use std::{
19 path::{Path, PathBuf},
20 sync::{Arc, Mutex},
21 thread,
22 time::{Duration, Instant},
23};
24use walkdir::WalkDir;
25
26pub struct LexerTester {
32 root: PathBuf,
33 extensions: Vec<String>,
34 timeout: Duration,
35}
36
37#[derive(Debug, PartialEq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct LexerTestExpected {
44 pub success: bool,
46 pub count: usize,
48 pub tokens: Vec<TokenData>,
50 pub errors: Vec<String>,
52}
53
54#[derive(Debug, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub struct TokenData {
61 pub kind: String,
63 pub text: String,
65 pub start: usize,
67 pub end: usize,
69}
70
71impl LexerTester {
72 pub fn new<P: AsRef<Path>>(root: P) -> Self {
74 Self { root: root.as_ref().to_path_buf(), extensions: vec![], timeout: Duration::from_secs(10) }
75 }
76
77 pub fn with_extension(mut self, extension: impl ToString) -> Self {
79 self.extensions.push(extension.to_string());
80 self
81 }
82
83 pub fn with_timeout(mut self, time: Duration) -> Self {
85 self.timeout = time;
86 self
87 }
88
89 #[cfg(feature = "serde")]
91 pub fn run_tests<L, Lex>(self, lexer: &Lex) -> Result<(), OakError>
92 where
93 L: Language + Send + Sync,
94 L::TokenType: serde::Serialize + std::fmt::Debug + Send + Sync,
95 Lex: Lexer<L> + Send + Sync + Clone,
96 {
97 let test_files = self.find_test_files()?;
98 let force_regenerated = std::env::var("REGENERATE_TESTS").unwrap_or("0".to_string()) == "1";
99 let mut regenerated_any = false;
100
101 for file_path in test_files {
102 println!("Testing file: {}", file_path.display());
103 regenerated_any |= self.test_single_file::<L, Lex>(&file_path, lexer, force_regenerated)?
104 }
105
106 if regenerated_any && force_regenerated {
107 println!("Tests regenerated for: {}", self.root.display());
108 }
109
110 Ok(())
111 }
112
113 #[cfg(not(feature = "serde"))]
115 pub fn run_tests<L, Lex>(self, _lexer: &Lex) -> Result<(), OakError>
116 where
117 L: Language + Send + Sync,
118 Lex: Lexer<L> + Send + Sync + Clone,
119 {
120 Ok(())
121 }
122
123 fn find_test_files(&self) -> Result<Vec<PathBuf>, OakError> {
124 let mut files = Vec::new();
125
126 for entry in WalkDir::new(&self.root) {
127 let entry = entry.unwrap();
128 let path = entry.path();
129
130 if path.is_file() {
131 if let Some(ext) = path.extension() {
132 let ext_str = ext.to_str().unwrap_or("");
133 if self.extensions.iter().any(|e| e == ext_str) {
134 let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
136 let is_output_file = file_name.ends_with(".parsed.json") || file_name.ends_with(".lexed.json") || file_name.ends_with(".built.json") || file_name.ends_with(".expected.json");
137
138 if !is_output_file {
139 files.push(path.to_path_buf())
140 }
141 }
142 }
143 }
144 }
145
146 Ok(files)
147 }
148
149 #[cfg(feature = "serde")]
150 fn test_single_file<L, Lex>(&self, file_path: &Path, lexer: &Lex, force_regenerated: bool) -> Result<bool, OakError>
151 where
152 L: Language + Send + Sync,
153 L::TokenType: serde::Serialize + std::fmt::Debug + Send + Sync,
154 Lex: Lexer<L> + Send + Sync + Clone,
155 {
156 let source = source_from_path(file_path)?;
157
158 let result = Arc::new(Mutex::new(None));
160 let result_clone = Arc::clone(&result);
161
162 let lexer_clone = lexer.clone();
164 let source_arc = Arc::new(source);
166 let source_clone = Arc::clone(&source_arc);
167
168 std::thread::scope(|s| {
170 let handle = s.spawn(move || {
171 let mut cache = oak_core::parser::ParseSession::<L>::default();
172 let output = lexer_clone.lex(&*source_clone, &[], &mut cache);
173 let mut result = result_clone.lock().unwrap();
174 *result = Some(output)
175 });
176
177 let start_time = Instant::now();
179 let timeout_occurred = loop {
180 if handle.is_finished() {
182 break false;
183 }
184
185 if start_time.elapsed() > self.timeout {
187 break true;
188 }
189
190 thread::sleep(Duration::from_millis(10));
192 };
193
194 if timeout_occurred {
196 return Err(OakError::custom_error(&format!("Lexer test timed out after {:?} for file: {}", self.timeout, file_path.display())));
197 }
198
199 Ok(())
200 })?;
201
202 let OakDiagnostics { result: tokens_result, mut diagnostics } = {
204 let result_guard = result.lock().unwrap();
205 match result_guard.as_ref() {
206 Some(output) => output.clone(),
207 None => return Err(OakError::custom_error("Failed to get lexer result")),
208 }
209 };
210
211 let mut success = true;
213 let tokens = match tokens_result {
214 Ok(tokens) => tokens,
215 Err(e) => {
216 success = false;
217 diagnostics.push(e);
218 oak_core::Tokens::default()
219 }
220 };
221
222 if !diagnostics.is_empty() {
223 success = false;
224 }
225
226 let tokens: Vec<TokenData> = tokens
227 .iter()
228 .filter(|token| !token.kind.is_ignored())
229 .map(|token| {
230 let len = source_arc.as_ref().length();
231 let start = token.span.start.min(len);
232 let end = token.span.end.min(len).max(start);
233 let text = source_arc.as_ref().get_text_in((start..end).into()).to_string();
234 TokenData { kind: format!("{:?}", token.kind), text, start: token.span.start, end: token.span.end }
235 })
236 .collect();
237
238 let errors: Vec<String> = diagnostics.iter().map(|e| e.to_string()).collect();
239 let test_result = LexerTestExpected { success, count: tokens.len(), tokens, errors };
240
241 let expected_file = file_path.with_extension(format!("{}.lexed.json", file_path.extension().unwrap_or_default().to_str().unwrap_or("")));
243
244 if !expected_file.exists() {
246 let legacy_file = file_path.with_extension("expected.json");
247 if legacy_file.exists() {
248 let _ = std::fs::rename(&legacy_file, &expected_file);
249 }
250 }
251
252 let mut regenerated = false;
253 if expected_file.exists() && !force_regenerated {
254 let expected_json = json_from_path(&expected_file)?;
255 let expected: LexerTestExpected = serde_json::from_value(expected_json).map_err(|e| OakError::custom_error(e.to_string()))?;
256 if test_result != expected {
257 return Err(OakError::test_failure(file_path.to_path_buf(), format!("{:#?}", expected), format!("{:#?}", test_result)));
258 }
259 }
260 else {
261 use std::io::Write;
262 let mut file = create_file(&expected_file)?;
263 let mut buf = Vec::new();
264 let formatter = serde_json::ser::PrettyFormatter::with_indent(b" "); let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
266 test_result.serialize(&mut ser).map_err(|e| OakError::custom_error(e.to_string()))?;
267 file.write_all(&buf).map_err(|e| OakError::custom_error(e.to_string()))?;
268 regenerated = true;
269 }
270
271 Ok(regenerated)
272 }
273}