use super::{Component, EventContext, RenderContext};
use crate::input::{Event, Key};
use crate::scroll::ScrollState;
mod render;
mod traversal;
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "serialization",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct TreeNode<T> {
label: String,
data: T,
children: Vec<TreeNode<T>>,
expanded: bool,
}
impl<T: PartialEq> PartialEq for TreeNode<T> {
fn eq(&self, other: &Self) -> bool {
self.label == other.label
&& self.data == other.data
&& self.children == other.children
&& self.expanded == other.expanded
}
}
impl<T: Clone> TreeNode<T> {
pub fn new(label: impl Into<String>, data: T) -> Self {
Self {
label: label.into(),
data,
children: Vec::new(),
expanded: false,
}
}
pub fn new_expanded(label: impl Into<String>, data: T) -> Self {
Self {
label: label.into(),
data,
children: Vec::new(),
expanded: true,
}
}
pub fn label(&self) -> &str {
&self.label
}
pub fn set_label(&mut self, label: impl Into<String>) {
self.label = label.into();
}
pub fn data(&self) -> &T {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
pub fn children(&self) -> &[TreeNode<T>] {
&self.children
}
pub fn children_mut(&mut self) -> &mut Vec<TreeNode<T>> {
&mut self.children
}
pub fn add_child(&mut self, child: TreeNode<T>) {
self.children.push(child);
}
pub fn has_children(&self) -> bool {
!self.children.is_empty()
}
pub fn is_expanded(&self) -> bool {
self.expanded
}
pub fn set_expanded(&mut self, expanded: bool) {
self.expanded = expanded;
}
pub fn expand(&mut self) {
self.expanded = true;
}
pub fn collapse(&mut self) {
self.expanded = false;
}
pub fn toggle(&mut self) {
self.expanded = !self.expanded;
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TreeMessage {
Down,
Up,
Expand,
Collapse,
Toggle,
Select,
ExpandAll,
CollapseAll,
SetFilter(String),
ClearFilter,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TreeOutput {
Selected(Vec<usize>),
Expanded(Vec<usize>),
Collapsed(Vec<usize>),
FilterChanged(String),
}
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "serialization",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct TreeState<T> {
roots: Vec<TreeNode<T>>,
selected_index: Option<usize>,
filter_text: String,
#[cfg_attr(feature = "serialization", serde(skip))]
scroll: ScrollState,
}
impl<T: Clone + PartialEq> PartialEq for TreeState<T> {
fn eq(&self, other: &Self) -> bool {
self.roots == other.roots
&& self.selected_index == other.selected_index
&& self.filter_text == other.filter_text
}
}
impl<T: Clone> Default for TreeState<T> {
fn default() -> Self {
Self::new(Vec::new())
}
}
impl<T: Clone> TreeState<T> {
pub fn new(roots: Vec<TreeNode<T>>) -> Self {
let selected_index = if roots.is_empty() { None } else { Some(0) };
Self {
roots,
selected_index,
filter_text: String::new(),
scroll: ScrollState::default(),
}
}
pub fn with_selected(mut self, index: usize) -> Self {
if self.roots.is_empty() {
return self;
}
let visible = self.flatten().len();
self.selected_index = Some(index.min(visible.saturating_sub(1)));
self
}
pub fn roots(&self) -> &[TreeNode<T>] {
&self.roots
}
pub fn roots_mut(&mut self) -> &mut Vec<TreeNode<T>> {
&mut self.roots
}
pub fn update_root(&mut self, index: usize, f: impl FnOnce(&mut TreeNode<T>)) {
if let Some(root) = self.roots.get_mut(index) {
f(root);
}
}
pub fn set_roots(&mut self, roots: Vec<TreeNode<T>>) {
self.roots = roots;
self.filter_text.clear();
self.selected_index = if self.roots.is_empty() { None } else { Some(0) };
self.scroll.set_content_length(self.flatten().len());
}
pub fn selected_index(&self) -> Option<usize> {
self.selected_index
}
pub fn selected(&self) -> Option<usize> {
self.selected_index()
}
pub fn set_selected(&mut self, index: Option<usize>) {
match index {
Some(i) => {
if self.roots.is_empty() {
return;
}
let visible = self.flatten().len();
self.selected_index = Some(i.min(visible.saturating_sub(1)));
}
None => self.selected_index = None,
}
}
pub fn is_empty(&self) -> bool {
self.roots.is_empty()
}
pub fn selected_path(&self) -> Option<Vec<usize>> {
let flat = self.flatten();
flat.get(self.selected_index?).map(|n| n.path.clone())
}
pub fn selected_node(&self) -> Option<&TreeNode<T>> {
let path = self.selected_path()?;
self.get_node(&path)
}
pub fn selected_item(&self) -> Option<&TreeNode<T>> {
self.selected_node()
}
pub fn expand_all(&mut self) {
for root in &mut self.roots {
Self::expand_all_recursive(root);
}
self.scroll.set_content_length(self.flatten().len());
}
pub fn collapse_all(&mut self) {
for root in &mut self.roots {
Self::collapse_all_recursive(root);
}
self.selected_index = if self.roots.is_empty() { None } else { Some(0) };
self.scroll.set_content_length(self.flatten().len());
}
pub fn visible_count(&self) -> usize {
self.flatten().len()
}
pub fn filter_text(&self) -> &str {
&self.filter_text
}
pub fn set_filter_text(&mut self, text: &str) {
let prev_path = self.selected_path();
self.filter_text = text.to_string();
self.revalidate_selection(prev_path);
self.scroll.set_content_length(self.flatten().len());
}
pub fn clear_filter(&mut self) {
let prev_path = self.selected_path();
self.filter_text.clear();
self.revalidate_selection(prev_path);
self.scroll.set_content_length(self.flatten().len());
}
}
impl<T: Clone + 'static> TreeState<T> {
pub fn update(&mut self, msg: TreeMessage) -> Option<TreeOutput> {
Tree::update(self, msg)
}
}
pub struct Tree<T>(std::marker::PhantomData<T>);
impl<T: Clone + 'static> Component for Tree<T> {
type State = TreeState<T>;
type Message = TreeMessage;
type Output = TreeOutput;
fn init() -> Self::State {
TreeState::default()
}
fn update(state: &mut Self::State, msg: Self::Message) -> Option<Self::Output> {
match msg {
TreeMessage::SetFilter(ref text) => {
state.set_filter_text(text);
return Some(TreeOutput::FilterChanged(text.clone()));
}
TreeMessage::ClearFilter => {
state.clear_filter();
return Some(TreeOutput::FilterChanged(String::new()));
}
_ => {}
}
let flat = state.flatten();
if flat.is_empty() {
return None;
}
let selected = state.selected_index?;
match msg {
TreeMessage::Down => {
if selected < flat.len() - 1 {
state.selected_index = Some(selected + 1);
}
None
}
TreeMessage::Up => {
if selected > 0 {
state.selected_index = Some(selected - 1);
}
None
}
TreeMessage::Expand => {
if let Some(node_info) = flat.get(selected) {
if node_info.has_children && !node_info.is_expanded {
let path = node_info.path.clone();
if let Some(node) = state.get_node_mut(&path) {
node.expand();
state.scroll.set_content_length(state.flatten().len());
return Some(TreeOutput::Expanded(path));
}
}
}
None
}
TreeMessage::Collapse => {
if let Some(node_info) = flat.get(selected) {
if node_info.has_children && node_info.is_expanded {
let path = node_info.path.clone();
if let Some(node) = state.get_node_mut(&path) {
node.collapse();
let new_flat = state.flatten();
if selected >= new_flat.len() {
state.selected_index = Some(new_flat.len().saturating_sub(1));
}
state.scroll.set_content_length(new_flat.len());
return Some(TreeOutput::Collapsed(path));
}
}
}
None
}
TreeMessage::Toggle => {
if let Some(node_info) = flat.get(selected) {
if node_info.has_children {
let path = node_info.path.clone();
let was_expanded = node_info.is_expanded;
if let Some(node) = state.get_node_mut(&path) {
node.toggle();
if was_expanded {
let new_flat = state.flatten();
if selected >= new_flat.len() {
state.selected_index = Some(new_flat.len().saturating_sub(1));
}
state.scroll.set_content_length(new_flat.len());
return Some(TreeOutput::Collapsed(path));
} else {
state.scroll.set_content_length(state.flatten().len());
return Some(TreeOutput::Expanded(path));
}
}
}
}
None
}
TreeMessage::Select => flat
.get(selected)
.map(|node_info| TreeOutput::Selected(node_info.path.clone())),
TreeMessage::ExpandAll => {
state.expand_all();
None
}
TreeMessage::CollapseAll => {
state.collapse_all();
None
}
TreeMessage::SetFilter(_) | TreeMessage::ClearFilter => {
unreachable!("handled above")
}
}
}
fn handle_event(
_state: &Self::State,
event: &Event,
ctx: &EventContext,
) -> Option<Self::Message> {
if !ctx.focused || ctx.disabled {
return None;
}
if let Some(key) = event.as_key() {
match key.code {
Key::Up | Key::Char('k') => Some(TreeMessage::Up),
Key::Down | Key::Char('j') => Some(TreeMessage::Down),
Key::Left | Key::Char('h') => Some(TreeMessage::Collapse),
Key::Right | Key::Char('l') => Some(TreeMessage::Expand),
Key::Char(' ') => Some(TreeMessage::Toggle),
Key::Enter => Some(TreeMessage::Select),
_ => None,
}
} else {
None
}
}
fn view(state: &Self::State, ctx: &mut RenderContext<'_, '_>) {
render::view(state, ctx);
}
}
#[cfg(test)]
mod snapshot_tests;
#[cfg(test)]
mod tests;