1use anyhow::Result;
7use colored::Colorize;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use tree_sitter::{Node, Tree};
12
13#[derive(Debug, Clone)]
15pub struct AstVisualizer {
16 config: VisualizationConfig,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VisualizationConfig {
22 pub max_depth: usize,
24 pub show_positions: bool,
26 pub show_byte_ranges: bool,
28 pub use_colors: bool,
30 pub show_text_content: bool,
32 pub max_text_length: usize,
34 pub named_nodes_only: bool,
36 pub node_color_names: HashMap<String, String>,
38 pub indent_string: String,
40}
41
42impl Default for VisualizationConfig {
43 fn default() -> Self {
44 let mut node_color_names = HashMap::new();
45
46 node_color_names.insert("function_definition".to_string(), "blue".to_string());
48 node_color_names.insert("class_definition".to_string(), "green".to_string());
49 node_color_names.insert("function_call".to_string(), "cyan".to_string());
50 node_color_names.insert("variable".to_string(), "yellow".to_string());
51 node_color_names.insert("string".to_string(), "red".to_string());
52 node_color_names.insert("number".to_string(), "magenta".to_string());
53 node_color_names.insert("comment".to_string(), "brightblack".to_string());
54 node_color_names.insert("keyword".to_string(), "brightblue".to_string());
55 node_color_names.insert("operator".to_string(), "brightyellow".to_string());
56 node_color_names.insert("identifier".to_string(), "white".to_string());
57
58 Self {
59 max_depth: 20,
60 show_positions: true,
61 show_byte_ranges: false,
62 use_colors: true,
63 show_text_content: true,
64 max_text_length: 50,
65 named_nodes_only: false,
66 node_color_names,
67 indent_string: " ".to_string(),
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum VisualizationFormat {
75 Tree,
77 List,
79 Json,
81 SExpression,
83 Compact,
85}
86
87impl AstVisualizer {
88 pub fn new() -> Self {
90 Self {
91 config: VisualizationConfig::default(),
92 }
93 }
94
95 pub fn with_config(config: VisualizationConfig) -> Self {
97 Self { config }
98 }
99
100 pub fn visualize_tree(&self, tree: &Tree, source: &str) -> Result<String> {
102 let root_node = tree.root_node();
103 self.visualize_node(&root_node, source, VisualizationFormat::Tree)
104 }
105
106 pub fn visualize_node(
108 &self,
109 node: &Node,
110 source: &str,
111 format: VisualizationFormat,
112 ) -> Result<String> {
113 match format {
114 VisualizationFormat::Tree => self.visualize_tree_format(node, source),
115 VisualizationFormat::List => self.visualize_list_format(node, source),
116 VisualizationFormat::Json => self.visualize_json_format(node, source),
117 VisualizationFormat::SExpression => self.visualize_sexp_format(node, source),
118 VisualizationFormat::Compact => self.visualize_compact_format(node, source),
119 }
120 }
121
122 fn visualize_tree_format(&self, node: &Node, source: &str) -> Result<String> {
124 let mut output = String::new();
125 self.visualize_node_recursive(node, source, 0, "", true, &mut output);
126 Ok(output)
127 }
128
129 fn visualize_node_recursive(
131 &self,
132 node: &Node,
133 source: &str,
134 depth: usize,
135 prefix: &str,
136 is_last: bool,
137 output: &mut String,
138 ) {
139 if depth > self.config.max_depth {
140 output.push_str(&format!("{}{}...\n", prefix, "─── ".dimmed()));
141 return;
142 }
143
144 if self.config.named_nodes_only && !node.is_named() {
146 return;
147 }
148
149 let connector = if is_last { "└── " } else { "├── " };
151 let node_prefix = format!("{}{}", prefix, connector);
152
153 let node_type = self.format_node_type(node.kind());
155
156 let position_info = if self.config.show_positions {
158 let start = node.start_position();
159 let end = node.end_position();
160 format!(
161 " @{}:{}-{}:{}",
162 start.row + 1,
163 start.column + 1,
164 end.row + 1,
165 end.column + 1
166 )
167 } else {
168 String::new()
169 };
170
171 let byte_range_info = if self.config.show_byte_ranges {
173 format!(" [{}..{}]", node.start_byte(), node.end_byte())
174 } else {
175 String::new()
176 };
177
178 let text_content = if self.config.show_text_content && node.child_count() == 0 {
180 let text = node
181 .utf8_text(source.as_bytes())
182 .unwrap_or("<invalid utf8>");
183 if text.len() <= self.config.max_text_length {
184 format!(" \"{}\"", text.replace('\n', "\\n").replace('\r', "\\r"))
185 } else {
186 format!(
187 " \"{}...\"",
188 &text[..self.config.max_text_length.min(text.len())]
189 )
190 }
191 } else {
192 String::new()
193 };
194
195 output.push_str(&format!(
197 "{}{}{}{}{}\n",
198 node_prefix, node_type, position_info, byte_range_info, text_content
199 ));
200
201 let child_count = node.child_count();
203 for i in 0..child_count {
204 if let Some(child) = node.child(i) {
205 let child_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " });
206 let is_last_child = i == child_count - 1;
207 self.visualize_node_recursive(
208 &child,
209 source,
210 depth + 1,
211 &child_prefix,
212 is_last_child,
213 output,
214 );
215 }
216 }
217 }
218
219 fn format_node_type(&self, node_type: &str) -> String {
221 if !self.config.use_colors {
222 return node_type.to_string();
223 }
224
225 if let Some(color_name) = self.config.node_color_names.get(node_type) {
226 match color_name.as_str() {
227 "blue" => node_type.blue().to_string(),
228 "green" => node_type.green().to_string(),
229 "cyan" => node_type.cyan().to_string(),
230 "red" => node_type.red().to_string(),
231 "yellow" => node_type.yellow().to_string(),
232 "magenta" => node_type.magenta().to_string(),
233 _ => node_type.normal().to_string(),
234 }
235 } else {
236 node_type.normal().to_string()
238 }
239 }
240
241 fn visualize_list_format(&self, node: &Node, _source: &str) -> Result<String> {
243 let mut output = String::new();
244 let mut cursor = node.walk();
245 let mut depth = 0;
246
247 loop {
248 let current_node = cursor.node();
249
250 if !self.config.named_nodes_only || current_node.is_named() {
252 let indent = self.config.indent_string.repeat(depth);
253 let node_type = self.format_node_type(current_node.kind());
254
255 let position_info = if self.config.show_positions {
256 let start = current_node.start_position();
257 format!(" @{}:{}", start.row + 1, start.column + 1)
258 } else {
259 String::new()
260 };
261
262 output.push_str(&format!("{}{}{}\n", indent, node_type, position_info));
263 }
264
265 if cursor.goto_first_child() {
266 depth += 1;
267 } else if cursor.goto_next_sibling() {
268 } else {
270 loop {
272 if !cursor.goto_parent() {
273 return Ok(output); }
275 depth -= 1;
276 if cursor.goto_next_sibling() {
277 break;
278 }
279 }
280 }
281
282 if depth > self.config.max_depth {
283 break;
284 }
285 }
286
287 Ok(output)
288 }
289
290 fn visualize_json_format(&self, node: &Node, source: &str) -> Result<String> {
292 let json_node = self.node_to_json(node, source, 0)?;
293 Ok(serde_json::to_string_pretty(&json_node)?)
294 }
295
296 fn node_to_json(&self, node: &Node, source: &str, depth: usize) -> Result<serde_json::Value> {
298 if depth > self.config.max_depth {
299 return Ok(serde_json::json!({
300 "type": "...",
301 "truncated": true
302 }));
303 }
304
305 let mut json_node = serde_json::Map::new();
306 json_node.insert(
307 "type".to_string(),
308 serde_json::Value::String(node.kind().to_string()),
309 );
310 json_node.insert(
311 "named".to_string(),
312 serde_json::Value::Bool(node.is_named()),
313 );
314
315 if self.config.show_positions {
316 let start = node.start_position();
317 let end = node.end_position();
318 json_node.insert(
319 "start".to_string(),
320 serde_json::json!({
321 "row": start.row,
322 "column": start.column
323 }),
324 );
325 json_node.insert(
326 "end".to_string(),
327 serde_json::json!({
328 "row": end.row,
329 "column": end.column
330 }),
331 );
332 }
333
334 if self.config.show_byte_ranges {
335 json_node.insert(
336 "start_byte".to_string(),
337 serde_json::Value::Number(node.start_byte().into()),
338 );
339 json_node.insert(
340 "end_byte".to_string(),
341 serde_json::Value::Number(node.end_byte().into()),
342 );
343 }
344
345 if self.config.show_text_content && node.child_count() == 0 {
346 if let Ok(text) = node.utf8_text(source.as_bytes()) {
347 let display_text = if text.len() <= self.config.max_text_length {
348 text.to_string()
349 } else {
350 format!(
351 "{}...",
352 &text[..self.config.max_text_length.min(text.len())]
353 )
354 };
355 json_node.insert("text".to_string(), serde_json::Value::String(display_text));
356 }
357 }
358
359 let mut children = Vec::new();
360 for i in 0..node.child_count() {
361 if let Some(child) = node.child(i) {
362 if !self.config.named_nodes_only || child.is_named() {
363 children.push(self.node_to_json(&child, source, depth + 1)?);
364 }
365 }
366 }
367
368 if !children.is_empty() {
369 json_node.insert("children".to_string(), serde_json::Value::Array(children));
370 }
371
372 Ok(serde_json::Value::Object(json_node))
373 }
374
375 fn visualize_sexp_format(&self, node: &Node, source: &str) -> Result<String> {
377 let mut output = String::new();
378 self.node_to_sexp(node, source, 0, &mut output)?;
379 Ok(output)
380 }
381
382 fn node_to_sexp(
384 &self,
385 node: &Node,
386 source: &str,
387 depth: usize,
388 output: &mut String,
389 ) -> Result<()> {
390 if depth > self.config.max_depth {
391 output.push_str("...");
392 return Ok(());
393 }
394
395 if self.config.named_nodes_only && !node.is_named() {
396 return Ok(());
397 }
398
399 output.push('(');
400 output.push_str(node.kind());
401
402 if node.child_count() == 0 && self.config.show_text_content {
404 if let Ok(text) = node.utf8_text(source.as_bytes()) {
405 let display_text = if text.len() <= self.config.max_text_length {
406 text.to_string()
407 } else {
408 format!(
409 "{}...",
410 &text[..self.config.max_text_length.min(text.len())]
411 )
412 };
413 output.push_str(&format!(" \"{}\"", display_text.replace('"', "\\\"")));
414 }
415 }
416
417 for i in 0..node.child_count() {
419 if let Some(child) = node.child(i) {
420 if !self.config.named_nodes_only || child.is_named() {
421 output.push(' ');
422 self.node_to_sexp(&child, source, depth + 1, output)?;
423 }
424 }
425 }
426
427 output.push(')');
428 Ok(())
429 }
430
431 fn visualize_compact_format(&self, node: &Node, source: &str) -> Result<String> {
433 let mut output = String::new();
434 self.node_to_compact(node, source, 0, &mut output)?;
435 Ok(output.trim().to_string())
436 }
437
438 fn node_to_compact(
440 &self,
441 node: &Node,
442 source: &str,
443 depth: usize,
444 output: &mut String,
445 ) -> Result<()> {
446 if depth > self.config.max_depth {
447 output.push_str("...");
448 return Ok(());
449 }
450
451 if self.config.named_nodes_only && !node.is_named() {
452 return Ok(());
453 }
454
455 output.push_str(node.kind());
456
457 if node.child_count() == 0 && self.config.show_text_content {
458 if let Ok(text) = node.utf8_text(source.as_bytes()) {
459 let display_text = if text.len() <= self.config.max_text_length {
460 text.to_string()
461 } else {
462 format!(
463 "{}...",
464 &text[..self.config.max_text_length.min(text.len())]
465 )
466 };
467 output.push_str(&format!(":{}", display_text.replace(' ', "_")));
468 }
469 }
470
471 if node.child_count() > 0 {
472 output.push('[');
473 for i in 0..node.child_count() {
474 if let Some(child) = node.child(i) {
475 if !self.config.named_nodes_only || child.is_named() {
476 if i > 0 {
477 output.push(',');
478 }
479 self.node_to_compact(&child, source, depth + 1, output)?;
480 }
481 }
482 }
483 output.push(']');
484 }
485
486 Ok(())
487 }
488
489 pub fn get_ast_statistics(&self, node: &Node) -> AstStatistics {
491 let mut stats = AstStatistics::default();
492 self.collect_statistics(node, &mut stats, 0);
493 stats
494 }
495
496 #[allow(clippy::only_used_in_recursion)] fn collect_statistics(&self, node: &Node, stats: &mut AstStatistics, depth: usize) {
499 stats.total_nodes += 1;
500 stats.max_depth = stats.max_depth.max(depth);
501
502 if node.is_named() {
503 stats.named_nodes += 1;
504 } else {
505 stats.unnamed_nodes += 1;
506 }
507
508 *stats
509 .node_type_counts
510 .entry(node.kind().to_string())
511 .or_insert(0) += 1;
512
513 if node.child_count() == 0 {
514 stats.leaf_nodes += 1;
515 }
516
517 for i in 0..node.child_count() {
518 if let Some(child) = node.child(i) {
519 self.collect_statistics(&child, stats, depth + 1);
520 }
521 }
522 }
523
524 pub fn compare_asts(&self, old_node: &Node, new_node: &Node, _source: &str) -> Result<String> {
526 let mut output = String::new();
527 output.push_str("=== AST Comparison ===\n\n");
528
529 let old_stats = self.get_ast_statistics(old_node);
530 let new_stats = self.get_ast_statistics(new_node);
531
532 output.push_str("## Statistics Comparison\n");
533 output.push_str(&format!(
534 "Total nodes: {} -> {} ({}{})\n",
535 old_stats.total_nodes,
536 new_stats.total_nodes,
537 if new_stats.total_nodes >= old_stats.total_nodes {
538 "+"
539 } else {
540 ""
541 },
542 new_stats.total_nodes as i32 - old_stats.total_nodes as i32
543 ));
544
545 output.push_str(&format!(
546 "Max depth: {} -> {} ({}{})\n",
547 old_stats.max_depth,
548 new_stats.max_depth,
549 if new_stats.max_depth >= old_stats.max_depth {
550 "+"
551 } else {
552 ""
553 },
554 new_stats.max_depth as i32 - old_stats.max_depth as i32
555 ));
556
557 output.push_str("\n## Structural Differences\n");
558 if old_node.kind() != new_node.kind() {
559 output.push_str(&format!(
560 "Root node type changed: {} -> {}\n",
561 old_node.kind(),
562 new_node.kind()
563 ));
564 }
565
566 Ok(output)
567 }
568}
569
570impl Default for AstVisualizer {
571 fn default() -> Self {
572 Self::new()
573 }
574}
575
576#[derive(Debug, Default)]
578pub struct AstStatistics {
579 pub total_nodes: usize,
580 pub named_nodes: usize,
581 pub unnamed_nodes: usize,
582 pub leaf_nodes: usize,
583 pub max_depth: usize,
584 pub node_type_counts: HashMap<String, usize>,
585}
586
587impl fmt::Display for AstStatistics {
588 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589 writeln!(f, "AST Statistics:")?;
590 writeln!(f, " Total nodes: {}", self.total_nodes)?;
591 writeln!(f, " Named nodes: {}", self.named_nodes)?;
592 writeln!(f, " Unnamed nodes: {}", self.unnamed_nodes)?;
593 writeln!(f, " Leaf nodes: {}", self.leaf_nodes)?;
594 writeln!(f, " Maximum depth: {}", self.max_depth)?;
595 writeln!(f, " Node types:")?;
596
597 let mut types: Vec<_> = self.node_type_counts.iter().collect();
598 types.sort_by(|a, b| b.1.cmp(a.1)); for (node_type, count) in types.iter().take(10) {
601 writeln!(f, " {}: {}", node_type, count)?;
603 }
604
605 if types.len() > 10 {
606 writeln!(f, " ... and {} more", types.len() - 10)?;
607 }
608
609 Ok(())
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616 use tree_sitter::Parser;
617
618 #[allow(dead_code)] fn create_test_parser() -> Parser {
620 Parser::new()
623 }
624
625 #[test]
626 fn test_ast_visualizer_creation() {
627 let visualizer = AstVisualizer::new();
628 assert_eq!(visualizer.config.max_depth, 20);
629 assert!(visualizer.config.show_positions);
630 }
631
632 #[test]
633 fn test_custom_config() {
634 let config = VisualizationConfig {
635 max_depth: 10,
636 show_positions: false,
637 ..Default::default()
638 };
639
640 let visualizer = AstVisualizer::with_config(config);
641 assert_eq!(visualizer.config.max_depth, 10);
642 assert!(!visualizer.config.show_positions);
643 }
644
645 #[test]
646 fn test_format_node_type_with_colors() {
647 let visualizer = AstVisualizer::new();
648 let formatted = visualizer.format_node_type("function_definition");
649 assert!(!formatted.is_empty());
651 }
652
653 #[test]
654 fn test_format_node_type_without_colors() {
655 let config = VisualizationConfig {
656 use_colors: false,
657 ..Default::default()
658 };
659 let visualizer = AstVisualizer::with_config(config);
660
661 let formatted = visualizer.format_node_type("function_definition");
662 assert_eq!(formatted, "function_definition");
663 }
664
665 #[test]
666 fn test_ast_statistics_display() {
667 let mut stats = AstStatistics {
668 total_nodes: 100,
669 named_nodes: 80,
670 unnamed_nodes: 20,
671 max_depth: 5,
672 ..Default::default()
673 };
674 stats.node_type_counts.insert("function".to_string(), 10);
675 stats.node_type_counts.insert("identifier".to_string(), 30);
676
677 let output = format!("{}", stats);
678 assert!(output.contains("Total nodes: 100"));
679 assert!(output.contains("Named nodes: 80"));
680 assert!(output.contains("Maximum depth: 5"));
681 }
682}