use std::{
collections::{HashMap, HashSet},
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use super::source::{Source, SourceLoadError, SourceLoader, SourceResolver};
use super::{FileId, ParseTree, Parser, SourceList, SourceMap};
use crate::{
Diagnostic, DiagnosticSet, GlyphMap, Kind, Node,
token_tree::{
AstSink,
typed::{self, AstNode as _},
},
};
const MAX_INCLUDE_DEPTH: usize = 50;
#[derive(Debug)]
pub(crate) struct ParseContext {
root_id: FileId,
sources: Arc<SourceList>,
parsed_files: HashMap<FileId, (Node, Vec<Diagnostic>)>,
graph: IncludeGraph,
}
#[derive(Clone, Debug, Default)]
struct IncludeGraph {
nodes: HashMap<FileId, Vec<(FileId, Range<usize>)>>,
}
pub struct IncludeStatement {
pub(crate) stmt: typed::Include,
pub(crate) scope: Kind,
}
struct IncludeError {
file: FileId,
statement_idx: usize,
range: Range<usize>,
kind: IncludeErrorKind,
}
enum IncludeErrorKind {
Cycle,
ToDeep,
}
impl IncludeStatement {
fn path(&self) -> &str {
&self.stmt.path().text
}
fn stmt_range(&self) -> Range<usize> {
self.stmt.range()
}
fn path_range(&self) -> Range<usize> {
self.stmt.path().range()
}
}
impl ParseContext {
pub(crate) fn parse(
path: PathBuf,
glyph_map: Option<&GlyphMap>,
resolver: Box<dyn SourceResolver>,
) -> Result<Self, SourceLoadError> {
let mut sources = SourceLoader::new(resolver);
let root_id = sources.source_for_path(&path, None)?;
let mut queue = vec![(root_id, Kind::SourceFile)];
let mut parsed_files = HashMap::new();
let mut includes = IncludeGraph::default();
while let Some((id, scope)) = queue.pop() {
if parsed_files.contains_key(&id) {
continue;
}
let source = sources.get(&id).unwrap();
let (node, mut errors, include_stmts) = parse_src(source, glyph_map, scope);
errors.iter_mut().for_each(|e| e.message.file = id);
parsed_files.insert(source.id(), (node, errors));
if include_stmts.is_empty() {
continue;
}
let source_id = source.id();
for include in &include_stmts {
match sources.source_for_path(Path::new(include.path()), Some(source_id)) {
Ok(included_id) => {
includes.add_edge(id, (included_id, include.stmt_range()));
queue.push((included_id, include.scope));
}
Err(e) => {
let range = include.path_range();
parsed_files.get_mut(&id).unwrap().1.push(Diagnostic::error(
id,
range,
e.to_string(),
));
}
}
}
}
Ok(ParseContext {
root_id,
sources: sources.into_inner(),
parsed_files,
graph: includes,
})
}
pub(crate) fn root_id(&self) -> FileId {
self.root_id
}
pub(crate) fn generate_parse_tree(self) -> (ParseTree, DiagnosticSet) {
let mut all_errors = self
.parsed_files
.iter()
.flat_map(|(_, (_, errs))| errs.iter())
.cloned()
.collect::<Vec<_>>();
let include_errors = self.graph.validate(self.root_id());
for IncludeError {
file, range, kind, ..
} in &include_errors
{
let message = match kind {
IncludeErrorKind::Cycle => "cyclical include statement",
IncludeErrorKind::ToDeep => "exceded maximum include depth",
};
all_errors.push(Diagnostic::error(*file, range.clone(), message));
}
let mut map = SourceMap::default();
let mut root = self.generate_recurse(self.root_id(), &include_errors, &mut map, 0);
let needs_update_positions = self.parsed_files.len() > 1;
drop(self.parsed_files);
if needs_update_positions {
root.update_positions_from_root();
}
let diagnostics = DiagnosticSet {
messages: all_errors,
sources: self.sources.clone(),
max_to_print: usize::MAX,
};
(
ParseTree {
root,
map: Arc::new(map),
sources: self.sources,
},
diagnostics,
)
}
fn generate_recurse(
&self,
id: FileId,
skip: &[IncludeError],
source_map: &mut SourceMap,
offset: usize,
) -> Node {
let this_node = self.parsed_files[&id].0.clone();
let self_len = this_node.text_len();
let mut self_pos = 0;
let mut global_pos = offset;
let this_node = match self.graph.includes_for_file(id) {
Some(includes) => {
let mut edits = Vec::with_capacity(includes.len());
for (i, (child_id, stmt)) in includes.iter().enumerate() {
if skip
.iter()
.any(|err| err.file == id && err.statement_idx == i)
{
continue;
}
let pre_len = stmt.start - self_pos;
let pre_range = global_pos..global_pos + pre_len;
source_map.add_entry(pre_range, (id, self_pos));
self_pos = stmt.end;
global_pos += pre_len;
let child_node = self.generate_recurse(*child_id, skip, source_map, global_pos);
global_pos += child_node.text_len();
edits.push((stmt.clone(), child_node));
}
this_node.edit(edits, true)
}
None => this_node,
};
let remain_len = self_len - self_pos;
let remaining_range = global_pos..global_pos + remain_len;
source_map.add_entry(remaining_range, (id, self_pos));
this_node
}
}
impl IncludeGraph {
fn add_edge(&mut self, from: FileId, to: (FileId, Range<usize>)) {
self.nodes.entry(from).or_default().push(to);
}
fn includes_for_file(&self, file: FileId) -> Option<&[(FileId, Range<usize>)]> {
self.nodes.get(&file).map(|f| f.as_slice())
}
fn validate(&self, root: FileId) -> Vec<IncludeError> {
let edges = match self.nodes.get(&root) {
None => return Vec::new(),
Some(edges) => edges,
};
let mut stack = vec![(root, edges, 0_usize)];
let mut seen = HashSet::new();
let mut bad_edges = Vec::new();
while let Some((node, edges, cur_edge)) = stack.pop() {
if let Some((child, stmt)) = edges.get(cur_edge) {
stack.push((node, edges, cur_edge + 1));
if stack.len() >= MAX_INCLUDE_DEPTH - 1 {
bad_edges.push(IncludeError {
file: node,
statement_idx: cur_edge,
range: stmt.clone(),
kind: IncludeErrorKind::ToDeep,
});
continue;
}
if seen.insert(*child) {
if let Some(child_edges) = self.nodes.get(child) {
stack.push((*child, child_edges, 0));
}
} else if stack.iter().any(|(ancestor, _, _)| ancestor == child) {
bad_edges.push(IncludeError {
file: node,
statement_idx: cur_edge,
range: stmt.clone(),
kind: IncludeErrorKind::Cycle,
});
}
}
}
bad_edges
}
}
fn parse_src(
src: &Source,
glyph_map: Option<&GlyphMap>,
scope: Kind,
) -> (Node, Vec<Diagnostic>, Vec<IncludeStatement>) {
let mut sink = AstSink::new(src.text(), src.id(), glyph_map);
{
let mut parser = Parser::new(src.text(), &mut sink);
match scope {
Kind::FeatureNode => {
parser.start_node(Kind::SourceFile);
super::grammar::eat_feature_block_items(&mut parser);
parser.eat_trivia();
parser.finish_node();
}
Kind::SourceFile => super::grammar::root(&mut parser),
other => {
log::warn!("encountered include statement in unhandled scope '{other}'");
super::grammar::root(&mut parser);
}
}
}
sink.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
Kind,
token_tree::{TreeBuilder, typed},
};
fn make_ids<const N: usize>() -> [FileId; N] {
let mut result = [FileId::CURRENT_FILE; N];
result.iter_mut().for_each(|id| *id = FileId::next());
result
}
#[test]
fn cycle_detection() {
let [a, b, c, d] = make_ids();
let statement = {
let mut builder = TreeBuilder::default();
builder.start_node(Kind::IncludeNode);
builder.token(Kind::IncludeKw, "include");
builder.token(Kind::LParen, "(");
builder.token(Kind::Path, "file.fea");
builder.token(Kind::LParen, ")");
builder.token(Kind::Semi, ";");
builder.finish_node(false, None);
builder.finish()
};
let statement = typed::Include::cast(&statement.into()).unwrap();
let mut graph = IncludeGraph::default();
graph.add_edge(a, (b, statement.range()));
graph.add_edge(b, (c, statement.range()));
graph.add_edge(c, (d, statement.range()));
graph.add_edge(d, (b, statement.range()));
let result = graph.validate(a);
assert_eq!(result[0].file, d);
assert_eq!(result[0].range, 0..18);
}
#[test]
fn skip_cycle_in_build() {
let parse = ParseContext::parse(
"a".into(),
None,
Box::new(|path: &Path| match path.to_str().unwrap() {
"a" => Ok("include(bb);".into()),
"bb" => Ok("include(a);".into()),
_ => Err(SourceLoadError::new(
path.to_owned(),
std::io::Error::new(std::io::ErrorKind::NotFound, "oh no"),
)),
}),
)
.unwrap();
let (resolved, errs) = parse.generate_parse_tree();
assert_eq!(errs.len(), 1);
assert_eq!(resolved.root.text_len(), "include(bb);".len());
}
#[test]
fn assembly_basic() {
let file_a = "\
include(b);\n\
# hmm\n\
include(c);";
let file_b = "languagesystem dflt DFLT;\n";
let file_c = "feature kern {\n pos a b 20;\n } kern;";
let b_len = file_b.len();
let c_len = file_c.len();
let parse = ParseContext::parse(
"file_a".into(),
None,
Box::new(|path: &Path| match path.to_str().unwrap() {
"file_a" => Ok(file_a.into()),
"b" => Ok(file_b.into()),
"c" => Ok(file_c.into()),
_ => Err(SourceLoadError::new(
path.into(),
std::io::Error::new(std::io::ErrorKind::NotFound, "oh no"),
)),
}),
)
.unwrap();
let a_id = parse.sources.id_for_path("file_a").unwrap();
let b_id = parse.sources.id_for_path("b").unwrap();
let c_id = parse.sources.id_for_path("c").unwrap();
let (resolved, errs) = parse.generate_parse_tree();
assert!(errs.is_empty(), "{errs:?}");
let top_level_nodes = resolved
.root
.iter_children()
.filter_map(|n| n.as_node())
.collect::<Vec<_>>();
let inter_node_len = "\n# hmm\n".len();
assert_eq!(top_level_nodes.len(), 2);
assert_eq!(top_level_nodes[0].kind(), Kind::LanguageSystemNode);
assert_eq!(top_level_nodes[0].range(), 0..b_len - 1); let node_2_start = b_len + inter_node_len;
assert_eq!(
top_level_nodes[1].range(),
node_2_start..node_2_start + c_len,
);
assert_eq!(top_level_nodes[1].kind(), Kind::FeatureNode);
assert_eq!(resolved.map.resolve_range(10..15), (b_id, 10..15));
assert_eq!(resolved.map.resolve_range(29..33), (a_id, 14..18));
assert_eq!(resolved.map.resolve_range(49..52), (c_id, 16..19));
}
}