Skip to main content

panproto_parse/
walker.rs

1//! Generic tree-sitter AST walker that converts parse trees to panproto schemas.
2//!
3//! Because theories are auto-derived from the grammar, the walker is fully generic:
4//! one implementation works for all languages. The node's `kind()` IS the panproto
5//! vertex kind; the field name IS the edge kind.
6//!
7//! Named-scope detection (functions, classes, methods, modules, types) is driven
8//! by the grammar's `queries/tags.scm` file via [`ScopeDetector`], not by a
9//! hardcoded node-kind list. This makes scope detection uniformly correct across
10//! every tree-sitter grammar that ships a tags query. See the [`scope_detector`]
11//! module for the full rationale.
12//!
13//! [`scope_detector`]: crate::scope_detector
14//! [`ScopeDetector`]: crate::scope_detector::ScopeDetector
15
16use std::collections::BTreeMap;
17
18use panproto_schema::{Protocol, Schema, SchemaBuilder};
19use rustc_hash::FxHashSet;
20
21use crate::error::ParseError;
22use crate::id_scheme::IdGenerator;
23use crate::scope_detector::{NamedScope, ScopeDetector};
24use crate::theory_extract::ExtractedTheoryMeta;
25
26/// Nodes whose kind names suggest they contain ordered statement sequences.
27///
28/// Unlike scope detection (which is grammar-driven via `tags.scm`), block
29/// grouping is a structural concern: we want sibling statements inside a
30/// block to get positional IDs (`$0`, `$1`, ...) so insertions don't
31/// cascade. Per-language [`WalkerConfig`] overrides extend this list.
32const BLOCK_KINDS: &[&str] = &[
33    "block",
34    "statement_block",
35    "compound_statement",
36    "declaration_list",
37    "field_declaration_list",
38    "enum_body",
39    "class_body",
40    "interface_body",
41    "module_body",
42];
43
44/// Configuration for the walker, allowing per-language customization.
45#[derive(Debug, Clone, Default)]
46pub struct WalkerConfig {
47    /// Additional node kinds that contain ordered statement sequences.
48    ///
49    /// Named-scope detection is handled by [`ScopeDetector`] from the
50    /// grammar's `tags.scm`; no per-language scope configuration is
51    /// required here.
52    ///
53    /// [`ScopeDetector`]: crate::scope_detector::ScopeDetector
54    pub extra_block_kinds: Vec<String>,
55    /// Whether to capture comment nodes as constraints on the following sibling.
56    pub capture_comments: bool,
57    /// Whether to capture whitespace/formatting as constraints.
58    pub capture_formatting: bool,
59}
60
61impl WalkerConfig {
62    /// Construct a config with formatting and comment capture enabled (the
63    /// common default; [`WalkerConfig::default`] returns all-false).
64    #[must_use]
65    pub const fn standard() -> Self {
66        Self {
67            extra_block_kinds: Vec::new(),
68            capture_comments: true,
69            capture_formatting: true,
70        }
71    }
72}
73
74/// Generic AST walker that converts a tree-sitter parse tree to a panproto [`Schema`].
75///
76/// The walker uses the auto-derived theory to determine vertex and edge kinds directly
77/// from the tree-sitter AST, requiring no manual mapping table. Named-scope identity
78/// (the part of the vertex ID that survives insertions) is driven by [`ScopeDetector`]
79/// from the grammar's `tags.scm` query.
80///
81/// [`ScopeDetector`]: crate::scope_detector::ScopeDetector
82pub struct AstWalker<'a> {
83    /// The source code bytes (needed for extracting text of leaf nodes).
84    source: &'a [u8],
85    /// The auto-derived theory metadata. The `vertex_kinds` set is used to
86    /// filter anonymous/internal tree-sitter nodes that are not part of the
87    /// language's public grammar.
88    theory_meta: &'a ExtractedTheoryMeta,
89    /// The protocol definition (for `SchemaBuilder` validation).
90    protocol: &'a Protocol,
91    /// Per-language configuration.
92    config: WalkerConfig,
93    /// Known block kinds (merged from defaults + config).
94    block_kinds: FxHashSet<String>,
95    /// Named scopes indexed by `(start_byte, end_byte)` for O(log n) lookup
96    /// during the tree walk. Derived from a [`ScopeDetector`] run over the
97    /// full source before the walk begins.
98    scope_map: BTreeMap<(usize, usize), NamedScope>,
99}
100
101impl<'a> AstWalker<'a> {
102    /// Create a new walker for the given source, theory, and protocol.
103    ///
104    /// Runs an optional [`ScopeDetector`] over the source to build a
105    /// per-file scope map. Pass `None` to disable named-scope detection
106    /// (every non-root vertex gets a positional ID). Pass `Some(detector)`
107    /// whose [`has_query`] is `false` for the same effect; the detector
108    /// short-circuits to an empty scope list.
109    ///
110    /// [`ScopeDetector`]: crate::scope_detector::ScopeDetector
111    /// [`has_query`]: crate::scope_detector::ScopeDetector::has_query
112    #[must_use]
113    pub fn new(
114        source: &'a [u8],
115        theory_meta: &'a ExtractedTheoryMeta,
116        protocol: &'a Protocol,
117        config: WalkerConfig,
118        scope_detector: Option<&mut ScopeDetector>,
119    ) -> Self {
120        let mut block_kinds: FxHashSet<String> =
121            BLOCK_KINDS.iter().map(|s| (*s).to_owned()).collect();
122        for kind in &config.extra_block_kinds {
123            block_kinds.insert(kind.clone());
124        }
125
126        let mut scope_map: BTreeMap<(usize, usize), NamedScope> = BTreeMap::new();
127        if let Some(det) = scope_detector {
128            for scope in det.scopes(source) {
129                scope_map.insert((scope.node_range.start, scope.node_range.end), scope);
130            }
131        }
132
133        Self {
134            source,
135            theory_meta,
136            protocol,
137            config,
138            block_kinds,
139            scope_map,
140        }
141    }
142
143    /// Walk the entire parse tree and produce a [`Schema`].
144    ///
145    /// # Errors
146    ///
147    /// Returns [`ParseError::SchemaConstruction`] if schema building fails.
148    pub fn walk(&self, tree: &tree_sitter::Tree, file_path: &str) -> Result<Schema, ParseError> {
149        let mut id_gen = IdGenerator::new(file_path);
150        let builder = SchemaBuilder::new(self.protocol);
151        let root = tree.root_node();
152
153        let builder = self.walk_node(root, builder, &mut id_gen, None)?;
154
155        builder.build().map_err(|e| ParseError::SchemaConstruction {
156            reason: e.to_string(),
157        })
158    }
159
160    /// Look up a node's named-scope entry, if any.
161    fn scope_for(&self, node: tree_sitter::Node<'_>) -> Option<&NamedScope> {
162        self.scope_map.get(&(node.start_byte(), node.end_byte()))
163    }
164
165    /// Recursively walk a single node, emitting vertices and edges.
166    fn walk_node(
167        &self,
168        node: tree_sitter::Node<'_>,
169        mut builder: SchemaBuilder,
170        id_gen: &mut IdGenerator,
171        parent_vertex_id: Option<&str>,
172    ) -> Result<SchemaBuilder, ParseError> {
173        // Skip anonymous tokens (punctuation, keywords like `{`, `}`, `,`, etc.).
174        if !node.is_named() {
175            return Ok(builder);
176        }
177
178        let kind = node.kind();
179
180        // Skip the root "program"/"source_file"/"module" wrapper if it just wraps children.
181        // We still process it to emit its children, but do so by iterating directly.
182        let is_root_wrapper = parent_vertex_id.is_none()
183            && (kind == "program"
184                || kind == "source_file"
185                || kind == "module"
186                || kind == "translation_unit");
187
188        let named_scope = if is_root_wrapper {
189            None
190        } else {
191            self.scope_for(node)
192        };
193
194        // Determine vertex ID.
195        let vertex_id = if is_root_wrapper {
196            // Root wrappers get the file path as their ID.
197            id_gen.current_prefix()
198        } else if let Some(scope) = named_scope {
199            id_gen.named_id(&scope.name)
200        } else {
201            // All other nodes get positional IDs.
202            id_gen.anonymous_id()
203        };
204
205        // Determine the effective vertex kind. If the theory has extracted vertex kinds,
206        // use those for validation. If the kind is unknown to the theory AND the protocol
207        // has a closed obj_kinds list, fall back to "node".
208        let effective_kind = if self.protocol.obj_kinds.is_empty() {
209            // Open protocol: accept all kinds.
210            kind
211        } else if self.protocol.obj_kinds.iter().any(|k| k == kind) {
212            kind
213        } else if !self.theory_meta.vertex_kinds.is_empty()
214            && self.theory_meta.vertex_kinds.iter().any(|k| k == kind)
215        {
216            // Known in the auto-derived theory even if not in the protocol's obj_kinds.
217            kind
218        } else {
219            "node"
220        };
221
222        builder = builder
223            .vertex(&vertex_id, effective_kind, None)
224            .map_err(|e| ParseError::SchemaConstruction {
225                reason: format!("vertex '{vertex_id}' ({kind}): {e}"),
226            })?;
227
228        // Emit edge from parent to this node.
229        if let Some(parent_id) = parent_vertex_id {
230            // Determine edge kind: use the tree-sitter field name if this node
231            // was accessed via a field, otherwise use "child_of".
232            let edge_kind = node
233                .parent()
234                .and_then(|p| {
235                    // Find which field of the parent this node corresponds to.
236                    for i in 0..p.child_count() {
237                        if let Some(child) = p.child(i) {
238                            if child.id() == node.id() {
239                                return u32::try_from(i)
240                                    .ok()
241                                    .and_then(|idx| p.field_name_for_child(idx));
242                            }
243                        }
244                    }
245                    None
246                })
247                .unwrap_or("child_of");
248
249            builder = builder
250                .edge(parent_id, &vertex_id, edge_kind, None)
251                .map_err(|e| ParseError::SchemaConstruction {
252                    reason: format!("edge {parent_id} -> {vertex_id} ({edge_kind}): {e}"),
253                })?;
254        }
255
256        // Store byte range for position-aware emission.
257        builder = builder.constraint(&vertex_id, "start-byte", &node.start_byte().to_string());
258        builder = builder.constraint(&vertex_id, "end-byte", &node.end_byte().to_string());
259
260        // Emit constraints for leaf nodes (literals, identifiers, operators).
261        if node.named_child_count() == 0 {
262            if let Ok(text) = node.utf8_text(self.source) {
263                builder = builder.constraint(&vertex_id, "literal-value", text);
264            }
265        }
266
267        // Emit formatting constraints if enabled.
268        if self.config.capture_formatting {
269            builder = self.emit_formatting_constraints(node, &vertex_id, builder);
270        }
271
272        // Enter scope if this is a scope-introducing node.
273        let entered_scope = if let Some(scope) = named_scope {
274            id_gen.push_named_scope(&scope.name);
275            true
276        } else if !is_root_wrapper && self.block_kinds.contains(kind) {
277            id_gen.push_anonymous_scope();
278            true
279        } else {
280            false
281        };
282
283        builder = self.walk_children_with_interstitials(node, builder, id_gen, &vertex_id)?;
284
285        if entered_scope {
286            id_gen.pop_scope();
287        }
288
289        Ok(builder)
290    }
291
292    /// Walk named children, capturing interstitial text between them.
293    fn walk_children_with_interstitials(
294        &self,
295        node: tree_sitter::Node<'_>,
296        mut builder: SchemaBuilder,
297        id_gen: &mut IdGenerator,
298        vertex_id: &str,
299    ) -> Result<SchemaBuilder, ParseError> {
300        let cursor = &mut node.walk();
301        let children: Vec<_> = node.named_children(cursor).collect();
302        let mut interstitial_idx = 0;
303        let mut prev_end = node.start_byte();
304
305        for child in &children {
306            let gap_start = prev_end;
307            let gap_end = child.start_byte();
308            builder = self.capture_interstitial(
309                builder,
310                vertex_id,
311                gap_start,
312                gap_end,
313                &mut interstitial_idx,
314            );
315            builder = self.walk_node(*child, builder, id_gen, Some(vertex_id))?;
316            prev_end = child.end_byte();
317        }
318
319        // Trailing interstitial after the last child.
320        builder = self.capture_interstitial(
321            builder,
322            vertex_id,
323            prev_end,
324            node.end_byte(),
325            &mut interstitial_idx,
326        );
327
328        Ok(builder)
329    }
330
331    /// Capture interstitial text between `gap_start` and `gap_end` as a constraint.
332    fn capture_interstitial(
333        &self,
334        mut builder: SchemaBuilder,
335        vertex_id: &str,
336        gap_start: usize,
337        gap_end: usize,
338        idx: &mut usize,
339    ) -> SchemaBuilder {
340        if gap_end > gap_start && gap_end <= self.source.len() {
341            if let Ok(gap_text) = std::str::from_utf8(&self.source[gap_start..gap_end]) {
342                if !gap_text.is_empty() {
343                    let sort = format!("interstitial-{}", *idx);
344                    builder = builder.constraint(vertex_id, &sort, gap_text);
345                    builder = builder.constraint(
346                        vertex_id,
347                        &format!("{sort}-start-byte"),
348                        &gap_start.to_string(),
349                    );
350                    *idx += 1;
351                }
352            }
353        }
354        builder
355    }
356
357    /// Emit formatting constraints for a node (indentation, position).
358    fn emit_formatting_constraints(
359        &self,
360        node: tree_sitter::Node<'_>,
361        vertex_id: &str,
362        mut builder: SchemaBuilder,
363    ) -> SchemaBuilder {
364        let start = node.start_position();
365
366        // Capture indentation (column of first character on the line).
367        if start.column > 0 {
368            // Extract the actual indentation characters from the source.
369            let line_start = node.start_byte().saturating_sub(start.column);
370            if line_start < self.source.len() {
371                let indent_end = line_start + start.column.min(self.source.len() - line_start);
372                if let Ok(indent) = std::str::from_utf8(&self.source[line_start..indent_end]) {
373                    // Only capture if the extracted region is pure whitespace.
374                    if !indent.is_empty() && indent.trim().is_empty() {
375                        builder = builder.constraint(vertex_id, "indent", indent);
376                    }
377                }
378            }
379        }
380
381        // Count blank lines before this node by looking at source between
382        // previous sibling's end and this node's start.
383        if let Some(prev) = node.prev_named_sibling() {
384            let gap_start = prev.end_byte();
385            let gap_end = node.start_byte();
386            if gap_start < gap_end && gap_end <= self.source.len() {
387                let gap = &self.source[gap_start..gap_end];
388                let blank_lines = memchr::memchr_iter(b'\n', gap).count().saturating_sub(1);
389                if blank_lines > 0 {
390                    builder = builder.constraint(
391                        vertex_id,
392                        "blank-lines-before",
393                        &blank_lines.to_string(),
394                    );
395                }
396            }
397        }
398
399        builder
400    }
401}
402
403#[cfg(test)]
404#[allow(clippy::unwrap_used)]
405mod tests {
406    use super::*;
407
408    fn make_test_protocol() -> Protocol {
409        Protocol {
410            name: "test".into(),
411            schema_theory: "ThTest".into(),
412            instance_theory: "ThTestInst".into(),
413            schema_composition: None,
414            instance_composition: None,
415            obj_kinds: vec![], // Empty = open protocol, accepts all kinds.
416            edge_rules: vec![],
417            constraint_sorts: vec![],
418            has_order: true,
419            has_coproducts: false,
420            has_recursion: false,
421            has_causal: false,
422            nominal_identity: false,
423            has_defaults: false,
424            has_coercions: false,
425            has_mergers: false,
426            has_policies: false,
427        }
428    }
429
430    fn make_test_meta() -> ExtractedTheoryMeta {
431        use panproto_gat::{Sort, Theory};
432        ExtractedTheoryMeta {
433            theory: Theory::new("ThTest", vec![Sort::simple("Vertex")], vec![], vec![]),
434            supertypes: FxHashSet::default(),
435            subtype_map: Vec::new(),
436            optional_fields: FxHashSet::default(),
437            ordered_fields: FxHashSet::default(),
438            vertex_kinds: Vec::new(),
439            edge_kinds: Vec::new(),
440        }
441    }
442
443    /// Helper to get a grammar by name from panproto-grammars.
444    #[cfg(feature = "grammars")]
445    fn get_grammar(name: &str) -> panproto_grammars::Grammar {
446        panproto_grammars::grammars()
447            .into_iter()
448            .find(|g| g.name == name)
449            .unwrap_or_else(|| panic!("grammar '{name}' not enabled in features"))
450    }
451
452    #[test]
453    #[cfg(feature = "grammars")]
454    fn walk_simple_typescript() {
455        let source = b"function greet(name: string): string { return name; }";
456        let grammar = get_grammar("typescript");
457
458        let mut parser = tree_sitter::Parser::new();
459        parser.set_language(&grammar.language).unwrap();
460        let tree = parser.parse(source, None).unwrap();
461
462        let protocol = make_test_protocol();
463        let meta = make_test_meta();
464        let mut detector =
465            crate::scope_detector::ScopeDetector::new(&grammar.language, grammar.tags_query, None)
466                .unwrap();
467        let walker = AstWalker::new(
468            source,
469            &meta,
470            &protocol,
471            WalkerConfig::standard(),
472            Some(&mut detector),
473        );
474
475        let schema = walker.walk(&tree, "test.ts").unwrap();
476
477        // Should have produced some vertices.
478        assert!(
479            schema.vertices.len() > 1,
480            "expected multiple vertices, got {}",
481            schema.vertices.len()
482        );
483
484        // The root should be the file.
485        let root_name: panproto_gat::Name = "test.ts".into();
486        assert!(
487            schema.vertices.contains_key(&root_name),
488            "missing root vertex"
489        );
490
491        // When tags.scm is present, the function name should appear in a vertex ID.
492        if detector.has_query() {
493            let has_greet = schema
494                .vertices
495                .keys()
496                .any(|n| n.to_string().ends_with("::greet"));
497            assert!(
498                has_greet,
499                "expected a vertex ID ending in ::greet, got: {:?}",
500                schema
501                    .vertices
502                    .keys()
503                    .map(ToString::to_string)
504                    .collect::<Vec<_>>()
505            );
506        }
507    }
508
509    #[test]
510    #[cfg(feature = "grammars")]
511    fn walk_simple_python() {
512        let source = b"def add(a, b):\n    return a + b\n";
513        let grammar = get_grammar("python");
514
515        let mut parser = tree_sitter::Parser::new();
516        parser.set_language(&grammar.language).unwrap();
517        let tree = parser.parse(source, None).unwrap();
518
519        let protocol = make_test_protocol();
520        let meta = make_test_meta();
521        let mut detector =
522            crate::scope_detector::ScopeDetector::new(&grammar.language, grammar.tags_query, None)
523                .unwrap();
524        let walker = AstWalker::new(
525            source,
526            &meta,
527            &protocol,
528            WalkerConfig::standard(),
529            Some(&mut detector),
530        );
531
532        let schema = walker.walk(&tree, "test.py").unwrap();
533
534        assert!(
535            schema.vertices.len() > 1,
536            "expected multiple vertices, got {}",
537            schema.vertices.len()
538        );
539
540        if detector.has_query() {
541            let has_add = schema
542                .vertices
543                .keys()
544                .any(|n| n.to_string().ends_with("::add"));
545            assert!(has_add, "expected ::add vertex");
546        }
547    }
548
549    #[test]
550    #[cfg(feature = "grammars")]
551    fn walk_simple_rust() {
552        let source = b"fn verify_push() {}\nstruct Foo;\nimpl Foo { fn bar(&self) {} }\n";
553        let grammar = get_grammar("rust");
554
555        let mut parser = tree_sitter::Parser::new();
556        parser.set_language(&grammar.language).unwrap();
557        let tree = parser.parse(source, None).unwrap();
558
559        let protocol = make_test_protocol();
560        let meta = make_test_meta();
561        let mut detector =
562            crate::scope_detector::ScopeDetector::new(&grammar.language, grammar.tags_query, None)
563                .unwrap();
564        let walker = AstWalker::new(
565            source,
566            &meta,
567            &protocol,
568            WalkerConfig::standard(),
569            Some(&mut detector),
570        );
571
572        let schema = walker.walk(&tree, "test.rs").unwrap();
573
574        assert!(
575            schema.vertices.len() > 1,
576            "expected multiple vertices, got {}",
577            schema.vertices.len()
578        );
579
580        if detector.has_query() {
581            let vertex_ids: Vec<String> = schema.vertices.keys().map(ToString::to_string).collect();
582
583            // Rust's function_item — the regression from issue #34 — must be
584            // detected as a named scope now.
585            assert!(
586                vertex_ids.iter().any(|id| id.ends_with("::verify_push")),
587                "expected ::verify_push named scope, got: {vertex_ids:?}"
588            );
589            assert!(
590                vertex_ids.iter().any(|id| id.ends_with("::Foo")),
591                "expected ::Foo named scope, got: {vertex_ids:?}"
592            );
593        }
594    }
595
596    /// Helper: parse source with a grammar, walk to Schema, emit back, compare.
597    #[cfg(feature = "group-data")]
598    fn assert_roundtrip(grammar_name: &str, source: &[u8], file_path: &str) {
599        use crate::registry::AstParser;
600        let grammar = panproto_grammars::grammars()
601            .into_iter()
602            .find(|g| g.name == grammar_name)
603            .unwrap_or_else(|| panic!("grammar '{grammar_name}' not enabled"));
604
605        let config = crate::languages::walker_configs::walker_config_for(grammar_name);
606        let lang_parser = crate::languages::common::LanguageParser::from_language(
607            grammar_name,
608            grammar.extensions.to_vec(),
609            grammar.language,
610            grammar.node_types,
611            grammar.tags_query,
612            config,
613        )
614        .unwrap();
615
616        let schema = lang_parser.parse(source, file_path).unwrap();
617        let emitted = lang_parser.emit(&schema).unwrap();
618
619        assert_eq!(
620            std::str::from_utf8(source).unwrap(),
621            std::str::from_utf8(&emitted).unwrap(),
622            "round-trip failed for {grammar_name}: emitted bytes differ from source"
623        );
624    }
625
626    #[test]
627    #[cfg(feature = "group-data")]
628    fn roundtrip_json_simple() {
629        assert_roundtrip("json", br#"{"name": "test", "value": 42}"#, "test.json");
630    }
631
632    #[test]
633    #[cfg(feature = "group-data")]
634    fn roundtrip_json_formatted() {
635        let source =
636            b"{\n  \"name\": \"test\",\n  \"value\": 42,\n  \"nested\": {\n    \"a\": true\n  }\n}";
637        assert_roundtrip("json", source, "test.json");
638    }
639
640    #[test]
641    #[cfg(feature = "group-data")]
642    fn roundtrip_json_array() {
643        let source = b"[\n  1,\n  2,\n  3\n]";
644        assert_roundtrip("json", source, "test.json");
645    }
646
647    #[test]
648    #[cfg(feature = "group-data")]
649    fn roundtrip_xml_simple() {
650        let source = b"<root>\n  <child attr=\"val\">text</child>\n</root>";
651        assert_roundtrip("xml", source, "test.xml");
652    }
653
654    #[test]
655    #[cfg(feature = "group-data")]
656    fn roundtrip_yaml_simple() {
657        let source = b"name: test\nvalue: 42\nnested:\n  a: true\n";
658        assert_roundtrip("yaml", source, "test.yaml");
659    }
660
661    #[test]
662    #[cfg(feature = "group-data")]
663    fn roundtrip_toml_simple() {
664        let source = b"[package]\nname = \"test\"\nversion = \"0.1.0\"\n";
665        assert_roundtrip("toml", source, "test.toml");
666    }
667}