1use std::collections::HashSet;
37
38use ratatui::{
39 buffer::Buffer,
40 layout::Rect,
41 style::{Color, Modifier, Style},
42 text::{Line, Span},
43 widgets::{Paragraph, Widget, Wrap},
44};
45
46#[derive(Debug, Clone)]
48pub struct TreeNode<T> {
49 pub id: String,
51 pub data: T,
53 pub children: Vec<TreeNode<T>>,
55}
56
57impl<T> TreeNode<T> {
58 pub fn new(id: impl Into<String>, data: T) -> Self {
60 Self {
61 id: id.into(),
62 data,
63 children: Vec::new(),
64 }
65 }
66
67 pub fn with_children(mut self, children: Vec<TreeNode<T>>) -> Self {
69 self.children = children;
70 self
71 }
72
73 pub fn add_child(&mut self, child: TreeNode<T>) {
75 self.children.push(child);
76 }
77
78 pub fn has_children(&self) -> bool {
80 !self.children.is_empty()
81 }
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct TreeViewState {
87 pub collapsed: HashSet<String>,
89 pub selected_index: usize,
91 pub scroll: u16,
93}
94
95impl TreeViewState {
96 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub fn toggle_collapsed(&mut self, id: &str) {
103 if self.collapsed.contains(id) {
104 self.collapsed.remove(id);
105 } else {
106 self.collapsed.insert(id.to_string());
107 }
108 }
109
110 pub fn is_collapsed(&self, id: &str) -> bool {
112 self.collapsed.contains(id)
113 }
114
115 pub fn collapse(&mut self, id: &str) {
117 self.collapsed.insert(id.to_string());
118 }
119
120 pub fn expand(&mut self, id: &str) {
122 self.collapsed.remove(id);
123 }
124
125 pub fn select_prev(&mut self) {
127 self.selected_index = self.selected_index.saturating_sub(1);
128 }
129
130 pub fn select_next(&mut self, total_visible: usize) {
132 if self.selected_index + 1 < total_visible {
133 self.selected_index += 1;
134 }
135 }
136
137 pub fn ensure_visible(&mut self, viewport_height: usize) {
139 if self.selected_index < self.scroll as usize {
140 self.scroll = self.selected_index as u16;
141 } else if self.selected_index >= self.scroll as usize + viewport_height {
142 self.scroll = (self.selected_index - viewport_height + 1) as u16;
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct TreeStyle {
150 pub selected_style: Style,
152 pub normal_style: Style,
154 pub connector_style: Style,
156 pub icon_style: Style,
158 pub collapsed_icon: &'static str,
160 pub expanded_icon: &'static str,
162 pub connector_branch: &'static str,
164 pub connector_last: &'static str,
166 pub connector_vertical: &'static str,
168 pub connector_space: &'static str,
170 pub cursor_selected: &'static str,
172 pub cursor_normal: &'static str,
174}
175
176impl Default for TreeStyle {
177 fn default() -> Self {
178 Self {
179 selected_style: Style::default()
180 .fg(Color::Yellow)
181 .add_modifier(Modifier::BOLD),
182 normal_style: Style::default().fg(Color::White),
183 connector_style: Style::default().fg(Color::DarkGray),
184 icon_style: Style::default().fg(Color::Cyan),
185 collapsed_icon: "▶ ",
186 expanded_icon: "▼ ",
187 connector_branch: "├── ",
188 connector_last: "└── ",
189 connector_vertical: "│ ",
190 connector_space: " ",
191 cursor_selected: "> ",
192 cursor_normal: " ",
193 }
194 }
195}
196
197impl TreeStyle {
198 pub fn minimal() -> Self {
200 Self {
201 connector_branch: " ",
202 connector_last: " ",
203 connector_vertical: " ",
204 connector_space: " ",
205 ..Default::default()
206 }
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct FlatNode<'a, T> {
213 pub node: &'a TreeNode<T>,
215 pub depth: usize,
217 pub is_last: bool,
219 pub parent_is_last: Vec<bool>,
221}
222
223pub struct TreeView<'a, T, F>
225where
226 F: Fn(&TreeNode<T>, bool) -> String,
227{
228 nodes: &'a [TreeNode<T>],
229 state: &'a TreeViewState,
230 style: TreeStyle,
231 render_fn: F,
232}
233
234impl<'a, T> TreeView<'a, T, fn(&TreeNode<T>, bool) -> String> {
235 pub fn new(nodes: &'a [TreeNode<T>], state: &'a TreeViewState) -> Self
237 where
238 T: std::fmt::Debug,
239 {
240 Self {
241 nodes,
242 state,
243 style: TreeStyle::default(),
244 render_fn: |node, _| format!("{:?}", node.id),
245 }
246 }
247}
248
249impl<'a, T, F> TreeView<'a, T, F>
250where
251 F: Fn(&TreeNode<T>, bool) -> String,
252{
253 pub fn render_item<G>(self, render_fn: G) -> TreeView<'a, T, G>
255 where
256 G: Fn(&TreeNode<T>, bool) -> String,
257 {
258 TreeView {
259 nodes: self.nodes,
260 state: self.state,
261 style: self.style,
262 render_fn,
263 }
264 }
265
266 pub fn style(mut self, style: TreeStyle) -> Self {
268 self.style = style;
269 self
270 }
271
272 fn flatten_visible(&self) -> Vec<FlatNode<'a, T>> {
274 let mut result = Vec::new();
275 self.flatten_nodes(self.nodes, 0, &mut result, &[]);
276 result
277 }
278
279 fn flatten_nodes(
280 &self,
281 nodes: &'a [TreeNode<T>],
282 depth: usize,
283 result: &mut Vec<FlatNode<'a, T>>,
284 parent_is_last: &[bool],
285 ) {
286 let count = nodes.len();
287 for (idx, node) in nodes.iter().enumerate() {
288 let is_last = idx == count - 1;
289 result.push(FlatNode {
290 node,
291 depth,
292 is_last,
293 parent_is_last: parent_is_last.to_vec(),
294 });
295
296 if node.has_children() && !self.state.is_collapsed(&node.id) {
298 let mut new_parent_is_last = parent_is_last.to_vec();
299 new_parent_is_last.push(is_last);
300 self.flatten_nodes(&node.children, depth + 1, result, &new_parent_is_last);
301 }
302 }
303 }
304
305 pub fn visible_count(&self) -> usize {
307 self.flatten_visible().len()
308 }
309
310 fn build_lines(&self, area: Rect) -> Vec<Line<'static>> {
312 let visible = self.flatten_visible();
313 let mut lines = Vec::new();
314
315 let scroll = self.state.scroll as usize;
316 let viewport_height = area.height as usize;
317
318 for (idx, flat_node) in visible
319 .iter()
320 .enumerate()
321 .skip(scroll)
322 .take(viewport_height)
323 {
324 let is_selected = idx == self.state.selected_index;
325 let mut spans = Vec::new();
326
327 let cursor = if is_selected {
329 self.style.cursor_selected
330 } else {
331 self.style.cursor_normal
332 };
333 spans.push(Span::styled(
334 cursor.to_string(),
335 if is_selected {
336 self.style.selected_style
337 } else {
338 self.style.normal_style
339 },
340 ));
341
342 for &parent_is_last in flat_node.parent_is_last.iter() {
344 let connector = if parent_is_last {
345 self.style.connector_space
346 } else {
347 self.style.connector_vertical
348 };
349 spans.push(Span::styled(
350 connector.to_string(),
351 self.style.connector_style,
352 ));
353 }
354
355 if flat_node.depth > 0 {
357 let connector = if flat_node.is_last {
358 self.style.connector_last
359 } else {
360 self.style.connector_branch
361 };
362 spans.push(Span::styled(
363 connector.to_string(),
364 self.style.connector_style,
365 ));
366 }
367
368 if flat_node.node.has_children() {
370 let icon = if self.state.is_collapsed(&flat_node.node.id) {
371 self.style.collapsed_icon
372 } else {
373 self.style.expanded_icon
374 };
375 spans.push(Span::styled(icon.to_string(), self.style.icon_style));
376 }
377
378 let content = (self.render_fn)(flat_node.node, is_selected);
380 spans.push(Span::styled(
381 content,
382 if is_selected {
383 self.style.selected_style
384 } else {
385 self.style.normal_style
386 },
387 ));
388
389 lines.push(Line::from(spans));
390 }
391
392 lines
393 }
394}
395
396impl<'a, T, F> Widget for TreeView<'a, T, F>
397where
398 F: Fn(&TreeNode<T>, bool) -> String,
399{
400 fn render(self, area: Rect, buf: &mut Buffer) {
401 let lines = self.build_lines(area);
402 let paragraph = Paragraph::new(lines).wrap(Wrap { trim: false });
403 paragraph.render(area, buf);
404 }
405}
406
407pub fn get_selected_id<T: std::fmt::Debug>(
409 nodes: &[TreeNode<T>],
410 state: &TreeViewState,
411) -> Option<String> {
412 let tree = TreeView::new(nodes, state);
413 let visible = tree.flatten_visible();
414 visible.get(state.selected_index).map(|f| f.node.id.clone())
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[derive(Debug, Clone)]
422 struct TestItem {
423 name: String,
424 }
425
426 fn create_test_tree() -> Vec<TreeNode<TestItem>> {
427 vec![
428 TreeNode::new(
429 "1",
430 TestItem {
431 name: "Root 1".into(),
432 },
433 )
434 .with_children(vec![
435 TreeNode::new(
436 "1.1",
437 TestItem {
438 name: "Child 1.1".into(),
439 },
440 ),
441 TreeNode::new(
442 "1.2",
443 TestItem {
444 name: "Child 1.2".into(),
445 },
446 ),
447 ]),
448 TreeNode::new(
449 "2",
450 TestItem {
451 name: "Root 2".into(),
452 },
453 ),
454 ]
455 }
456
457 fn create_deep_tree() -> Vec<TreeNode<TestItem>> {
458 vec![
459 TreeNode::new(
460 "root",
461 TestItem {
462 name: "Root".into(),
463 },
464 )
465 .with_children(vec![
466 TreeNode::new(
467 "level1",
468 TestItem {
469 name: "Level 1".into(),
470 },
471 )
472 .with_children(vec![
473 TreeNode::new(
474 "level2",
475 TestItem {
476 name: "Level 2".into(),
477 },
478 )
479 .with_children(vec![TreeNode::new(
480 "level3",
481 TestItem {
482 name: "Level 3".into(),
483 },
484 )]),
485 ]),
486 ]),
487 ]
488 }
489
490 #[test]
491 fn test_tree_node_new() {
492 let node: TreeNode<TestItem> = TreeNode::new(
493 "test-id",
494 TestItem {
495 name: "Test".into(),
496 },
497 );
498 assert_eq!(node.id, "test-id");
499 assert_eq!(node.data.name, "Test");
500 assert!(node.children.is_empty());
501 }
502
503 #[test]
504 fn test_tree_node_with_children() {
505 let node: TreeNode<TestItem> = TreeNode::new(
506 "parent",
507 TestItem {
508 name: "Parent".into(),
509 },
510 )
511 .with_children(vec![
512 TreeNode::new(
513 "child1",
514 TestItem {
515 name: "Child 1".into(),
516 },
517 ),
518 TreeNode::new(
519 "child2",
520 TestItem {
521 name: "Child 2".into(),
522 },
523 ),
524 ]);
525 assert_eq!(node.children.len(), 2);
526 }
527
528 #[test]
529 fn test_tree_node_has_children() {
530 let leaf: TreeNode<TestItem> = TreeNode::new(
531 "leaf",
532 TestItem {
533 name: "Leaf".into(),
534 },
535 );
536 assert!(!leaf.has_children());
537
538 let parent: TreeNode<TestItem> = TreeNode::new(
539 "parent",
540 TestItem {
541 name: "Parent".into(),
542 },
543 )
544 .with_children(vec![leaf.clone()]);
545 assert!(parent.has_children());
546 }
547
548 #[test]
549 fn test_tree_state_new() {
550 let state = TreeViewState::new();
551 assert_eq!(state.selected_index, 0);
552 assert!(state.collapsed.is_empty());
553 }
554
555 #[test]
556 fn test_tree_state() {
557 let mut state = TreeViewState::new();
558 assert!(!state.is_collapsed("1"));
559
560 state.collapse("1");
561 assert!(state.is_collapsed("1"));
562
563 state.toggle_collapsed("1");
564 assert!(!state.is_collapsed("1"));
565 }
566
567 #[test]
568 fn test_tree_state_expand() {
569 let mut state = TreeViewState::new();
570 state.collapse("node1");
571 state.collapse("node2");
572
573 assert!(state.is_collapsed("node1"));
574 state.expand("node1");
575 assert!(!state.is_collapsed("node1"));
576 assert!(state.is_collapsed("node2"));
577 }
578
579 #[test]
580 fn test_tree_state_collapse_multiple() {
581 let mut state = TreeViewState::new();
582
583 state.collapse("1");
584 state.collapse("2");
585 assert!(state.is_collapsed("1"));
586 assert!(state.is_collapsed("2"));
587
588 state.expand("1");
589 state.expand("2");
590 assert!(!state.is_collapsed("1"));
591 assert!(!state.is_collapsed("2"));
592 }
593
594 #[test]
595 fn test_tree_state_navigation() {
596 let mut state = TreeViewState::new();
597 assert_eq!(state.selected_index, 0);
598
599 state.select_next(5);
600 assert_eq!(state.selected_index, 1);
601
602 state.select_next(5);
603 state.select_next(5);
604 state.select_next(5);
605 assert_eq!(state.selected_index, 4);
606
607 state.select_next(5); assert_eq!(state.selected_index, 4);
609
610 state.select_prev();
611 assert_eq!(state.selected_index, 3);
612
613 state.select_prev();
614 state.select_prev();
615 state.select_prev();
616 state.select_prev(); assert_eq!(state.selected_index, 0);
618 }
619
620 #[test]
621 fn test_tree_state_ensure_visible() {
622 let mut state = TreeViewState::new();
623 state.selected_index = 15;
624 state.scroll = 5;
625 state.ensure_visible(10);
626 assert!(state.scroll >= 6); state.selected_index = 2;
629 state.scroll = 10;
630 state.ensure_visible(10);
631 assert_eq!(state.scroll, 2);
632 }
633
634 #[test]
635 fn test_tree_state_ensure_visible_zero_viewport() {
636 let mut state = TreeViewState::new();
637 state.scroll = 5;
638 state.selected_index = 10;
639 state.ensure_visible(0);
640 assert_eq!(state.scroll, 11); }
643
644 #[test]
645 fn test_flatten_visible() {
646 let nodes = create_test_tree();
647 let state = TreeViewState::new();
648 let tree = TreeView::new(&nodes, &state);
649
650 let visible = tree.flatten_visible();
651 assert_eq!(visible.len(), 4); }
653
654 #[test]
655 fn test_flatten_with_collapsed() {
656 let nodes = create_test_tree();
657 let mut state = TreeViewState::new();
658 state.collapse("1");
659
660 let tree = TreeView::new(&nodes, &state);
661 let visible = tree.flatten_visible();
662 assert_eq!(visible.len(), 2); }
664
665 #[test]
666 fn test_flatten_deep_tree() {
667 let nodes = create_deep_tree();
668 let state = TreeViewState::new();
669 let tree = TreeView::new(&nodes, &state);
670
671 let visible = tree.flatten_visible();
672 assert_eq!(visible.len(), 4); assert_eq!(visible[0].depth, 0);
676 assert_eq!(visible[1].depth, 1);
677 assert_eq!(visible[2].depth, 2);
678 assert_eq!(visible[3].depth, 3);
679 }
680
681 #[test]
682 fn test_visible_count() {
683 let nodes = create_test_tree();
684 let state = TreeViewState::new();
685 let tree = TreeView::new(&nodes, &state);
686 assert_eq!(tree.visible_count(), 4);
687
688 let mut collapsed_state = TreeViewState::new();
689 collapsed_state.collapse("1");
690 let collapsed_tree = TreeView::new(&nodes, &collapsed_state);
691 assert_eq!(collapsed_tree.visible_count(), 2);
692 }
693
694 #[test]
695 fn test_selection_navigation() {
696 let nodes = create_test_tree();
697 let mut state = TreeViewState::new();
698 let tree = TreeView::new(&nodes, &state);
699 let count = tree.visible_count();
700
701 assert_eq!(state.selected_index, 0);
702 state.select_next(count);
703 assert_eq!(state.selected_index, 1);
704 state.select_prev();
705 assert_eq!(state.selected_index, 0);
706 }
707
708 #[test]
709 fn test_get_selected_id() {
710 let nodes = create_test_tree();
711 let mut state = TreeViewState::new();
712
713 let id = get_selected_id(&nodes, &state);
714 assert_eq!(id, Some("1".to_string()));
715
716 state.selected_index = 2;
717 let id = get_selected_id(&nodes, &state);
718 assert_eq!(id, Some("1.2".to_string()));
719
720 state.selected_index = 3;
721 let id = get_selected_id(&nodes, &state);
722 assert_eq!(id, Some("2".to_string()));
723 }
724
725 #[test]
726 fn test_get_selected_id_with_collapsed() {
727 let nodes = create_test_tree();
728 let mut state = TreeViewState::new();
729 state.collapse("1");
730 state.selected_index = 1;
731
732 let id = get_selected_id(&nodes, &state);
733 assert_eq!(id, Some("2".to_string()));
734 }
735
736 #[test]
737 fn test_tree_style_default() {
738 let style = TreeStyle::default();
739 assert_eq!(style.collapsed_icon, "▶ ");
740 assert_eq!(style.expanded_icon, "▼ ");
741 assert_eq!(style.connector_branch, "├── ");
742 assert_eq!(style.connector_last, "└── ");
743 }
744
745 #[test]
746 fn test_tree_view_render() {
747 let nodes = create_test_tree();
748 let state = TreeViewState::new();
749 let tree = TreeView::new(&nodes, &state)
750 .render_item(|node, _| format!("Item: {}", node.data.name));
751
752 let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
753 tree.render(Rect::new(0, 0, 40, 10), &mut buf);
754 }
756
757 #[test]
758 fn test_tree_view_with_style() {
759 let nodes = create_test_tree();
760 let state = TreeViewState::new();
761 let custom_style = TreeStyle {
762 collapsed_icon: "+",
763 expanded_icon: "-",
764 ..TreeStyle::default()
765 };
766 let tree = TreeView::new(&nodes, &state).style(custom_style);
767
768 let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
769 tree.render(Rect::new(0, 0, 40, 10), &mut buf);
770 }
771
772 #[test]
773 fn test_empty_tree() {
774 let nodes: Vec<TreeNode<TestItem>> = vec![];
775 let state = TreeViewState::new();
776 let tree = TreeView::new(&nodes, &state);
777
778 assert_eq!(tree.visible_count(), 0);
779 assert!(tree.flatten_visible().is_empty());
780 }
781}