Skip to main content

oak_testing/
parsing.rs

1//! Parser testing utilities for the Oak ecosystem.
2//!
3//! This module provides comprehensive testing infrastructure for parsers,
4//! including file-based testing, expected output comparison, timeout handling,
5//! and test result serialization.
6
7use crate::{create_file, source_from_path};
8use oak_core::{Language, Parser, errors::OakError};
9
10#[cfg(feature = "serde")]
11use crate::json_from_path;
12#[cfg(feature = "serde")]
13use serde::Serialize;
14
15use std::{
16    fmt::Debug,
17    path::{Path, PathBuf},
18    time::Duration,
19};
20use walkdir::WalkDir;
21
22/// A concurrent parser testing utility that can run tests against multiple files with timeout support.
23///
24/// The `ParserTester` provides functionality to test parsers against a directory
25/// of files with specific extensions, comparing actual output against expected
26/// results stored in JSON files, with configurable timeout protection.
27pub struct ParserTester {
28    root: PathBuf,
29    extensions: Vec<String>,
30    timeout: Duration,
31}
32
33/// Expected parser test results for comparison.
34///
35/// This struct represents the expected output of a parser test, including
36/// success status, node count, AST structure, and any expected errors.
37#[derive(Debug, Clone, PartialEq)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39pub struct ParserTestExpected {
40    /// Whether the parsing was expected to succeed.
41    pub success: bool,
42    /// The expected number of nodes in the AST.
43    pub node_count: usize,
44    /// The expected structure of the AST.
45    pub ast_structure: AstNodeData,
46    /// Any expected error messages.
47    pub errors: Vec<String>,
48}
49
50/// AST node data structure for parser testing.
51///
52/// Represents a node in the abstract kind tree with its kind, children,
53/// text length, and leaf status used for testing parser output.
54#[derive(Debug, Clone, PartialEq)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct AstNodeData {
57    /// The kind of the AST node as a string.
58    pub kind: String,
59    /// The child nodes of this AST node.
60    pub children: Vec<AstNodeData>,
61    /// The length of the text covered by this node.
62    pub text_length: usize,
63    /// Whether this node is a leaf node.
64    pub is_leaf: bool,
65}
66
67impl ParserTester {
68    /// Creates a new parser tester with the specified root directory and default 10-second timeout.
69    pub fn new<P: AsRef<Path>>(root: P) -> Self {
70        Self { root: root.as_ref().to_path_buf(), extensions: vec![], timeout: Duration::from_secs(10) }
71    }
72
73    /// Adds a file extension to test against.
74    pub fn with_extension(mut self, extension: impl ToString) -> Self {
75        self.extensions.push(extension.to_string());
76        self
77    }
78
79    /// Sets the timeout for parsing operations.
80    pub fn with_timeout(mut self, timeout: Duration) -> Self {
81        self.timeout = timeout;
82        self
83    }
84
85    /// Run tests for the given parser against all files in the root directory with the specified extensions.
86    #[cfg(feature = "serde")]
87    pub fn run_tests<L, P>(self, parser: &P) -> Result<(), OakError>
88    where
89        P: Parser<L> + Send + Sync,
90        L: Language + Send + Sync,
91        L::ElementType: serde::Serialize + Debug + Sync + Send + Eq + From<L::TokenType>,
92    {
93        let test_files = self.find_test_files()?;
94        let force_regenerated = std::env::var("REGENERATE_TESTS").unwrap_or("0".to_string()) == "1";
95        let mut regenerated_any = false;
96
97        for file_path in test_files {
98            println!("Testing file: {}", file_path.display());
99            regenerated_any |= self.test_single_file::<L, P>(&file_path, parser, force_regenerated)?;
100        }
101
102        if regenerated_any && force_regenerated {
103            println!("Tests regenerated for: {}", self.root.display());
104            Ok(())
105        }
106        else {
107            Ok(())
108        }
109    }
110
111    /// Run tests for the given parser against all files in the root directory with the specified extensions.
112    #[cfg(not(feature = "serde"))]
113    pub fn run_tests<L, P>(self, _parser: &P) -> Result<(), OakError>
114    where
115        P: Parser<L> + Send + Sync,
116        L: Language + Send + Sync,
117        L::ElementType: Debug + Sync + Send + Eq + From<L::TokenType>,
118    {
119        Ok(())
120    }
121
122    fn find_test_files(&self) -> Result<Vec<PathBuf>, OakError> {
123        let mut files = Vec::new();
124
125        for entry in WalkDir::new(&self.root) {
126            let entry = entry.unwrap();
127            let path = entry.path();
128
129            if path.is_file() {
130                if let Some(ext) = path.extension() {
131                    let ext_str = ext.to_str().unwrap_or("");
132                    if self.extensions.iter().any(|e| e == ext_str) {
133                        // Ignore output files generated by the Tester itself to prevent recursive inclusion
134                        let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
135                        let is_output_file = file_name.ends_with(".parsed.json") || file_name.ends_with(".lexed.json") || file_name.ends_with(".built.json") || file_name.ends_with(".expected.json");
136
137                        if !is_output_file {
138                            files.push(path.to_path_buf());
139                        }
140                    }
141                }
142            }
143        }
144
145        Ok(files)
146    }
147
148    #[cfg(feature = "serde")]
149    fn test_single_file<L, P>(&self, file_path: &Path, parser: &P, force_regenerated: bool) -> Result<bool, OakError>
150    where
151        P: Parser<L> + Send + Sync,
152        L: Language + Send + Sync,
153        L::ElementType: serde::Serialize + Debug + Sync + Send + Eq + From<L::TokenType>,
154    {
155        let source = source_from_path(file_path)?;
156
157        // Perform parsing in a thread and construct test results, with main thread handling timeout control
158        use std::sync::{Arc, Mutex};
159        let result: Arc<Mutex<Option<Result<ParserTestExpected, OakError>>>> = Arc::new(Mutex::new(None));
160        let result_clone = Arc::clone(&result);
161        let timeout = self.timeout;
162
163        std::thread::scope(|s| {
164            let handle = s.spawn(move || {
165                let mut cache = oak_core::parser::ParseSession::<L>::default();
166                let parse_out = parser.parse(&source, &[], &mut cache);
167
168                // Build AST structure if parse succeeded, else create a minimal error node
169                let (success, ast_structure) = match &parse_out.result {
170                    Ok(root) => {
171                        let ast = Self::to_ast::<L>(root);
172                        (true, ast)
173                    }
174                    Err(_) => {
175                        let ast = AstNodeData { kind: "Error".to_string(), children: vec![], text_length: 0, is_leaf: true };
176                        (false, ast)
177                    }
178                };
179
180                // Collect error messages
181                let mut error_messages: Vec<String> = parse_out.diagnostics.iter().map(|e| e.to_string()).collect();
182                if let Err(e) = &parse_out.result {
183                    error_messages.push(e.to_string());
184                }
185
186                // Count nodes (including leaves)
187                let node_count = Self::count_nodes(&ast_structure);
188
189                let test_result = ParserTestExpected { success, node_count, ast_structure, errors: error_messages };
190
191                let mut result = result_clone.lock().unwrap();
192                *result = Some(Ok(test_result));
193            });
194
195            // Wait for thread completion or timeout
196            let start_time = std::time::Instant::now();
197            let timeout_occurred = loop {
198                // Check if thread has finished
199                if handle.is_finished() {
200                    break false;
201                }
202
203                // Check for timeout
204                if start_time.elapsed() > timeout {
205                    break true;
206                }
207
208                // Sleep briefly to avoid busy waiting
209                std::thread::sleep(std::time::Duration::from_millis(10));
210            };
211
212            // Return error if timed out
213            if timeout_occurred {
214                return Err(OakError::custom_error(&format!("Parser test timed out after {:?} for file: {}", timeout, file_path.display())));
215            }
216
217            // Get parsing result
218            let test_result = {
219                let result_guard = result.lock().unwrap();
220                match result_guard.as_ref() {
221                    Some(Ok(test_result)) => test_result.clone(),
222                    Some(Err(e)) => return Err(e.clone()),
223                    None => return Err(OakError::custom_error("Parser test thread disconnected unexpectedly")),
224                }
225            };
226
227            let mut regenerated = false;
228            let expected_file = file_path.with_extension(format!("{}.parsed.json", file_path.extension().unwrap_or_default().to_str().unwrap_or("")));
229
230            // Migration: If the new naming convention file doesn't exist, but the old one does, rename it
231            if !expected_file.exists() {
232                let legacy_file = file_path.with_extension("expected.json");
233                if legacy_file.exists() {
234                    let _ = std::fs::rename(&legacy_file, &expected_file);
235                }
236            }
237
238            if expected_file.exists() && !force_regenerated {
239                let expected_json = json_from_path(&expected_file)?;
240                let expected: ParserTestExpected = serde_json::from_value(expected_json).map_err(|e| OakError::custom_error(e.to_string()))?;
241                if test_result != expected {
242                    return Err(OakError::test_failure(file_path.to_path_buf(), format!("{:#?}", expected), format!("{:#?}", test_result)));
243                }
244            }
245            else {
246                use std::io::Write;
247                let mut file = create_file(&expected_file)?;
248                let mut buf = Vec::new();
249                let formatter = serde_json::ser::PrettyFormatter::with_indent(b"    "); // 4 spaces indentation
250                let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
251                test_result.serialize(&mut ser).map_err(|e| OakError::custom_error(e.to_string()))?;
252                file.write_all(&buf).map_err(|e| OakError::custom_error(e.to_string()))?;
253
254                if force_regenerated {
255                    regenerated = true;
256                }
257                else {
258                    return Err(OakError::test_regenerated(expected_file));
259                }
260            }
261
262            Ok(regenerated)
263        })
264    }
265
266    fn to_ast<'a, L: Language>(root: &'a oak_core::GreenNode<'a, L>) -> AstNodeData {
267        let kind_str = format!("{:?}", root.kind);
268        let mut children = Vec::new();
269        let mut leaf_count: usize = 0;
270        let mut leaf_text_length: usize = 0;
271
272        for c in root.children {
273            match c {
274                oak_core::GreenTree::Node(n) => children.push(Self::to_ast(n)),
275                oak_core::GreenTree::Leaf(l) => {
276                    leaf_count += 1;
277                    leaf_text_length += l.length as usize;
278                }
279            }
280        }
281
282        if leaf_count > 0 {
283            children.push(AstNodeData { kind: format!("Leaves({})", leaf_count), children: vec![], text_length: leaf_text_length, is_leaf: true });
284        }
285
286        AstNodeData { kind: kind_str, children, text_length: root.byte_length as usize, is_leaf: false }
287    }
288
289    fn count_nodes(node: &AstNodeData) -> usize {
290        1 + node.children.iter().map(Self::count_nodes).sum::<usize>()
291    }
292}