use panproto_schema::{Protocol, Schema, SchemaBuilder};
use rustc_hash::FxHashSet;
use crate::error::ParseError;
use crate::id_scheme::IdGenerator;
use crate::theory_extract::ExtractedTheoryMeta;
const SCOPE_INTRODUCING_KINDS: &[&str] = &[
"function_declaration",
"function_definition",
"method_declaration",
"method_definition",
"class_declaration",
"class_definition",
"interface_declaration",
"struct_item",
"enum_item",
"enum_declaration",
"impl_item",
"trait_item",
"module",
"namespace_definition",
"package_declaration",
];
const BLOCK_KINDS: &[&str] = &[
"block",
"statement_block",
"compound_statement",
"declaration_list",
"field_declaration_list",
"enum_body",
"class_body",
"interface_body",
"module_body",
];
#[derive(Debug, Clone)]
pub struct WalkerConfig {
pub extra_scope_kinds: Vec<String>,
pub extra_block_kinds: Vec<String>,
pub name_fields: Vec<String>,
pub capture_comments: bool,
pub capture_formatting: bool,
}
impl Default for WalkerConfig {
fn default() -> Self {
Self {
extra_scope_kinds: Vec::new(),
extra_block_kinds: Vec::new(),
name_fields: vec!["name".to_owned(), "identifier".to_owned()],
capture_comments: true,
capture_formatting: true,
}
}
}
pub struct AstWalker<'a> {
source: &'a [u8],
theory_meta: &'a ExtractedTheoryMeta,
protocol: &'a Protocol,
config: WalkerConfig,
scope_kinds: FxHashSet<String>,
block_kinds: FxHashSet<String>,
}
impl<'a> AstWalker<'a> {
#[must_use]
pub fn new(
source: &'a [u8],
theory_meta: &'a ExtractedTheoryMeta,
protocol: &'a Protocol,
config: WalkerConfig,
) -> Self {
let mut scope_kinds: FxHashSet<String> = SCOPE_INTRODUCING_KINDS
.iter()
.map(|s| (*s).to_owned())
.collect();
for kind in &config.extra_scope_kinds {
scope_kinds.insert(kind.clone());
}
let mut block_kinds: FxHashSet<String> =
BLOCK_KINDS.iter().map(|s| (*s).to_owned()).collect();
for kind in &config.extra_block_kinds {
block_kinds.insert(kind.clone());
}
Self {
source,
theory_meta,
protocol,
config,
scope_kinds,
block_kinds,
}
}
pub fn walk(&self, tree: &tree_sitter::Tree, file_path: &str) -> Result<Schema, ParseError> {
let mut id_gen = IdGenerator::new(file_path);
let builder = SchemaBuilder::new(self.protocol);
let root = tree.root_node();
let builder = self.walk_node(root, builder, &mut id_gen, None)?;
builder.build().map_err(|e| ParseError::SchemaConstruction {
reason: e.to_string(),
})
}
fn walk_node(
&self,
node: tree_sitter::Node<'_>,
mut builder: SchemaBuilder,
id_gen: &mut IdGenerator,
parent_vertex_id: Option<&str>,
) -> Result<SchemaBuilder, ParseError> {
if !node.is_named() {
return Ok(builder);
}
let kind = node.kind();
let is_root_wrapper = parent_vertex_id.is_none()
&& (kind == "program"
|| kind == "source_file"
|| kind == "module"
|| kind == "translation_unit");
let vertex_id = if is_root_wrapper {
id_gen.current_prefix()
} else if self.scope_kinds.contains(kind) {
let name = self.extract_scope_name(&node);
match name {
Some(n) => id_gen.named_id(&n),
None => id_gen.anonymous_id(),
}
} else {
id_gen.anonymous_id()
};
let effective_kind = if self.protocol.obj_kinds.is_empty() {
kind
} else if self.protocol.obj_kinds.iter().any(|k| k == kind) {
kind
} else if !self.theory_meta.vertex_kinds.is_empty()
&& self.theory_meta.vertex_kinds.iter().any(|k| k == kind)
{
kind
} else {
"node"
};
builder = builder
.vertex(&vertex_id, effective_kind, None)
.map_err(|e| ParseError::SchemaConstruction {
reason: format!("vertex '{vertex_id}' ({kind}): {e}"),
})?;
if let Some(parent_id) = parent_vertex_id {
let edge_kind = node
.parent()
.and_then(|p| {
for i in 0..p.child_count() {
if let Some(child) = p.child(i) {
if child.id() == node.id() {
return u32::try_from(i)
.ok()
.and_then(|idx| p.field_name_for_child(idx));
}
}
}
None
})
.unwrap_or("child_of");
builder = builder
.edge(parent_id, &vertex_id, edge_kind, None)
.map_err(|e| ParseError::SchemaConstruction {
reason: format!("edge {parent_id} -> {vertex_id} ({edge_kind}): {e}"),
})?;
}
builder = builder.constraint(&vertex_id, "start-byte", &node.start_byte().to_string());
builder = builder.constraint(&vertex_id, "end-byte", &node.end_byte().to_string());
if node.named_child_count() == 0 {
if let Ok(text) = node.utf8_text(self.source) {
builder = builder.constraint(&vertex_id, "literal-value", text);
}
}
if self.config.capture_formatting {
builder = self.emit_formatting_constraints(node, &vertex_id, builder);
}
let entered_scope = if self.scope_kinds.contains(kind) && !is_root_wrapper {
match self.extract_scope_name(&node) {
Some(n) => id_gen.push_named_scope(&n),
None => {
id_gen.push_anonymous_scope();
}
}
true
} else if self.block_kinds.contains(kind) {
id_gen.push_anonymous_scope();
true
} else {
false
};
builder = self.walk_children_with_interstitials(node, builder, id_gen, &vertex_id)?;
if entered_scope {
id_gen.pop_scope();
}
Ok(builder)
}
fn walk_children_with_interstitials(
&self,
node: tree_sitter::Node<'_>,
mut builder: SchemaBuilder,
id_gen: &mut IdGenerator,
vertex_id: &str,
) -> Result<SchemaBuilder, ParseError> {
let cursor = &mut node.walk();
let children: Vec<_> = node.named_children(cursor).collect();
let mut interstitial_idx = 0;
let mut prev_end = node.start_byte();
for child in &children {
let gap_start = prev_end;
let gap_end = child.start_byte();
builder = self.capture_interstitial(
builder,
vertex_id,
gap_start,
gap_end,
&mut interstitial_idx,
);
builder = self.walk_node(*child, builder, id_gen, Some(vertex_id))?;
prev_end = child.end_byte();
}
builder = self.capture_interstitial(
builder,
vertex_id,
prev_end,
node.end_byte(),
&mut interstitial_idx,
);
Ok(builder)
}
fn capture_interstitial(
&self,
mut builder: SchemaBuilder,
vertex_id: &str,
gap_start: usize,
gap_end: usize,
idx: &mut usize,
) -> SchemaBuilder {
if gap_end > gap_start && gap_end <= self.source.len() {
if let Ok(gap_text) = std::str::from_utf8(&self.source[gap_start..gap_end]) {
if !gap_text.is_empty() {
let sort = format!("interstitial-{}", *idx);
builder = builder.constraint(vertex_id, &sort, gap_text);
builder = builder.constraint(
vertex_id,
&format!("{sort}-start-byte"),
&gap_start.to_string(),
);
*idx += 1;
}
}
}
builder
}
fn extract_scope_name(&self, node: &tree_sitter::Node<'_>) -> Option<String> {
for field_name in &self.config.name_fields {
if let Some(name_node) = node.child_by_field_name(field_name.as_bytes()) {
if let Ok(text) = name_node.utf8_text(self.source) {
return Some(text.to_owned());
}
}
}
None
}
fn emit_formatting_constraints(
&self,
node: tree_sitter::Node<'_>,
vertex_id: &str,
mut builder: SchemaBuilder,
) -> SchemaBuilder {
let start = node.start_position();
if start.column > 0 {
let line_start = node.start_byte().saturating_sub(start.column);
if line_start < self.source.len() {
let indent_end = line_start + start.column.min(self.source.len() - line_start);
if let Ok(indent) = std::str::from_utf8(&self.source[line_start..indent_end]) {
if !indent.is_empty() && indent.trim().is_empty() {
builder = builder.constraint(vertex_id, "indent", indent);
}
}
}
}
if let Some(prev) = node.prev_named_sibling() {
let gap_start = prev.end_byte();
let gap_end = node.start_byte();
if gap_start < gap_end && gap_end <= self.source.len() {
let gap = &self.source[gap_start..gap_end];
let blank_lines = memchr::memchr_iter(b'\n', gap).count().saturating_sub(1);
if blank_lines > 0 {
builder = builder.constraint(
vertex_id,
"blank-lines-before",
&blank_lines.to_string(),
);
}
}
}
builder
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_test_protocol() -> Protocol {
Protocol {
name: "test".into(),
schema_theory: "ThTest".into(),
instance_theory: "ThTestInst".into(),
schema_composition: None,
instance_composition: None,
obj_kinds: vec![], edge_rules: vec![],
constraint_sorts: vec![],
has_order: true,
has_coproducts: false,
has_recursion: false,
has_causal: false,
nominal_identity: false,
has_defaults: false,
has_coercions: false,
has_mergers: false,
has_policies: false,
}
}
fn make_test_meta() -> ExtractedTheoryMeta {
use panproto_gat::{Sort, Theory};
ExtractedTheoryMeta {
theory: Theory::new("ThTest", vec![Sort::simple("Vertex")], vec![], vec![]),
supertypes: FxHashSet::default(),
subtype_map: Vec::new(),
optional_fields: FxHashSet::default(),
ordered_fields: FxHashSet::default(),
vertex_kinds: Vec::new(),
edge_kinds: Vec::new(),
}
}
#[cfg(feature = "grammars")]
fn get_language(name: &str) -> tree_sitter::Language {
panproto_grammars::grammars()
.into_iter()
.find(|g| g.name == name)
.unwrap_or_else(|| panic!("grammar '{name}' not enabled in features"))
.language
}
#[test]
#[cfg(feature = "grammars")]
fn walk_simple_typescript() {
let source = b"function greet(name: string): string { return name; }";
let mut parser = tree_sitter::Parser::new();
parser.set_language(&get_language("typescript")).unwrap();
let tree = parser.parse(source, None).unwrap();
let protocol = make_test_protocol();
let meta = make_test_meta();
let walker = AstWalker::new(source, &meta, &protocol, WalkerConfig::default());
let schema = walker.walk(&tree, "test.ts").unwrap();
assert!(
schema.vertices.len() > 1,
"expected multiple vertices, got {}",
schema.vertices.len()
);
let root_name: panproto_gat::Name = "test.ts".into();
assert!(
schema.vertices.contains_key(&root_name),
"missing root vertex"
);
}
#[test]
#[cfg(feature = "grammars")]
fn walk_simple_python() {
let source = b"def add(a, b):\n return a + b\n";
let mut parser = tree_sitter::Parser::new();
parser.set_language(&get_language("python")).unwrap();
let tree = parser.parse(source, None).unwrap();
let protocol = make_test_protocol();
let meta = make_test_meta();
let walker = AstWalker::new(source, &meta, &protocol, WalkerConfig::default());
let schema = walker.walk(&tree, "test.py").unwrap();
assert!(
schema.vertices.len() > 1,
"expected multiple vertices, got {}",
schema.vertices.len()
);
}
#[test]
#[cfg(feature = "grammars")]
fn walk_simple_rust() {
let source = b"fn main() { let x = 42; println!(\"{}\", x); }";
let mut parser = tree_sitter::Parser::new();
parser.set_language(&get_language("rust")).unwrap();
let tree = parser.parse(source, None).unwrap();
let protocol = make_test_protocol();
let meta = make_test_meta();
let walker = AstWalker::new(source, &meta, &protocol, WalkerConfig::default());
let schema = walker.walk(&tree, "test.rs").unwrap();
assert!(
schema.vertices.len() > 1,
"expected multiple vertices, got {}",
schema.vertices.len()
);
}
#[cfg(feature = "group-data")]
fn assert_roundtrip(grammar_name: &str, source: &[u8], file_path: &str) {
use crate::registry::AstParser;
let grammar = panproto_grammars::grammars()
.into_iter()
.find(|g| g.name == grammar_name)
.unwrap_or_else(|| panic!("grammar '{grammar_name}' not enabled"));
let config = crate::languages::walker_configs::walker_config_for(grammar_name);
let lang_parser = crate::languages::common::LanguageParser::from_language(
grammar_name,
grammar.extensions.to_vec(),
grammar.language,
grammar.node_types,
config,
)
.unwrap();
let schema = lang_parser.parse(source, file_path).unwrap();
let emitted = lang_parser.emit(&schema).unwrap();
assert_eq!(
std::str::from_utf8(source).unwrap(),
std::str::from_utf8(&emitted).unwrap(),
"round-trip failed for {grammar_name}: emitted bytes differ from source"
);
}
#[test]
#[cfg(feature = "group-data")]
fn roundtrip_json_simple() {
assert_roundtrip("json", br#"{"name": "test", "value": 42}"#, "test.json");
}
#[test]
#[cfg(feature = "group-data")]
fn roundtrip_json_formatted() {
let source =
b"{\n \"name\": \"test\",\n \"value\": 42,\n \"nested\": {\n \"a\": true\n }\n}";
assert_roundtrip("json", source, "test.json");
}
#[test]
#[cfg(feature = "group-data")]
fn roundtrip_json_array() {
let source = b"[\n 1,\n 2,\n 3\n]";
assert_roundtrip("json", source, "test.json");
}
#[test]
#[cfg(feature = "group-data")]
fn roundtrip_xml_simple() {
let source = b"<root>\n <child attr=\"val\">text</child>\n</root>";
assert_roundtrip("xml", source, "test.xml");
}
#[test]
#[cfg(feature = "group-data")]
fn roundtrip_yaml_simple() {
let source = b"name: test\nvalue: 42\nnested:\n a: true\n";
assert_roundtrip("yaml", source, "test.yaml");
}
#[test]
#[cfg(feature = "group-data")]
fn roundtrip_toml_simple() {
let source = b"[package]\nname = \"test\"\nversion = \"0.1.0\"\n";
assert_roundtrip("toml", source, "test.toml");
}
}