use std::collections::HashSet;
use ratatui::{
buffer::Buffer,
layout::Rect,
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Paragraph, Widget, Wrap},
};
#[derive(Debug, Clone)]
pub struct TreeNode<T> {
pub id: String,
pub data: T,
pub children: Vec<TreeNode<T>>,
}
impl<T> TreeNode<T> {
pub fn new(id: impl Into<String>, data: T) -> Self {
Self {
id: id.into(),
data,
children: Vec::new(),
}
}
pub fn with_children(mut self, children: Vec<TreeNode<T>>) -> Self {
self.children = children;
self
}
pub fn add_child(&mut self, child: TreeNode<T>) {
self.children.push(child);
}
pub fn has_children(&self) -> bool {
!self.children.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct TreeViewState {
pub collapsed: HashSet<String>,
pub selected_index: usize,
pub scroll: u16,
}
impl TreeViewState {
pub fn new() -> Self {
Self::default()
}
pub fn toggle_collapsed(&mut self, id: &str) {
if self.collapsed.contains(id) {
self.collapsed.remove(id);
} else {
self.collapsed.insert(id.to_string());
}
}
pub fn is_collapsed(&self, id: &str) -> bool {
self.collapsed.contains(id)
}
pub fn collapse(&mut self, id: &str) {
self.collapsed.insert(id.to_string());
}
pub fn expand(&mut self, id: &str) {
self.collapsed.remove(id);
}
pub fn select_prev(&mut self) {
self.selected_index = self.selected_index.saturating_sub(1);
}
pub fn select_next(&mut self, total_visible: usize) {
if self.selected_index + 1 < total_visible {
self.selected_index += 1;
}
}
pub fn ensure_visible(&mut self, viewport_height: usize) {
if self.selected_index < self.scroll as usize {
self.scroll = self.selected_index as u16;
} else if self.selected_index >= self.scroll as usize + viewport_height {
self.scroll = (self.selected_index - viewport_height + 1) as u16;
}
}
}
#[derive(Debug, Clone)]
pub struct TreeStyle {
pub selected_style: Style,
pub normal_style: Style,
pub connector_style: Style,
pub icon_style: Style,
pub collapsed_icon: &'static str,
pub expanded_icon: &'static str,
pub connector_branch: &'static str,
pub connector_last: &'static str,
pub connector_vertical: &'static str,
pub connector_space: &'static str,
pub cursor_selected: &'static str,
pub cursor_normal: &'static str,
}
impl Default for TreeStyle {
fn default() -> Self {
Self {
selected_style: Style::default()
.fg(Color::Yellow)
.add_modifier(Modifier::BOLD),
normal_style: Style::default().fg(Color::White),
connector_style: Style::default().fg(Color::DarkGray),
icon_style: Style::default().fg(Color::Cyan),
collapsed_icon: "▶ ",
expanded_icon: "▼ ",
connector_branch: "├── ",
connector_last: "└── ",
connector_vertical: "│ ",
connector_space: " ",
cursor_selected: "> ",
cursor_normal: " ",
}
}
}
impl From<&crate::theme::Theme> for TreeStyle {
fn from(theme: &crate::theme::Theme) -> Self {
let p = &theme.palette;
Self {
selected_style: Style::default().fg(p.primary).add_modifier(Modifier::BOLD),
normal_style: Style::default().fg(p.text),
connector_style: Style::default().fg(p.text_disabled),
icon_style: Style::default().fg(p.secondary),
collapsed_icon: "▶ ",
expanded_icon: "▼ ",
connector_branch: "├── ",
connector_last: "└── ",
connector_vertical: "│ ",
connector_space: " ",
cursor_selected: "> ",
cursor_normal: " ",
}
}
}
impl TreeStyle {
pub fn minimal() -> Self {
Self {
connector_branch: " ",
connector_last: " ",
connector_vertical: " ",
connector_space: " ",
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct FlatNode<'a, T> {
pub node: &'a TreeNode<T>,
pub depth: usize,
pub is_last: bool,
pub parent_is_last: Vec<bool>,
}
pub struct TreeView<'a, T, F>
where
F: Fn(&TreeNode<T>, bool) -> String,
{
nodes: &'a [TreeNode<T>],
state: &'a TreeViewState,
style: TreeStyle,
render_fn: F,
}
impl<'a, T> TreeView<'a, T, fn(&TreeNode<T>, bool) -> String> {
pub fn new(nodes: &'a [TreeNode<T>], state: &'a TreeViewState) -> Self
where
T: std::fmt::Debug,
{
Self {
nodes,
state,
style: TreeStyle::default(),
render_fn: |node, _| format!("{:?}", node.id),
}
}
}
impl<'a, T, F> TreeView<'a, T, F>
where
F: Fn(&TreeNode<T>, bool) -> String,
{
pub fn render_item<G>(self, render_fn: G) -> TreeView<'a, T, G>
where
G: Fn(&TreeNode<T>, bool) -> String,
{
TreeView {
nodes: self.nodes,
state: self.state,
style: self.style,
render_fn,
}
}
pub fn style(mut self, style: TreeStyle) -> Self {
self.style = style;
self
}
pub fn theme(self, theme: &crate::theme::Theme) -> Self {
self.style(TreeStyle::from(theme))
}
fn flatten_visible(&self) -> Vec<FlatNode<'a, T>> {
let mut result = Vec::new();
self.flatten_nodes(self.nodes, 0, &mut result, &[]);
result
}
fn flatten_nodes(
&self,
nodes: &'a [TreeNode<T>],
depth: usize,
result: &mut Vec<FlatNode<'a, T>>,
parent_is_last: &[bool],
) {
let count = nodes.len();
for (idx, node) in nodes.iter().enumerate() {
let is_last = idx == count - 1;
result.push(FlatNode {
node,
depth,
is_last,
parent_is_last: parent_is_last.to_vec(),
});
if node.has_children() && !self.state.is_collapsed(&node.id) {
let mut new_parent_is_last = parent_is_last.to_vec();
new_parent_is_last.push(is_last);
self.flatten_nodes(&node.children, depth + 1, result, &new_parent_is_last);
}
}
}
pub fn visible_count(&self) -> usize {
self.flatten_visible().len()
}
fn build_lines(&self, area: Rect) -> Vec<Line<'static>> {
let visible = self.flatten_visible();
let mut lines = Vec::new();
let scroll = self.state.scroll as usize;
let viewport_height = area.height as usize;
for (idx, flat_node) in visible
.iter()
.enumerate()
.skip(scroll)
.take(viewport_height)
{
let is_selected = idx == self.state.selected_index;
let mut spans = Vec::new();
let cursor = if is_selected {
self.style.cursor_selected
} else {
self.style.cursor_normal
};
spans.push(Span::styled(
cursor.to_string(),
if is_selected {
self.style.selected_style
} else {
self.style.normal_style
},
));
for &parent_is_last in flat_node.parent_is_last.iter() {
let connector = if parent_is_last {
self.style.connector_space
} else {
self.style.connector_vertical
};
spans.push(Span::styled(
connector.to_string(),
self.style.connector_style,
));
}
if flat_node.depth > 0 {
let connector = if flat_node.is_last {
self.style.connector_last
} else {
self.style.connector_branch
};
spans.push(Span::styled(
connector.to_string(),
self.style.connector_style,
));
}
if flat_node.node.has_children() {
let icon = if self.state.is_collapsed(&flat_node.node.id) {
self.style.collapsed_icon
} else {
self.style.expanded_icon
};
spans.push(Span::styled(icon.to_string(), self.style.icon_style));
}
let content = (self.render_fn)(flat_node.node, is_selected);
spans.push(Span::styled(
content,
if is_selected {
self.style.selected_style
} else {
self.style.normal_style
},
));
lines.push(Line::from(spans));
}
lines
}
}
impl<'a, T, F> Widget for TreeView<'a, T, F>
where
F: Fn(&TreeNode<T>, bool) -> String,
{
fn render(self, area: Rect, buf: &mut Buffer) {
let lines = self.build_lines(area);
let paragraph = Paragraph::new(lines).wrap(Wrap { trim: false });
paragraph.render(area, buf);
}
}
pub fn get_selected_id<T: std::fmt::Debug>(
nodes: &[TreeNode<T>],
state: &TreeViewState,
) -> Option<String> {
let tree = TreeView::new(nodes, state);
let visible = tree.flatten_visible();
visible.get(state.selected_index).map(|f| f.node.id.clone())
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestItem {
name: String,
}
fn create_test_tree() -> Vec<TreeNode<TestItem>> {
vec![
TreeNode::new(
"1",
TestItem {
name: "Root 1".into(),
},
)
.with_children(vec![
TreeNode::new(
"1.1",
TestItem {
name: "Child 1.1".into(),
},
),
TreeNode::new(
"1.2",
TestItem {
name: "Child 1.2".into(),
},
),
]),
TreeNode::new(
"2",
TestItem {
name: "Root 2".into(),
},
),
]
}
fn create_deep_tree() -> Vec<TreeNode<TestItem>> {
vec![
TreeNode::new(
"root",
TestItem {
name: "Root".into(),
},
)
.with_children(vec![
TreeNode::new(
"level1",
TestItem {
name: "Level 1".into(),
},
)
.with_children(vec![
TreeNode::new(
"level2",
TestItem {
name: "Level 2".into(),
},
)
.with_children(vec![TreeNode::new(
"level3",
TestItem {
name: "Level 3".into(),
},
)]),
]),
]),
]
}
#[test]
fn test_tree_node_new() {
let node: TreeNode<TestItem> = TreeNode::new(
"test-id",
TestItem {
name: "Test".into(),
},
);
assert_eq!(node.id, "test-id");
assert_eq!(node.data.name, "Test");
assert!(node.children.is_empty());
}
#[test]
fn test_tree_node_with_children() {
let node: TreeNode<TestItem> = TreeNode::new(
"parent",
TestItem {
name: "Parent".into(),
},
)
.with_children(vec![
TreeNode::new(
"child1",
TestItem {
name: "Child 1".into(),
},
),
TreeNode::new(
"child2",
TestItem {
name: "Child 2".into(),
},
),
]);
assert_eq!(node.children.len(), 2);
}
#[test]
fn test_tree_node_has_children() {
let leaf: TreeNode<TestItem> = TreeNode::new(
"leaf",
TestItem {
name: "Leaf".into(),
},
);
assert!(!leaf.has_children());
let parent: TreeNode<TestItem> = TreeNode::new(
"parent",
TestItem {
name: "Parent".into(),
},
)
.with_children(vec![leaf.clone()]);
assert!(parent.has_children());
}
#[test]
fn test_tree_state_new() {
let state = TreeViewState::new();
assert_eq!(state.selected_index, 0);
assert!(state.collapsed.is_empty());
}
#[test]
fn test_tree_state() {
let mut state = TreeViewState::new();
assert!(!state.is_collapsed("1"));
state.collapse("1");
assert!(state.is_collapsed("1"));
state.toggle_collapsed("1");
assert!(!state.is_collapsed("1"));
}
#[test]
fn test_tree_state_expand() {
let mut state = TreeViewState::new();
state.collapse("node1");
state.collapse("node2");
assert!(state.is_collapsed("node1"));
state.expand("node1");
assert!(!state.is_collapsed("node1"));
assert!(state.is_collapsed("node2"));
}
#[test]
fn test_tree_state_collapse_multiple() {
let mut state = TreeViewState::new();
state.collapse("1");
state.collapse("2");
assert!(state.is_collapsed("1"));
assert!(state.is_collapsed("2"));
state.expand("1");
state.expand("2");
assert!(!state.is_collapsed("1"));
assert!(!state.is_collapsed("2"));
}
#[test]
fn test_tree_state_navigation() {
let mut state = TreeViewState::new();
assert_eq!(state.selected_index, 0);
state.select_next(5);
assert_eq!(state.selected_index, 1);
state.select_next(5);
state.select_next(5);
state.select_next(5);
assert_eq!(state.selected_index, 4);
state.select_next(5); assert_eq!(state.selected_index, 4);
state.select_prev();
assert_eq!(state.selected_index, 3);
state.select_prev();
state.select_prev();
state.select_prev();
state.select_prev(); assert_eq!(state.selected_index, 0);
}
#[test]
fn test_tree_state_ensure_visible() {
let mut state = TreeViewState::new();
state.selected_index = 15;
state.scroll = 5;
state.ensure_visible(10);
assert!(state.scroll >= 6);
state.selected_index = 2;
state.scroll = 10;
state.ensure_visible(10);
assert_eq!(state.scroll, 2);
}
#[test]
fn test_tree_state_ensure_visible_zero_viewport() {
let mut state = TreeViewState::new();
state.scroll = 5;
state.selected_index = 10;
state.ensure_visible(0);
assert_eq!(state.scroll, 11); }
#[test]
fn test_flatten_visible() {
let nodes = create_test_tree();
let state = TreeViewState::new();
let tree = TreeView::new(&nodes, &state);
let visible = tree.flatten_visible();
assert_eq!(visible.len(), 4); }
#[test]
fn test_flatten_with_collapsed() {
let nodes = create_test_tree();
let mut state = TreeViewState::new();
state.collapse("1");
let tree = TreeView::new(&nodes, &state);
let visible = tree.flatten_visible();
assert_eq!(visible.len(), 2); }
#[test]
fn test_flatten_deep_tree() {
let nodes = create_deep_tree();
let state = TreeViewState::new();
let tree = TreeView::new(&nodes, &state);
let visible = tree.flatten_visible();
assert_eq!(visible.len(), 4);
assert_eq!(visible[0].depth, 0);
assert_eq!(visible[1].depth, 1);
assert_eq!(visible[2].depth, 2);
assert_eq!(visible[3].depth, 3);
}
#[test]
fn test_visible_count() {
let nodes = create_test_tree();
let state = TreeViewState::new();
let tree = TreeView::new(&nodes, &state);
assert_eq!(tree.visible_count(), 4);
let mut collapsed_state = TreeViewState::new();
collapsed_state.collapse("1");
let collapsed_tree = TreeView::new(&nodes, &collapsed_state);
assert_eq!(collapsed_tree.visible_count(), 2);
}
#[test]
fn test_selection_navigation() {
let nodes = create_test_tree();
let mut state = TreeViewState::new();
let tree = TreeView::new(&nodes, &state);
let count = tree.visible_count();
assert_eq!(state.selected_index, 0);
state.select_next(count);
assert_eq!(state.selected_index, 1);
state.select_prev();
assert_eq!(state.selected_index, 0);
}
#[test]
fn test_get_selected_id() {
let nodes = create_test_tree();
let mut state = TreeViewState::new();
let id = get_selected_id(&nodes, &state);
assert_eq!(id, Some("1".to_string()));
state.selected_index = 2;
let id = get_selected_id(&nodes, &state);
assert_eq!(id, Some("1.2".to_string()));
state.selected_index = 3;
let id = get_selected_id(&nodes, &state);
assert_eq!(id, Some("2".to_string()));
}
#[test]
fn test_get_selected_id_with_collapsed() {
let nodes = create_test_tree();
let mut state = TreeViewState::new();
state.collapse("1");
state.selected_index = 1;
let id = get_selected_id(&nodes, &state);
assert_eq!(id, Some("2".to_string()));
}
#[test]
fn test_tree_style_default() {
let style = TreeStyle::default();
assert_eq!(style.collapsed_icon, "▶ ");
assert_eq!(style.expanded_icon, "▼ ");
assert_eq!(style.connector_branch, "├── ");
assert_eq!(style.connector_last, "└── ");
}
#[test]
fn test_tree_view_render() {
let nodes = create_test_tree();
let state = TreeViewState::new();
let tree = TreeView::new(&nodes, &state)
.render_item(|node, _| format!("Item: {}", node.data.name));
let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
tree.render(Rect::new(0, 0, 40, 10), &mut buf);
}
#[test]
fn test_tree_view_with_style() {
let nodes = create_test_tree();
let state = TreeViewState::new();
let custom_style = TreeStyle {
collapsed_icon: "+",
expanded_icon: "-",
..TreeStyle::default()
};
let tree = TreeView::new(&nodes, &state).style(custom_style);
let mut buf = Buffer::empty(Rect::new(0, 0, 40, 10));
tree.render(Rect::new(0, 0, 40, 10), &mut buf);
}
#[test]
fn test_empty_tree() {
let nodes: Vec<TreeNode<TestItem>> = vec![];
let state = TreeViewState::new();
let tree = TreeView::new(&nodes, &state);
assert_eq!(tree.visible_count(), 0);
assert!(tree.flatten_visible().is_empty());
}
}