codeprism_dev_tools/
ast_visualizer.rs

1//! AST Visualization utilities for parser development
2//!
3//! This module provides tools for visualizing Abstract Syntax Trees (ASTs) in various formats
4//! to help with parser development and debugging.
5
6use anyhow::Result;
7use colored::Colorize;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use tree_sitter::{Node, Tree};
12
13/// AST visualizer for pretty-printing syntax trees
14#[derive(Debug, Clone)]
15pub struct AstVisualizer {
16    config: VisualizationConfig,
17}
18
19/// Configuration for AST visualization
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VisualizationConfig {
22    /// Maximum depth to visualize (prevents infinite recursion)
23    pub max_depth: usize,
24    /// Whether to show node positions (line, column)
25    pub show_positions: bool,
26    /// Whether to show node byte ranges
27    pub show_byte_ranges: bool,
28    /// Whether to use colors in output
29    pub use_colors: bool,
30    /// Whether to show node text content
31    pub show_text_content: bool,
32    /// Maximum length of text content to display
33    pub max_text_length: usize,
34    /// Whether to show only named nodes
35    pub named_nodes_only: bool,
36    /// Custom node type colors
37    pub node_color_names: HashMap<String, String>,
38    /// Indentation string for tree structure
39    pub indent_string: String,
40}
41
42impl Default for VisualizationConfig {
43    fn default() -> Self {
44        let mut node_color_names = HashMap::new();
45
46        // Set up default colors for common node types
47        node_color_names.insert("function_definition".to_string(), "blue".to_string());
48        node_color_names.insert("class_definition".to_string(), "green".to_string());
49        node_color_names.insert("function_call".to_string(), "cyan".to_string());
50        node_color_names.insert("variable".to_string(), "yellow".to_string());
51        node_color_names.insert("string".to_string(), "red".to_string());
52        node_color_names.insert("number".to_string(), "magenta".to_string());
53        node_color_names.insert("comment".to_string(), "brightblack".to_string());
54        node_color_names.insert("keyword".to_string(), "brightblue".to_string());
55        node_color_names.insert("operator".to_string(), "brightyellow".to_string());
56        node_color_names.insert("identifier".to_string(), "white".to_string());
57
58        Self {
59            max_depth: 20,
60            show_positions: true,
61            show_byte_ranges: false,
62            use_colors: true,
63            show_text_content: true,
64            max_text_length: 50,
65            named_nodes_only: false,
66            node_color_names,
67            indent_string: "  ".to_string(),
68        }
69    }
70}
71
72/// Format options for AST visualization
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum VisualizationFormat {
75    /// Pretty-printed tree structure
76    Tree,
77    /// Flat list of nodes
78    List,
79    /// JSON representation
80    Json,
81    /// S-expression format
82    SExpression,
83    /// Compact one-line format
84    Compact,
85}
86
87impl AstVisualizer {
88    /// Create a new AST visualizer with default configuration
89    pub fn new() -> Self {
90        Self {
91            config: VisualizationConfig::default(),
92        }
93    }
94
95    /// Create an AST visualizer with custom configuration
96    pub fn with_config(config: VisualizationConfig) -> Self {
97        Self { config }
98    }
99
100    /// Visualize a tree-sitter Tree
101    pub fn visualize_tree(&self, tree: &Tree, source: &str) -> Result<String> {
102        let root_node = tree.root_node();
103        self.visualize_node(&root_node, source, VisualizationFormat::Tree)
104    }
105
106    /// Visualize a specific node with the given format
107    pub fn visualize_node(
108        &self,
109        node: &Node,
110        source: &str,
111        format: VisualizationFormat,
112    ) -> Result<String> {
113        match format {
114            VisualizationFormat::Tree => self.visualize_tree_format(node, source),
115            VisualizationFormat::List => self.visualize_list_format(node, source),
116            VisualizationFormat::Json => self.visualize_json_format(node, source),
117            VisualizationFormat::SExpression => self.visualize_sexp_format(node, source),
118            VisualizationFormat::Compact => self.visualize_compact_format(node, source),
119        }
120    }
121
122    /// Visualize in tree format (default pretty-print)
123    fn visualize_tree_format(&self, node: &Node, source: &str) -> Result<String> {
124        let mut output = String::new();
125        self.visualize_node_recursive(node, source, 0, "", true, &mut output);
126        Ok(output)
127    }
128
129    /// Recursive helper for tree visualization
130    fn visualize_node_recursive(
131        &self,
132        node: &Node,
133        source: &str,
134        depth: usize,
135        prefix: &str,
136        is_last: bool,
137        output: &mut String,
138    ) {
139        if depth > self.config.max_depth {
140            output.push_str(&format!("{}{}...\n", prefix, "─── ".dimmed()));
141            return;
142        }
143
144        // Skip unnamed nodes if configured
145        if self.config.named_nodes_only && !node.is_named() {
146            return;
147        }
148
149        // Create the current line prefix
150        let connector = if is_last { "└── " } else { "├── " };
151        let node_prefix = format!("{prefix}{connector}");
152
153        // Format the node type
154        let node_type = self.format_node_type(node.kind());
155
156        // Add position information if enabled
157        let position_info = if self.config.show_positions {
158            let start = node.start_position();
159            let end = node.end_position();
160            format!(
161                " @{}:{}-{}:{}",
162                start.row + 1,
163                start.column + 1,
164                end.row + 1,
165                end.column + 1
166            )
167        } else {
168            String::new()
169        };
170
171        // Add byte range information if enabled
172        let byte_range_info = if self.config.show_byte_ranges {
173            format!(" [{}..{}]", node.start_byte(), node.end_byte())
174        } else {
175            String::new()
176        };
177
178        // Add text content if enabled and node is small enough
179        let text_content = if self.config.show_text_content && node.child_count() == 0 {
180            let text = node
181                .utf8_text(source.as_bytes())
182                .unwrap_or("<invalid utf8>");
183            if text.len() <= self.config.max_text_length {
184                format!(" \"{}\"", text.replace('\n', "\\n").replace('\r', "\\r"))
185            } else {
186                format!(
187                    " \"{}...\"",
188                    &text[..self.config.max_text_length.min(text.len())]
189                )
190            }
191        } else {
192            String::new()
193        };
194
195        // Write the formatted node
196        output.push_str(&format!(
197            "{node_prefix}{node_type}{position_info}{byte_range_info}{text_content}\n"
198        ));
199
200        // Process children
201        let child_count = node.child_count();
202        for i in 0..child_count {
203            if let Some(child) = node.child(i) {
204                let child_prefix = format!("{}{}", prefix, if is_last { "    " } else { "│   " });
205                let is_last_child = i == child_count - 1;
206                self.visualize_node_recursive(
207                    &child,
208                    source,
209                    depth + 1,
210                    &child_prefix,
211                    is_last_child,
212                    output,
213                );
214            }
215        }
216    }
217
218    /// Format node type with colors if enabled
219    fn format_node_type(&self, node_type: &str) -> String {
220        if !self.config.use_colors {
221            return node_type.to_string();
222        }
223
224        if let Some(color_name) = self.config.node_color_names.get(node_type) {
225            match color_name.as_str() {
226                "blue" => node_type.blue().to_string(),
227                "green" => node_type.green().to_string(),
228                "cyan" => node_type.cyan().to_string(),
229                "red" => node_type.red().to_string(),
230                "yellow" => node_type.yellow().to_string(),
231                "magenta" => node_type.magenta().to_string(),
232                _ => node_type.normal().to_string(),
233            }
234        } else {
235            // Default color for unknown node types
236            node_type.normal().to_string()
237        }
238    }
239
240    /// Visualize in list format
241    fn visualize_list_format(&self, node: &Node, _source: &str) -> Result<String> {
242        let mut output = String::new();
243        let mut cursor = node.walk();
244        let mut depth = 0;
245
246        loop {
247            let current_node = cursor.node();
248
249            // Skip unnamed nodes if configured
250            if !self.config.named_nodes_only || current_node.is_named() {
251                let indent = self.config.indent_string.repeat(depth);
252                let node_type = self.format_node_type(current_node.kind());
253
254                let position_info = if self.config.show_positions {
255                    let start = current_node.start_position();
256                    format!(" @{}:{}", start.row + 1, start.column + 1)
257                } else {
258                    String::new()
259                };
260
261                output.push_str(&format!("{indent}{node_type}{position_info}\n"));
262            }
263
264            if cursor.goto_first_child() {
265                depth += 1;
266            } else if cursor.goto_next_sibling() {
267                // Stay at same depth
268            } else {
269                // Go back up until we find a sibling or reach the root
270                loop {
271                    if !cursor.goto_parent() {
272                        return Ok(output); // Reached root
273                    }
274                    depth -= 1;
275                    if cursor.goto_next_sibling() {
276                        break;
277                    }
278                }
279            }
280
281            if depth > self.config.max_depth {
282                break;
283            }
284        }
285
286        Ok(output)
287    }
288
289    /// Visualize in JSON format
290    fn visualize_json_format(&self, node: &Node, source: &str) -> Result<String> {
291        let json_node = self.node_to_json(node, source, 0)?;
292        Ok(serde_json::to_string_pretty(&json_node)?)
293    }
294
295    /// Convert a node to JSON representation
296    fn node_to_json(&self, node: &Node, source: &str, depth: usize) -> Result<serde_json::Value> {
297        if depth > self.config.max_depth {
298            return Ok(serde_json::json!({
299                "type": "...",
300                "truncated": true
301            }));
302        }
303
304        let mut json_node = serde_json::Map::new();
305        json_node.insert(
306            "type".to_string(),
307            serde_json::Value::String(node.kind().to_string()),
308        );
309        json_node.insert(
310            "named".to_string(),
311            serde_json::Value::Bool(node.is_named()),
312        );
313
314        if self.config.show_positions {
315            let start = node.start_position();
316            let end = node.end_position();
317            json_node.insert(
318                "start".to_string(),
319                serde_json::json!({
320                    "row": start.row,
321                    "column": start.column
322                }),
323            );
324            json_node.insert(
325                "end".to_string(),
326                serde_json::json!({
327                    "row": end.row,
328                    "column": end.column
329                }),
330            );
331        }
332
333        if self.config.show_byte_ranges {
334            json_node.insert(
335                "start_byte".to_string(),
336                serde_json::Value::Number(node.start_byte().into()),
337            );
338            json_node.insert(
339                "end_byte".to_string(),
340                serde_json::Value::Number(node.end_byte().into()),
341            );
342        }
343
344        if self.config.show_text_content && node.child_count() == 0 {
345            if let Ok(text) = node.utf8_text(source.as_bytes()) {
346                let display_text = if text.len() <= self.config.max_text_length {
347                    text.to_string()
348                } else {
349                    format!(
350                        "{}...",
351                        &text[..self.config.max_text_length.min(text.len())]
352                    )
353                };
354                json_node.insert("text".to_string(), serde_json::Value::String(display_text));
355            }
356        }
357
358        let mut children = Vec::new();
359        for i in 0..node.child_count() {
360            if let Some(child) = node.child(i) {
361                if !self.config.named_nodes_only || child.is_named() {
362                    children.push(self.node_to_json(&child, source, depth + 1)?);
363                }
364            }
365        }
366
367        if !children.is_empty() {
368            json_node.insert("children".to_string(), serde_json::Value::Array(children));
369        }
370
371        Ok(serde_json::Value::Object(json_node))
372    }
373
374    /// Visualize in S-expression format
375    fn visualize_sexp_format(&self, node: &Node, source: &str) -> Result<String> {
376        let mut output = String::new();
377        self.node_to_sexp(node, source, 0, &mut output)?;
378        Ok(output)
379    }
380
381    /// Convert node to S-expression format
382    fn node_to_sexp(
383        &self,
384        node: &Node,
385        source: &str,
386        depth: usize,
387        output: &mut String,
388    ) -> Result<()> {
389        if depth > self.config.max_depth {
390            output.push_str("...");
391            return Ok(());
392        }
393
394        if self.config.named_nodes_only && !node.is_named() {
395            return Ok(());
396        }
397
398        output.push('(');
399        output.push_str(node.kind());
400
401        // Add text for leaf nodes
402        if node.child_count() == 0 && self.config.show_text_content {
403            if let Ok(text) = node.utf8_text(source.as_bytes()) {
404                let display_text = if text.len() <= self.config.max_text_length {
405                    text.to_string()
406                } else {
407                    format!(
408                        "{}...",
409                        &text[..self.config.max_text_length.min(text.len())]
410                    )
411                };
412                output.push_str(&format!(" \"{}\"", display_text.replace('"', "\\\"")));
413            }
414        }
415
416        // Process children
417        for i in 0..node.child_count() {
418            if let Some(child) = node.child(i) {
419                if !self.config.named_nodes_only || child.is_named() {
420                    output.push(' ');
421                    self.node_to_sexp(&child, source, depth + 1, output)?;
422                }
423            }
424        }
425
426        output.push(')');
427        Ok(())
428    }
429
430    /// Visualize in compact format
431    fn visualize_compact_format(&self, node: &Node, source: &str) -> Result<String> {
432        let mut output = String::new();
433        self.node_to_compact(node, source, 0, &mut output)?;
434        Ok(output.trim().to_string())
435    }
436
437    /// Convert node to compact format
438    fn node_to_compact(
439        &self,
440        node: &Node,
441        source: &str,
442        depth: usize,
443        output: &mut String,
444    ) -> Result<()> {
445        if depth > self.config.max_depth {
446            output.push_str("...");
447            return Ok(());
448        }
449
450        if self.config.named_nodes_only && !node.is_named() {
451            return Ok(());
452        }
453
454        output.push_str(node.kind());
455
456        if node.child_count() == 0 && self.config.show_text_content {
457            if let Ok(text) = node.utf8_text(source.as_bytes()) {
458                let display_text = if text.len() <= self.config.max_text_length {
459                    text.to_string()
460                } else {
461                    format!(
462                        "{}...",
463                        &text[..self.config.max_text_length.min(text.len())]
464                    )
465                };
466                output.push_str(&format!(":{}", display_text.replace(' ', "_")));
467            }
468        }
469
470        if node.child_count() > 0 {
471            output.push('[');
472            for i in 0..node.child_count() {
473                if let Some(child) = node.child(i) {
474                    if !self.config.named_nodes_only || child.is_named() {
475                        if i > 0 {
476                            output.push(',');
477                        }
478                        self.node_to_compact(&child, source, depth + 1, output)?;
479                    }
480                }
481            }
482            output.push(']');
483        }
484
485        Ok(())
486    }
487
488    /// Get statistics about the AST
489    pub fn get_ast_statistics(&self, node: &Node) -> AstStatistics {
490        let mut stats = AstStatistics::default();
491        self.collect_statistics(node, &mut stats, 0);
492        stats
493    }
494
495    /// Recursively collect AST statistics
496    #[allow(clippy::only_used_in_recursion)] // Method is used recursively by design
497    fn collect_statistics(&self, node: &Node, stats: &mut AstStatistics, depth: usize) {
498        stats.total_nodes += 1;
499        stats.max_depth = stats.max_depth.max(depth);
500
501        if node.is_named() {
502            stats.named_nodes += 1;
503        } else {
504            stats.unnamed_nodes += 1;
505        }
506
507        *stats
508            .node_type_counts
509            .entry(node.kind().to_string())
510            .or_insert(0) += 1;
511
512        if node.child_count() == 0 {
513            stats.leaf_nodes += 1;
514        }
515
516        for i in 0..node.child_count() {
517            if let Some(child) = node.child(i) {
518                self.collect_statistics(&child, stats, depth + 1);
519            }
520        }
521    }
522
523    /// Compare two ASTs and highlight differences
524    pub fn compare_asts(&self, old_node: &Node, new_node: &Node, _source: &str) -> Result<String> {
525        let mut output = String::new();
526        output.push_str("=== AST Comparison ===\n\n");
527
528        let old_stats = self.get_ast_statistics(old_node);
529        let new_stats = self.get_ast_statistics(new_node);
530
531        output.push_str("## Statistics Comparison\n");
532        output.push_str(&format!(
533            "Total nodes: {} -> {} ({}{})\n",
534            old_stats.total_nodes,
535            new_stats.total_nodes,
536            if new_stats.total_nodes >= old_stats.total_nodes {
537                "+"
538            } else {
539                ""
540            },
541            new_stats.total_nodes as i32 - old_stats.total_nodes as i32
542        ));
543
544        output.push_str(&format!(
545            "Max depth: {} -> {} ({}{})\n",
546            old_stats.max_depth,
547            new_stats.max_depth,
548            if new_stats.max_depth >= old_stats.max_depth {
549                "+"
550            } else {
551                ""
552            },
553            new_stats.max_depth as i32 - old_stats.max_depth as i32
554        ));
555
556        output.push_str("\n## Structural Differences\n");
557        if old_node.kind() != new_node.kind() {
558            output.push_str(&format!(
559                "Root node type changed: {} -> {}\n",
560                old_node.kind(),
561                new_node.kind()
562            ));
563        }
564
565        Ok(output)
566    }
567}
568
569impl Default for AstVisualizer {
570    fn default() -> Self {
571        Self::new()
572    }
573}
574
575/// Statistics about an AST
576#[derive(Debug, Default)]
577pub struct AstStatistics {
578    pub total_nodes: usize,
579    pub named_nodes: usize,
580    pub unnamed_nodes: usize,
581    pub leaf_nodes: usize,
582    pub max_depth: usize,
583    pub node_type_counts: HashMap<String, usize>,
584}
585
586impl fmt::Display for AstStatistics {
587    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588        writeln!(f, "AST Statistics:")?;
589        writeln!(f, "  Total nodes: {}", self.total_nodes)?;
590        writeln!(f, "  Named nodes: {}", self.named_nodes)?;
591        writeln!(f, "  Unnamed nodes: {}", self.unnamed_nodes)?;
592        writeln!(f, "  Leaf nodes: {}", self.leaf_nodes)?;
593        writeln!(f, "  Maximum depth: {}", self.max_depth)?;
594        writeln!(f, "  Node types:")?;
595
596        let mut types: Vec<_> = self.node_type_counts.iter().collect();
597        types.sort_by(|a, b| b.1.cmp(a.1)); // Sort by count descending
598
599        for (node_type, count) in types.iter().take(10) {
600            // Show top 10
601            writeln!(f, "    {node_type}: {count}")?;
602        }
603
604        if types.len() > 10 {
605            writeln!(f, "    ... and {} more", types.len() - 10)?;
606        }
607
608        Ok(())
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use tree_sitter::Parser;
616
617    #[allow(dead_code)] // Used in tests
618    fn create_test_parser() -> Parser {
619        // For testing, we'll use a simple language grammar
620        // In real usage, this would use the appropriate language
621        Parser::new()
622    }
623
624    #[test]
625    fn test_ast_visualizer_creation() {
626        let visualizer = AstVisualizer::new();
627        assert_eq!(visualizer.config.max_depth, 20);
628        assert!(visualizer.config.show_positions);
629    }
630
631    #[test]
632    fn test_custom_config() {
633        let config = VisualizationConfig {
634            max_depth: 10,
635            show_positions: false,
636            ..Default::default()
637        };
638
639        let visualizer = AstVisualizer::with_config(config);
640        assert_eq!(visualizer.config.max_depth, 10);
641        assert!(!visualizer.config.show_positions);
642    }
643
644    #[test]
645    fn test_format_node_type_with_colors() {
646        let visualizer = AstVisualizer::new();
647        let formatted = visualizer.format_node_type("function_definition");
648        // Note: Testing colored output is difficult, so we just ensure it doesn't panic
649        assert!(!formatted.is_empty(), "Should not be empty");
650    }
651
652    #[test]
653    fn test_format_node_type_without_colors() {
654        let config = VisualizationConfig {
655            use_colors: false,
656            ..Default::default()
657        };
658        let visualizer = AstVisualizer::with_config(config);
659
660        let formatted = visualizer.format_node_type("function_definition");
661        assert_eq!(formatted, "function_definition");
662    }
663
664    #[test]
665    fn test_ast_statistics_display() {
666        let mut stats = AstStatistics {
667            total_nodes: 100,
668            named_nodes: 80,
669            unnamed_nodes: 20,
670            max_depth: 5,
671            ..Default::default()
672        };
673        stats.node_type_counts.insert("function".to_string(), 10);
674        stats.node_type_counts.insert("identifier".to_string(), 30);
675
676        let output = format!("{stats}");
677        assert!(output.contains("Total nodes: 100"));
678        assert!(output.contains("Named nodes: 80"));
679        assert!(output.contains("Maximum depth: 5"));
680    }
681}