Skip to main content

leekscript_core/
preprocess.rs

1//! Include handling: parse files and build a tree of (path, source, AST) with circular include detection.
2//!
3//! Paths are resolved relative to the directory of the current file (or the current
4//! working directory when parsing from stdin). Circular includes are detected and
5//! reported as errors. No source expansion: each file keeps its own AST.
6
7use std::path::{Path, PathBuf};
8
9use sipha::red::SyntaxNode;
10use sipha::types::IntoSyntaxKind;
11
12use crate::parser::parse;
13use crate::syntax::Kind;
14
15/// Error from the include preprocessor.
16#[derive(Debug, Clone)]
17pub enum IncludeError {
18    /// File could not be read (e.g. not found, permission). Message includes path and reason.
19    Io(String),
20    /// Circular include detected. `path` is the file that was included again; `included_from` is the file that requested it (when known).
21    CircularInclude {
22        path: PathBuf,
23        /// File that contained the `include(...)` leading to the cycle.
24        included_from: Option<PathBuf>,
25    },
26    /// Invalid path (e.g. outside allowed base).
27    InvalidPath(String),
28}
29
30impl std::fmt::Display for IncludeError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            IncludeError::Io(msg) => write!(f, "include: {msg}"),
34            IncludeError::CircularInclude {
35                path,
36                included_from,
37            } => {
38                write!(f, "circular include: {}", path.display())?;
39                if let Some(from) = included_from {
40                    write!(f, " (included from {})", from.display())?;
41                }
42                Ok(())
43            }
44            IncludeError::InvalidPath(msg) => write!(f, "include path: {msg}"),
45        }
46    }
47}
48
49impl std::error::Error for IncludeError {}
50
51/// One file in the include tree: path, source, parsed root, and parsed included files in order.
52#[derive(Debug, Clone)]
53pub struct IncludeTree {
54    /// Resolved path of this file (empty if from stdin / no path).
55    pub path: PathBuf,
56    /// Full source of this file.
57    pub source: String,
58    /// Parsed AST (program root); None if parse failed or file is empty.
59    pub root: Option<SyntaxNode>,
60    /// Resolved path and subtree for each `include("...")` in order.
61    pub includes: Vec<(PathBuf, IncludeTree)>,
62}
63
64impl IncludeTree {
65    /// Root AST for a path within this tree (main file or an included file).
66    #[must_use]
67    pub fn root_for_path(&self, main_path: &Path, path: &Path) -> Option<&SyntaxNode> {
68        if path == main_path {
69            return self.root.as_ref();
70        }
71        for (p, child) in &self.includes {
72            if p.as_path() == path {
73                return child.root.as_ref();
74            }
75        }
76        None
77    }
78
79    /// Source for a path within this tree (main file or an included file).
80    #[must_use]
81    pub fn source_for_path(&self, main_path: &Path, path: &Path) -> Option<&str> {
82        if path == main_path {
83            return Some(self.source.as_str());
84        }
85        for (p, child) in &self.includes {
86            if p.as_path() == path {
87                return Some(child.source.as_str());
88            }
89        }
90        None
91    }
92}
93
94/// Build the include tree: parse `source` as the main file, resolve each `include("path")`,
95/// load and parse those files (with circular include detection), and return the tree.
96///
97/// If `base_path` is `None` (e.g. stdin), the current working directory is used to resolve includes.
98pub fn build_include_tree(
99    source: &str,
100    base_path: Option<&Path>,
101) -> Result<IncludeTree, IncludeError> {
102    let base_dir = base_path
103        .and_then(|p| p.parent())
104        .unwrap_or_else(|| Path::new("."));
105    let current_path = base_path.map(Path::to_path_buf).unwrap_or_default();
106    let mut visited = std::collections::HashSet::new();
107    build_include_tree_impl(source, base_dir, current_path.as_path(), &mut visited)
108}
109
110fn build_include_tree_impl(
111    source: &str,
112    base_dir: &Path,
113    current_path: &Path,
114    visited: &mut std::collections::HashSet<PathBuf>,
115) -> Result<IncludeTree, IncludeError> {
116    let root = parse(source).ok().flatten();
117
118    let include_paths = root
119        .as_ref()
120        .map(|r| collect_include_paths(r, source))
121        .unwrap_or_default();
122
123    let mut includes = Vec::with_capacity(include_paths.len());
124    for path_str in include_paths {
125        let resolved = resolve_path(base_dir, &path_str);
126        let content = std::fs::read_to_string(&resolved).map_err(|e| {
127            let msg = match e.kind() {
128                std::io::ErrorKind::NotFound => format!("file not found: {}", resolved.display()),
129                std::io::ErrorKind::PermissionDenied => {
130                    format!("permission denied: {}", resolved.display())
131                }
132                _ => format!("{}: {}", resolved.display(), e),
133            };
134            IncludeError::Io(msg)
135        })?;
136        let canonical = resolved
137            .canonicalize()
138            .map_err(|e| IncludeError::Io(format!("{}: {}", resolved.display(), e)))?;
139        if !visited.insert(canonical.clone()) {
140            return Err(IncludeError::CircularInclude {
141                path: canonical,
142                included_from: Some(current_path.to_path_buf()),
143            });
144        }
145        let child_base = resolved.parent().unwrap_or(base_dir);
146        let child_tree =
147            build_include_tree_impl(&content, child_base, resolved.as_path(), visited)?;
148        visited.remove(&canonical);
149        includes.push((resolved, child_tree));
150    }
151
152    Ok(IncludeTree {
153        path: current_path.to_path_buf(),
154        source: source.to_string(),
155        root,
156        includes,
157    })
158}
159
160/// Collect include path strings from a program root (order of `include("...")` in the file).
161fn collect_include_paths(root: &SyntaxNode, source: &str) -> Vec<String> {
162    let bytes = source.as_bytes();
163    let mut out = Vec::new();
164    for node in root.find_all_nodes(Kind::NodeInclude.into_syntax_kind()) {
165        if let Some(path) = include_path_from_node(&node, bytes) {
166            out.push(path);
167        }
168    }
169    out
170}
171
172/// Collect (`start_byte`, `end_byte`, `path_string`) for each `include("...")` in the file.
173/// Used by the LSP to provide document links and go-to-definition on include paths.
174#[must_use]
175pub fn collect_include_path_ranges(root: &SyntaxNode, source: &str) -> Vec<(u32, u32, String)> {
176    let bytes = source.as_bytes();
177    let mut out = Vec::new();
178    for node in root.find_all_nodes(Kind::NodeInclude.into_syntax_kind()) {
179        let token = node
180            .descendant_tokens()
181            .into_iter()
182            .find(|t| t.kind_as::<Kind>() == Some(Kind::TokString));
183        if let (Some(t), Some(path)) = (token, include_path_from_node(&node, bytes)) {
184            let range = t.text_range();
185            out.push((range.start, range.end, path));
186        }
187    }
188    out
189}
190
191/// Extract the path string from a `NodeInclude` (first `TokString` token, unquoted).
192fn include_path_from_node(node: &SyntaxNode, source_bytes: &[u8]) -> Option<String> {
193    let token = node
194        .descendant_tokens()
195        .into_iter()
196        .find(|t| t.kind_as::<Kind>() == Some(Kind::TokString))?;
197    let range = token.text_range();
198    let start = range.start as usize;
199    if start >= source_bytes.len() {
200        return None;
201    }
202    parse_include_string(source_bytes, start).map(|(s, _)| s)
203}
204
205fn parse_include_string(bytes: &[u8], i: usize) -> Option<(String, usize)> {
206    if i >= bytes.len() {
207        return None;
208    }
209    let quote = bytes[i];
210    if quote != b'"' && quote != b'\'' {
211        return None;
212    }
213    let mut out = String::new();
214    let mut j = i + 1;
215    while j < bytes.len() {
216        if bytes[j] == b'\\' && j + 1 < bytes.len() {
217            match bytes[j + 1] {
218                b'n' => out.push('\n'),
219                b't' => out.push('\t'),
220                b'r' => out.push('\r'),
221                b'"' => out.push('"'),
222                b'\'' => out.push('\''),
223                b'\\' => out.push('\\'),
224                b'u' if j + 5 < bytes.len() => {
225                    let hex = std::str::from_utf8(&bytes[j + 2..j + 6]).ok()?;
226                    let code = u32::from_str_radix(hex, 16).ok()?;
227                    out.push(char::from_u32(code)?);
228                    j += 4;
229                }
230                _ => out.push(bytes[j + 1] as char),
231            }
232            j += 2;
233            continue;
234        }
235        if bytes[j] == quote {
236            return Some((out, j + 1));
237        }
238        out.push(bytes[j] as char);
239        j = next_char_boundary(bytes, j);
240    }
241    None
242}
243
244fn next_char_boundary(bytes: &[u8], i: usize) -> usize {
245    if i >= bytes.len() {
246        return bytes.len();
247    }
248    let b = bytes[i];
249    if b < 128 {
250        return i + 1;
251    }
252    let mut j = i + 1;
253    while j < bytes.len() && (bytes[j] & 0xC0) == 0x80 {
254        j += 1;
255    }
256    j
257}
258
259fn resolve_path(base_dir: &Path, path_str: &str) -> PathBuf {
260    let path = Path::new(path_str);
261    if path.is_absolute() {
262        path.to_path_buf()
263    } else {
264        base_dir.join(path)
265    }
266}
267
268/// Flatten the tree into (path, source) for all files (root first, then includes depth-first).
269#[must_use]
270pub fn all_files(tree: &IncludeTree) -> Vec<(PathBuf, &str)> {
271    let mut out = vec![(tree.path.clone(), tree.source.as_str())];
272    for (_, child) in &tree.includes {
273        out.extend(all_files(child));
274    }
275    out
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn build_tree_inlines_nothing_but_parses_includes() {
284        let dir = std::env::temp_dir().join("leekscript_include_test");
285        let _ = std::fs::create_dir_all(&dir);
286        let main_path = dir.join("main.leek");
287        let lib_path = dir.join("lib.leek");
288        std::fs::write(&lib_path, "var x = 42;\n").unwrap();
289        std::fs::write(&main_path, "include(\"lib.leek\");\nreturn 0;\n").unwrap();
290        let source = std::fs::read_to_string(&main_path).unwrap();
291        let tree = build_include_tree(&source, Some(main_path.as_path())).unwrap();
292        assert!(tree.root.is_some(), "main should parse");
293        assert_eq!(tree.includes.len(), 1, "one include");
294        assert_eq!(tree.includes[0].0, lib_path);
295        assert!(tree.includes[0].1.source.contains("var x = 42"));
296        let _ = std::fs::remove_dir_all(&dir);
297    }
298
299    #[test]
300    fn circular_include_errors() {
301        let dir = std::env::temp_dir().join("leekscript_circular_test");
302        let _ = std::fs::create_dir_all(&dir);
303        let a_path = dir.join("a.leek");
304        let b_path = dir.join("b.leek");
305        std::fs::write(&a_path, "include(\"b.leek\");\n").unwrap();
306        std::fs::write(&b_path, "include(\"a.leek\");\n").unwrap();
307        let source = std::fs::read_to_string(&a_path).unwrap();
308        let result = build_include_tree(&source, Some(a_path.as_path()));
309        assert!(
310            matches!(result, Err(IncludeError::CircularInclude { .. })),
311            "expected CircularInclude: {:?}",
312            result
313        );
314        let _ = std::fs::remove_dir_all(&dir);
315    }
316}