Skip to main content

panproto_parse/
walker.rs

1//! Generic tree-sitter AST walker that converts parse trees to panproto schemas.
2//!
3//! Because theories are auto-derived from the grammar, the walker is fully generic:
4//! one implementation works for all languages. The node's `kind()` IS the panproto
5//! vertex kind; the field name IS the edge kind. Per-language customization is limited
6//! to formatting constraints and scope detection callbacks.
7
8use panproto_schema::{Protocol, Schema, SchemaBuilder};
9use rustc_hash::FxHashSet;
10
11use crate::error::ParseError;
12use crate::id_scheme::IdGenerator;
13use crate::theory_extract::ExtractedTheoryMeta;
14
15/// Nodes whose kind names suggest they introduce a named scope.
16///
17/// When the walker encounters one of these node kinds, it looks for a `name`
18/// or `identifier` child to use as the scope name in the ID generator.
19const SCOPE_INTRODUCING_KINDS: &[&str] = &[
20    "function_declaration",
21    "function_definition",
22    "method_declaration",
23    "method_definition",
24    "class_declaration",
25    "class_definition",
26    "interface_declaration",
27    "struct_item",
28    "enum_item",
29    "enum_declaration",
30    "impl_item",
31    "trait_item",
32    "module",
33    "namespace_definition",
34    "package_declaration",
35];
36
37/// Nodes whose kind names suggest they contain ordered statement sequences.
38const BLOCK_KINDS: &[&str] = &[
39    "block",
40    "statement_block",
41    "compound_statement",
42    "declaration_list",
43    "field_declaration_list",
44    "enum_body",
45    "class_body",
46    "interface_body",
47    "module_body",
48];
49
50/// Configuration for the walker, allowing per-language customization.
51#[derive(Debug, Clone)]
52pub struct WalkerConfig {
53    /// Additional node kinds that introduce named scopes in this language.
54    pub extra_scope_kinds: Vec<String>,
55    /// Additional node kinds that contain ordered statement sequences.
56    pub extra_block_kinds: Vec<String>,
57    /// Field names to use when looking for the "name" of a scope-introducing node.
58    /// Defaults to `["name", "identifier"]`.
59    pub name_fields: Vec<String>,
60    /// Whether to capture comment nodes as constraints on the following sibling.
61    pub capture_comments: bool,
62    /// Whether to capture whitespace/formatting as constraints.
63    pub capture_formatting: bool,
64}
65
66impl Default for WalkerConfig {
67    fn default() -> Self {
68        Self {
69            extra_scope_kinds: Vec::new(),
70            extra_block_kinds: Vec::new(),
71            name_fields: vec!["name".to_owned(), "identifier".to_owned()],
72            capture_comments: true,
73            capture_formatting: true,
74        }
75    }
76}
77
78/// Generic AST walker that converts a tree-sitter parse tree to a panproto [`Schema`].
79///
80/// The walker uses the auto-derived theory to determine vertex and edge kinds directly
81/// from the tree-sitter AST, requiring no manual mapping table.
82pub struct AstWalker<'a> {
83    /// The source code bytes (needed for extracting text of leaf nodes).
84    source: &'a [u8],
85    /// The auto-derived theory metadata. The `vertex_kinds` set is used to
86    /// filter anonymous/internal tree-sitter nodes that are not part of the
87    /// language's public grammar.
88    theory_meta: &'a ExtractedTheoryMeta,
89    /// The protocol definition (for `SchemaBuilder` validation).
90    protocol: &'a Protocol,
91    /// Per-language configuration.
92    config: WalkerConfig,
93    /// Known scope-introducing kinds (merged from defaults + config).
94    scope_kinds: FxHashSet<String>,
95    /// Known block kinds (merged from defaults + config).
96    block_kinds: FxHashSet<String>,
97}
98
99impl<'a> AstWalker<'a> {
100    /// Create a new walker for the given source, theory, and protocol.
101    #[must_use]
102    pub fn new(
103        source: &'a [u8],
104        theory_meta: &'a ExtractedTheoryMeta,
105        protocol: &'a Protocol,
106        config: WalkerConfig,
107    ) -> Self {
108        let mut scope_kinds: FxHashSet<String> = SCOPE_INTRODUCING_KINDS
109            .iter()
110            .map(|s| (*s).to_owned())
111            .collect();
112        for kind in &config.extra_scope_kinds {
113            scope_kinds.insert(kind.clone());
114        }
115
116        let mut block_kinds: FxHashSet<String> =
117            BLOCK_KINDS.iter().map(|s| (*s).to_owned()).collect();
118        for kind in &config.extra_block_kinds {
119            block_kinds.insert(kind.clone());
120        }
121
122        Self {
123            source,
124            theory_meta,
125            protocol,
126            config,
127            scope_kinds,
128            block_kinds,
129        }
130    }
131
132    /// Walk the entire parse tree and produce a [`Schema`].
133    ///
134    /// # Errors
135    ///
136    /// Returns [`ParseError::SchemaConstruction`] if schema building fails.
137    pub fn walk(&self, tree: &tree_sitter::Tree, file_path: &str) -> Result<Schema, ParseError> {
138        let mut id_gen = IdGenerator::new(file_path);
139        let builder = SchemaBuilder::new(self.protocol);
140        let root = tree.root_node();
141
142        let builder = self.walk_node(root, builder, &mut id_gen, None)?;
143
144        builder.build().map_err(|e| ParseError::SchemaConstruction {
145            reason: e.to_string(),
146        })
147    }
148
149    /// Recursively walk a single node, emitting vertices and edges.
150    fn walk_node(
151        &self,
152        node: tree_sitter::Node<'_>,
153        mut builder: SchemaBuilder,
154        id_gen: &mut IdGenerator,
155        parent_vertex_id: Option<&str>,
156    ) -> Result<SchemaBuilder, ParseError> {
157        // Skip anonymous tokens (punctuation, keywords like `{`, `}`, `,`, etc.).
158        if !node.is_named() {
159            return Ok(builder);
160        }
161
162        let kind = node.kind();
163
164        // Skip the root "program"/"source_file"/"module" wrapper if it just wraps children.
165        // We still process it to emit its children, but do so by iterating directly.
166        let is_root_wrapper = parent_vertex_id.is_none()
167            && (kind == "program"
168                || kind == "source_file"
169                || kind == "module"
170                || kind == "translation_unit");
171
172        // Determine vertex ID.
173        let vertex_id = if is_root_wrapper {
174            // Root wrappers get the file path as their ID.
175            id_gen.current_prefix()
176        } else if self.scope_kinds.contains(kind) {
177            // Scope-introducing nodes use their name child for the ID.
178            let name = self.extract_scope_name(&node);
179            match name {
180                Some(n) => id_gen.named_id(&n),
181                None => id_gen.anonymous_id(),
182            }
183        } else {
184            // All other nodes get positional IDs.
185            id_gen.anonymous_id()
186        };
187
188        // Determine the effective vertex kind. If the theory has extracted vertex kinds,
189        // use those for validation. If the kind is unknown to the theory AND the protocol
190        // has a closed obj_kinds list, fall back to "node".
191        let effective_kind = if self.protocol.obj_kinds.is_empty() {
192            // Open protocol: accept all kinds.
193            kind
194        } else if self.protocol.obj_kinds.iter().any(|k| k == kind) {
195            kind
196        } else if !self.theory_meta.vertex_kinds.is_empty()
197            && self.theory_meta.vertex_kinds.iter().any(|k| k == kind)
198        {
199            // Known in the auto-derived theory even if not in the protocol's obj_kinds.
200            kind
201        } else {
202            "node"
203        };
204
205        builder = builder
206            .vertex(&vertex_id, effective_kind, None)
207            .map_err(|e| ParseError::SchemaConstruction {
208                reason: format!("vertex '{vertex_id}' ({kind}): {e}"),
209            })?;
210
211        // Emit edge from parent to this node.
212        if let Some(parent_id) = parent_vertex_id {
213            // Determine edge kind: use the tree-sitter field name if this node
214            // was accessed via a field, otherwise use "child_of".
215            let edge_kind = node
216                .parent()
217                .and_then(|p| {
218                    // Find which field of the parent this node corresponds to.
219                    for i in 0..p.child_count() {
220                        if let Some(child) = p.child(i) {
221                            if child.id() == node.id() {
222                                return u32::try_from(i)
223                                    .ok()
224                                    .and_then(|idx| p.field_name_for_child(idx));
225                            }
226                        }
227                    }
228                    None
229                })
230                .unwrap_or("child_of");
231
232            builder = builder
233                .edge(parent_id, &vertex_id, edge_kind, None)
234                .map_err(|e| ParseError::SchemaConstruction {
235                    reason: format!("edge {parent_id} -> {vertex_id} ({edge_kind}): {e}"),
236                })?;
237        }
238
239        // Store byte range for position-aware emission.
240        builder = builder.constraint(&vertex_id, "start-byte", &node.start_byte().to_string());
241        builder = builder.constraint(&vertex_id, "end-byte", &node.end_byte().to_string());
242
243        // Emit constraints for leaf nodes (literals, identifiers, operators).
244        if node.named_child_count() == 0 {
245            if let Ok(text) = node.utf8_text(self.source) {
246                builder = builder.constraint(&vertex_id, "literal-value", text);
247            }
248        }
249
250        // Emit formatting constraints if enabled.
251        if self.config.capture_formatting {
252            builder = self.emit_formatting_constraints(node, &vertex_id, builder);
253        }
254
255        // Enter scope if this is a scope-introducing node.
256        let entered_scope = if self.scope_kinds.contains(kind) && !is_root_wrapper {
257            match self.extract_scope_name(&node) {
258                Some(n) => id_gen.push_named_scope(&n),
259                None => {
260                    id_gen.push_anonymous_scope();
261                }
262            }
263            true
264        } else if self.block_kinds.contains(kind) {
265            id_gen.push_anonymous_scope();
266            true
267        } else {
268            false
269        };
270
271        builder = self.walk_children_with_interstitials(node, builder, id_gen, &vertex_id)?;
272
273        if entered_scope {
274            id_gen.pop_scope();
275        }
276
277        Ok(builder)
278    }
279
280    /// Walk named children, capturing interstitial text between them.
281    fn walk_children_with_interstitials(
282        &self,
283        node: tree_sitter::Node<'_>,
284        mut builder: SchemaBuilder,
285        id_gen: &mut IdGenerator,
286        vertex_id: &str,
287    ) -> Result<SchemaBuilder, ParseError> {
288        let cursor = &mut node.walk();
289        let children: Vec<_> = node.named_children(cursor).collect();
290        let mut interstitial_idx = 0;
291        let mut prev_end = node.start_byte();
292
293        for child in &children {
294            let gap_start = prev_end;
295            let gap_end = child.start_byte();
296            builder = self.capture_interstitial(
297                builder,
298                vertex_id,
299                gap_start,
300                gap_end,
301                &mut interstitial_idx,
302            );
303            builder = self.walk_node(*child, builder, id_gen, Some(vertex_id))?;
304            prev_end = child.end_byte();
305        }
306
307        // Trailing interstitial after the last child.
308        builder = self.capture_interstitial(
309            builder,
310            vertex_id,
311            prev_end,
312            node.end_byte(),
313            &mut interstitial_idx,
314        );
315
316        Ok(builder)
317    }
318
319    /// Capture interstitial text between `gap_start` and `gap_end` as a constraint.
320    fn capture_interstitial(
321        &self,
322        mut builder: SchemaBuilder,
323        vertex_id: &str,
324        gap_start: usize,
325        gap_end: usize,
326        idx: &mut usize,
327    ) -> SchemaBuilder {
328        if gap_end > gap_start && gap_end <= self.source.len() {
329            if let Ok(gap_text) = std::str::from_utf8(&self.source[gap_start..gap_end]) {
330                if !gap_text.is_empty() {
331                    let sort = format!("interstitial-{}", *idx);
332                    builder = builder.constraint(vertex_id, &sort, gap_text);
333                    builder = builder.constraint(
334                        vertex_id,
335                        &format!("{sort}-start-byte"),
336                        &gap_start.to_string(),
337                    );
338                    *idx += 1;
339                }
340            }
341        }
342        builder
343    }
344
345    /// Extract the name of a scope-introducing node by looking for name/identifier children.
346    fn extract_scope_name(&self, node: &tree_sitter::Node<'_>) -> Option<String> {
347        for field_name in &self.config.name_fields {
348            if let Some(name_node) = node.child_by_field_name(field_name.as_bytes()) {
349                if let Ok(text) = name_node.utf8_text(self.source) {
350                    return Some(text.to_owned());
351                }
352            }
353        }
354        None
355    }
356
357    /// Emit formatting constraints for a node (indentation, position).
358    fn emit_formatting_constraints(
359        &self,
360        node: tree_sitter::Node<'_>,
361        vertex_id: &str,
362        mut builder: SchemaBuilder,
363    ) -> SchemaBuilder {
364        let start = node.start_position();
365
366        // Capture indentation (column of first character on the line).
367        if start.column > 0 {
368            // Extract the actual indentation characters from the source.
369            let line_start = node.start_byte().saturating_sub(start.column);
370            if line_start < self.source.len() {
371                let indent_end = line_start + start.column.min(self.source.len() - line_start);
372                if let Ok(indent) = std::str::from_utf8(&self.source[line_start..indent_end]) {
373                    // Only capture if the extracted region is pure whitespace.
374                    if !indent.is_empty() && indent.trim().is_empty() {
375                        builder = builder.constraint(vertex_id, "indent", indent);
376                    }
377                }
378            }
379        }
380
381        // Count blank lines before this node by looking at source between
382        // previous sibling's end and this node's start.
383        if let Some(prev) = node.prev_named_sibling() {
384            let gap_start = prev.end_byte();
385            let gap_end = node.start_byte();
386            if gap_start < gap_end && gap_end <= self.source.len() {
387                let gap = &self.source[gap_start..gap_end];
388                let blank_lines = memchr::memchr_iter(b'\n', gap).count().saturating_sub(1);
389                if blank_lines > 0 {
390                    builder = builder.constraint(
391                        vertex_id,
392                        "blank-lines-before",
393                        &blank_lines.to_string(),
394                    );
395                }
396            }
397        }
398
399        builder
400    }
401}
402
403#[cfg(test)]
404#[allow(clippy::unwrap_used)]
405mod tests {
406    use super::*;
407
408    fn make_test_protocol() -> Protocol {
409        Protocol {
410            name: "test".into(),
411            schema_theory: "ThTest".into(),
412            instance_theory: "ThTestInst".into(),
413            obj_kinds: vec![], // Empty = open protocol, accepts all kinds.
414            edge_rules: vec![],
415            constraint_sorts: vec![],
416            has_order: true,
417            has_coproducts: false,
418            has_recursion: false,
419            has_causal: false,
420            nominal_identity: false,
421            has_defaults: false,
422            has_coercions: false,
423            has_mergers: false,
424            has_policies: false,
425        }
426    }
427
428    fn make_test_meta() -> ExtractedTheoryMeta {
429        use panproto_gat::{Sort, Theory};
430        ExtractedTheoryMeta {
431            theory: Theory::new("ThTest", vec![Sort::simple("Vertex")], vec![], vec![]),
432            supertypes: FxHashSet::default(),
433            subtype_map: Vec::new(),
434            optional_fields: FxHashSet::default(),
435            ordered_fields: FxHashSet::default(),
436            vertex_kinds: Vec::new(),
437            edge_kinds: Vec::new(),
438        }
439    }
440
441    #[test]
442    fn walk_simple_typescript() {
443        let source = b"function greet(name: string): string { return name; }";
444
445        let mut parser = tree_sitter::Parser::new();
446        parser
447            .set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
448            .unwrap();
449        let tree = parser.parse(source, None).unwrap();
450
451        let protocol = make_test_protocol();
452        let meta = make_test_meta();
453        let walker = AstWalker::new(source, &meta, &protocol, WalkerConfig::default());
454
455        let schema = walker.walk(&tree, "test.ts").unwrap();
456
457        // Should have produced some vertices.
458        assert!(
459            schema.vertices.len() > 1,
460            "expected multiple vertices, got {}",
461            schema.vertices.len()
462        );
463
464        // The root should be the file.
465        let root_name: panproto_gat::Name = "test.ts".into();
466        assert!(
467            schema.vertices.contains_key(&root_name),
468            "missing root vertex"
469        );
470    }
471
472    #[test]
473    fn walk_simple_python() {
474        let source = b"def add(a, b):\n    return a + b\n";
475
476        let mut parser = tree_sitter::Parser::new();
477        parser
478            .set_language(&tree_sitter_python::LANGUAGE.into())
479            .unwrap();
480        let tree = parser.parse(source, None).unwrap();
481
482        let protocol = make_test_protocol();
483        let meta = make_test_meta();
484        let walker = AstWalker::new(source, &meta, &protocol, WalkerConfig::default());
485
486        let schema = walker.walk(&tree, "test.py").unwrap();
487
488        assert!(
489            schema.vertices.len() > 1,
490            "expected multiple vertices, got {}",
491            schema.vertices.len()
492        );
493    }
494
495    #[test]
496    fn walk_simple_rust() {
497        let source = b"fn main() { let x = 42; println!(\"{}\", x); }";
498
499        let mut parser = tree_sitter::Parser::new();
500        parser
501            .set_language(&tree_sitter_rust::LANGUAGE.into())
502            .unwrap();
503        let tree = parser.parse(source, None).unwrap();
504
505        let protocol = make_test_protocol();
506        let meta = make_test_meta();
507        let walker = AstWalker::new(source, &meta, &protocol, WalkerConfig::default());
508
509        let schema = walker.walk(&tree, "test.rs").unwrap();
510
511        assert!(
512            schema.vertices.len() > 1,
513            "expected multiple vertices, got {}",
514            schema.vertices.len()
515        );
516    }
517}