Skip to main content

panproto_parse/
theory_extract.rs

1//! Automated theory extraction from tree-sitter grammar metadata.
2//!
3//! Tree-sitter grammars are theory presentations: each grammar's `node-types.json`
4//! file is structurally isomorphic to a GAT. This module extracts panproto theories
5//! directly from grammar metadata, ensuring the theory is always in sync with the
6//! parser.
7//!
8//! ## Mapping
9//!
10//! | `node-types.json` | panproto GAT |
11//! |---|---|
12//! | Named node type | Sort (vertex kind) |
13//! | Field (`required: true`) | Operation (mandatory edge kind) |
14//! | Field (`required: false`) | Operation (optional edge kind) |
15//! | Field (`multiple: true`) | Operation (ordered edge kind) |
16//! | Supertype with subtypes | Abstract sort with subtype inclusions |
17
18use panproto_gat::{Operation, Sort, Theory};
19use rustc_hash::FxHashSet;
20
21use crate::error::ParseError;
22
23// ─── node-types.json schema ───────────────────────────────────────────────
24
25/// A node type entry from tree-sitter's `node-types.json`.
26#[derive(Debug, Clone, serde::Deserialize)]
27pub struct NodeType {
28    /// The node type name (e.g. `"function_declaration"`).
29    #[serde(rename = "type")]
30    pub node_type: String,
31    /// Whether this is a named grammar rule (true) or anonymous token (false).
32    pub named: bool,
33    /// Named fields and their specifications.
34    #[serde(default)]
35    pub fields: serde_json::Map<String, serde_json::Value>,
36    /// Unnamed children specification.
37    #[serde(default)]
38    pub children: Option<ChildSpec>,
39    /// For supertype nodes, the concrete subtypes.
40    #[serde(default)]
41    pub subtypes: Option<Vec<SubtypeRef>>,
42}
43
44/// Specification for unnamed children of a node type.
45#[derive(Debug, Clone, serde::Deserialize)]
46pub struct ChildSpec {
47    /// Whether multiple children are allowed.
48    pub multiple: bool,
49    /// Whether at least one child is required.
50    pub required: bool,
51    /// The allowed child node types.
52    pub types: Vec<SubtypeRef>,
53}
54
55/// A reference to a node type (used in field types and subtype arrays).
56#[derive(Debug, Clone, serde::Deserialize)]
57pub struct SubtypeRef {
58    /// The node type name.
59    #[serde(rename = "type")]
60    pub node_type: String,
61    /// Whether this is a named node type.
62    pub named: bool,
63}
64
65/// A parsed field specification from `node-types.json`.
66#[derive(Debug, Clone)]
67pub struct FieldSpec {
68    /// The field name (e.g. `"body"`, `"condition"`, `"parameters"`).
69    pub name: String,
70    /// Whether this field is required.
71    pub required: bool,
72    /// Whether this field can contain multiple children.
73    pub multiple: bool,
74    /// The allowed child node types.
75    pub types: Vec<SubtypeRef>,
76}
77
78/// Metadata about an extracted theory, capturing information beyond the GAT itself.
79#[derive(Debug, Clone)]
80pub struct ExtractedTheoryMeta {
81    /// The GAT extracted from the grammar.
82    pub theory: Theory,
83    /// Named node types that are supertypes (abstract sorts).
84    pub supertypes: FxHashSet<String>,
85    /// Mapping from supertype name to its concrete subtypes.
86    pub subtype_map: Vec<(String, Vec<String>)>,
87    /// Fields that are optional (for `ThPartial` composition).
88    pub optional_fields: FxHashSet<String>,
89    /// Fields that are ordered (for `ThOrder` composition).
90    pub ordered_fields: FxHashSet<String>,
91    /// All named node types (vertex kinds for the protocol).
92    pub vertex_kinds: Vec<String>,
93    /// All field names (edge kinds for the protocol).
94    pub edge_kinds: Vec<String>,
95}
96
97// ─── extraction from node-types.json ──────────────────────────────────────
98
99/// Parse tree-sitter's `node-types.json` bytes into a vector of [`NodeType`] entries.
100///
101/// # Errors
102///
103/// Returns [`ParseError::NodeTypesJson`] if JSON deserialization fails.
104pub fn parse_node_types(json: &[u8]) -> Result<Vec<NodeType>, ParseError> {
105    serde_json::from_slice(json).map_err(|e| ParseError::NodeTypesJson { source: e })
106}
107
108/// Extract a panproto [`Theory`] from tree-sitter's `node-types.json` content.
109///
110/// The returned [`ExtractedTheoryMeta`] includes the GAT plus metadata about
111/// supertypes, optional fields, and ordered fields needed for protocol definition
112/// and colimit composition.
113///
114/// # Errors
115///
116/// Returns [`ParseError`] if JSON parsing fails or the grammar has structural
117/// issues preventing theory extraction.
118pub fn extract_theory_from_node_types(
119    theory_name: &str,
120    json: &[u8],
121) -> Result<ExtractedTheoryMeta, ParseError> {
122    let node_types = parse_node_types(json)?;
123    extract_theory_from_entries(theory_name, &node_types)
124}
125
126/// Extract a theory from already-parsed [`NodeType`] entries.
127///
128/// # Errors
129///
130/// Returns [`ParseError::TheoryExtraction`] if the grammar structure is invalid.
131pub fn extract_theory_from_entries(
132    theory_name: &str,
133    node_types: &[NodeType],
134) -> Result<ExtractedTheoryMeta, ParseError> {
135    let mut sorts: Vec<Sort> = Vec::new();
136    let mut ops: Vec<Operation> = Vec::new();
137    let mut supertypes = FxHashSet::default();
138    let mut subtype_map: Vec<(String, Vec<String>)> = Vec::new();
139    let mut optional_fields = FxHashSet::default();
140    let mut ordered_fields = FxHashSet::default();
141    let mut vertex_kinds: Vec<String> = Vec::new();
142    let mut edge_kind_set = FxHashSet::default();
143    let mut seen_sorts = FxHashSet::default();
144
145    // Always include Vertex and Edge as base sorts (shared with ThGraph via colimit).
146    sorts.push(Sort::simple("Vertex"));
147    sorts.push(Sort::simple("Edge"));
148    seen_sorts.insert("Vertex".to_owned());
149    seen_sorts.insert("Edge".to_owned());
150
151    for entry in node_types {
152        // Skip anonymous tokens (punctuation, keywords).
153        if !entry.named {
154            continue;
155        }
156
157        let sort_name = &entry.node_type;
158
159        // Supertype nodes define abstract sorts with subtype inclusions.
160        if let Some(ref subtypes) = entry.subtypes {
161            supertypes.insert(sort_name.clone());
162            let concrete: Vec<String> = subtypes
163                .iter()
164                .filter(|s| s.named)
165                .map(|s| s.node_type.clone())
166                .collect();
167            subtype_map.push((sort_name.clone(), concrete));
168
169            // Register the supertype as a sort if not already present.
170            if seen_sorts.insert(sort_name.clone()) {
171                sorts.push(Sort::simple(sort_name.as_str()));
172                vertex_kinds.push(sort_name.clone());
173            }
174            continue;
175        }
176
177        // Regular named node type: create a sort and operations for its fields.
178        if seen_sorts.insert(sort_name.clone()) {
179            sorts.push(Sort::simple(sort_name.as_str()));
180            vertex_kinds.push(sort_name.clone());
181        }
182
183        // Process fields: each field becomes an operation (edge kind).
184        for (field_name, field_value) in &entry.fields {
185            let spec = parse_field_spec(field_name, field_value)?;
186
187            // Track optional and ordered fields for later composition.
188            if !spec.required {
189                optional_fields.insert(field_name.clone());
190            }
191            if spec.multiple {
192                ordered_fields.insert(field_name.clone());
193            }
194
195            // Create an operation for this field if not already registered.
196            // The operation represents an edge kind: parent_sort --field_name--> child_sort.
197            // Since tree-sitter fields can accept multiple child types, we model
198            // the operation as mapping from the parent sort to Vertex (the abstract base).
199            if edge_kind_set.insert(field_name.clone()) {
200                ops.push(Operation::unary(
201                    field_name.as_str(),
202                    "parent",
203                    "Vertex",
204                    "Vertex",
205                ));
206            }
207        }
208
209        // Process unnamed children (if present).
210        if let Some(ref children) = entry.children {
211            if children.multiple {
212                ordered_fields.insert("children".to_owned());
213            }
214            // Unnamed children use a generic "child_of" edge.
215            if edge_kind_set.insert("child_of".to_owned()) {
216                ops.push(Operation::unary("child_of", "parent", "Vertex", "Vertex"));
217            }
218        }
219    }
220
221    let edge_kinds: Vec<String> = edge_kind_set.into_iter().collect();
222
223    let theory = Theory::new(theory_name, sorts, ops, vec![]);
224
225    Ok(ExtractedTheoryMeta {
226        theory,
227        supertypes,
228        subtype_map,
229        optional_fields,
230        ordered_fields,
231        vertex_kinds,
232        edge_kinds,
233    })
234}
235
236/// Extract a theory at runtime from a tree-sitter `Language` object.
237///
238/// This uses the Language introspection API (`node_kind_count`, `field_count`,
239/// `node_kind_is_named`, `supertypes()`) rather than parsing `node-types.json`.
240///
241/// Supertype information is available via the runtime API. However, field
242/// optionality (`required`) and multiplicity (`multiple`) are NOT exposed
243/// by the Language runtime API. For full metadata, use
244/// [`extract_theory_from_node_types`] with the `NODE_TYPES` constant
245/// embedded in each grammar crate.
246///
247/// # Errors
248///
249/// Returns [`ParseError::TheoryExtraction`] if introspection fails.
250pub fn extract_theory_from_language(
251    theory_name: &str,
252    language: &tree_sitter::Language,
253) -> Result<ExtractedTheoryMeta, ParseError> {
254    let mut sorts: Vec<Sort> = Vec::new();
255    let mut ops: Vec<Operation> = Vec::new();
256    let mut vertex_kinds: Vec<String> = Vec::new();
257    let mut edge_kind_set = FxHashSet::default();
258    let mut seen_sorts = FxHashSet::default();
259    // Base sorts.
260    sorts.push(Sort::simple("Vertex"));
261    sorts.push(Sort::simple("Edge"));
262    seen_sorts.insert("Vertex".to_owned());
263    seen_sorts.insert("Edge".to_owned());
264
265    // Enumerate all named node types as sorts.
266    let node_count = language.node_kind_count();
267    for id in 0..node_count {
268        let Ok(id_u16) = u16::try_from(id) else {
269            continue;
270        };
271        if language.node_kind_is_named(id_u16) {
272            if let Some(name) = language.node_kind_for_id(id_u16) {
273                // Skip internal hidden nodes (prefixed with _).
274                if name.starts_with('_') {
275                    continue;
276                }
277
278                if seen_sorts.insert(name.to_owned()) {
279                    sorts.push(Sort::simple(name));
280                    vertex_kinds.push(name.to_owned());
281                }
282            }
283        }
284    }
285
286    // Enumerate all field names as operations (edge kinds).
287    let field_count = language.field_count();
288    for id in 1..=field_count {
289        let Ok(id_u16) = u16::try_from(id) else {
290            continue;
291        };
292        if let Some(name) = language.field_name_for_id(id_u16) {
293            if edge_kind_set.insert(name.to_owned()) {
294                ops.push(Operation::unary(name, "parent", "Vertex", "Vertex"));
295            }
296        }
297    }
298
299    let edge_kinds: Vec<String> = edge_kind_set.into_iter().collect();
300
301    let theory = Theory::new(theory_name, sorts, ops, vec![]);
302
303    // Note: optional_fields, ordered_fields, supertypes, and subtype_map
304    // cannot be fully determined from the tree-sitter 0.24 Language runtime API.
305    // For full metadata, use extract_theory_from_node_types() with the NODE_TYPES
306    // constant from the grammar crate.
307    Ok(ExtractedTheoryMeta {
308        theory,
309        supertypes: FxHashSet::default(),
310        subtype_map: Vec::new(),
311        optional_fields: FxHashSet::default(),
312        ordered_fields: FxHashSet::default(),
313        vertex_kinds,
314        edge_kinds,
315    })
316}
317
318// ─── helpers ──────────────────────────────────────────────────────────────
319
320/// Parse a field specification from the JSON value in node-types.json.
321fn parse_field_spec(name: &str, value: &serde_json::Value) -> Result<FieldSpec, ParseError> {
322    let obj = value
323        .as_object()
324        .ok_or_else(|| ParseError::TheoryExtraction {
325            reason: format!("field '{name}' is not an object"),
326        })?;
327
328    let required = obj
329        .get("required")
330        .and_then(serde_json::Value::as_bool)
331        .unwrap_or(false);
332
333    let multiple = obj
334        .get("multiple")
335        .and_then(serde_json::Value::as_bool)
336        .unwrap_or(false);
337
338    let types: Vec<SubtypeRef> = obj
339        .get("types")
340        .and_then(|v| serde_json::from_value(v.clone()).ok())
341        .unwrap_or_default();
342
343    Ok(FieldSpec {
344        name: name.to_owned(),
345        required,
346        multiple,
347        types,
348    })
349}
350
351#[cfg(test)]
352#[allow(clippy::unwrap_used)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn extract_minimal_grammar() {
358        let json = br#"[
359            {
360                "type": "program",
361                "named": true,
362                "fields": {},
363                "children": {
364                    "multiple": true,
365                    "required": false,
366                    "types": [{"type": "statement", "named": true}]
367                }
368            },
369            {
370                "type": "statement",
371                "named": true,
372                "fields": {
373                    "body": {
374                        "multiple": false,
375                        "required": true,
376                        "types": [{"type": "expression", "named": true}]
377                    }
378                }
379            },
380            {
381                "type": "expression",
382                "named": true,
383                "fields": {}
384            },
385            {
386                "type": ";",
387                "named": false
388            }
389        ]"#;
390
391        let meta = extract_theory_from_node_types("ThTest", json).unwrap();
392
393        // Should have Vertex, Edge (base) + program, statement, expression = 5 sorts.
394        assert_eq!(meta.theory.sorts.len(), 5);
395
396        // Should have body + child_of = 2 operations.
397        assert_eq!(meta.theory.ops.len(), 2);
398
399        // "program", "statement", "expression" as vertex kinds.
400        assert_eq!(meta.vertex_kinds.len(), 3);
401        assert!(meta.vertex_kinds.contains(&"program".to_owned()));
402        assert!(meta.vertex_kinds.contains(&"statement".to_owned()));
403        assert!(meta.vertex_kinds.contains(&"expression".to_owned()));
404
405        // "body" and "child_of" as edge kinds.
406        assert_eq!(meta.edge_kinds.len(), 2);
407
408        // "children" on program is ordered (multiple=true).
409        assert!(meta.ordered_fields.contains("children"));
410    }
411
412    #[test]
413    fn extract_supertype() {
414        let json = br#"[
415            {
416                "type": "_expression",
417                "named": true,
418                "subtypes": [
419                    {"type": "binary_expression", "named": true},
420                    {"type": "call_expression", "named": true}
421                ]
422            },
423            {
424                "type": "binary_expression",
425                "named": true,
426                "fields": {
427                    "left": {
428                        "multiple": false,
429                        "required": true,
430                        "types": [{"type": "_expression", "named": true}]
431                    },
432                    "right": {
433                        "multiple": false,
434                        "required": true,
435                        "types": [{"type": "_expression", "named": true}]
436                    }
437                }
438            },
439            {
440                "type": "call_expression",
441                "named": true,
442                "fields": {
443                    "function": {
444                        "multiple": false,
445                        "required": true,
446                        "types": [{"type": "_expression", "named": true}]
447                    },
448                    "arguments": {
449                        "multiple": true,
450                        "required": true,
451                        "types": [{"type": "_expression", "named": true}]
452                    }
453                }
454            }
455        ]"#;
456
457        let meta = extract_theory_from_node_types("ThExprTest", json).unwrap();
458
459        // _expression is a supertype.
460        assert!(meta.supertypes.contains("_expression"));
461
462        // subtype_map: _expression → [binary_expression, call_expression]
463        assert_eq!(meta.subtype_map.len(), 1);
464        let (st, subs) = &meta.subtype_map[0];
465        assert_eq!(st, "_expression");
466        assert_eq!(subs.len(), 2);
467
468        // "arguments" is ordered.
469        assert!(meta.ordered_fields.contains("arguments"));
470
471        // Operations: left, right, function, arguments = 4 edge kinds.
472        assert_eq!(meta.edge_kinds.len(), 4);
473    }
474
475    #[test]
476    fn anonymous_tokens_skipped() {
477        let json = br#"[
478            {"type": "identifier", "named": true, "fields": {}},
479            {"type": "(", "named": false},
480            {"type": ")", "named": false}
481        ]"#;
482
483        let meta = extract_theory_from_node_types("ThAnon", json).unwrap();
484
485        // Only "identifier" + base sorts (Vertex, Edge) = 3.
486        assert_eq!(meta.theory.sorts.len(), 3);
487        assert_eq!(meta.vertex_kinds.len(), 1);
488    }
489}