1use std::{
4 collections::{HashMap, HashSet},
5 ops::Range,
6 path::{Path, PathBuf},
7 sync::Arc,
8};
9
10use super::source::{Source, SourceLoadError, SourceLoader, SourceResolver};
11use super::{FileId, ParseTree, Parser, SourceList, SourceMap};
12use crate::{
13 Diagnostic, DiagnosticSet, GlyphMap, Kind, Node,
14 token_tree::{
15 AstSink,
16 typed::{self, AstNode as _},
17 },
18};
19
20const MAX_INCLUDE_DEPTH: usize = 50;
21
22#[derive(Debug)]
56pub(crate) struct ParseContext {
57 root_id: FileId,
58 sources: Arc<SourceList>,
59 parsed_files: HashMap<FileId, (Node, Vec<Diagnostic>)>,
60 graph: IncludeGraph,
61}
62
63#[derive(Clone, Debug, Default)]
69struct IncludeGraph {
70 nodes: HashMap<FileId, Vec<(FileId, Range<usize>)>>,
72}
73
74pub struct IncludeStatement {
76 pub(crate) stmt: typed::Include,
77 pub(crate) scope: Kind,
79}
80
81struct IncludeError {
82 file: FileId,
83 statement_idx: usize,
85 range: Range<usize>,
86 kind: IncludeErrorKind,
87}
88
89enum IncludeErrorKind {
90 Cycle,
91 ToDeep,
92}
93
94impl IncludeStatement {
95 fn path(&self) -> &str {
99 &self.stmt.path().text
100 }
101
102 fn stmt_range(&self) -> Range<usize> {
104 self.stmt.range()
105 }
106
107 fn path_range(&self) -> Range<usize> {
109 self.stmt.path().range()
110 }
111}
112
113impl ParseContext {
114 pub(crate) fn parse(
124 path: PathBuf,
125 glyph_map: Option<&GlyphMap>,
126 resolver: Box<dyn SourceResolver>,
127 ) -> Result<Self, SourceLoadError> {
128 let mut sources = SourceLoader::new(resolver);
129 let root_id = sources.source_for_path(&path, None)?;
130 let mut queue = vec![(root_id, Kind::SourceFile)];
131 let mut parsed_files = HashMap::new();
132 let mut includes = IncludeGraph::default();
133
134 while let Some((id, scope)) = queue.pop() {
135 if parsed_files.contains_key(&id) {
137 continue;
138 }
139 let source = sources.get(&id).unwrap();
140 let (node, mut errors, include_stmts) = parse_src(source, glyph_map, scope);
141 errors.iter_mut().for_each(|e| e.message.file = id);
142
143 parsed_files.insert(source.id(), (node, errors));
144 if include_stmts.is_empty() {
145 continue;
146 }
147
148 let source_id = source.id();
150
151 for include in &include_stmts {
152 match sources.source_for_path(Path::new(include.path()), Some(source_id)) {
153 Ok(included_id) => {
154 includes.add_edge(id, (included_id, include.stmt_range()));
155 queue.push((included_id, include.scope));
156 }
157 Err(e) => {
158 let range = include.path_range();
159 parsed_files.get_mut(&id).unwrap().1.push(Diagnostic::error(
160 id,
161 range,
162 e.to_string(),
163 ));
164 }
165 }
166 }
167 }
168
169 Ok(ParseContext {
170 root_id,
171 sources: sources.into_inner(),
172 parsed_files,
173 graph: includes,
174 })
175 }
176
177 pub(crate) fn root_id(&self) -> FileId {
178 self.root_id
179 }
180
181 pub(crate) fn generate_parse_tree(self) -> (ParseTree, DiagnosticSet) {
185 let mut all_errors = self
186 .parsed_files
187 .iter()
188 .flat_map(|(_, (_, errs))| errs.iter())
189 .cloned()
190 .collect::<Vec<_>>();
191 let include_errors = self.graph.validate(self.root_id());
192 for IncludeError {
194 file, range, kind, ..
195 } in &include_errors
196 {
197 let message = match kind {
199 IncludeErrorKind::Cycle => "cyclical include statement",
200 IncludeErrorKind::ToDeep => "exceded maximum include depth",
201 };
202 all_errors.push(Diagnostic::error(*file, range.clone(), message));
203 }
204
205 let mut map = SourceMap::default();
206 let mut root = self.generate_recurse(self.root_id(), &include_errors, &mut map, 0);
207 let needs_update_positions = self.parsed_files.len() > 1;
208 drop(self.parsed_files);
211 if needs_update_positions {
212 root.update_positions_from_root();
213 }
214
215 let diagnostics = DiagnosticSet {
216 messages: all_errors,
217 sources: self.sources.clone(),
218 max_to_print: usize::MAX,
219 };
220
221 (
222 ParseTree {
223 root,
224 map: Arc::new(map),
225 sources: self.sources,
226 },
227 diagnostics,
228 )
229 }
230
231 fn generate_recurse(
237 &self,
238 id: FileId,
239 skip: &[IncludeError],
240 source_map: &mut SourceMap,
241 offset: usize,
242 ) -> Node {
243 let this_node = self.parsed_files[&id].0.clone();
244 let self_len = this_node.text_len();
245 let mut self_pos = 0;
246 let mut global_pos = offset;
247 let this_node = match self.graph.includes_for_file(id) {
248 Some(includes) => {
249 let mut edits = Vec::with_capacity(includes.len());
250
251 for (i, (child_id, stmt)) in includes.iter().enumerate() {
252 if skip
253 .iter()
254 .any(|err| err.file == id && err.statement_idx == i)
255 {
256 continue;
257 }
258 let pre_len = stmt.start - self_pos;
260 let pre_range = global_pos..global_pos + pre_len;
261 source_map.add_entry(pre_range, (id, self_pos));
262 self_pos = stmt.end;
263 global_pos += pre_len;
264 let child_node = self.generate_recurse(*child_id, skip, source_map, global_pos);
265 global_pos += child_node.text_len();
266 edits.push((stmt.clone(), child_node));
267 }
268 this_node.edit(edits, true)
269 }
270 None => this_node,
271 };
272 let remain_len = self_len - self_pos;
274 let remaining_range = global_pos..global_pos + remain_len;
275 source_map.add_entry(remaining_range, (id, self_pos));
276 this_node
277 }
278}
279
280impl IncludeGraph {
281 fn add_edge(&mut self, from: FileId, to: (FileId, Range<usize>)) {
282 self.nodes.entry(from).or_default().push(to);
283 }
284
285 fn includes_for_file(&self, file: FileId) -> Option<&[(FileId, Range<usize>)]> {
286 self.nodes.get(&file).map(|f| f.as_slice())
287 }
288
289 fn validate(&self, root: FileId) -> Vec<IncludeError> {
295 let edges = match self.nodes.get(&root) {
296 None => return Vec::new(),
297 Some(edges) => edges,
298 };
299
300 let mut stack = vec![(root, edges, 0_usize)];
301 let mut seen = HashSet::new();
302 let mut bad_edges = Vec::new();
303
304 while let Some((node, edges, cur_edge)) = stack.pop() {
305 if let Some((child, stmt)) = edges.get(cur_edge) {
306 stack.push((node, edges, cur_edge + 1));
308 if stack.len() >= MAX_INCLUDE_DEPTH - 1 {
309 bad_edges.push(IncludeError {
310 file: node,
311 statement_idx: cur_edge,
312 range: stmt.clone(),
313 kind: IncludeErrorKind::ToDeep,
314 });
315 continue;
316 }
317
318 if seen.insert(*child) {
320 if let Some(child_edges) = self.nodes.get(child) {
321 stack.push((*child, child_edges, 0));
322 }
323 } else if stack.iter().any(|(ancestor, _, _)| ancestor == child) {
324 bad_edges.push(IncludeError {
326 file: node,
327 statement_idx: cur_edge,
328 range: stmt.clone(),
329 kind: IncludeErrorKind::Cycle,
330 });
331 }
332 }
333 }
334 bad_edges
335 }
336}
337
338fn parse_src(
340 src: &Source,
341 glyph_map: Option<&GlyphMap>,
342 scope: Kind,
343) -> (Node, Vec<Diagnostic>, Vec<IncludeStatement>) {
344 let mut sink = AstSink::new(src.text(), src.id(), glyph_map);
345 {
346 let mut parser = Parser::new(src.text(), &mut sink);
347 match scope {
348 Kind::FeatureNode => {
349 parser.start_node(Kind::SourceFile);
350 super::grammar::eat_feature_block_items(&mut parser);
351 parser.eat_trivia();
352 parser.finish_node();
353 }
354 Kind::SourceFile => super::grammar::root(&mut parser),
355 other => {
356 log::warn!("encountered include statement in unhandled scope '{other}'");
357 super::grammar::root(&mut parser);
359 }
360 }
361 }
362 sink.finish()
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::{
369 Kind,
370 token_tree::{TreeBuilder, typed},
371 };
372
373 fn make_ids<const N: usize>() -> [FileId; N] {
374 let mut result = [FileId::CURRENT_FILE; N];
375 result.iter_mut().for_each(|id| *id = FileId::next());
376 result
377 }
378
379 #[test]
381 fn cycle_detection() {
382 let [a, b, c, d] = make_ids();
383 let statement = {
384 let mut builder = TreeBuilder::default();
385 builder.start_node(Kind::IncludeNode);
386 builder.token(Kind::IncludeKw, "include");
387 builder.token(Kind::LParen, "(");
388 builder.token(Kind::Path, "file.fea");
389 builder.token(Kind::LParen, ")");
390 builder.token(Kind::Semi, ";");
391 builder.finish_node(false, None);
392 builder.finish()
393 };
394 let statement = typed::Include::cast(&statement.into()).unwrap();
395 let mut graph = IncludeGraph::default();
396 graph.add_edge(a, (b, statement.range()));
397 graph.add_edge(b, (c, statement.range()));
398 graph.add_edge(c, (d, statement.range()));
399 graph.add_edge(d, (b, statement.range()));
400
401 let result = graph.validate(a);
402 assert_eq!(result[0].file, d);
403 assert_eq!(result[0].range, 0..18);
404 }
405
406 #[test]
407 fn skip_cycle_in_build() {
408 let parse = ParseContext::parse(
409 "a".into(),
410 None,
411 Box::new(|path: &Path| match path.to_str().unwrap() {
412 "a" => Ok("include(bb);".into()),
413 "bb" => Ok("include(a);".into()),
414 _ => Err(SourceLoadError::new(
415 path.to_owned(),
416 std::io::Error::new(std::io::ErrorKind::NotFound, "oh no"),
417 )),
418 }),
419 )
420 .unwrap();
421 let (resolved, errs) = parse.generate_parse_tree();
422 assert_eq!(errs.len(), 1);
423 assert_eq!(resolved.root.text_len(), "include(bb);".len());
424 }
425
426 #[test]
427 fn assembly_basic() {
428 let file_a = "\
429 include(b);\n\
430 # hmm\n\
431 include(c);";
432 let file_b = "languagesystem dflt DFLT;\n";
433 let file_c = "feature kern {\n pos a b 20;\n } kern;";
434
435 let b_len = file_b.len();
436 let c_len = file_c.len();
437
438 let parse = ParseContext::parse(
439 "file_a".into(),
440 None,
441 Box::new(|path: &Path| match path.to_str().unwrap() {
442 "file_a" => Ok(file_a.into()),
443 "b" => Ok(file_b.into()),
444 "c" => Ok(file_c.into()),
445 _ => Err(SourceLoadError::new(
446 path.into(),
447 std::io::Error::new(std::io::ErrorKind::NotFound, "oh no"),
448 )),
449 }),
450 )
451 .unwrap();
452
453 let a_id = parse.sources.id_for_path("file_a").unwrap();
454 let b_id = parse.sources.id_for_path("b").unwrap();
455 let c_id = parse.sources.id_for_path("c").unwrap();
456
457 let (resolved, errs) = parse.generate_parse_tree();
458 assert!(errs.is_empty(), "{errs:?}");
459 let top_level_nodes = resolved
460 .root
461 .iter_children()
462 .filter_map(|n| n.as_node())
463 .collect::<Vec<_>>();
464 let inter_node_len = "\n# hmm\n".len();
465 assert_eq!(top_level_nodes.len(), 2);
466 assert_eq!(top_level_nodes[0].kind(), Kind::LanguageSystemNode);
467 assert_eq!(top_level_nodes[0].range(), 0..b_len - 1); let node_2_start = b_len + inter_node_len;
469 assert_eq!(
470 top_level_nodes[1].range(),
471 node_2_start..node_2_start + c_len,
472 );
473 assert_eq!(top_level_nodes[1].kind(), Kind::FeatureNode);
474
475 assert_eq!(resolved.map.resolve_range(10..15), (b_id, 10..15));
477 assert_eq!(resolved.map.resolve_range(29..33), (a_id, 14..18));
478 assert_eq!(resolved.map.resolve_range(49..52), (c_id, 16..19));
479 }
480}