Skip to main content

ratatui_interact/components/
tree_view.rs

1//! Tree view widget
2//!
3//! A collapsible tree view with selection, status icons, and customizable rendering.
4//!
5//! # Example
6//!
7//! ```rust
8//! use ratatui_interact::components::{TreeView, TreeViewState, TreeNode, TreeStyle};
9//! use ratatui::layout::Rect;
10//!
11//! // Define your tree node type
12//! #[derive(Clone, Debug)]
13//! struct Task {
14//!     id: String,
15//!     name: String,
16//!     status: &'static str,
17//! }
18//!
19//! // Create nodes
20//! let nodes = vec![
21//!     TreeNode::new("1", Task { id: "1".into(), name: "Root".into(), status: "pending" })
22//!         .with_children(vec![
23//!             TreeNode::new("1.1", Task { id: "1.1".into(), name: "Child 1".into(), status: "done" }),
24//!             TreeNode::new("1.2", Task { id: "1.2".into(), name: "Child 2".into(), status: "running" }),
25//!         ]),
26//! ];
27//!
28//! // Create state and view
29//! let mut state = TreeViewState::new();
30//! let tree = TreeView::new(&nodes, &state)
31//!     .render_item(|node, is_selected| {
32//!         format!("{} [{}]", node.data.name, node.data.status)
33//!     });
34//! ```
35
36use std::collections::HashSet;
37
38use ratatui::{
39    buffer::Buffer,
40    layout::Rect,
41    style::{Color, Modifier, Style},
42    text::{Line, Span},
43    widgets::{Paragraph, Widget, Wrap},
44};
45
46/// A node in the tree
47#[derive(Debug, Clone)]
48pub struct TreeNode<T> {
49    /// Unique identifier for this node
50    pub id: String,
51    /// The data associated with this node
52    pub data: T,
53    /// Child nodes
54    pub children: Vec<TreeNode<T>>,
55}
56
57impl<T> TreeNode<T> {
58    /// Create a new tree node
59    pub fn new(id: impl Into<String>, data: T) -> Self {
60        Self {
61            id: id.into(),
62            data,
63            children: Vec::new(),
64        }
65    }
66
67    /// Add children to this node
68    pub fn with_children(mut self, children: Vec<TreeNode<T>>) -> Self {
69        self.children = children;
70        self
71    }
72
73    /// Add a single child to this node
74    pub fn add_child(&mut self, child: TreeNode<T>) {
75        self.children.push(child);
76    }
77
78    /// Check if this node has children
79    pub fn has_children(&self) -> bool {
80        !self.children.is_empty()
81    }
82}
83
84/// State for the tree view widget
85#[derive(Debug, Clone, Default)]
86pub struct TreeViewState {
87    /// Set of collapsed node IDs
88    pub collapsed: HashSet<String>,
89    /// Currently selected index in the flattened visible list
90    pub selected_index: usize,
91    /// Scroll offset
92    pub scroll: u16,
93}
94
95impl TreeViewState {
96    /// Create a new tree view state
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Toggle the collapsed state of a node
102    pub fn toggle_collapsed(&mut self, id: &str) {
103        if self.collapsed.contains(id) {
104            self.collapsed.remove(id);
105        } else {
106            self.collapsed.insert(id.to_string());
107        }
108    }
109
110    /// Check if a node is collapsed
111    pub fn is_collapsed(&self, id: &str) -> bool {
112        self.collapsed.contains(id)
113    }
114
115    /// Collapse a node
116    pub fn collapse(&mut self, id: &str) {
117        self.collapsed.insert(id.to_string());
118    }
119
120    /// Expand a node
121    pub fn expand(&mut self, id: &str) {
122        self.collapsed.remove(id);
123    }
124
125    /// Move selection up
126    pub fn select_prev(&mut self) {
127        self.selected_index = self.selected_index.saturating_sub(1);
128    }
129
130    /// Move selection down (needs total count)
131    pub fn select_next(&mut self, total_visible: usize) {
132        if self.selected_index + 1 < total_visible {
133            self.selected_index += 1;
134        }
135    }
136
137    /// Ensure selection is visible given viewport height
138    pub fn ensure_visible(&mut self, viewport_height: usize) {
139        if self.selected_index < self.scroll as usize {
140            self.scroll = self.selected_index as u16;
141        } else if self.selected_index >= self.scroll as usize + viewport_height {
142            self.scroll = (self.selected_index - viewport_height + 1) as u16;
143        }
144    }
145}
146
147/// Style configuration for tree view
148#[derive(Debug, Clone)]
149pub struct TreeStyle {
150    /// Style for selected items
151    pub selected_style: Style,
152    /// Style for normal items
153    pub normal_style: Style,
154    /// Style for tree connectors
155    pub connector_style: Style,
156    /// Style for expand/collapse icons
157    pub icon_style: Style,
158    /// Collapsed icon
159    pub collapsed_icon: &'static str,
160    /// Expanded icon
161    pub expanded_icon: &'static str,
162    /// Tree connector: branch (has siblings after)
163    pub connector_branch: &'static str,
164    /// Tree connector: last (no siblings after)
165    pub connector_last: &'static str,
166    /// Tree connector: vertical line
167    pub connector_vertical: &'static str,
168    /// Tree connector: empty space
169    pub connector_space: &'static str,
170    /// Selection cursor for selected item
171    pub cursor_selected: &'static str,
172    /// Selection cursor for non-selected items
173    pub cursor_normal: &'static str,
174}
175
176impl Default for TreeStyle {
177    fn default() -> Self {
178        Self {
179            selected_style: Style::default()
180                .fg(Color::Yellow)
181                .add_modifier(Modifier::BOLD),
182            normal_style: Style::default().fg(Color::White),
183            connector_style: Style::default().fg(Color::DarkGray),
184            icon_style: Style::default().fg(Color::Cyan),
185            collapsed_icon: "▶ ",
186            expanded_icon: "▼ ",
187            connector_branch: "├── ",
188            connector_last: "└── ",
189            connector_vertical: "│   ",
190            connector_space: "    ",
191            cursor_selected: "> ",
192            cursor_normal: "  ",
193        }
194    }
195}
196
197impl From<&crate::theme::Theme> for TreeStyle {
198    fn from(theme: &crate::theme::Theme) -> Self {
199        let p = &theme.palette;
200        Self {
201            selected_style: Style::default().fg(p.primary).add_modifier(Modifier::BOLD),
202            normal_style: Style::default().fg(p.text),
203            connector_style: Style::default().fg(p.text_disabled),
204            icon_style: Style::default().fg(p.secondary),
205            collapsed_icon: "▶ ",
206            expanded_icon: "▼ ",
207            connector_branch: "├── ",
208            connector_last: "└── ",
209            connector_vertical: "│   ",
210            connector_space: "    ",
211            cursor_selected: "> ",
212            cursor_normal: "  ",
213        }
214    }
215}
216
217impl TreeStyle {
218    /// Create a minimal style without tree connectors
219    pub fn minimal() -> Self {
220        Self {
221            connector_branch: "  ",
222            connector_last: "  ",
223            connector_vertical: "  ",
224            connector_space: "  ",
225            ..Default::default()
226        }
227    }
228}
229
230/// Flattened node info for rendering
231#[derive(Debug, Clone)]
232pub struct FlatNode<'a, T> {
233    /// Reference to the original node
234    pub node: &'a TreeNode<T>,
235    /// Depth in the tree (0 = root)
236    pub depth: usize,
237    /// Whether this is the last sibling at its level
238    pub is_last: bool,
239    /// Path of is_last values from root to parent
240    pub parent_is_last: Vec<bool>,
241}
242
243/// Tree view widget
244pub struct TreeView<'a, T, F>
245where
246    F: Fn(&TreeNode<T>, bool) -> String,
247{
248    nodes: &'a [TreeNode<T>],
249    state: &'a TreeViewState,
250    style: TreeStyle,
251    render_fn: F,
252}
253
254impl<'a, T> TreeView<'a, T, fn(&TreeNode<T>, bool) -> String> {
255    /// Create a new tree view with default rendering
256    pub fn new(nodes: &'a [TreeNode<T>], state: &'a TreeViewState) -> Self
257    where
258        T: std::fmt::Debug,
259    {
260        Self {
261            nodes,
262            state,
263            style: TreeStyle::default(),
264            render_fn: |node, _| format!("{:?}", node.id),
265        }
266    }
267}
268
269impl<'a, T, F> TreeView<'a, T, F>
270where
271    F: Fn(&TreeNode<T>, bool) -> String,
272{
273    /// Set the render function for items
274    pub fn render_item<G>(self, render_fn: G) -> TreeView<'a, T, G>
275    where
276        G: Fn(&TreeNode<T>, bool) -> String,
277    {
278        TreeView {
279            nodes: self.nodes,
280            state: self.state,
281            style: self.style,
282            render_fn,
283        }
284    }
285
286    /// Set the style
287    pub fn style(mut self, style: TreeStyle) -> Self {
288        self.style = style;
289        self
290    }
291
292    /// Apply a theme to derive the style
293    pub fn theme(self, theme: &crate::theme::Theme) -> Self {
294        self.style(TreeStyle::from(theme))
295    }
296
297    /// Flatten the tree into a list of visible nodes
298    fn flatten_visible(&self) -> Vec<FlatNode<'a, T>> {
299        let mut result = Vec::new();
300        self.flatten_nodes(self.nodes, 0, &mut result, &[]);
301        result
302    }
303
304    fn flatten_nodes(
305        &self,
306        nodes: &'a [TreeNode<T>],
307        depth: usize,
308        result: &mut Vec<FlatNode<'a, T>>,
309        parent_is_last: &[bool],
310    ) {
311        let count = nodes.len();
312        for (idx, node) in nodes.iter().enumerate() {
313            let is_last = idx == count - 1;
314            result.push(FlatNode {
315                node,
316                depth,
317                is_last,
318                parent_is_last: parent_is_last.to_vec(),
319            });
320
321            // Only recurse into children if not collapsed
322            if node.has_children() && !self.state.is_collapsed(&node.id) {
323                let mut new_parent_is_last = parent_is_last.to_vec();
324                new_parent_is_last.push(is_last);
325                self.flatten_nodes(&node.children, depth + 1, result, &new_parent_is_last);
326            }
327        }
328    }
329
330    /// Get the total number of visible nodes
331    pub fn visible_count(&self) -> usize {
332        self.flatten_visible().len()
333    }
334
335    /// Build the lines for rendering
336    fn build_lines(&self, area: Rect) -> Vec<Line<'static>> {
337        let visible = self.flatten_visible();
338        let mut lines = Vec::new();
339
340        let scroll = self.state.scroll as usize;
341        let viewport_height = area.height as usize;
342
343        for (idx, flat_node) in visible
344            .iter()
345            .enumerate()
346            .skip(scroll)
347            .take(viewport_height)
348        {
349            let is_selected = idx == self.state.selected_index;
350            let mut spans = Vec::new();
351
352            // Selection cursor
353            let cursor = if is_selected {
354                self.style.cursor_selected
355            } else {
356                self.style.cursor_normal
357            };
358            spans.push(Span::styled(
359                cursor.to_string(),
360                if is_selected {
361                    self.style.selected_style
362                } else {
363                    self.style.normal_style
364                },
365            ));
366
367            // Tree connectors
368            for &parent_is_last in flat_node.parent_is_last.iter() {
369                let connector = if parent_is_last {
370                    self.style.connector_space
371                } else {
372                    self.style.connector_vertical
373                };
374                spans.push(Span::styled(
375                    connector.to_string(),
376                    self.style.connector_style,
377                ));
378            }
379
380            // Branch connector for this node (if not root)
381            if flat_node.depth > 0 {
382                let connector = if flat_node.is_last {
383                    self.style.connector_last
384                } else {
385                    self.style.connector_branch
386                };
387                spans.push(Span::styled(
388                    connector.to_string(),
389                    self.style.connector_style,
390                ));
391            }
392
393            // Expand/collapse icon (if has children)
394            if flat_node.node.has_children() {
395                let icon = if self.state.is_collapsed(&flat_node.node.id) {
396                    self.style.collapsed_icon
397                } else {
398                    self.style.expanded_icon
399                };
400                spans.push(Span::styled(icon.to_string(), self.style.icon_style));
401            }
402
403            // Node content
404            let content = (self.render_fn)(flat_node.node, is_selected);
405            spans.push(Span::styled(
406                content,
407                if is_selected {
408                    self.style.selected_style
409                } else {
410                    self.style.normal_style
411                },
412            ));
413
414            lines.push(Line::from(spans));
415        }
416
417        lines
418    }
419}
420
421impl<'a, T, F> Widget for TreeView<'a, T, F>
422where
423    F: Fn(&TreeNode<T>, bool) -> String,
424{
425    fn render(self, area: Rect, buf: &mut Buffer) {
426        let lines = self.build_lines(area);
427        let paragraph = Paragraph::new(lines).wrap(Wrap { trim: false });
428        paragraph.render(area, buf);
429    }
430}
431
432/// Get the selected node ID from a tree view state and nodes
433pub fn get_selected_id<T: std::fmt::Debug>(
434    nodes: &[TreeNode<T>],
435    state: &TreeViewState,
436) -> Option<String> {
437    let tree = TreeView::new(nodes, state);
438    let visible = tree.flatten_visible();
439    visible.get(state.selected_index).map(|f| f.node.id.clone())
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[derive(Debug, Clone)]
447    struct TestItem {
448        name: String,
449    }
450
451    fn create_test_tree() -> Vec<TreeNode<TestItem>> {
452        vec![
453            TreeNode::new(
454                "1",
455                TestItem {
456                    name: "Root 1".into(),
457                },
458            )
459            .with_children(vec![
460                TreeNode::new(
461                    "1.1",
462                    TestItem {
463                        name: "Child 1.1".into(),
464                    },
465                ),
466                TreeNode::new(
467                    "1.2",
468                    TestItem {
469                        name: "Child 1.2".into(),
470                    },
471                ),
472            ]),
473            TreeNode::new(
474                "2",
475                TestItem {
476                    name: "Root 2".into(),
477                },
478            ),
479        ]
480    }
481
482    fn create_deep_tree() -> Vec<TreeNode<TestItem>> {
483        vec![
484            TreeNode::new(
485                "root",
486                TestItem {
487                    name: "Root".into(),
488                },
489            )
490            .with_children(vec![
491                TreeNode::new(
492                    "level1",
493                    TestItem {
494                        name: "Level 1".into(),
495                    },
496                )
497                .with_children(vec![
498                    TreeNode::new(
499                        "level2",
500                        TestItem {
501                            name: "Level 2".into(),
502                        },
503                    )
504                    .with_children(vec![TreeNode::new(
505                        "level3",
506                        TestItem {
507                            name: "Level 3".into(),
508                        },
509                    )]),
510                ]),
511            ]),
512        ]
513    }
514
515    #[test]
516    fn test_tree_node_new() {
517        let node: TreeNode<TestItem> = TreeNode::new(
518            "test-id",
519            TestItem {
520                name: "Test".into(),
521            },
522        );
523        assert_eq!(node.id, "test-id");
524        assert_eq!(node.data.name, "Test");
525        assert!(node.children.is_empty());
526    }
527
528    #[test]
529    fn test_tree_node_with_children() {
530        let node: TreeNode<TestItem> = TreeNode::new(
531            "parent",
532            TestItem {
533                name: "Parent".into(),
534            },
535        )
536        .with_children(vec![
537            TreeNode::new(
538                "child1",
539                TestItem {
540                    name: "Child 1".into(),
541                },
542            ),
543            TreeNode::new(
544                "child2",
545                TestItem {
546                    name: "Child 2".into(),
547                },
548            ),
549        ]);
550        assert_eq!(node.children.len(), 2);
551    }
552
553    #[test]
554    fn test_tree_node_has_children() {
555        let leaf: TreeNode<TestItem> = TreeNode::new(
556            "leaf",
557            TestItem {
558                name: "Leaf".into(),
559            },
560        );
561        assert!(!leaf.has_children());
562
563        let parent: TreeNode<TestItem> = TreeNode::new(
564            "parent",
565            TestItem {
566                name: "Parent".into(),
567            },
568        )
569        .with_children(vec![leaf.clone()]);
570        assert!(parent.has_children());
571    }
572
573    #[test]
574    fn test_tree_state_new() {
575        let state = TreeViewState::new();
576        assert_eq!(state.selected_index, 0);
577        assert!(state.collapsed.is_empty());
578    }
579
580    #[test]
581    fn test_tree_state() {
582        let mut state = TreeViewState::new();
583        assert!(!state.is_collapsed("1"));
584
585        state.collapse("1");
586        assert!(state.is_collapsed("1"));
587
588        state.toggle_collapsed("1");
589        assert!(!state.is_collapsed("1"));
590    }
591
592    #[test]
593    fn test_tree_state_expand() {
594        let mut state = TreeViewState::new();
595        state.collapse("node1");
596        state.collapse("node2");
597
598        assert!(state.is_collapsed("node1"));
599        state.expand("node1");
600        assert!(!state.is_collapsed("node1"));
601        assert!(state.is_collapsed("node2"));
602    }
603
604    #[test]
605    fn test_tree_state_collapse_multiple() {
606        let mut state = TreeViewState::new();
607
608        state.collapse("1");
609        state.collapse("2");
610        assert!(state.is_collapsed("1"));
611        assert!(state.is_collapsed("2"));
612
613        state.expand("1");
614        state.expand("2");
615        assert!(!state.is_collapsed("1"));
616        assert!(!state.is_collapsed("2"));
617    }
618
619    #[test]
620    fn test_tree_state_navigation() {
621        let mut state = TreeViewState::new();
622        assert_eq!(state.selected_index, 0);
623
624        state.select_next(5);
625        assert_eq!(state.selected_index, 1);
626
627        state.select_next(5);
628        state.select_next(5);
629        state.select_next(5);
630        assert_eq!(state.selected_index, 4);
631
632        state.select_next(5); // At max, should not increase
633        assert_eq!(state.selected_index, 4);
634
635        state.select_prev();
636        assert_eq!(state.selected_index, 3);
637
638        state.select_prev();
639        state.select_prev();
640        state.select_prev();
641        state.select_prev(); // At min, should not decrease
642        assert_eq!(state.selected_index, 0);
643    }
644
645    #[test]
646    fn test_tree_state_ensure_visible() {
647        let mut state = TreeViewState::new();
648        state.selected_index = 15;
649        state.scroll = 5;
650        state.ensure_visible(10);
651        assert!(state.scroll >= 6); // 15 - 10 + 1 = 6
652
653        state.selected_index = 2;
654        state.scroll = 10;
655        state.ensure_visible(10);
656        assert_eq!(state.scroll, 2);
657    }
658
659    #[test]
660    fn test_tree_state_ensure_visible_zero_viewport() {
661        let mut state = TreeViewState::new();
662        state.scroll = 5;
663        state.selected_index = 10;
664        state.ensure_visible(0);
665        // With viewport 0, condition (10 >= 5 + 0) is true, so scroll updates
666        assert_eq!(state.scroll, 11); // selected_index - 0 + 1
667    }
668
669    #[test]
670    fn test_flatten_visible() {
671        let nodes = create_test_tree();
672        let state = TreeViewState::new();
673        let tree = TreeView::new(&nodes, &state);
674
675        let visible = tree.flatten_visible();
676        assert_eq!(visible.len(), 4); // Root1, Child1.1, Child1.2, Root2
677    }
678
679    #[test]
680    fn test_flatten_with_collapsed() {
681        let nodes = create_test_tree();
682        let mut state = TreeViewState::new();
683        state.collapse("1");
684
685        let tree = TreeView::new(&nodes, &state);
686        let visible = tree.flatten_visible();
687        assert_eq!(visible.len(), 2); // Root1 (collapsed), Root2
688    }
689
690    #[test]
691    fn test_flatten_deep_tree() {
692        let nodes = create_deep_tree();
693        let state = TreeViewState::new();
694        let tree = TreeView::new(&nodes, &state);
695
696        let visible = tree.flatten_visible();
697        assert_eq!(visible.len(), 4); // root, level1, level2, level3
698
699        // Check depth levels
700        assert_eq!(visible[0].depth, 0);
701        assert_eq!(visible[1].depth, 1);
702        assert_eq!(visible[2].depth, 2);
703        assert_eq!(visible[3].depth, 3);
704    }
705
706    #[test]
707    fn test_visible_count() {
708        let nodes = create_test_tree();
709        let state = TreeViewState::new();
710        let tree = TreeView::new(&nodes, &state);
711        assert_eq!(tree.visible_count(), 4);
712
713        let mut collapsed_state = TreeViewState::new();
714        collapsed_state.collapse("1");
715        let collapsed_tree = TreeView::new(&nodes, &collapsed_state);
716        assert_eq!(collapsed_tree.visible_count(), 2);
717    }
718
719    #[test]
720    fn test_selection_navigation() {
721        let nodes = create_test_tree();
722        let mut state = TreeViewState::new();
723        let tree = TreeView::new(&nodes, &state);
724        let count = tree.visible_count();
725
726        assert_eq!(state.selected_index, 0);
727        state.select_next(count);
728        assert_eq!(state.selected_index, 1);
729        state.select_prev();
730        assert_eq!(state.selected_index, 0);
731    }
732
733    #[test]
734    fn test_get_selected_id() {
735        let nodes = create_test_tree();
736        let mut state = TreeViewState::new();
737
738        let id = get_selected_id(&nodes, &state);
739        assert_eq!(id, Some("1".to_string()));
740
741        state.selected_index = 2;
742        let id = get_selected_id(&nodes, &state);
743        assert_eq!(id, Some("1.2".to_string()));
744
745        state.selected_index = 3;
746        let id = get_selected_id(&nodes, &state);
747        assert_eq!(id, Some("2".to_string()));
748    }
749
750    #[test]
751    fn test_get_selected_id_with_collapsed() {
752        let nodes = create_test_tree();
753        let mut state = TreeViewState::new();
754        state.collapse("1");
755        state.selected_index = 1;
756
757        let id = get_selected_id(&nodes, &state);
758        assert_eq!(id, Some("2".to_string()));
759    }
760
761    #[test]
762    fn test_tree_style_default() {
763        let style = TreeStyle::default();
764        assert_eq!(style.collapsed_icon, "▶ ");
765        assert_eq!(style.expanded_icon, "▼ ");
766        assert_eq!(style.connector_branch, "├── ");
767        assert_eq!(style.connector_last, "└── ");
768    }
769
770    #[test]
771    fn test_tree_view_render() {
772        let nodes = create_test_tree();
773        let state = TreeViewState::new();
774        let tree = TreeView::new(&nodes, &state)
775            .render_item(|node, _| format!("Item: {}", node.data.name));
776
777        let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
778        tree.render(Rect::new(0, 0, 40, 10), &mut buf);
779        // Should not panic
780    }
781
782    #[test]
783    fn test_tree_view_with_style() {
784        let nodes = create_test_tree();
785        let state = TreeViewState::new();
786        let custom_style = TreeStyle {
787            collapsed_icon: "+",
788            expanded_icon: "-",
789            ..TreeStyle::default()
790        };
791        let tree = TreeView::new(&nodes, &state).style(custom_style);
792
793        let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
794        tree.render(Rect::new(0, 0, 40, 10), &mut buf);
795    }
796
797    #[test]
798    fn test_empty_tree() {
799        let nodes: Vec<TreeNode<TestItem>> = vec![];
800        let state = TreeViewState::new();
801        let tree = TreeView::new(&nodes, &state);
802
803        assert_eq!(tree.visible_count(), 0);
804        assert!(tree.flatten_visible().is_empty());
805    }
806}