codeprism_dev_tools/
ast_visualizer.rs

1//! AST Visualization utilities for parser development
2//!
3//! This module provides tools for visualizing Abstract Syntax Trees (ASTs) in various formats
4//! to help with parser development and debugging.
5
6use anyhow::Result;
7use colored::Colorize;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use tree_sitter::{Node, Tree};
12
13/// AST visualizer for pretty-printing syntax trees
14#[derive(Debug, Clone)]
15pub struct AstVisualizer {
16    config: VisualizationConfig,
17}
18
19/// Configuration for AST visualization
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VisualizationConfig {
22    /// Maximum depth to visualize (prevents infinite recursion)
23    pub max_depth: usize,
24    /// Whether to show node positions (line, column)
25    pub show_positions: bool,
26    /// Whether to show node byte ranges
27    pub show_byte_ranges: bool,
28    /// Whether to use colors in output
29    pub use_colors: bool,
30    /// Whether to show node text content
31    pub show_text_content: bool,
32    /// Maximum length of text content to display
33    pub max_text_length: usize,
34    /// Whether to show only named nodes
35    pub named_nodes_only: bool,
36    /// Custom node type colors
37    pub node_color_names: HashMap<String, String>,
38    /// Indentation string for tree structure
39    pub indent_string: String,
40}
41
42impl Default for VisualizationConfig {
43    fn default() -> Self {
44        let mut node_color_names = HashMap::new();
45
46        // Set up default colors for common node types
47        node_color_names.insert("function_definition".to_string(), "blue".to_string());
48        node_color_names.insert("class_definition".to_string(), "green".to_string());
49        node_color_names.insert("function_call".to_string(), "cyan".to_string());
50        node_color_names.insert("variable".to_string(), "yellow".to_string());
51        node_color_names.insert("string".to_string(), "red".to_string());
52        node_color_names.insert("number".to_string(), "magenta".to_string());
53        node_color_names.insert("comment".to_string(), "brightblack".to_string());
54        node_color_names.insert("keyword".to_string(), "brightblue".to_string());
55        node_color_names.insert("operator".to_string(), "brightyellow".to_string());
56        node_color_names.insert("identifier".to_string(), "white".to_string());
57
58        Self {
59            max_depth: 20,
60            show_positions: true,
61            show_byte_ranges: false,
62            use_colors: true,
63            show_text_content: true,
64            max_text_length: 50,
65            named_nodes_only: false,
66            node_color_names,
67            indent_string: "  ".to_string(),
68        }
69    }
70}
71
72/// Format options for AST visualization
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum VisualizationFormat {
75    /// Pretty-printed tree structure
76    Tree,
77    /// Flat list of nodes
78    List,
79    /// JSON representation
80    Json,
81    /// S-expression format
82    SExpression,
83    /// Compact one-line format
84    Compact,
85}
86
87impl AstVisualizer {
88    /// Create a new AST visualizer with default configuration
89    pub fn new() -> Self {
90        Self {
91            config: VisualizationConfig::default(),
92        }
93    }
94
95    /// Create an AST visualizer with custom configuration
96    pub fn with_config(config: VisualizationConfig) -> Self {
97        Self { config }
98    }
99
100    /// Visualize a tree-sitter Tree
101    pub fn visualize_tree(&self, tree: &Tree, source: &str) -> Result<String> {
102        let root_node = tree.root_node();
103        self.visualize_node(&root_node, source, VisualizationFormat::Tree)
104    }
105
106    /// Visualize a specific node with the given format
107    pub fn visualize_node(
108        &self,
109        node: &Node,
110        source: &str,
111        format: VisualizationFormat,
112    ) -> Result<String> {
113        match format {
114            VisualizationFormat::Tree => self.visualize_tree_format(node, source),
115            VisualizationFormat::List => self.visualize_list_format(node, source),
116            VisualizationFormat::Json => self.visualize_json_format(node, source),
117            VisualizationFormat::SExpression => self.visualize_sexp_format(node, source),
118            VisualizationFormat::Compact => self.visualize_compact_format(node, source),
119        }
120    }
121
122    /// Visualize in tree format (default pretty-print)
123    fn visualize_tree_format(&self, node: &Node, source: &str) -> Result<String> {
124        let mut output = String::new();
125        self.visualize_node_recursive(node, source, 0, "", true, &mut output);
126        Ok(output)
127    }
128
129    /// Recursive helper for tree visualization
130    fn visualize_node_recursive(
131        &self,
132        node: &Node,
133        source: &str,
134        depth: usize,
135        prefix: &str,
136        is_last: bool,
137        output: &mut String,
138    ) {
139        if depth > self.config.max_depth {
140            output.push_str(&format!("{}{}...\n", prefix, "─── ".dimmed()));
141            return;
142        }
143
144        // Skip unnamed nodes if configured
145        if self.config.named_nodes_only && !node.is_named() {
146            return;
147        }
148
149        // Create the current line prefix
150        let connector = if is_last { "└── " } else { "├── " };
151        let node_prefix = format!("{}{}", prefix, connector);
152
153        // Format the node type
154        let node_type = self.format_node_type(node.kind());
155
156        // Add position information if enabled
157        let position_info = if self.config.show_positions {
158            let start = node.start_position();
159            let end = node.end_position();
160            format!(
161                " @{}:{}-{}:{}",
162                start.row + 1,
163                start.column + 1,
164                end.row + 1,
165                end.column + 1
166            )
167        } else {
168            String::new()
169        };
170
171        // Add byte range information if enabled
172        let byte_range_info = if self.config.show_byte_ranges {
173            format!(" [{}..{}]", node.start_byte(), node.end_byte())
174        } else {
175            String::new()
176        };
177
178        // Add text content if enabled and node is small enough
179        let text_content = if self.config.show_text_content && node.child_count() == 0 {
180            let text = node
181                .utf8_text(source.as_bytes())
182                .unwrap_or("<invalid utf8>");
183            if text.len() <= self.config.max_text_length {
184                format!(" \"{}\"", text.replace('\n', "\\n").replace('\r', "\\r"))
185            } else {
186                format!(
187                    " \"{}...\"",
188                    &text[..self.config.max_text_length.min(text.len())]
189                )
190            }
191        } else {
192            String::new()
193        };
194
195        // Write the formatted node
196        output.push_str(&format!(
197            "{}{}{}{}{}\n",
198            node_prefix, node_type, position_info, byte_range_info, text_content
199        ));
200
201        // Process children
202        let child_count = node.child_count();
203        for i in 0..child_count {
204            if let Some(child) = node.child(i) {
205                let child_prefix = format!("{}{}", prefix, if is_last { "    " } else { "│   " });
206                let is_last_child = i == child_count - 1;
207                self.visualize_node_recursive(
208                    &child,
209                    source,
210                    depth + 1,
211                    &child_prefix,
212                    is_last_child,
213                    output,
214                );
215            }
216        }
217    }
218
219    /// Format node type with colors if enabled
220    fn format_node_type(&self, node_type: &str) -> String {
221        if !self.config.use_colors {
222            return node_type.to_string();
223        }
224
225        if let Some(color_name) = self.config.node_color_names.get(node_type) {
226            match color_name.as_str() {
227                "blue" => node_type.blue().to_string(),
228                "green" => node_type.green().to_string(),
229                "cyan" => node_type.cyan().to_string(),
230                "red" => node_type.red().to_string(),
231                "yellow" => node_type.yellow().to_string(),
232                "magenta" => node_type.magenta().to_string(),
233                _ => node_type.normal().to_string(),
234            }
235        } else {
236            // Default color for unknown node types
237            node_type.normal().to_string()
238        }
239    }
240
241    /// Visualize in list format
242    fn visualize_list_format(&self, node: &Node, _source: &str) -> Result<String> {
243        let mut output = String::new();
244        let mut cursor = node.walk();
245        let mut depth = 0;
246
247        loop {
248            let current_node = cursor.node();
249
250            // Skip unnamed nodes if configured
251            if !self.config.named_nodes_only || current_node.is_named() {
252                let indent = self.config.indent_string.repeat(depth);
253                let node_type = self.format_node_type(current_node.kind());
254
255                let position_info = if self.config.show_positions {
256                    let start = current_node.start_position();
257                    format!(" @{}:{}", start.row + 1, start.column + 1)
258                } else {
259                    String::new()
260                };
261
262                output.push_str(&format!("{}{}{}\n", indent, node_type, position_info));
263            }
264
265            if cursor.goto_first_child() {
266                depth += 1;
267            } else if cursor.goto_next_sibling() {
268                // Stay at same depth
269            } else {
270                // Go back up until we find a sibling or reach the root
271                loop {
272                    if !cursor.goto_parent() {
273                        return Ok(output); // Reached root
274                    }
275                    depth -= 1;
276                    if cursor.goto_next_sibling() {
277                        break;
278                    }
279                }
280            }
281
282            if depth > self.config.max_depth {
283                break;
284            }
285        }
286
287        Ok(output)
288    }
289
290    /// Visualize in JSON format
291    fn visualize_json_format(&self, node: &Node, source: &str) -> Result<String> {
292        let json_node = self.node_to_json(node, source, 0)?;
293        Ok(serde_json::to_string_pretty(&json_node)?)
294    }
295
296    /// Convert a node to JSON representation
297    fn node_to_json(&self, node: &Node, source: &str, depth: usize) -> Result<serde_json::Value> {
298        if depth > self.config.max_depth {
299            return Ok(serde_json::json!({
300                "type": "...",
301                "truncated": true
302            }));
303        }
304
305        let mut json_node = serde_json::Map::new();
306        json_node.insert(
307            "type".to_string(),
308            serde_json::Value::String(node.kind().to_string()),
309        );
310        json_node.insert(
311            "named".to_string(),
312            serde_json::Value::Bool(node.is_named()),
313        );
314
315        if self.config.show_positions {
316            let start = node.start_position();
317            let end = node.end_position();
318            json_node.insert(
319                "start".to_string(),
320                serde_json::json!({
321                    "row": start.row,
322                    "column": start.column
323                }),
324            );
325            json_node.insert(
326                "end".to_string(),
327                serde_json::json!({
328                    "row": end.row,
329                    "column": end.column
330                }),
331            );
332        }
333
334        if self.config.show_byte_ranges {
335            json_node.insert(
336                "start_byte".to_string(),
337                serde_json::Value::Number(node.start_byte().into()),
338            );
339            json_node.insert(
340                "end_byte".to_string(),
341                serde_json::Value::Number(node.end_byte().into()),
342            );
343        }
344
345        if self.config.show_text_content && node.child_count() == 0 {
346            if let Ok(text) = node.utf8_text(source.as_bytes()) {
347                let display_text = if text.len() <= self.config.max_text_length {
348                    text.to_string()
349                } else {
350                    format!(
351                        "{}...",
352                        &text[..self.config.max_text_length.min(text.len())]
353                    )
354                };
355                json_node.insert("text".to_string(), serde_json::Value::String(display_text));
356            }
357        }
358
359        let mut children = Vec::new();
360        for i in 0..node.child_count() {
361            if let Some(child) = node.child(i) {
362                if !self.config.named_nodes_only || child.is_named() {
363                    children.push(self.node_to_json(&child, source, depth + 1)?);
364                }
365            }
366        }
367
368        if !children.is_empty() {
369            json_node.insert("children".to_string(), serde_json::Value::Array(children));
370        }
371
372        Ok(serde_json::Value::Object(json_node))
373    }
374
375    /// Visualize in S-expression format
376    fn visualize_sexp_format(&self, node: &Node, source: &str) -> Result<String> {
377        let mut output = String::new();
378        self.node_to_sexp(node, source, 0, &mut output)?;
379        Ok(output)
380    }
381
382    /// Convert node to S-expression format
383    fn node_to_sexp(
384        &self,
385        node: &Node,
386        source: &str,
387        depth: usize,
388        output: &mut String,
389    ) -> Result<()> {
390        if depth > self.config.max_depth {
391            output.push_str("...");
392            return Ok(());
393        }
394
395        if self.config.named_nodes_only && !node.is_named() {
396            return Ok(());
397        }
398
399        output.push('(');
400        output.push_str(node.kind());
401
402        // Add text for leaf nodes
403        if node.child_count() == 0 && self.config.show_text_content {
404            if let Ok(text) = node.utf8_text(source.as_bytes()) {
405                let display_text = if text.len() <= self.config.max_text_length {
406                    text.to_string()
407                } else {
408                    format!(
409                        "{}...",
410                        &text[..self.config.max_text_length.min(text.len())]
411                    )
412                };
413                output.push_str(&format!(" \"{}\"", display_text.replace('"', "\\\"")));
414            }
415        }
416
417        // Process children
418        for i in 0..node.child_count() {
419            if let Some(child) = node.child(i) {
420                if !self.config.named_nodes_only || child.is_named() {
421                    output.push(' ');
422                    self.node_to_sexp(&child, source, depth + 1, output)?;
423                }
424            }
425        }
426
427        output.push(')');
428        Ok(())
429    }
430
431    /// Visualize in compact format
432    fn visualize_compact_format(&self, node: &Node, source: &str) -> Result<String> {
433        let mut output = String::new();
434        self.node_to_compact(node, source, 0, &mut output)?;
435        Ok(output.trim().to_string())
436    }
437
438    /// Convert node to compact format
439    fn node_to_compact(
440        &self,
441        node: &Node,
442        source: &str,
443        depth: usize,
444        output: &mut String,
445    ) -> Result<()> {
446        if depth > self.config.max_depth {
447            output.push_str("...");
448            return Ok(());
449        }
450
451        if self.config.named_nodes_only && !node.is_named() {
452            return Ok(());
453        }
454
455        output.push_str(node.kind());
456
457        if node.child_count() == 0 && self.config.show_text_content {
458            if let Ok(text) = node.utf8_text(source.as_bytes()) {
459                let display_text = if text.len() <= self.config.max_text_length {
460                    text.to_string()
461                } else {
462                    format!(
463                        "{}...",
464                        &text[..self.config.max_text_length.min(text.len())]
465                    )
466                };
467                output.push_str(&format!(":{}", display_text.replace(' ', "_")));
468            }
469        }
470
471        if node.child_count() > 0 {
472            output.push('[');
473            for i in 0..node.child_count() {
474                if let Some(child) = node.child(i) {
475                    if !self.config.named_nodes_only || child.is_named() {
476                        if i > 0 {
477                            output.push(',');
478                        }
479                        self.node_to_compact(&child, source, depth + 1, output)?;
480                    }
481                }
482            }
483            output.push(']');
484        }
485
486        Ok(())
487    }
488
489    /// Get statistics about the AST
490    pub fn get_ast_statistics(&self, node: &Node) -> AstStatistics {
491        let mut stats = AstStatistics::default();
492        self.collect_statistics(node, &mut stats, 0);
493        stats
494    }
495
496    /// Recursively collect AST statistics
497    #[allow(clippy::only_used_in_recursion)] // Method is used recursively by design
498    fn collect_statistics(&self, node: &Node, stats: &mut AstStatistics, depth: usize) {
499        stats.total_nodes += 1;
500        stats.max_depth = stats.max_depth.max(depth);
501
502        if node.is_named() {
503            stats.named_nodes += 1;
504        } else {
505            stats.unnamed_nodes += 1;
506        }
507
508        *stats
509            .node_type_counts
510            .entry(node.kind().to_string())
511            .or_insert(0) += 1;
512
513        if node.child_count() == 0 {
514            stats.leaf_nodes += 1;
515        }
516
517        for i in 0..node.child_count() {
518            if let Some(child) = node.child(i) {
519                self.collect_statistics(&child, stats, depth + 1);
520            }
521        }
522    }
523
524    /// Compare two ASTs and highlight differences
525    pub fn compare_asts(&self, old_node: &Node, new_node: &Node, _source: &str) -> Result<String> {
526        let mut output = String::new();
527        output.push_str("=== AST Comparison ===\n\n");
528
529        let old_stats = self.get_ast_statistics(old_node);
530        let new_stats = self.get_ast_statistics(new_node);
531
532        output.push_str("## Statistics Comparison\n");
533        output.push_str(&format!(
534            "Total nodes: {} -> {} ({}{})\n",
535            old_stats.total_nodes,
536            new_stats.total_nodes,
537            if new_stats.total_nodes >= old_stats.total_nodes {
538                "+"
539            } else {
540                ""
541            },
542            new_stats.total_nodes as i32 - old_stats.total_nodes as i32
543        ));
544
545        output.push_str(&format!(
546            "Max depth: {} -> {} ({}{})\n",
547            old_stats.max_depth,
548            new_stats.max_depth,
549            if new_stats.max_depth >= old_stats.max_depth {
550                "+"
551            } else {
552                ""
553            },
554            new_stats.max_depth as i32 - old_stats.max_depth as i32
555        ));
556
557        output.push_str("\n## Structural Differences\n");
558        if old_node.kind() != new_node.kind() {
559            output.push_str(&format!(
560                "Root node type changed: {} -> {}\n",
561                old_node.kind(),
562                new_node.kind()
563            ));
564        }
565
566        Ok(output)
567    }
568}
569
570impl Default for AstVisualizer {
571    fn default() -> Self {
572        Self::new()
573    }
574}
575
576/// Statistics about an AST
577#[derive(Debug, Default)]
578pub struct AstStatistics {
579    pub total_nodes: usize,
580    pub named_nodes: usize,
581    pub unnamed_nodes: usize,
582    pub leaf_nodes: usize,
583    pub max_depth: usize,
584    pub node_type_counts: HashMap<String, usize>,
585}
586
587impl fmt::Display for AstStatistics {
588    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589        writeln!(f, "AST Statistics:")?;
590        writeln!(f, "  Total nodes: {}", self.total_nodes)?;
591        writeln!(f, "  Named nodes: {}", self.named_nodes)?;
592        writeln!(f, "  Unnamed nodes: {}", self.unnamed_nodes)?;
593        writeln!(f, "  Leaf nodes: {}", self.leaf_nodes)?;
594        writeln!(f, "  Maximum depth: {}", self.max_depth)?;
595        writeln!(f, "  Node types:")?;
596
597        let mut types: Vec<_> = self.node_type_counts.iter().collect();
598        types.sort_by(|a, b| b.1.cmp(a.1)); // Sort by count descending
599
600        for (node_type, count) in types.iter().take(10) {
601            // Show top 10
602            writeln!(f, "    {}: {}", node_type, count)?;
603        }
604
605        if types.len() > 10 {
606            writeln!(f, "    ... and {} more", types.len() - 10)?;
607        }
608
609        Ok(())
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616    use tree_sitter::Parser;
617
618    #[allow(dead_code)] // Used in tests
619    fn create_test_parser() -> Parser {
620        // For testing, we'll use a simple language grammar
621        // In real usage, this would use the appropriate language
622        Parser::new()
623    }
624
625    #[test]
626    fn test_ast_visualizer_creation() {
627        let visualizer = AstVisualizer::new();
628        assert_eq!(visualizer.config.max_depth, 20);
629        assert!(visualizer.config.show_positions);
630    }
631
632    #[test]
633    fn test_custom_config() {
634        let config = VisualizationConfig {
635            max_depth: 10,
636            show_positions: false,
637            ..Default::default()
638        };
639
640        let visualizer = AstVisualizer::with_config(config);
641        assert_eq!(visualizer.config.max_depth, 10);
642        assert!(!visualizer.config.show_positions);
643    }
644
645    #[test]
646    fn test_format_node_type_with_colors() {
647        let visualizer = AstVisualizer::new();
648        let formatted = visualizer.format_node_type("function_definition");
649        // Note: Testing colored output is difficult, so we just ensure it doesn't panic
650        assert!(!formatted.is_empty());
651    }
652
653    #[test]
654    fn test_format_node_type_without_colors() {
655        let config = VisualizationConfig {
656            use_colors: false,
657            ..Default::default()
658        };
659        let visualizer = AstVisualizer::with_config(config);
660
661        let formatted = visualizer.format_node_type("function_definition");
662        assert_eq!(formatted, "function_definition");
663    }
664
665    #[test]
666    fn test_ast_statistics_display() {
667        let mut stats = AstStatistics {
668            total_nodes: 100,
669            named_nodes: 80,
670            unnamed_nodes: 20,
671            max_depth: 5,
672            ..Default::default()
673        };
674        stats.node_type_counts.insert("function".to_string(), 10);
675        stats.node_type_counts.insert("identifier".to_string(), 30);
676
677        let output = format!("{}", stats);
678        assert!(output.contains("Total nodes: 100"));
679        assert!(output.contains("Named nodes: 80"));
680        assert!(output.contains("Maximum depth: 5"));
681    }
682}