Skip to main content

ratatui_interact/components/
tree_view.rs

1//! Tree view widget
2//!
3//! A collapsible tree view with selection, status icons, and customizable rendering.
4//!
5//! # Example
6//!
7//! ```rust
8//! use ratatui_interact::components::{TreeView, TreeViewState, TreeNode, TreeStyle};
9//! use ratatui::layout::Rect;
10//!
11//! // Define your tree node type
12//! #[derive(Clone, Debug)]
13//! struct Task {
14//!     id: String,
15//!     name: String,
16//!     status: &'static str,
17//! }
18//!
19//! // Create nodes
20//! let nodes = vec![
21//!     TreeNode::new("1", Task { id: "1".into(), name: "Root".into(), status: "pending" })
22//!         .with_children(vec![
23//!             TreeNode::new("1.1", Task { id: "1.1".into(), name: "Child 1".into(), status: "done" }),
24//!             TreeNode::new("1.2", Task { id: "1.2".into(), name: "Child 2".into(), status: "running" }),
25//!         ]),
26//! ];
27//!
28//! // Create state and view
29//! let mut state = TreeViewState::new();
30//! let tree = TreeView::new(&nodes, &state)
31//!     .render_item(|node, is_selected| {
32//!         format!("{} [{}]", node.data.name, node.data.status)
33//!     });
34//! ```
35
36use std::collections::HashSet;
37
38use ratatui::{
39    buffer::Buffer,
40    layout::Rect,
41    style::{Color, Modifier, Style},
42    text::{Line, Span},
43    widgets::{Paragraph, Widget, Wrap},
44};
45
46/// A node in the tree
47#[derive(Debug, Clone)]
48pub struct TreeNode<T> {
49    /// Unique identifier for this node
50    pub id: String,
51    /// The data associated with this node
52    pub data: T,
53    /// Child nodes
54    pub children: Vec<TreeNode<T>>,
55}
56
57impl<T> TreeNode<T> {
58    /// Create a new tree node
59    pub fn new(id: impl Into<String>, data: T) -> Self {
60        Self {
61            id: id.into(),
62            data,
63            children: Vec::new(),
64        }
65    }
66
67    /// Add children to this node
68    pub fn with_children(mut self, children: Vec<TreeNode<T>>) -> Self {
69        self.children = children;
70        self
71    }
72
73    /// Add a single child to this node
74    pub fn add_child(&mut self, child: TreeNode<T>) {
75        self.children.push(child);
76    }
77
78    /// Check if this node has children
79    pub fn has_children(&self) -> bool {
80        !self.children.is_empty()
81    }
82}
83
84/// State for the tree view widget
85#[derive(Debug, Clone, Default)]
86pub struct TreeViewState {
87    /// Set of collapsed node IDs
88    pub collapsed: HashSet<String>,
89    /// Currently selected index in the flattened visible list
90    pub selected_index: usize,
91    /// Scroll offset
92    pub scroll: u16,
93}
94
95impl TreeViewState {
96    /// Create a new tree view state
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Toggle the collapsed state of a node
102    pub fn toggle_collapsed(&mut self, id: &str) {
103        if self.collapsed.contains(id) {
104            self.collapsed.remove(id);
105        } else {
106            self.collapsed.insert(id.to_string());
107        }
108    }
109
110    /// Check if a node is collapsed
111    pub fn is_collapsed(&self, id: &str) -> bool {
112        self.collapsed.contains(id)
113    }
114
115    /// Collapse a node
116    pub fn collapse(&mut self, id: &str) {
117        self.collapsed.insert(id.to_string());
118    }
119
120    /// Expand a node
121    pub fn expand(&mut self, id: &str) {
122        self.collapsed.remove(id);
123    }
124
125    /// Move selection up
126    pub fn select_prev(&mut self) {
127        self.selected_index = self.selected_index.saturating_sub(1);
128    }
129
130    /// Move selection down (needs total count)
131    pub fn select_next(&mut self, total_visible: usize) {
132        if self.selected_index + 1 < total_visible {
133            self.selected_index += 1;
134        }
135    }
136
137    /// Ensure selection is visible given viewport height
138    pub fn ensure_visible(&mut self, viewport_height: usize) {
139        if self.selected_index < self.scroll as usize {
140            self.scroll = self.selected_index as u16;
141        } else if self.selected_index >= self.scroll as usize + viewport_height {
142            self.scroll = (self.selected_index - viewport_height + 1) as u16;
143        }
144    }
145}
146
147/// Style configuration for tree view
148#[derive(Debug, Clone)]
149pub struct TreeStyle {
150    /// Style for selected items
151    pub selected_style: Style,
152    /// Style for normal items
153    pub normal_style: Style,
154    /// Style for tree connectors
155    pub connector_style: Style,
156    /// Style for expand/collapse icons
157    pub icon_style: Style,
158    /// Collapsed icon
159    pub collapsed_icon: &'static str,
160    /// Expanded icon
161    pub expanded_icon: &'static str,
162    /// Tree connector: branch (has siblings after)
163    pub connector_branch: &'static str,
164    /// Tree connector: last (no siblings after)
165    pub connector_last: &'static str,
166    /// Tree connector: vertical line
167    pub connector_vertical: &'static str,
168    /// Tree connector: empty space
169    pub connector_space: &'static str,
170    /// Selection cursor for selected item
171    pub cursor_selected: &'static str,
172    /// Selection cursor for non-selected items
173    pub cursor_normal: &'static str,
174}
175
176impl Default for TreeStyle {
177    fn default() -> Self {
178        Self {
179            selected_style: Style::default()
180                .fg(Color::Yellow)
181                .add_modifier(Modifier::BOLD),
182            normal_style: Style::default().fg(Color::White),
183            connector_style: Style::default().fg(Color::DarkGray),
184            icon_style: Style::default().fg(Color::Cyan),
185            collapsed_icon: "▶ ",
186            expanded_icon: "▼ ",
187            connector_branch: "├── ",
188            connector_last: "└── ",
189            connector_vertical: "│   ",
190            connector_space: "    ",
191            cursor_selected: "> ",
192            cursor_normal: "  ",
193        }
194    }
195}
196
197impl TreeStyle {
198    /// Create a minimal style without tree connectors
199    pub fn minimal() -> Self {
200        Self {
201            connector_branch: "  ",
202            connector_last: "  ",
203            connector_vertical: "  ",
204            connector_space: "  ",
205            ..Default::default()
206        }
207    }
208}
209
210/// Flattened node info for rendering
211#[derive(Debug, Clone)]
212pub struct FlatNode<'a, T> {
213    /// Reference to the original node
214    pub node: &'a TreeNode<T>,
215    /// Depth in the tree (0 = root)
216    pub depth: usize,
217    /// Whether this is the last sibling at its level
218    pub is_last: bool,
219    /// Path of is_last values from root to parent
220    pub parent_is_last: Vec<bool>,
221}
222
223/// Tree view widget
224pub struct TreeView<'a, T, F>
225where
226    F: Fn(&TreeNode<T>, bool) -> String,
227{
228    nodes: &'a [TreeNode<T>],
229    state: &'a TreeViewState,
230    style: TreeStyle,
231    render_fn: F,
232}
233
234impl<'a, T> TreeView<'a, T, fn(&TreeNode<T>, bool) -> String> {
235    /// Create a new tree view with default rendering
236    pub fn new(nodes: &'a [TreeNode<T>], state: &'a TreeViewState) -> Self
237    where
238        T: std::fmt::Debug,
239    {
240        Self {
241            nodes,
242            state,
243            style: TreeStyle::default(),
244            render_fn: |node, _| format!("{:?}", node.id),
245        }
246    }
247}
248
249impl<'a, T, F> TreeView<'a, T, F>
250where
251    F: Fn(&TreeNode<T>, bool) -> String,
252{
253    /// Set the render function for items
254    pub fn render_item<G>(self, render_fn: G) -> TreeView<'a, T, G>
255    where
256        G: Fn(&TreeNode<T>, bool) -> String,
257    {
258        TreeView {
259            nodes: self.nodes,
260            state: self.state,
261            style: self.style,
262            render_fn,
263        }
264    }
265
266    /// Set the style
267    pub fn style(mut self, style: TreeStyle) -> Self {
268        self.style = style;
269        self
270    }
271
272    /// Flatten the tree into a list of visible nodes
273    fn flatten_visible(&self) -> Vec<FlatNode<'a, T>> {
274        let mut result = Vec::new();
275        self.flatten_nodes(self.nodes, 0, &mut result, &[]);
276        result
277    }
278
279    fn flatten_nodes(
280        &self,
281        nodes: &'a [TreeNode<T>],
282        depth: usize,
283        result: &mut Vec<FlatNode<'a, T>>,
284        parent_is_last: &[bool],
285    ) {
286        let count = nodes.len();
287        for (idx, node) in nodes.iter().enumerate() {
288            let is_last = idx == count - 1;
289            result.push(FlatNode {
290                node,
291                depth,
292                is_last,
293                parent_is_last: parent_is_last.to_vec(),
294            });
295
296            // Only recurse into children if not collapsed
297            if node.has_children() && !self.state.is_collapsed(&node.id) {
298                let mut new_parent_is_last = parent_is_last.to_vec();
299                new_parent_is_last.push(is_last);
300                self.flatten_nodes(&node.children, depth + 1, result, &new_parent_is_last);
301            }
302        }
303    }
304
305    /// Get the total number of visible nodes
306    pub fn visible_count(&self) -> usize {
307        self.flatten_visible().len()
308    }
309
310    /// Build the lines for rendering
311    fn build_lines(&self, area: Rect) -> Vec<Line<'static>> {
312        let visible = self.flatten_visible();
313        let mut lines = Vec::new();
314
315        let scroll = self.state.scroll as usize;
316        let viewport_height = area.height as usize;
317
318        for (idx, flat_node) in visible
319            .iter()
320            .enumerate()
321            .skip(scroll)
322            .take(viewport_height)
323        {
324            let is_selected = idx == self.state.selected_index;
325            let mut spans = Vec::new();
326
327            // Selection cursor
328            let cursor = if is_selected {
329                self.style.cursor_selected
330            } else {
331                self.style.cursor_normal
332            };
333            spans.push(Span::styled(
334                cursor.to_string(),
335                if is_selected {
336                    self.style.selected_style
337                } else {
338                    self.style.normal_style
339                },
340            ));
341
342            // Tree connectors
343            for &parent_is_last in flat_node.parent_is_last.iter() {
344                let connector = if parent_is_last {
345                    self.style.connector_space
346                } else {
347                    self.style.connector_vertical
348                };
349                spans.push(Span::styled(
350                    connector.to_string(),
351                    self.style.connector_style,
352                ));
353            }
354
355            // Branch connector for this node (if not root)
356            if flat_node.depth > 0 {
357                let connector = if flat_node.is_last {
358                    self.style.connector_last
359                } else {
360                    self.style.connector_branch
361                };
362                spans.push(Span::styled(
363                    connector.to_string(),
364                    self.style.connector_style,
365                ));
366            }
367
368            // Expand/collapse icon (if has children)
369            if flat_node.node.has_children() {
370                let icon = if self.state.is_collapsed(&flat_node.node.id) {
371                    self.style.collapsed_icon
372                } else {
373                    self.style.expanded_icon
374                };
375                spans.push(Span::styled(icon.to_string(), self.style.icon_style));
376            }
377
378            // Node content
379            let content = (self.render_fn)(flat_node.node, is_selected);
380            spans.push(Span::styled(
381                content,
382                if is_selected {
383                    self.style.selected_style
384                } else {
385                    self.style.normal_style
386                },
387            ));
388
389            lines.push(Line::from(spans));
390        }
391
392        lines
393    }
394}
395
396impl<'a, T, F> Widget for TreeView<'a, T, F>
397where
398    F: Fn(&TreeNode<T>, bool) -> String,
399{
400    fn render(self, area: Rect, buf: &mut Buffer) {
401        let lines = self.build_lines(area);
402        let paragraph = Paragraph::new(lines).wrap(Wrap { trim: false });
403        paragraph.render(area, buf);
404    }
405}
406
407/// Get the selected node ID from a tree view state and nodes
408pub fn get_selected_id<T: std::fmt::Debug>(
409    nodes: &[TreeNode<T>],
410    state: &TreeViewState,
411) -> Option<String> {
412    let tree = TreeView::new(nodes, state);
413    let visible = tree.flatten_visible();
414    visible.get(state.selected_index).map(|f| f.node.id.clone())
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[derive(Debug, Clone)]
422    struct TestItem {
423        name: String,
424    }
425
426    fn create_test_tree() -> Vec<TreeNode<TestItem>> {
427        vec![
428            TreeNode::new(
429                "1",
430                TestItem {
431                    name: "Root 1".into(),
432                },
433            )
434            .with_children(vec![
435                TreeNode::new(
436                    "1.1",
437                    TestItem {
438                        name: "Child 1.1".into(),
439                    },
440                ),
441                TreeNode::new(
442                    "1.2",
443                    TestItem {
444                        name: "Child 1.2".into(),
445                    },
446                ),
447            ]),
448            TreeNode::new(
449                "2",
450                TestItem {
451                    name: "Root 2".into(),
452                },
453            ),
454        ]
455    }
456
457    fn create_deep_tree() -> Vec<TreeNode<TestItem>> {
458        vec![
459            TreeNode::new(
460                "root",
461                TestItem {
462                    name: "Root".into(),
463                },
464            )
465            .with_children(vec![
466                TreeNode::new(
467                    "level1",
468                    TestItem {
469                        name: "Level 1".into(),
470                    },
471                )
472                .with_children(vec![
473                    TreeNode::new(
474                        "level2",
475                        TestItem {
476                            name: "Level 2".into(),
477                        },
478                    )
479                    .with_children(vec![TreeNode::new(
480                        "level3",
481                        TestItem {
482                            name: "Level 3".into(),
483                        },
484                    )]),
485                ]),
486            ]),
487        ]
488    }
489
490    #[test]
491    fn test_tree_node_new() {
492        let node: TreeNode<TestItem> = TreeNode::new(
493            "test-id",
494            TestItem {
495                name: "Test".into(),
496            },
497        );
498        assert_eq!(node.id, "test-id");
499        assert_eq!(node.data.name, "Test");
500        assert!(node.children.is_empty());
501    }
502
503    #[test]
504    fn test_tree_node_with_children() {
505        let node: TreeNode<TestItem> = TreeNode::new(
506            "parent",
507            TestItem {
508                name: "Parent".into(),
509            },
510        )
511        .with_children(vec![
512            TreeNode::new(
513                "child1",
514                TestItem {
515                    name: "Child 1".into(),
516                },
517            ),
518            TreeNode::new(
519                "child2",
520                TestItem {
521                    name: "Child 2".into(),
522                },
523            ),
524        ]);
525        assert_eq!(node.children.len(), 2);
526    }
527
528    #[test]
529    fn test_tree_node_has_children() {
530        let leaf: TreeNode<TestItem> = TreeNode::new(
531            "leaf",
532            TestItem {
533                name: "Leaf".into(),
534            },
535        );
536        assert!(!leaf.has_children());
537
538        let parent: TreeNode<TestItem> = TreeNode::new(
539            "parent",
540            TestItem {
541                name: "Parent".into(),
542            },
543        )
544        .with_children(vec![leaf.clone()]);
545        assert!(parent.has_children());
546    }
547
548    #[test]
549    fn test_tree_state_new() {
550        let state = TreeViewState::new();
551        assert_eq!(state.selected_index, 0);
552        assert!(state.collapsed.is_empty());
553    }
554
555    #[test]
556    fn test_tree_state() {
557        let mut state = TreeViewState::new();
558        assert!(!state.is_collapsed("1"));
559
560        state.collapse("1");
561        assert!(state.is_collapsed("1"));
562
563        state.toggle_collapsed("1");
564        assert!(!state.is_collapsed("1"));
565    }
566
567    #[test]
568    fn test_tree_state_expand() {
569        let mut state = TreeViewState::new();
570        state.collapse("node1");
571        state.collapse("node2");
572
573        assert!(state.is_collapsed("node1"));
574        state.expand("node1");
575        assert!(!state.is_collapsed("node1"));
576        assert!(state.is_collapsed("node2"));
577    }
578
579    #[test]
580    fn test_tree_state_collapse_multiple() {
581        let mut state = TreeViewState::new();
582
583        state.collapse("1");
584        state.collapse("2");
585        assert!(state.is_collapsed("1"));
586        assert!(state.is_collapsed("2"));
587
588        state.expand("1");
589        state.expand("2");
590        assert!(!state.is_collapsed("1"));
591        assert!(!state.is_collapsed("2"));
592    }
593
594    #[test]
595    fn test_tree_state_navigation() {
596        let mut state = TreeViewState::new();
597        assert_eq!(state.selected_index, 0);
598
599        state.select_next(5);
600        assert_eq!(state.selected_index, 1);
601
602        state.select_next(5);
603        state.select_next(5);
604        state.select_next(5);
605        assert_eq!(state.selected_index, 4);
606
607        state.select_next(5); // At max, should not increase
608        assert_eq!(state.selected_index, 4);
609
610        state.select_prev();
611        assert_eq!(state.selected_index, 3);
612
613        state.select_prev();
614        state.select_prev();
615        state.select_prev();
616        state.select_prev(); // At min, should not decrease
617        assert_eq!(state.selected_index, 0);
618    }
619
620    #[test]
621    fn test_tree_state_ensure_visible() {
622        let mut state = TreeViewState::new();
623        state.selected_index = 15;
624        state.scroll = 5;
625        state.ensure_visible(10);
626        assert!(state.scroll >= 6); // 15 - 10 + 1 = 6
627
628        state.selected_index = 2;
629        state.scroll = 10;
630        state.ensure_visible(10);
631        assert_eq!(state.scroll, 2);
632    }
633
634    #[test]
635    fn test_tree_state_ensure_visible_zero_viewport() {
636        let mut state = TreeViewState::new();
637        state.scroll = 5;
638        state.selected_index = 10;
639        state.ensure_visible(0);
640        // With viewport 0, condition (10 >= 5 + 0) is true, so scroll updates
641        assert_eq!(state.scroll, 11); // selected_index - 0 + 1
642    }
643
644    #[test]
645    fn test_flatten_visible() {
646        let nodes = create_test_tree();
647        let state = TreeViewState::new();
648        let tree = TreeView::new(&nodes, &state);
649
650        let visible = tree.flatten_visible();
651        assert_eq!(visible.len(), 4); // Root1, Child1.1, Child1.2, Root2
652    }
653
654    #[test]
655    fn test_flatten_with_collapsed() {
656        let nodes = create_test_tree();
657        let mut state = TreeViewState::new();
658        state.collapse("1");
659
660        let tree = TreeView::new(&nodes, &state);
661        let visible = tree.flatten_visible();
662        assert_eq!(visible.len(), 2); // Root1 (collapsed), Root2
663    }
664
665    #[test]
666    fn test_flatten_deep_tree() {
667        let nodes = create_deep_tree();
668        let state = TreeViewState::new();
669        let tree = TreeView::new(&nodes, &state);
670
671        let visible = tree.flatten_visible();
672        assert_eq!(visible.len(), 4); // root, level1, level2, level3
673
674        // Check depth levels
675        assert_eq!(visible[0].depth, 0);
676        assert_eq!(visible[1].depth, 1);
677        assert_eq!(visible[2].depth, 2);
678        assert_eq!(visible[3].depth, 3);
679    }
680
681    #[test]
682    fn test_visible_count() {
683        let nodes = create_test_tree();
684        let state = TreeViewState::new();
685        let tree = TreeView::new(&nodes, &state);
686        assert_eq!(tree.visible_count(), 4);
687
688        let mut collapsed_state = TreeViewState::new();
689        collapsed_state.collapse("1");
690        let collapsed_tree = TreeView::new(&nodes, &collapsed_state);
691        assert_eq!(collapsed_tree.visible_count(), 2);
692    }
693
694    #[test]
695    fn test_selection_navigation() {
696        let nodes = create_test_tree();
697        let mut state = TreeViewState::new();
698        let tree = TreeView::new(&nodes, &state);
699        let count = tree.visible_count();
700
701        assert_eq!(state.selected_index, 0);
702        state.select_next(count);
703        assert_eq!(state.selected_index, 1);
704        state.select_prev();
705        assert_eq!(state.selected_index, 0);
706    }
707
708    #[test]
709    fn test_get_selected_id() {
710        let nodes = create_test_tree();
711        let mut state = TreeViewState::new();
712
713        let id = get_selected_id(&nodes, &state);
714        assert_eq!(id, Some("1".to_string()));
715
716        state.selected_index = 2;
717        let id = get_selected_id(&nodes, &state);
718        assert_eq!(id, Some("1.2".to_string()));
719
720        state.selected_index = 3;
721        let id = get_selected_id(&nodes, &state);
722        assert_eq!(id, Some("2".to_string()));
723    }
724
725    #[test]
726    fn test_get_selected_id_with_collapsed() {
727        let nodes = create_test_tree();
728        let mut state = TreeViewState::new();
729        state.collapse("1");
730        state.selected_index = 1;
731
732        let id = get_selected_id(&nodes, &state);
733        assert_eq!(id, Some("2".to_string()));
734    }
735
736    #[test]
737    fn test_tree_style_default() {
738        let style = TreeStyle::default();
739        assert_eq!(style.collapsed_icon, "▶ ");
740        assert_eq!(style.expanded_icon, "▼ ");
741        assert_eq!(style.connector_branch, "├── ");
742        assert_eq!(style.connector_last, "└── ");
743    }
744
745    #[test]
746    fn test_tree_view_render() {
747        let nodes = create_test_tree();
748        let state = TreeViewState::new();
749        let tree = TreeView::new(&nodes, &state)
750            .render_item(|node, _| format!("Item: {}", node.data.name));
751
752        let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
753        tree.render(Rect::new(0, 0, 40, 10), &mut buf);
754        // Should not panic
755    }
756
757    #[test]
758    fn test_tree_view_with_style() {
759        let nodes = create_test_tree();
760        let state = TreeViewState::new();
761        let custom_style = TreeStyle {
762            collapsed_icon: "+",
763            expanded_icon: "-",
764            ..TreeStyle::default()
765        };
766        let tree = TreeView::new(&nodes, &state).style(custom_style);
767
768        let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
769        tree.render(Rect::new(0, 0, 40, 10), &mut buf);
770    }
771
772    #[test]
773    fn test_empty_tree() {
774        let nodes: Vec<TreeNode<TestItem>> = vec![];
775        let state = TreeViewState::new();
776        let tree = TreeView::new(&nodes, &state);
777
778        assert_eq!(tree.visible_count(), 0);
779        assert!(tree.flatten_visible().is_empty());
780    }
781}