fea_rs/parse/
context.rs

1//! parsing and resolving includes
2
3use std::{
4    collections::{HashMap, HashSet},
5    ops::Range,
6    path::{Path, PathBuf},
7    sync::Arc,
8};
9
10use super::source::{Source, SourceLoadError, SourceLoader, SourceResolver};
11use super::{FileId, ParseTree, Parser, SourceList, SourceMap};
12use crate::{
13    Diagnostic, DiagnosticSet, GlyphMap, Kind, Node,
14    token_tree::{
15        AstSink,
16        typed::{self, AstNode as _},
17    },
18};
19
20const MAX_INCLUDE_DEPTH: usize = 50;
21
22/// Oversees parsing, following, resolving and validating input statements.
23///
24/// Includes are annoying. Existing tools tend to handle them as they're
25/// encountered, pushing another parser onto the stack and building the tree
26/// in-place. This doesn't work for us, because we want to be able to preserve
27/// the original source locations for tokens so that we can provide good error
28/// messages.
29///
30/// We handle this in a reasonably straight-forward way: instead of parsing
31/// includes immediately, we return a list of the includes found in each
32/// source file. We use these to build a graph of include statements, and then
33/// we also add these files to a queue, skipping files that have been parsed
34/// already. Importantly, we don't worry about recursion or include depth
35/// at parse time; we just parse every file we find, and if there's a cycle
36/// we avoid it by keeping track of what we've already parsed.
37///
38/// Once parsing is finished, we use our `IncludeGraph` to validate that there
39/// are no cycles, and that the depth limit is not exceeded.
40///
41/// After parsing, you use [`generate_parse_tree`] to validate and assemble
42/// the parsed sources into a single parse tree. This is also where validation
43/// occurs; if there are any errors in include statements, those statements
44/// are ignored when the tree is built, and an error is recorded: however we will
45/// always attempt to construct *some* tree.
46///
47/// In the future, it should be possible to augment this type to allow for use by
48/// a language server or similar, where an edit could be applied to a particular
49/// source, and then only that source would need to be recompiled.
50///
51/// But we don't take advantage of that, yet, so this is currently more like
52/// an intermediate type for generating a `ParseTree`.
53///
54/// [`generate_parse_tree`]: ParseContext::generate_parse_tree
55#[derive(Debug)]
56pub(crate) struct ParseContext {
57    root_id: FileId,
58    sources: Arc<SourceList>,
59    parsed_files: HashMap<FileId, (Node, Vec<Diagnostic>)>,
60    graph: IncludeGraph,
61}
62
63/// A simple graph of files and their includes.
64///
65/// We maintain this in order to validate that the input does not contain
66/// any cyclical include statements, and does not exceed the maximum include
67/// depth of 50.
68#[derive(Clone, Debug, Default)]
69struct IncludeGraph {
70    // source file -> (destination file, span-in-source-for-error)
71    nodes: HashMap<FileId, Vec<(FileId, Range<usize>)>>,
72}
73
74/// An include statement in a source file.
75pub struct IncludeStatement {
76    pub(crate) stmt: typed::Include,
77    /// the type of the parent node, dictates how this should be parsed.
78    pub(crate) scope: Kind,
79}
80
81struct IncludeError {
82    file: FileId,
83    /// the index of the problem statement, in the list of that file's includes
84    statement_idx: usize,
85    range: Range<usize>,
86    kind: IncludeErrorKind,
87}
88
89enum IncludeErrorKind {
90    Cycle,
91    ToDeep,
92}
93
94impl IncludeStatement {
95    /// The path part of the statement.
96    ///
97    /// For the statement `include(file.fea)`, this is `file.fea`.
98    fn path(&self) -> &str {
99        &self.stmt.path().text
100    }
101
102    /// The range of the entire include statement.
103    fn stmt_range(&self) -> Range<usize> {
104        self.stmt.range()
105    }
106
107    /// The range of just the path text.
108    fn path_range(&self) -> Range<usize> {
109        self.stmt.path().range()
110    }
111}
112
113impl ParseContext {
114    /// Attempt to parse the feature file at `path` and any includes.
115    ///
116    /// This will only error in unrecoverable cases, such as if `path` cannot
117    /// be read.
118    ///
119    /// After parsing, you can call [`generate_parse_tree`] in order to generate
120    /// a unified parse tree suitable for compilation.
121    ///
122    /// [`generate_parse_tree`]: ParseContext::generate_parse_tree
123    pub(crate) fn parse(
124        path: PathBuf,
125        glyph_map: Option<&GlyphMap>,
126        resolver: Box<dyn SourceResolver>,
127    ) -> Result<Self, SourceLoadError> {
128        let mut sources = SourceLoader::new(resolver);
129        let root_id = sources.source_for_path(&path, None)?;
130        let mut queue = vec![(root_id, Kind::SourceFile)];
131        let mut parsed_files = HashMap::new();
132        let mut includes = IncludeGraph::default();
133
134        while let Some((id, scope)) = queue.pop() {
135            // skip things we've already parsed.
136            if parsed_files.contains_key(&id) {
137                continue;
138            }
139            let source = sources.get(&id).unwrap();
140            let (node, mut errors, include_stmts) = parse_src(source, glyph_map, scope);
141            errors.iter_mut().for_each(|e| e.message.file = id);
142
143            parsed_files.insert(source.id(), (node, errors));
144            if include_stmts.is_empty() {
145                continue;
146            }
147
148            // we need to drop `source` so we can mutate source_map below
149            let source_id = source.id();
150
151            for include in &include_stmts {
152                match sources.source_for_path(Path::new(include.path()), Some(source_id)) {
153                    Ok(included_id) => {
154                        includes.add_edge(id, (included_id, include.stmt_range()));
155                        queue.push((included_id, include.scope));
156                    }
157                    Err(e) => {
158                        let range = include.path_range();
159                        parsed_files.get_mut(&id).unwrap().1.push(Diagnostic::error(
160                            id,
161                            range,
162                            e.to_string(),
163                        ));
164                    }
165                }
166            }
167        }
168
169        Ok(ParseContext {
170            root_id,
171            sources: sources.into_inner(),
172            parsed_files,
173            graph: includes,
174        })
175    }
176
177    pub(crate) fn root_id(&self) -> FileId {
178        self.root_id
179    }
180
181    /// Construct a `ParseTree`, and return any diagnostics.
182    ///
183    /// This method also performs validation of include statements.
184    pub(crate) fn generate_parse_tree(self) -> (ParseTree, DiagnosticSet) {
185        let mut all_errors = self
186            .parsed_files
187            .iter()
188            .flat_map(|(_, (_, errs))| errs.iter())
189            .cloned()
190            .collect::<Vec<_>>();
191        let include_errors = self.graph.validate(self.root_id());
192        // record any errors:
193        for IncludeError {
194            file, range, kind, ..
195        } in &include_errors
196        {
197            // find statement
198            let message = match kind {
199                IncludeErrorKind::Cycle => "cyclical include statement",
200                IncludeErrorKind::ToDeep => "exceded maximum include depth",
201            };
202            all_errors.push(Diagnostic::error(*file, range.clone(), message));
203        }
204
205        let mut map = SourceMap::default();
206        let mut root = self.generate_recurse(self.root_id(), &include_errors, &mut map, 0);
207        let needs_update_positions = self.parsed_files.len() > 1;
208        // we need to do this before updating positions, since it mutates and
209        // requires that there exist only one reference (via Arc) to the node
210        drop(self.parsed_files);
211        if needs_update_positions {
212            root.update_positions_from_root();
213        }
214
215        let diagnostics = DiagnosticSet {
216            messages: all_errors,
217            sources: self.sources.clone(),
218            max_to_print: usize::MAX,
219        };
220
221        (
222            ParseTree {
223                root,
224                map: Arc::new(map),
225                sources: self.sources,
226            },
227            diagnostics,
228        )
229    }
230
231    /// recursively construct the output tree.
232    ///
233    /// The final result will be a Vec of len 1, but intermediate results
234    /// can be longer. This is because an include statement can cause us to parse
235    /// from within another node, instead of always parsing a root node.
236    fn generate_recurse(
237        &self,
238        id: FileId,
239        skip: &[IncludeError],
240        source_map: &mut SourceMap,
241        offset: usize,
242    ) -> Node {
243        let this_node = self.parsed_files[&id].0.clone();
244        let self_len = this_node.text_len();
245        let mut self_pos = 0;
246        let mut global_pos = offset;
247        let this_node = match self.graph.includes_for_file(id) {
248            Some(includes) => {
249                let mut edits = Vec::with_capacity(includes.len());
250
251                for (i, (child_id, stmt)) in includes.iter().enumerate() {
252                    if skip
253                        .iter()
254                        .any(|err| err.file == id && err.statement_idx == i)
255                    {
256                        continue;
257                    }
258                    // add everything up to this attach to the sourcemap
259                    let pre_len = stmt.start - self_pos;
260                    let pre_range = global_pos..global_pos + pre_len;
261                    source_map.add_entry(pre_range, (id, self_pos));
262                    self_pos = stmt.end;
263                    global_pos += pre_len;
264                    let child_node = self.generate_recurse(*child_id, skip, source_map, global_pos);
265                    global_pos += child_node.text_len();
266                    edits.push((stmt.clone(), child_node));
267                }
268                this_node.edit(edits, true)
269            }
270            None => this_node,
271        };
272        // now add any remaining contents to source_map
273        let remain_len = self_len - self_pos;
274        let remaining_range = global_pos..global_pos + remain_len;
275        source_map.add_entry(remaining_range, (id, self_pos));
276        this_node
277    }
278}
279
280impl IncludeGraph {
281    fn add_edge(&mut self, from: FileId, to: (FileId, Range<usize>)) {
282        self.nodes.entry(from).or_default().push(to);
283    }
284
285    fn includes_for_file(&self, file: FileId) -> Option<&[(FileId, Range<usize>)]> {
286        self.nodes.get(&file).map(|f| f.as_slice())
287    }
288
289    /// Validate the graph of include statements, returning any problems.
290    ///
291    /// If the result is non-empty, each returned error should be converted to
292    /// d to diagnostics by the caller, and those statements should
293    /// not be resolved when building the final tree.
294    fn validate(&self, root: FileId) -> Vec<IncludeError> {
295        let edges = match self.nodes.get(&root) {
296            None => return Vec::new(),
297            Some(edges) => edges,
298        };
299
300        let mut stack = vec![(root, edges, 0_usize)];
301        let mut seen = HashSet::new();
302        let mut bad_edges = Vec::new();
303
304        while let Some((node, edges, cur_edge)) = stack.pop() {
305            if let Some((child, stmt)) = edges.get(cur_edge) {
306                // push parent, advancing idx
307                stack.push((node, edges, cur_edge + 1));
308                if stack.len() >= MAX_INCLUDE_DEPTH - 1 {
309                    bad_edges.push(IncludeError {
310                        file: node,
311                        statement_idx: cur_edge,
312                        range: stmt.clone(),
313                        kind: IncludeErrorKind::ToDeep,
314                    });
315                    continue;
316                }
317
318                // only recurse if we haven't seen this node yet
319                if seen.insert(*child) {
320                    if let Some(child_edges) = self.nodes.get(child) {
321                        stack.push((*child, child_edges, 0));
322                    }
323                } else if stack.iter().any(|(ancestor, _, _)| ancestor == child) {
324                    // we have a cycle
325                    bad_edges.push(IncludeError {
326                        file: node,
327                        statement_idx: cur_edge,
328                        range: stmt.clone(),
329                        kind: IncludeErrorKind::Cycle,
330                    });
331                }
332            }
333        }
334        bad_edges
335    }
336}
337
338/// Parse a single source file.
339fn parse_src(
340    src: &Source,
341    glyph_map: Option<&GlyphMap>,
342    scope: Kind,
343) -> (Node, Vec<Diagnostic>, Vec<IncludeStatement>) {
344    let mut sink = AstSink::new(src.text(), src.id(), glyph_map);
345    {
346        let mut parser = Parser::new(src.text(), &mut sink);
347        match scope {
348            Kind::FeatureNode => {
349                parser.start_node(Kind::SourceFile);
350                super::grammar::eat_feature_block_items(&mut parser);
351                parser.eat_trivia();
352                parser.finish_node();
353            }
354            Kind::SourceFile => super::grammar::root(&mut parser),
355            other => {
356                log::warn!("encountered include statement in unhandled scope '{other}'");
357                // just parse as root, like we would have originally
358                super::grammar::root(&mut parser);
359            }
360        }
361    }
362    sink.finish()
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::{
369        Kind,
370        token_tree::{TreeBuilder, typed},
371    };
372
373    fn make_ids<const N: usize>() -> [FileId; N] {
374        let mut result = [FileId::CURRENT_FILE; N];
375        result.iter_mut().for_each(|id| *id = FileId::next());
376        result
377    }
378
379    /// Ensure we error if there are cyclical includes
380    #[test]
381    fn cycle_detection() {
382        let [a, b, c, d] = make_ids();
383        let statement = {
384            let mut builder = TreeBuilder::default();
385            builder.start_node(Kind::IncludeNode);
386            builder.token(Kind::IncludeKw, "include");
387            builder.token(Kind::LParen, "(");
388            builder.token(Kind::Path, "file.fea");
389            builder.token(Kind::LParen, ")");
390            builder.token(Kind::Semi, ";");
391            builder.finish_node(false, None);
392            builder.finish()
393        };
394        let statement = typed::Include::cast(&statement.into()).unwrap();
395        let mut graph = IncludeGraph::default();
396        graph.add_edge(a, (b, statement.range()));
397        graph.add_edge(b, (c, statement.range()));
398        graph.add_edge(c, (d, statement.range()));
399        graph.add_edge(d, (b, statement.range()));
400
401        let result = graph.validate(a);
402        assert_eq!(result[0].file, d);
403        assert_eq!(result[0].range, 0..18);
404    }
405
406    #[test]
407    fn skip_cycle_in_build() {
408        let parse = ParseContext::parse(
409            "a".into(),
410            None,
411            Box::new(|path: &Path| match path.to_str().unwrap() {
412                "a" => Ok("include(bb);".into()),
413                "bb" => Ok("include(a);".into()),
414                _ => Err(SourceLoadError::new(
415                    path.to_owned(),
416                    std::io::Error::new(std::io::ErrorKind::NotFound, "oh no"),
417                )),
418            }),
419        )
420        .unwrap();
421        let (resolved, errs) = parse.generate_parse_tree();
422        assert_eq!(errs.len(), 1);
423        assert_eq!(resolved.root.text_len(), "include(bb);".len());
424    }
425
426    #[test]
427    fn assembly_basic() {
428        let file_a = "\
429        include(b);\n\
430        # hmm\n\
431        include(c);";
432        let file_b = "languagesystem dflt DFLT;\n";
433        let file_c = "feature kern {\n pos a b 20;\n } kern;";
434
435        let b_len = file_b.len();
436        let c_len = file_c.len();
437
438        let parse = ParseContext::parse(
439            "file_a".into(),
440            None,
441            Box::new(|path: &Path| match path.to_str().unwrap() {
442                "file_a" => Ok(file_a.into()),
443                "b" => Ok(file_b.into()),
444                "c" => Ok(file_c.into()),
445                _ => Err(SourceLoadError::new(
446                    path.into(),
447                    std::io::Error::new(std::io::ErrorKind::NotFound, "oh no"),
448                )),
449            }),
450        )
451        .unwrap();
452
453        let a_id = parse.sources.id_for_path("file_a").unwrap();
454        let b_id = parse.sources.id_for_path("b").unwrap();
455        let c_id = parse.sources.id_for_path("c").unwrap();
456
457        let (resolved, errs) = parse.generate_parse_tree();
458        assert!(errs.is_empty(), "{errs:?}");
459        let top_level_nodes = resolved
460            .root
461            .iter_children()
462            .filter_map(|n| n.as_node())
463            .collect::<Vec<_>>();
464        let inter_node_len = "\n# hmm\n".len();
465        assert_eq!(top_level_nodes.len(), 2);
466        assert_eq!(top_level_nodes[0].kind(), Kind::LanguageSystemNode);
467        assert_eq!(top_level_nodes[0].range(), 0..b_len - 1); // ignore newline
468        let node_2_start = b_len + inter_node_len;
469        assert_eq!(
470            top_level_nodes[1].range(),
471            node_2_start..node_2_start + c_len,
472        );
473        assert_eq!(top_level_nodes[1].kind(), Kind::FeatureNode);
474
475        //resolved.root.debug_print_structure(true);
476        assert_eq!(resolved.map.resolve_range(10..15), (b_id, 10..15));
477        assert_eq!(resolved.map.resolve_range(29..33), (a_id, 14..18));
478        assert_eq!(resolved.map.resolve_range(49..52), (c_id, 16..19));
479    }
480}