1use anyhow::Result;
7use colored::Colorize;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use tree_sitter::{Node, Tree};
12
13#[derive(Debug, Clone)]
15pub struct AstVisualizer {
16 config: VisualizationConfig,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct VisualizationConfig {
22 pub max_depth: usize,
24 pub show_positions: bool,
26 pub show_byte_ranges: bool,
28 pub use_colors: bool,
30 pub show_text_content: bool,
32 pub max_text_length: usize,
34 pub named_nodes_only: bool,
36 pub node_color_names: HashMap<String, String>,
38 pub indent_string: String,
40}
41
42impl Default for VisualizationConfig {
43 fn default() -> Self {
44 let mut node_color_names = HashMap::new();
45
46 node_color_names.insert("function_definition".to_string(), "blue".to_string());
48 node_color_names.insert("class_definition".to_string(), "green".to_string());
49 node_color_names.insert("function_call".to_string(), "cyan".to_string());
50 node_color_names.insert("variable".to_string(), "yellow".to_string());
51 node_color_names.insert("string".to_string(), "red".to_string());
52 node_color_names.insert("number".to_string(), "magenta".to_string());
53 node_color_names.insert("comment".to_string(), "brightblack".to_string());
54 node_color_names.insert("keyword".to_string(), "brightblue".to_string());
55 node_color_names.insert("operator".to_string(), "brightyellow".to_string());
56 node_color_names.insert("identifier".to_string(), "white".to_string());
57
58 Self {
59 max_depth: 20,
60 show_positions: true,
61 show_byte_ranges: false,
62 use_colors: true,
63 show_text_content: true,
64 max_text_length: 50,
65 named_nodes_only: false,
66 node_color_names,
67 indent_string: " ".to_string(),
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum VisualizationFormat {
75 Tree,
77 List,
79 Json,
81 SExpression,
83 Compact,
85}
86
87impl AstVisualizer {
88 pub fn new() -> Self {
90 Self {
91 config: VisualizationConfig::default(),
92 }
93 }
94
95 pub fn with_config(config: VisualizationConfig) -> Self {
97 Self { config }
98 }
99
100 pub fn visualize_tree(&self, tree: &Tree, source: &str) -> Result<String> {
102 let root_node = tree.root_node();
103 self.visualize_node(&root_node, source, VisualizationFormat::Tree)
104 }
105
106 pub fn visualize_node(
108 &self,
109 node: &Node,
110 source: &str,
111 format: VisualizationFormat,
112 ) -> Result<String> {
113 match format {
114 VisualizationFormat::Tree => self.visualize_tree_format(node, source),
115 VisualizationFormat::List => self.visualize_list_format(node, source),
116 VisualizationFormat::Json => self.visualize_json_format(node, source),
117 VisualizationFormat::SExpression => self.visualize_sexp_format(node, source),
118 VisualizationFormat::Compact => self.visualize_compact_format(node, source),
119 }
120 }
121
122 fn visualize_tree_format(&self, node: &Node, source: &str) -> Result<String> {
124 let mut output = String::new();
125 self.visualize_node_recursive(node, source, 0, "", true, &mut output);
126 Ok(output)
127 }
128
129 fn visualize_node_recursive(
131 &self,
132 node: &Node,
133 source: &str,
134 depth: usize,
135 prefix: &str,
136 is_last: bool,
137 output: &mut String,
138 ) {
139 if depth > self.config.max_depth {
140 output.push_str(&format!("{}{}...\n", prefix, "─── ".dimmed()));
141 return;
142 }
143
144 if self.config.named_nodes_only && !node.is_named() {
146 return;
147 }
148
149 let connector = if is_last { "└── " } else { "├── " };
151 let node_prefix = format!("{prefix}{connector}");
152
153 let node_type = self.format_node_type(node.kind());
155
156 let position_info = if self.config.show_positions {
158 let start = node.start_position();
159 let end = node.end_position();
160 format!(
161 " @{}:{}-{}:{}",
162 start.row + 1,
163 start.column + 1,
164 end.row + 1,
165 end.column + 1
166 )
167 } else {
168 String::new()
169 };
170
171 let byte_range_info = if self.config.show_byte_ranges {
173 format!(" [{}..{}]", node.start_byte(), node.end_byte())
174 } else {
175 String::new()
176 };
177
178 let text_content = if self.config.show_text_content && node.child_count() == 0 {
180 let text = node
181 .utf8_text(source.as_bytes())
182 .unwrap_or("<invalid utf8>");
183 if text.len() <= self.config.max_text_length {
184 format!(" \"{}\"", text.replace('\n', "\\n").replace('\r', "\\r"))
185 } else {
186 format!(
187 " \"{}...\"",
188 &text[..self.config.max_text_length.min(text.len())]
189 )
190 }
191 } else {
192 String::new()
193 };
194
195 output.push_str(&format!(
197 "{node_prefix}{node_type}{position_info}{byte_range_info}{text_content}\n"
198 ));
199
200 let child_count = node.child_count();
202 for i in 0..child_count {
203 if let Some(child) = node.child(i) {
204 let child_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " });
205 let is_last_child = i == child_count - 1;
206 self.visualize_node_recursive(
207 &child,
208 source,
209 depth + 1,
210 &child_prefix,
211 is_last_child,
212 output,
213 );
214 }
215 }
216 }
217
218 fn format_node_type(&self, node_type: &str) -> String {
220 if !self.config.use_colors {
221 return node_type.to_string();
222 }
223
224 if let Some(color_name) = self.config.node_color_names.get(node_type) {
225 match color_name.as_str() {
226 "blue" => node_type.blue().to_string(),
227 "green" => node_type.green().to_string(),
228 "cyan" => node_type.cyan().to_string(),
229 "red" => node_type.red().to_string(),
230 "yellow" => node_type.yellow().to_string(),
231 "magenta" => node_type.magenta().to_string(),
232 _ => node_type.normal().to_string(),
233 }
234 } else {
235 node_type.normal().to_string()
237 }
238 }
239
240 fn visualize_list_format(&self, node: &Node, _source: &str) -> Result<String> {
242 let mut output = String::new();
243 let mut cursor = node.walk();
244 let mut depth = 0;
245
246 loop {
247 let current_node = cursor.node();
248
249 if !self.config.named_nodes_only || current_node.is_named() {
251 let indent = self.config.indent_string.repeat(depth);
252 let node_type = self.format_node_type(current_node.kind());
253
254 let position_info = if self.config.show_positions {
255 let start = current_node.start_position();
256 format!(" @{}:{}", start.row + 1, start.column + 1)
257 } else {
258 String::new()
259 };
260
261 output.push_str(&format!("{indent}{node_type}{position_info}\n"));
262 }
263
264 if cursor.goto_first_child() {
265 depth += 1;
266 } else if cursor.goto_next_sibling() {
267 } else {
269 loop {
271 if !cursor.goto_parent() {
272 return Ok(output); }
274 depth -= 1;
275 if cursor.goto_next_sibling() {
276 break;
277 }
278 }
279 }
280
281 if depth > self.config.max_depth {
282 break;
283 }
284 }
285
286 Ok(output)
287 }
288
289 fn visualize_json_format(&self, node: &Node, source: &str) -> Result<String> {
291 let json_node = self.node_to_json(node, source, 0)?;
292 Ok(serde_json::to_string_pretty(&json_node)?)
293 }
294
295 fn node_to_json(&self, node: &Node, source: &str, depth: usize) -> Result<serde_json::Value> {
297 if depth > self.config.max_depth {
298 return Ok(serde_json::json!({
299 "type": "...",
300 "truncated": true
301 }));
302 }
303
304 let mut json_node = serde_json::Map::new();
305 json_node.insert(
306 "type".to_string(),
307 serde_json::Value::String(node.kind().to_string()),
308 );
309 json_node.insert(
310 "named".to_string(),
311 serde_json::Value::Bool(node.is_named()),
312 );
313
314 if self.config.show_positions {
315 let start = node.start_position();
316 let end = node.end_position();
317 json_node.insert(
318 "start".to_string(),
319 serde_json::json!({
320 "row": start.row,
321 "column": start.column
322 }),
323 );
324 json_node.insert(
325 "end".to_string(),
326 serde_json::json!({
327 "row": end.row,
328 "column": end.column
329 }),
330 );
331 }
332
333 if self.config.show_byte_ranges {
334 json_node.insert(
335 "start_byte".to_string(),
336 serde_json::Value::Number(node.start_byte().into()),
337 );
338 json_node.insert(
339 "end_byte".to_string(),
340 serde_json::Value::Number(node.end_byte().into()),
341 );
342 }
343
344 if self.config.show_text_content && node.child_count() == 0 {
345 if let Ok(text) = node.utf8_text(source.as_bytes()) {
346 let display_text = if text.len() <= self.config.max_text_length {
347 text.to_string()
348 } else {
349 format!(
350 "{}...",
351 &text[..self.config.max_text_length.min(text.len())]
352 )
353 };
354 json_node.insert("text".to_string(), serde_json::Value::String(display_text));
355 }
356 }
357
358 let mut children = Vec::new();
359 for i in 0..node.child_count() {
360 if let Some(child) = node.child(i) {
361 if !self.config.named_nodes_only || child.is_named() {
362 children.push(self.node_to_json(&child, source, depth + 1)?);
363 }
364 }
365 }
366
367 if !children.is_empty() {
368 json_node.insert("children".to_string(), serde_json::Value::Array(children));
369 }
370
371 Ok(serde_json::Value::Object(json_node))
372 }
373
374 fn visualize_sexp_format(&self, node: &Node, source: &str) -> Result<String> {
376 let mut output = String::new();
377 self.node_to_sexp(node, source, 0, &mut output)?;
378 Ok(output)
379 }
380
381 fn node_to_sexp(
383 &self,
384 node: &Node,
385 source: &str,
386 depth: usize,
387 output: &mut String,
388 ) -> Result<()> {
389 if depth > self.config.max_depth {
390 output.push_str("...");
391 return Ok(());
392 }
393
394 if self.config.named_nodes_only && !node.is_named() {
395 return Ok(());
396 }
397
398 output.push('(');
399 output.push_str(node.kind());
400
401 if node.child_count() == 0 && self.config.show_text_content {
403 if let Ok(text) = node.utf8_text(source.as_bytes()) {
404 let display_text = if text.len() <= self.config.max_text_length {
405 text.to_string()
406 } else {
407 format!(
408 "{}...",
409 &text[..self.config.max_text_length.min(text.len())]
410 )
411 };
412 output.push_str(&format!(" \"{}\"", display_text.replace('"', "\\\"")));
413 }
414 }
415
416 for i in 0..node.child_count() {
418 if let Some(child) = node.child(i) {
419 if !self.config.named_nodes_only || child.is_named() {
420 output.push(' ');
421 self.node_to_sexp(&child, source, depth + 1, output)?;
422 }
423 }
424 }
425
426 output.push(')');
427 Ok(())
428 }
429
430 fn visualize_compact_format(&self, node: &Node, source: &str) -> Result<String> {
432 let mut output = String::new();
433 self.node_to_compact(node, source, 0, &mut output)?;
434 Ok(output.trim().to_string())
435 }
436
437 fn node_to_compact(
439 &self,
440 node: &Node,
441 source: &str,
442 depth: usize,
443 output: &mut String,
444 ) -> Result<()> {
445 if depth > self.config.max_depth {
446 output.push_str("...");
447 return Ok(());
448 }
449
450 if self.config.named_nodes_only && !node.is_named() {
451 return Ok(());
452 }
453
454 output.push_str(node.kind());
455
456 if node.child_count() == 0 && self.config.show_text_content {
457 if let Ok(text) = node.utf8_text(source.as_bytes()) {
458 let display_text = if text.len() <= self.config.max_text_length {
459 text.to_string()
460 } else {
461 format!(
462 "{}...",
463 &text[..self.config.max_text_length.min(text.len())]
464 )
465 };
466 output.push_str(&format!(":{}", display_text.replace(' ', "_")));
467 }
468 }
469
470 if node.child_count() > 0 {
471 output.push('[');
472 for i in 0..node.child_count() {
473 if let Some(child) = node.child(i) {
474 if !self.config.named_nodes_only || child.is_named() {
475 if i > 0 {
476 output.push(',');
477 }
478 self.node_to_compact(&child, source, depth + 1, output)?;
479 }
480 }
481 }
482 output.push(']');
483 }
484
485 Ok(())
486 }
487
488 pub fn get_ast_statistics(&self, node: &Node) -> AstStatistics {
490 let mut stats = AstStatistics::default();
491 self.collect_statistics(node, &mut stats, 0);
492 stats
493 }
494
495 #[allow(clippy::only_used_in_recursion)] fn collect_statistics(&self, node: &Node, stats: &mut AstStatistics, depth: usize) {
498 stats.total_nodes += 1;
499 stats.max_depth = stats.max_depth.max(depth);
500
501 if node.is_named() {
502 stats.named_nodes += 1;
503 } else {
504 stats.unnamed_nodes += 1;
505 }
506
507 *stats
508 .node_type_counts
509 .entry(node.kind().to_string())
510 .or_insert(0) += 1;
511
512 if node.child_count() == 0 {
513 stats.leaf_nodes += 1;
514 }
515
516 for i in 0..node.child_count() {
517 if let Some(child) = node.child(i) {
518 self.collect_statistics(&child, stats, depth + 1);
519 }
520 }
521 }
522
523 pub fn compare_asts(&self, old_node: &Node, new_node: &Node, _source: &str) -> Result<String> {
525 let mut output = String::new();
526 output.push_str("=== AST Comparison ===\n\n");
527
528 let old_stats = self.get_ast_statistics(old_node);
529 let new_stats = self.get_ast_statistics(new_node);
530
531 output.push_str("## Statistics Comparison\n");
532 output.push_str(&format!(
533 "Total nodes: {} -> {} ({}{})\n",
534 old_stats.total_nodes,
535 new_stats.total_nodes,
536 if new_stats.total_nodes >= old_stats.total_nodes {
537 "+"
538 } else {
539 ""
540 },
541 new_stats.total_nodes as i32 - old_stats.total_nodes as i32
542 ));
543
544 output.push_str(&format!(
545 "Max depth: {} -> {} ({}{})\n",
546 old_stats.max_depth,
547 new_stats.max_depth,
548 if new_stats.max_depth >= old_stats.max_depth {
549 "+"
550 } else {
551 ""
552 },
553 new_stats.max_depth as i32 - old_stats.max_depth as i32
554 ));
555
556 output.push_str("\n## Structural Differences\n");
557 if old_node.kind() != new_node.kind() {
558 output.push_str(&format!(
559 "Root node type changed: {} -> {}\n",
560 old_node.kind(),
561 new_node.kind()
562 ));
563 }
564
565 Ok(output)
566 }
567}
568
569impl Default for AstVisualizer {
570 fn default() -> Self {
571 Self::new()
572 }
573}
574
575#[derive(Debug, Default)]
577pub struct AstStatistics {
578 pub total_nodes: usize,
579 pub named_nodes: usize,
580 pub unnamed_nodes: usize,
581 pub leaf_nodes: usize,
582 pub max_depth: usize,
583 pub node_type_counts: HashMap<String, usize>,
584}
585
586impl fmt::Display for AstStatistics {
587 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588 writeln!(f, "AST Statistics:")?;
589 writeln!(f, " Total nodes: {}", self.total_nodes)?;
590 writeln!(f, " Named nodes: {}", self.named_nodes)?;
591 writeln!(f, " Unnamed nodes: {}", self.unnamed_nodes)?;
592 writeln!(f, " Leaf nodes: {}", self.leaf_nodes)?;
593 writeln!(f, " Maximum depth: {}", self.max_depth)?;
594 writeln!(f, " Node types:")?;
595
596 let mut types: Vec<_> = self.node_type_counts.iter().collect();
597 types.sort_by(|a, b| b.1.cmp(a.1)); for (node_type, count) in types.iter().take(10) {
600 writeln!(f, " {node_type}: {count}")?;
602 }
603
604 if types.len() > 10 {
605 writeln!(f, " ... and {} more", types.len() - 10)?;
606 }
607
608 Ok(())
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use tree_sitter::Parser;
616
617 #[allow(dead_code)] fn create_test_parser() -> Parser {
619 Parser::new()
622 }
623
624 #[test]
625 fn test_ast_visualizer_creation() {
626 let visualizer = AstVisualizer::new();
627 assert_eq!(visualizer.config.max_depth, 20);
628 assert!(visualizer.config.show_positions);
629 }
630
631 #[test]
632 fn test_custom_config() {
633 let config = VisualizationConfig {
634 max_depth: 10,
635 show_positions: false,
636 ..Default::default()
637 };
638
639 let visualizer = AstVisualizer::with_config(config);
640 assert_eq!(visualizer.config.max_depth, 10);
641 assert!(!visualizer.config.show_positions);
642 }
643
644 #[test]
645 fn test_format_node_type_with_colors() {
646 let visualizer = AstVisualizer::new();
647 let formatted = visualizer.format_node_type("function_definition");
648 assert!(!formatted.is_empty(), "Should not be empty");
650 }
651
652 #[test]
653 fn test_format_node_type_without_colors() {
654 let config = VisualizationConfig {
655 use_colors: false,
656 ..Default::default()
657 };
658 let visualizer = AstVisualizer::with_config(config);
659
660 let formatted = visualizer.format_node_type("function_definition");
661 assert_eq!(formatted, "function_definition");
662 }
663
664 #[test]
665 fn test_ast_statistics_display() {
666 let mut stats = AstStatistics {
667 total_nodes: 100,
668 named_nodes: 80,
669 unnamed_nodes: 20,
670 max_depth: 5,
671 ..Default::default()
672 };
673 stats.node_type_counts.insert("function".to_string(), 10);
674 stats.node_type_counts.insert("identifier".to_string(), 30);
675
676 let output = format!("{stats}");
677 assert!(output.contains("Total nodes: 100"));
678 assert!(output.contains("Named nodes: 80"));
679 assert!(output.contains("Maximum depth: 5"));
680 }
681}