Skip to main content

panproto_parse/
theory_extract.rs

1//! Automated theory extraction from tree-sitter grammar metadata.
2//!
3//! Tree-sitter grammars are theory presentations: each grammar's `node-types.json`
4//! file is structurally isomorphic to a GAT. This module extracts panproto theories
5//! directly from grammar metadata, ensuring the theory is always in sync with the
6//! parser.
7//!
8//! ## Mapping
9//!
10//! | `node-types.json` | panproto GAT |
11//! |---|---|
12//! | Named node type | Sort (vertex kind) |
13//! | Field (`required: true`) | Operation (mandatory edge kind) |
14//! | Field (`required: false`) | Operation (optional edge kind) |
15//! | Field (`multiple: true`) | Operation (ordered edge kind) |
16//! | Supertype with subtypes | Abstract sort with subtype inclusions |
17
18use panproto_gat::{Operation, Sort, Theory};
19use rustc_hash::FxHashSet;
20
21use crate::error::ParseError;
22
23// ─── node-types.json schema ───────────────────────────────────────────────
24
25/// A node type entry from tree-sitter's `node-types.json`.
26#[derive(Debug, Clone, serde::Deserialize)]
27pub struct NodeType {
28    /// The node type name (e.g. `"function_declaration"`).
29    #[serde(rename = "type")]
30    pub node_type: String,
31    /// Whether this is a named grammar rule (true) or anonymous token (false).
32    pub named: bool,
33    /// Named fields and their specifications.
34    #[serde(default)]
35    pub fields: serde_json::Map<String, serde_json::Value>,
36    /// Unnamed children specification.
37    #[serde(default)]
38    pub children: Option<ChildSpec>,
39    /// For supertype nodes, the concrete subtypes.
40    #[serde(default)]
41    pub subtypes: Option<Vec<SubtypeRef>>,
42}
43
44/// Specification for unnamed children of a node type.
45#[derive(Debug, Clone, serde::Deserialize)]
46pub struct ChildSpec {
47    /// Whether multiple children are allowed.
48    pub multiple: bool,
49    /// Whether at least one child is required.
50    pub required: bool,
51    /// The allowed child node types.
52    pub types: Vec<SubtypeRef>,
53}
54
55/// A reference to a node type (used in field types and subtype arrays).
56#[derive(Debug, Clone, serde::Deserialize)]
57pub struct SubtypeRef {
58    /// The node type name.
59    #[serde(rename = "type")]
60    pub node_type: String,
61    /// Whether this is a named node type.
62    pub named: bool,
63}
64
65/// A parsed field specification from `node-types.json`.
66#[derive(Debug, Clone)]
67pub struct FieldSpec {
68    /// The field name (e.g. `"body"`, `"condition"`, `"parameters"`).
69    pub name: String,
70    /// Whether this field is required.
71    pub required: bool,
72    /// Whether this field can contain multiple children.
73    pub multiple: bool,
74    /// The allowed child node types.
75    pub types: Vec<SubtypeRef>,
76}
77
78/// Metadata about an extracted theory, capturing information beyond the GAT itself.
79#[derive(Debug, Clone)]
80pub struct ExtractedTheoryMeta {
81    /// The GAT extracted from the grammar.
82    pub theory: Theory,
83    /// Named node types that are supertypes (abstract sorts).
84    pub supertypes: FxHashSet<String>,
85    /// Mapping from supertype name to its concrete subtypes.
86    pub subtype_map: Vec<(String, Vec<String>)>,
87    /// Fields that are optional (for `ThPartial` composition).
88    pub optional_fields: FxHashSet<String>,
89    /// Fields that are ordered (for `ThOrder` composition).
90    pub ordered_fields: FxHashSet<String>,
91    /// All named node types (vertex kinds for the protocol).
92    pub vertex_kinds: Vec<String>,
93    /// All field names (edge kinds for the protocol).
94    pub edge_kinds: Vec<String>,
95}
96
97// ─── extraction from node-types.json ──────────────────────────────────────
98
99/// Parse tree-sitter's `node-types.json` bytes into a vector of [`NodeType`] entries.
100///
101/// # Errors
102///
103/// Returns [`ParseError::NodeTypesJson`] if JSON deserialization fails.
104pub fn parse_node_types(json: &[u8]) -> Result<Vec<NodeType>, ParseError> {
105    // `node-types.json` is an array of node-type entries, but recent
106    // tree-sitter releases append non-node metadata markers (e.g.
107    // `{"@generated": true}`, as in the Erlang grammar) that carry no
108    // `type` field. Deserialize loosely and skip any entry without a
109    // string `type` rather than failing the whole grammar.
110    let raw: Vec<serde_json::Value> =
111        serde_json::from_slice(json).map_err(|e| ParseError::NodeTypesJson { source: e })?;
112    raw.into_iter()
113        .filter(|entry| {
114            entry
115                .get("type")
116                .and_then(serde_json::Value::as_str)
117                .is_some()
118        })
119        .map(|entry| {
120            serde_json::from_value(entry).map_err(|e| ParseError::NodeTypesJson { source: e })
121        })
122        .collect()
123}
124
125/// Extract a panproto [`Theory`] from tree-sitter's `node-types.json` content.
126///
127/// The returned [`ExtractedTheoryMeta`] includes the GAT plus metadata about
128/// supertypes, optional fields, and ordered fields needed for protocol definition
129/// and colimit composition.
130///
131/// # Errors
132///
133/// Returns [`ParseError`] if JSON parsing fails or the grammar has structural
134/// issues preventing theory extraction.
135pub fn extract_theory_from_node_types(
136    theory_name: &str,
137    json: &[u8],
138) -> Result<ExtractedTheoryMeta, ParseError> {
139    let node_types = parse_node_types(json)?;
140    extract_theory_from_entries(theory_name, &node_types)
141}
142
143/// Extract a theory from already-parsed [`NodeType`] entries.
144///
145/// # Errors
146///
147/// Returns [`ParseError::TheoryExtraction`] if the grammar structure is invalid.
148pub fn extract_theory_from_entries(
149    theory_name: &str,
150    node_types: &[NodeType],
151) -> Result<ExtractedTheoryMeta, ParseError> {
152    let mut sorts: Vec<Sort> = Vec::new();
153    let mut ops: Vec<Operation> = Vec::new();
154    let mut supertypes = FxHashSet::default();
155    let mut subtype_map: Vec<(String, Vec<String>)> = Vec::new();
156    let mut optional_fields = FxHashSet::default();
157    let mut ordered_fields = FxHashSet::default();
158    let mut vertex_kinds: Vec<String> = Vec::new();
159    let mut edge_kind_set = FxHashSet::default();
160    let mut seen_sorts = FxHashSet::default();
161
162    // Always include Vertex and Edge as base sorts (shared with ThGraph via colimit).
163    sorts.push(Sort::simple("Vertex"));
164    sorts.push(Sort::simple("Edge"));
165    seen_sorts.insert("Vertex".to_owned());
166    seen_sorts.insert("Edge".to_owned());
167
168    for entry in node_types {
169        // Skip anonymous tokens (punctuation, keywords).
170        if !entry.named {
171            continue;
172        }
173
174        let sort_name = &entry.node_type;
175
176        // Supertype nodes define abstract sorts with subtype inclusions.
177        if let Some(ref subtypes) = entry.subtypes {
178            supertypes.insert(sort_name.clone());
179            let concrete: Vec<String> = subtypes
180                .iter()
181                .filter(|s| s.named)
182                .map(|s| s.node_type.clone())
183                .collect();
184            subtype_map.push((sort_name.clone(), concrete));
185
186            // Register the supertype as a sort if not already present.
187            if seen_sorts.insert(sort_name.clone()) {
188                sorts.push(Sort::simple(sort_name.as_str()));
189                vertex_kinds.push(sort_name.clone());
190            }
191            continue;
192        }
193
194        // Regular named node type: create a sort and operations for its fields.
195        if seen_sorts.insert(sort_name.clone()) {
196            sorts.push(Sort::simple(sort_name.as_str()));
197            vertex_kinds.push(sort_name.clone());
198        }
199
200        // Process fields: each field becomes an operation (edge kind).
201        for (field_name, field_value) in &entry.fields {
202            let spec = parse_field_spec(field_name, field_value)?;
203
204            // Track optional and ordered fields for later composition.
205            if !spec.required {
206                optional_fields.insert(field_name.clone());
207            }
208            if spec.multiple {
209                ordered_fields.insert(field_name.clone());
210            }
211
212            // Create an operation for this field if not already registered.
213            // The operation represents an edge kind: parent_sort --field_name--> child_sort.
214            // Since tree-sitter fields can accept multiple child types, we model
215            // the operation as mapping from the parent sort to Vertex (the abstract base).
216            if edge_kind_set.insert(field_name.clone()) {
217                ops.push(Operation::unary(
218                    field_name.as_str(),
219                    "parent",
220                    "Vertex",
221                    "Vertex",
222                ));
223            }
224        }
225
226        // Process unnamed children (if present).
227        if let Some(ref children) = entry.children {
228            if children.multiple {
229                ordered_fields.insert("children".to_owned());
230            }
231            // Unnamed children use a generic "child_of" edge.
232            if edge_kind_set.insert("child_of".to_owned()) {
233                ops.push(Operation::unary("child_of", "parent", "Vertex", "Vertex"));
234            }
235        }
236    }
237
238    let edge_kinds: Vec<String> = edge_kind_set.into_iter().collect();
239
240    let theory = Theory::new(theory_name, sorts, ops, vec![]);
241
242    Ok(ExtractedTheoryMeta {
243        theory,
244        supertypes,
245        subtype_map,
246        optional_fields,
247        ordered_fields,
248        vertex_kinds,
249        edge_kinds,
250    })
251}
252
253/// Extract a theory at runtime from a tree-sitter `Language` object.
254///
255/// This uses the Language introspection API (`node_kind_count`, `field_count`,
256/// `node_kind_is_named`, `supertypes()`) rather than parsing `node-types.json`.
257///
258/// Supertype information is available via the runtime API. However, field
259/// optionality (`required`) and multiplicity (`multiple`) are NOT exposed
260/// by the Language runtime API. For full metadata, use
261/// [`extract_theory_from_node_types`] with the `NODE_TYPES` constant
262/// embedded in each grammar crate.
263///
264/// # Errors
265///
266/// Returns [`ParseError::TheoryExtraction`] if introspection fails.
267pub fn extract_theory_from_language(
268    theory_name: &str,
269    language: &tree_sitter::Language,
270) -> Result<ExtractedTheoryMeta, ParseError> {
271    let mut sorts: Vec<Sort> = Vec::new();
272    let mut ops: Vec<Operation> = Vec::new();
273    let mut vertex_kinds: Vec<String> = Vec::new();
274    let mut edge_kind_set = FxHashSet::default();
275    let mut seen_sorts = FxHashSet::default();
276    // Base sorts.
277    sorts.push(Sort::simple("Vertex"));
278    sorts.push(Sort::simple("Edge"));
279    seen_sorts.insert("Vertex".to_owned());
280    seen_sorts.insert("Edge".to_owned());
281
282    // Enumerate all named node types as sorts.
283    let node_count = language.node_kind_count();
284    for id in 0..node_count {
285        let Ok(id_u16) = u16::try_from(id) else {
286            continue;
287        };
288        if language.node_kind_is_named(id_u16) {
289            if let Some(name) = language.node_kind_for_id(id_u16) {
290                // Skip internal hidden nodes (prefixed with _).
291                if name.starts_with('_') {
292                    continue;
293                }
294
295                if seen_sorts.insert(name.to_owned()) {
296                    sorts.push(Sort::simple(name));
297                    vertex_kinds.push(name.to_owned());
298                }
299            }
300        }
301    }
302
303    // Enumerate all field names as operations (edge kinds).
304    let field_count = language.field_count();
305    for id in 1..=field_count {
306        let Ok(id_u16) = u16::try_from(id) else {
307            continue;
308        };
309        if let Some(name) = language.field_name_for_id(id_u16) {
310            if edge_kind_set.insert(name.to_owned()) {
311                ops.push(Operation::unary(name, "parent", "Vertex", "Vertex"));
312            }
313        }
314    }
315
316    let edge_kinds: Vec<String> = edge_kind_set.into_iter().collect();
317
318    let theory = Theory::new(theory_name, sorts, ops, vec![]);
319
320    // Note: optional_fields, ordered_fields, supertypes, and subtype_map
321    // cannot be fully determined from the tree-sitter 0.24 Language runtime API.
322    // For full metadata, use extract_theory_from_node_types() with the NODE_TYPES
323    // constant from the grammar crate.
324    Ok(ExtractedTheoryMeta {
325        theory,
326        supertypes: FxHashSet::default(),
327        subtype_map: Vec::new(),
328        optional_fields: FxHashSet::default(),
329        ordered_fields: FxHashSet::default(),
330        vertex_kinds,
331        edge_kinds,
332    })
333}
334
335// ─── helpers ──────────────────────────────────────────────────────────────
336
337/// Parse a field specification from the JSON value in node-types.json.
338fn parse_field_spec(name: &str, value: &serde_json::Value) -> Result<FieldSpec, ParseError> {
339    let obj = value
340        .as_object()
341        .ok_or_else(|| ParseError::TheoryExtraction {
342            reason: format!("field '{name}' is not an object"),
343        })?;
344
345    let required = obj
346        .get("required")
347        .and_then(serde_json::Value::as_bool)
348        .unwrap_or(false);
349
350    let multiple = obj
351        .get("multiple")
352        .and_then(serde_json::Value::as_bool)
353        .unwrap_or(false);
354
355    let types: Vec<SubtypeRef> = obj
356        .get("types")
357        .and_then(|v| serde_json::from_value(v.clone()).ok())
358        .unwrap_or_default();
359
360    Ok(FieldSpec {
361        name: name.to_owned(),
362        required,
363        multiple,
364        types,
365    })
366}
367
368#[cfg(test)]
369#[allow(clippy::unwrap_used)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn extract_minimal_grammar() {
375        let json = br#"[
376            {
377                "type": "program",
378                "named": true,
379                "fields": {},
380                "children": {
381                    "multiple": true,
382                    "required": false,
383                    "types": [{"type": "statement", "named": true}]
384                }
385            },
386            {
387                "type": "statement",
388                "named": true,
389                "fields": {
390                    "body": {
391                        "multiple": false,
392                        "required": true,
393                        "types": [{"type": "expression", "named": true}]
394                    }
395                }
396            },
397            {
398                "type": "expression",
399                "named": true,
400                "fields": {}
401            },
402            {
403                "type": ";",
404                "named": false
405            }
406        ]"#;
407
408        let meta = extract_theory_from_node_types("ThTest", json).unwrap();
409
410        // Should have Vertex, Edge (base) + program, statement, expression = 5 sorts.
411        assert_eq!(meta.theory.sorts.len(), 5);
412
413        // Should have body + child_of = 2 operations.
414        assert_eq!(meta.theory.ops.len(), 2);
415
416        // "program", "statement", "expression" as vertex kinds.
417        assert_eq!(meta.vertex_kinds.len(), 3);
418        assert!(meta.vertex_kinds.contains(&"program".to_owned()));
419        assert!(meta.vertex_kinds.contains(&"statement".to_owned()));
420        assert!(meta.vertex_kinds.contains(&"expression".to_owned()));
421
422        // "body" and "child_of" as edge kinds.
423        assert_eq!(meta.edge_kinds.len(), 2);
424
425        // "children" on program is ordered (multiple=true).
426        assert!(meta.ordered_fields.contains("children"));
427    }
428
429    #[test]
430    fn extract_supertype() {
431        let json = br#"[
432            {
433                "type": "_expression",
434                "named": true,
435                "subtypes": [
436                    {"type": "binary_expression", "named": true},
437                    {"type": "call_expression", "named": true}
438                ]
439            },
440            {
441                "type": "binary_expression",
442                "named": true,
443                "fields": {
444                    "left": {
445                        "multiple": false,
446                        "required": true,
447                        "types": [{"type": "_expression", "named": true}]
448                    },
449                    "right": {
450                        "multiple": false,
451                        "required": true,
452                        "types": [{"type": "_expression", "named": true}]
453                    }
454                }
455            },
456            {
457                "type": "call_expression",
458                "named": true,
459                "fields": {
460                    "function": {
461                        "multiple": false,
462                        "required": true,
463                        "types": [{"type": "_expression", "named": true}]
464                    },
465                    "arguments": {
466                        "multiple": true,
467                        "required": true,
468                        "types": [{"type": "_expression", "named": true}]
469                    }
470                }
471            }
472        ]"#;
473
474        let meta = extract_theory_from_node_types("ThExprTest", json).unwrap();
475
476        // _expression is a supertype.
477        assert!(meta.supertypes.contains("_expression"));
478
479        // subtype_map: _expression → [binary_expression, call_expression]
480        assert_eq!(meta.subtype_map.len(), 1);
481        let (st, subs) = &meta.subtype_map[0];
482        assert_eq!(st, "_expression");
483        assert_eq!(subs.len(), 2);
484
485        // "arguments" is ordered.
486        assert!(meta.ordered_fields.contains("arguments"));
487
488        // Operations: left, right, function, arguments = 4 edge kinds.
489        assert_eq!(meta.edge_kinds.len(), 4);
490    }
491
492    #[test]
493    fn anonymous_tokens_skipped() {
494        let json = br#"[
495            {"type": "identifier", "named": true, "fields": {}},
496            {"type": "(", "named": false},
497            {"type": ")", "named": false}
498        ]"#;
499
500        let meta = extract_theory_from_node_types("ThAnon", json).unwrap();
501
502        // Only "identifier" + base sorts (Vertex, Edge) = 3.
503        assert_eq!(meta.theory.sorts.len(), 3);
504        assert_eq!(meta.vertex_kinds.len(), 1);
505    }
506}