1use std::path::{Path, PathBuf};
8
9use sipha::red::SyntaxNode;
10use sipha::types::IntoSyntaxKind;
11
12use crate::parser::parse;
13use crate::syntax::Kind;
14
15#[derive(Debug, Clone)]
17pub enum IncludeError {
18 Io(String),
20 CircularInclude {
22 path: PathBuf,
23 included_from: Option<PathBuf>,
25 },
26 InvalidPath(String),
28}
29
30impl std::fmt::Display for IncludeError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 IncludeError::Io(msg) => write!(f, "include: {msg}"),
34 IncludeError::CircularInclude {
35 path,
36 included_from,
37 } => {
38 write!(f, "circular include: {}", path.display())?;
39 if let Some(from) = included_from {
40 write!(f, " (included from {})", from.display())?;
41 }
42 Ok(())
43 }
44 IncludeError::InvalidPath(msg) => write!(f, "include path: {msg}"),
45 }
46 }
47}
48
49impl std::error::Error for IncludeError {}
50
51#[derive(Debug, Clone)]
53pub struct IncludeTree {
54 pub path: PathBuf,
56 pub source: String,
58 pub root: Option<SyntaxNode>,
60 pub includes: Vec<(PathBuf, IncludeTree)>,
62}
63
64impl IncludeTree {
65 #[must_use]
67 pub fn root_for_path(&self, main_path: &Path, path: &Path) -> Option<&SyntaxNode> {
68 if path == main_path {
69 return self.root.as_ref();
70 }
71 for (p, child) in &self.includes {
72 if p.as_path() == path {
73 return child.root.as_ref();
74 }
75 }
76 None
77 }
78
79 #[must_use]
81 pub fn source_for_path(&self, main_path: &Path, path: &Path) -> Option<&str> {
82 if path == main_path {
83 return Some(self.source.as_str());
84 }
85 for (p, child) in &self.includes {
86 if p.as_path() == path {
87 return Some(child.source.as_str());
88 }
89 }
90 None
91 }
92}
93
94pub fn build_include_tree(
99 source: &str,
100 base_path: Option<&Path>,
101) -> Result<IncludeTree, IncludeError> {
102 let base_dir = base_path
103 .and_then(|p| p.parent())
104 .unwrap_or_else(|| Path::new("."));
105 let current_path = base_path.map(Path::to_path_buf).unwrap_or_default();
106 let mut visited = std::collections::HashSet::new();
107 build_include_tree_impl(source, base_dir, current_path.as_path(), &mut visited)
108}
109
110fn build_include_tree_impl(
111 source: &str,
112 base_dir: &Path,
113 current_path: &Path,
114 visited: &mut std::collections::HashSet<PathBuf>,
115) -> Result<IncludeTree, IncludeError> {
116 let root = parse(source).ok().flatten();
117
118 let include_paths = root
119 .as_ref()
120 .map(|r| collect_include_paths(r, source))
121 .unwrap_or_default();
122
123 let mut includes = Vec::with_capacity(include_paths.len());
124 for path_str in include_paths {
125 let resolved = resolve_path(base_dir, &path_str);
126 let content = std::fs::read_to_string(&resolved).map_err(|e| {
127 let msg = match e.kind() {
128 std::io::ErrorKind::NotFound => format!("file not found: {}", resolved.display()),
129 std::io::ErrorKind::PermissionDenied => {
130 format!("permission denied: {}", resolved.display())
131 }
132 _ => format!("{}: {}", resolved.display(), e),
133 };
134 IncludeError::Io(msg)
135 })?;
136 let canonical = resolved
137 .canonicalize()
138 .map_err(|e| IncludeError::Io(format!("{}: {}", resolved.display(), e)))?;
139 if !visited.insert(canonical.clone()) {
140 return Err(IncludeError::CircularInclude {
141 path: canonical,
142 included_from: Some(current_path.to_path_buf()),
143 });
144 }
145 let child_base = resolved.parent().unwrap_or(base_dir);
146 let child_tree =
147 build_include_tree_impl(&content, child_base, resolved.as_path(), visited)?;
148 visited.remove(&canonical);
149 includes.push((resolved, child_tree));
150 }
151
152 Ok(IncludeTree {
153 path: current_path.to_path_buf(),
154 source: source.to_string(),
155 root,
156 includes,
157 })
158}
159
160fn collect_include_paths(root: &SyntaxNode, source: &str) -> Vec<String> {
162 let bytes = source.as_bytes();
163 let mut out = Vec::new();
164 for node in root.find_all_nodes(Kind::NodeInclude.into_syntax_kind()) {
165 if let Some(path) = include_path_from_node(&node, bytes) {
166 out.push(path);
167 }
168 }
169 out
170}
171
172#[must_use]
175pub fn collect_include_path_ranges(root: &SyntaxNode, source: &str) -> Vec<(u32, u32, String)> {
176 let bytes = source.as_bytes();
177 let mut out = Vec::new();
178 for node in root.find_all_nodes(Kind::NodeInclude.into_syntax_kind()) {
179 let token = node
180 .descendant_tokens()
181 .into_iter()
182 .find(|t| t.kind_as::<Kind>() == Some(Kind::TokString));
183 if let (Some(t), Some(path)) = (token, include_path_from_node(&node, bytes)) {
184 let range = t.text_range();
185 out.push((range.start, range.end, path));
186 }
187 }
188 out
189}
190
191fn include_path_from_node(node: &SyntaxNode, source_bytes: &[u8]) -> Option<String> {
193 let token = node
194 .descendant_tokens()
195 .into_iter()
196 .find(|t| t.kind_as::<Kind>() == Some(Kind::TokString))?;
197 let range = token.text_range();
198 let start = range.start as usize;
199 if start >= source_bytes.len() {
200 return None;
201 }
202 parse_include_string(source_bytes, start).map(|(s, _)| s)
203}
204
205fn parse_include_string(bytes: &[u8], i: usize) -> Option<(String, usize)> {
206 if i >= bytes.len() {
207 return None;
208 }
209 let quote = bytes[i];
210 if quote != b'"' && quote != b'\'' {
211 return None;
212 }
213 let mut out = String::new();
214 let mut j = i + 1;
215 while j < bytes.len() {
216 if bytes[j] == b'\\' && j + 1 < bytes.len() {
217 match bytes[j + 1] {
218 b'n' => out.push('\n'),
219 b't' => out.push('\t'),
220 b'r' => out.push('\r'),
221 b'"' => out.push('"'),
222 b'\'' => out.push('\''),
223 b'\\' => out.push('\\'),
224 b'u' if j + 5 < bytes.len() => {
225 let hex = std::str::from_utf8(&bytes[j + 2..j + 6]).ok()?;
226 let code = u32::from_str_radix(hex, 16).ok()?;
227 out.push(char::from_u32(code)?);
228 j += 4;
229 }
230 _ => out.push(bytes[j + 1] as char),
231 }
232 j += 2;
233 continue;
234 }
235 if bytes[j] == quote {
236 return Some((out, j + 1));
237 }
238 out.push(bytes[j] as char);
239 j = next_char_boundary(bytes, j);
240 }
241 None
242}
243
244fn next_char_boundary(bytes: &[u8], i: usize) -> usize {
245 if i >= bytes.len() {
246 return bytes.len();
247 }
248 let b = bytes[i];
249 if b < 128 {
250 return i + 1;
251 }
252 let mut j = i + 1;
253 while j < bytes.len() && (bytes[j] & 0xC0) == 0x80 {
254 j += 1;
255 }
256 j
257}
258
259fn resolve_path(base_dir: &Path, path_str: &str) -> PathBuf {
260 let path = Path::new(path_str);
261 if path.is_absolute() {
262 path.to_path_buf()
263 } else {
264 base_dir.join(path)
265 }
266}
267
268#[must_use]
270pub fn all_files(tree: &IncludeTree) -> Vec<(PathBuf, &str)> {
271 let mut out = vec![(tree.path.clone(), tree.source.as_str())];
272 for (_, child) in &tree.includes {
273 out.extend(all_files(child));
274 }
275 out
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn build_tree_inlines_nothing_but_parses_includes() {
284 let dir = std::env::temp_dir().join("leekscript_include_test");
285 let _ = std::fs::create_dir_all(&dir);
286 let main_path = dir.join("main.leek");
287 let lib_path = dir.join("lib.leek");
288 std::fs::write(&lib_path, "var x = 42;\n").unwrap();
289 std::fs::write(&main_path, "include(\"lib.leek\");\nreturn 0;\n").unwrap();
290 let source = std::fs::read_to_string(&main_path).unwrap();
291 let tree = build_include_tree(&source, Some(main_path.as_path())).unwrap();
292 assert!(tree.root.is_some(), "main should parse");
293 assert_eq!(tree.includes.len(), 1, "one include");
294 assert_eq!(tree.includes[0].0, lib_path);
295 assert!(tree.includes[0].1.source.contains("var x = 42"));
296 let _ = std::fs::remove_dir_all(&dir);
297 }
298
299 #[test]
300 fn circular_include_errors() {
301 let dir = std::env::temp_dir().join("leekscript_circular_test");
302 let _ = std::fs::create_dir_all(&dir);
303 let a_path = dir.join("a.leek");
304 let b_path = dir.join("b.leek");
305 std::fs::write(&a_path, "include(\"b.leek\");\n").unwrap();
306 std::fs::write(&b_path, "include(\"a.leek\");\n").unwrap();
307 let source = std::fs::read_to_string(&a_path).unwrap();
308 let result = build_include_tree(&source, Some(a_path.as_path()));
309 assert!(
310 matches!(result, Err(IncludeError::CircularInclude { .. })),
311 "expected CircularInclude: {:?}",
312 result
313 );
314 let _ = std::fs::remove_dir_all(&dir);
315 }
316}