1use ratatui::buffer::Buffer;
9use ratatui::layout::Rect;
10use ratatui::style::Style;
11use ratatui::text::ToText;
12use ratatui::widgets::{Block, Scrollbar, ScrollbarState, StatefulWidget, Widget};
13use std::collections::HashSet;
14use std::fmt::Display;
15use std::hash::Hash;
16use unicode_width::UnicodeWidthStr;
17
18pub use crate::flatten::Flattened;
19pub use crate::tree_item::TreeItem;
20pub use crate::tree_state::TreeState;
21
22mod flatten;
23mod tree_item;
24mod tree_state;
25
26#[must_use]
53#[derive(Debug, Clone)]
54pub struct Tree<'a, T>
55where
56 T: ToText + Clone + Default + Display + Hash + PartialEq + Eq,
57{
58 items: &'a [TreeItem<T>],
59
60 block: Option<Block<'a>>,
61 scrollbar: Option<Scrollbar<'a>>,
62 style: Style,
64
65 highlight_style: Style,
67 highlight_symbol: &'a str,
69
70 node_closed_symbol: &'a str,
72 node_open_symbol: &'a str,
74 node_no_children_symbol: &'a str,
76}
77
78impl<'a, T> Tree<'a, T>
79where
80 T: ToText + Clone + Default + Display + Hash + PartialEq + Eq,
81{
82 pub fn new(items: &'a [TreeItem<T>]) -> std::io::Result<Self> {
88 let identifiers = items
89 .iter()
90 .map(|item| &item.identifier)
91 .collect::<HashSet<_>>();
92 if identifiers.len() != items.len() {
93 return Err(std::io::Error::new(
94 std::io::ErrorKind::AlreadyExists,
95 "The items contain duplicate identifiers",
96 ));
97 }
98
99 Ok(Self {
100 items,
101 block: None,
102 scrollbar: None,
103 style: Style::new(),
104 highlight_style: Style::new(),
105 highlight_symbol: "",
106 node_closed_symbol: "\u{25b6} ", node_open_symbol: "\u{25bc} ", node_no_children_symbol: " ",
109 })
110 }
111
112 #[allow(clippy::missing_const_for_fn)]
113 pub fn block(mut self, block: Block<'a>) -> Self {
114 self.block = Some(block);
115 self
116 }
117
118 pub const fn experimental_scrollbar(mut self, scrollbar: Option<Scrollbar<'a>>) -> Self {
124 self.scrollbar = scrollbar;
125 self
126 }
127
128 pub const fn style(mut self, style: Style) -> Self {
129 self.style = style;
130 self
131 }
132
133 pub const fn highlight_style(mut self, style: Style) -> Self {
134 self.highlight_style = style;
135 self
136 }
137
138 pub const fn highlight_symbol(mut self, highlight_symbol: &'a str) -> Self {
139 self.highlight_symbol = highlight_symbol;
140 self
141 }
142
143 pub const fn node_closed_symbol(mut self, symbol: &'a str) -> Self {
144 self.node_closed_symbol = symbol;
145 self
146 }
147
148 pub const fn node_open_symbol(mut self, symbol: &'a str) -> Self {
149 self.node_open_symbol = symbol;
150 self
151 }
152
153 pub const fn node_no_children_symbol(mut self, symbol: &'a str) -> Self {
154 self.node_no_children_symbol = symbol;
155 self
156 }
157}
158
159#[test]
160#[should_panic = "duplicate identifiers"]
161fn tree_new_errors_with_duplicate_identifiers() {
162 let item = TreeItem::new_leaf("text".to_owned());
163 let another = item.clone();
164 let items = [item, another];
165 let _ = Tree::new(&items).unwrap();
166}
167
168impl<'a, T> StatefulWidget for Tree<'a, T>
169where
170 T: ToText + Clone + Default + Display + Hash + PartialEq + Eq,
171{
172 type State = TreeState;
173
174 #[allow(clippy::too_many_lines)]
175 fn render(self, full_area: Rect, buf: &mut Buffer, state: &mut Self::State) {
176 buf.set_style(full_area, self.style);
177
178 let area = self.block.map_or(full_area, |block| {
180 let inner_area = block.inner(full_area);
181 block.render(full_area, buf);
182 inner_area
183 });
184
185 state.last_area = area;
186 state.last_rendered_identifiers.clear();
187 if area.width < 1 || area.height < 1 {
188 return;
189 }
190
191 let visible = state.flatten(self.items);
192 state.last_biggest_index = visible.len().saturating_sub(1);
193 if visible.is_empty() {
194 return;
195 }
196 let available_height = area.height as usize;
197
198 let ensure_index_in_view =
199 if state.ensure_selected_in_view_on_next_render && !state.selected.is_empty() {
200 visible
201 .iter()
202 .position(|flattened| flattened.identifier == state.selected)
203 } else {
204 None
205 };
206
207 let mut start = state.offset.min(state.last_biggest_index);
209
210 if let Some(ensure_index_in_view) = ensure_index_in_view {
211 start = start.min(ensure_index_in_view);
212 }
213
214 let mut end = start;
215 let mut height = 0;
216 for item_height in visible
217 .iter()
218 .skip(start)
219 .map(|flattened| flattened.item.height())
220 {
221 if height + item_height > available_height {
222 break;
223 }
224 height += item_height;
225 end += 1;
226 }
227
228 if let Some(ensure_index_in_view) = ensure_index_in_view {
229 while ensure_index_in_view >= end {
230 height += visible[end].item.height();
231 end += 1;
232 while height > available_height {
233 height = height.saturating_sub(visible[start].item.height());
234 start += 1;
235 }
236 }
237 }
238
239 state.offset = start;
240 state.ensure_selected_in_view_on_next_render = false;
241
242 if let Some(scrollbar) = self.scrollbar {
243 let mut scrollbar_state = ScrollbarState::new(visible.len().saturating_sub(height))
244 .position(start)
245 .viewport_content_length(height);
246 let scrollbar_area = Rect {
247 y: area.y,
249 height: area.height,
250 x: full_area.x,
252 width: full_area.width,
253 };
254 scrollbar.render(scrollbar_area, buf, &mut scrollbar_state);
255 }
256
257 let blank_symbol = " ".repeat(self.highlight_symbol.width());
258
259 let mut current_height = 0;
260 let has_selection = !state.selected.is_empty();
261 #[allow(clippy::cast_possible_truncation)]
262 for flattened in visible.iter().skip(state.offset).take(end - start) {
263 let Flattened { identifier, item } = flattened;
264
265 let x = area.x;
266 let y = area.y + current_height;
267 let height = item.height() as u16;
268 current_height += height;
269
270 let area = Rect {
271 x,
272 y,
273 width: area.width,
274 height,
275 };
276
277 let text = item.content.to_text();
278 let item_style = text.style;
279
280 let is_selected = state.selected == *identifier;
281 let after_highlight_symbol_x = if has_selection {
282 let symbol = if is_selected {
283 self.highlight_symbol
284 } else {
285 &blank_symbol
286 };
287 let (x, _) = buf.set_stringn(x, y, symbol, area.width as usize, item_style);
288 x
289 } else {
290 x
291 };
292
293 let after_depth_x = {
294 let indent_width = flattened.depth() * 2;
295 let (after_indent_x, _) = buf.set_stringn(
296 after_highlight_symbol_x,
297 y,
298 " ".repeat(indent_width),
299 indent_width,
300 item_style,
301 );
302 let symbol = if item.children.is_empty() {
303 self.node_no_children_symbol
304 } else if state.opened.contains(identifier.as_slice()) {
305 self.node_open_symbol
306 } else {
307 self.node_closed_symbol
308 };
309 let max_width = area.width.saturating_sub(after_indent_x - x);
310 let (x, _) =
311 buf.set_stringn(after_indent_x, y, symbol, max_width as usize, item_style);
312 x
313 };
314
315 let text_area = Rect {
316 x: after_depth_x,
317 width: area.width.saturating_sub(after_depth_x - x),
318 ..area
319 };
320 text.render(text_area, buf);
321
322 if is_selected {
323 buf.set_style(area, self.highlight_style);
324 }
325
326 state
327 .last_rendered_identifiers
328 .push((area.y, identifier.clone()));
329 }
330 state.last_identifiers = visible
331 .into_iter()
332 .map(|flattened| flattened.identifier)
333 .collect();
334 }
335}
336
337impl<'a, T> Widget for Tree<'a, T>
338where
339 T: ToText + Clone + Default + Display + Hash + PartialEq + Eq,
340{
341 fn render(self, area: Rect, buf: &mut Buffer) {
342 let mut state = TreeState::default();
343 StatefulWidget::render(self, area, buf, &mut state);
344 }
345}
346
347#[cfg(test)]
348mod render_tests {
349 use super::*;
350 use std::hash::{DefaultHasher, Hasher};
351
352 #[must_use]
353 #[track_caller]
354 fn render(width: u16, height: u16, state: &mut TreeState) -> Buffer {
355 let items = TreeItem::example();
356 let tree = Tree::new(&items).unwrap();
357 let area = Rect::new(0, 0, width, height);
358 let mut buffer = Buffer::empty(area);
359 StatefulWidget::render(tree, area, &mut buffer, state);
360 buffer
361 }
362
363 #[test]
364 fn does_not_panic() {
365 _ = render(0, 0, &mut TreeState::default());
366 _ = render(10, 0, &mut TreeState::default());
367 _ = render(0, 10, &mut TreeState::default());
368 _ = render(10, 10, &mut TreeState::default());
369 }
370
371 #[test]
372 fn nothing_open() {
373 let buffer = render(10, 4, &mut TreeState::default());
374 #[rustfmt::skip]
375 let expected = Buffer::with_lines([
376 " Alfa ",
377 "▶ Bravo ",
378 " Hotel ",
379 " ",
380 ]);
381 assert_eq!(buffer, expected);
382 }
383
384 #[test]
385 fn depth_one() {
386 let mut state = TreeState::default();
387 let mut hasher = DefaultHasher::new();
388 "Bravo".hash(&mut hasher);
389 state.open(vec![hasher.finish()]);
390 let buffer = render(13, 7, &mut state);
391 let expected = Buffer::with_lines([
392 " Alfa ",
393 "▼ Bravo ",
394 " Charlie ",
395 " ▶ Delta ",
396 " Golf ",
397 " Hotel ",
398 " ",
399 ]);
400 assert_eq!(buffer, expected);
401 }
402
403 #[test]
404 fn depth_two() {
405 let mut state = TreeState::default();
406 let mut hasher = DefaultHasher::new();
407 "Bravo".hash(&mut hasher);
408 let bravo_hash = hasher.finish();
409 let mut hasher = DefaultHasher::new();
410 "Delta".hash(&mut hasher);
411 state.open(vec![bravo_hash]);
412 state.open(vec![bravo_hash, hasher.finish()]);
413 let buffer = render(15, 9, &mut state);
414 let expected = Buffer::with_lines([
415 " Alfa ",
416 "▼ Bravo ",
417 " Charlie ",
418 " ▼ Delta ",
419 " Echo ",
420 " Foxtrot ",
421 " Golf ",
422 " Hotel ",
423 " ",
424 ]);
425 assert_eq!(buffer, expected);
426 }
427}