1use std::collections::HashSet;
7
8use ratatui::buffer::Buffer;
9use ratatui::layout::Rect;
10use ratatui::style::Style;
11use ratatui::widgets::{Block, Scrollbar, ScrollbarState, StatefulWidget, Widget};
12use unicode_width::UnicodeWidthStr as _;
13
14pub use crate::flatten::Flattened;
15pub use crate::tree_item::TreeItem;
16pub use crate::tree_state::TreeState;
17
18mod flatten;
19mod tree_item;
20mod tree_state;
21
22#[must_use]
52#[derive(Debug, Clone)]
53pub struct Tree<'a, Identifier> {
54 items: &'a [TreeItem<'a, Identifier>],
55
56 block: Option<Block<'a>>,
57 scrollbar: Option<Scrollbar<'a>>,
58 style: Style,
60
61 highlight_style: Style,
63 highlight_symbol: &'a str,
65
66 node_closed_symbol: &'a str,
68 node_open_symbol: &'a str,
70 node_no_children_symbol: &'a str,
72}
73
74impl<'a, Identifier> Tree<'a, Identifier>
75where
76 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
77{
78 pub fn new(items: &'a [TreeItem<'a, Identifier>]) -> std::io::Result<Self> {
84 let identifiers = items
85 .iter()
86 .map(|item| &item.identifier)
87 .collect::<HashSet<_>>();
88 if identifiers.len() != items.len() {
89 return Err(std::io::Error::new(
90 std::io::ErrorKind::AlreadyExists,
91 "The items contain duplicate identifiers",
92 ));
93 }
94
95 Ok(Self {
96 items,
97 block: None,
98 scrollbar: None,
99 style: Style::new(),
100 highlight_style: Style::new(),
101 highlight_symbol: "",
102 node_closed_symbol: "\u{25b6} ", node_open_symbol: "\u{25bc} ", node_no_children_symbol: " ",
105 })
106 }
107
108 pub fn block(mut self, block: Block<'a>) -> Self {
109 self.block = Some(block);
110 self
111 }
112
113 pub const fn experimental_scrollbar(mut self, scrollbar: Option<Scrollbar<'a>>) -> Self {
119 self.scrollbar = scrollbar;
120 self
121 }
122
123 pub const fn style(mut self, style: Style) -> Self {
124 self.style = style;
125 self
126 }
127
128 pub const fn highlight_style(mut self, style: Style) -> Self {
129 self.highlight_style = style;
130 self
131 }
132
133 pub const fn highlight_symbol(mut self, highlight_symbol: &'a str) -> Self {
134 self.highlight_symbol = highlight_symbol;
135 self
136 }
137
138 pub const fn node_closed_symbol(mut self, symbol: &'a str) -> Self {
139 self.node_closed_symbol = symbol;
140 self
141 }
142
143 pub const fn node_open_symbol(mut self, symbol: &'a str) -> Self {
144 self.node_open_symbol = symbol;
145 self
146 }
147
148 pub const fn node_no_children_symbol(mut self, symbol: &'a str) -> Self {
149 self.node_no_children_symbol = symbol;
150 self
151 }
152}
153
154#[test]
155#[should_panic = "duplicate identifiers"]
156fn tree_new_errors_with_duplicate_identifiers() {
157 let item = TreeItem::new_leaf("same", "text");
158 let another = item.clone();
159 let items = [item, another];
160 let _: Tree<_> = Tree::new(&items).unwrap();
161}
162
163impl<Identifier> StatefulWidget for Tree<'_, Identifier>
164where
165 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
166{
167 type State = TreeState<Identifier>;
168
169 #[allow(clippy::too_many_lines)]
170 fn render(self, full_area: Rect, buf: &mut Buffer, state: &mut Self::State) {
171 buf.set_style(full_area, self.style);
172
173 let area = self.block.map_or(full_area, |block| {
175 let inner_area = block.inner(full_area);
176 block.render(full_area, buf);
177 inner_area
178 });
179
180 state.last_area = area;
181 state.last_rendered_identifiers.clear();
182 if area.width < 1 || area.height < 1 {
183 return;
184 }
185
186 let visible = state.flatten(self.items);
187 state.last_biggest_index = visible.len().saturating_sub(1);
188 if visible.is_empty() {
189 return;
190 }
191 let available_height = area.height as usize;
192
193 let ensure_index_in_view =
194 if state.ensure_selected_in_view_on_next_render && !state.selected.is_empty() {
195 visible
196 .iter()
197 .position(|flattened| flattened.identifier == state.selected)
198 } else {
199 None
200 };
201
202 let mut start = state.offset.min(state.last_biggest_index);
204
205 if let Some(ensure_index_in_view) = ensure_index_in_view {
206 start = start.min(ensure_index_in_view);
207 }
208
209 let mut end = start;
210 let mut height = 0;
211 for item_height in visible
212 .iter()
213 .skip(start)
214 .map(|flattened| flattened.item.height())
215 {
216 if height + item_height > available_height {
217 break;
218 }
219 height += item_height;
220 end += 1;
221 }
222
223 if let Some(ensure_index_in_view) = ensure_index_in_view {
224 while ensure_index_in_view >= end {
225 height += visible[end].item.height();
226 end += 1;
227 while height > available_height {
228 height = height.saturating_sub(visible[start].item.height());
229 start += 1;
230 }
231 }
232 }
233
234 state.offset = start;
235 state.ensure_selected_in_view_on_next_render = false;
236
237 if let Some(scrollbar) = self.scrollbar {
238 let mut scrollbar_state = ScrollbarState::new(visible.len().saturating_sub(height))
239 .position(start)
240 .viewport_content_length(height);
241 let scrollbar_area = Rect {
242 y: area.y,
244 height: area.height,
245 x: full_area.x,
247 width: full_area.width,
248 };
249 scrollbar.render(scrollbar_area, buf, &mut scrollbar_state);
250 }
251
252 let blank_symbol = " ".repeat(self.highlight_symbol.width());
253
254 let mut current_height = 0;
255 let has_selection = !state.selected.is_empty();
256 #[allow(clippy::cast_possible_truncation)]
257 for flattened in visible.iter().skip(state.offset).take(end - start) {
258 let Flattened { identifier, item } = flattened;
259
260 let x = area.x;
261 let y = area.y + current_height;
262 let height = item.height() as u16;
263 current_height += height;
264
265 let area = Rect {
266 x,
267 y,
268 width: area.width,
269 height,
270 };
271
272 let text = &item.text;
273 let item_style = text.style;
274
275 let is_selected = state.selected == *identifier;
276 let after_highlight_symbol_x = if has_selection {
277 let symbol = if is_selected {
278 self.highlight_symbol
279 } else {
280 &blank_symbol
281 };
282 let (x, _) = buf.set_stringn(x, y, symbol, area.width as usize, item_style);
283 x
284 } else {
285 x
286 };
287
288 let after_depth_x = {
289 let indent_width = flattened.depth() * 2;
290 let (after_indent_x, _) = buf.set_stringn(
291 after_highlight_symbol_x,
292 y,
293 " ".repeat(indent_width),
294 indent_width,
295 item_style,
296 );
297 let symbol = if item.children.is_empty() {
298 self.node_no_children_symbol
299 } else if state.opened.contains(identifier) {
300 self.node_open_symbol
301 } else {
302 self.node_closed_symbol
303 };
304 let max_width = area.width.saturating_sub(after_indent_x - x);
305 let (x, _) =
306 buf.set_stringn(after_indent_x, y, symbol, max_width as usize, item_style);
307 x
308 };
309
310 let text_area = Rect {
311 x: after_depth_x,
312 width: area.width.saturating_sub(after_depth_x - x),
313 ..area
314 };
315 text.render(text_area, buf);
316
317 if is_selected {
318 buf.set_style(area, self.highlight_style);
319 }
320
321 state
322 .last_rendered_identifiers
323 .push((area.y, identifier.clone()));
324 }
325 state.last_identifiers = visible
326 .into_iter()
327 .map(|flattened| flattened.identifier)
328 .collect();
329 }
330}
331
332impl<Identifier> Widget for Tree<'_, Identifier>
333where
334 Identifier: Clone + Eq + core::hash::Hash,
335{
336 fn render(self, area: Rect, buf: &mut Buffer) {
337 let mut state = TreeState::default();
338 StatefulWidget::render(self, area, buf, &mut state);
339 }
340}
341
342#[cfg(test)]
343mod render_tests {
344 use super::*;
345
346 #[must_use]
347 #[track_caller]
348 fn render(width: u16, height: u16, state: &mut TreeState<&'static str>) -> Buffer {
349 let items = TreeItem::example();
350 let tree = Tree::new(&items).unwrap();
351 let area = Rect::new(0, 0, width, height);
352 let mut buffer = Buffer::empty(area);
353 StatefulWidget::render(tree, area, &mut buffer, state);
354 buffer
355 }
356
357 #[test]
358 fn does_not_panic() {
359 _ = render(0, 0, &mut TreeState::default());
360 _ = render(10, 0, &mut TreeState::default());
361 _ = render(0, 10, &mut TreeState::default());
362 _ = render(10, 10, &mut TreeState::default());
363 }
364
365 #[test]
366 fn nothing_open() {
367 let buffer = render(10, 4, &mut TreeState::default());
368 #[rustfmt::skip]
369 let expected = Buffer::with_lines([
370 " Alfa ",
371 "▶ Bravo ",
372 " Hotel ",
373 " ",
374 ]);
375 assert_eq!(buffer, expected);
376 }
377
378 #[test]
379 fn depth_one() {
380 let mut state = TreeState::default();
381 state.open(vec!["b"]);
382 let buffer = render(13, 7, &mut state);
383 let expected = Buffer::with_lines([
384 " Alfa ",
385 "▼ Bravo ",
386 " Charlie ",
387 " ▶ Delta ",
388 " Golf ",
389 " Hotel ",
390 " ",
391 ]);
392 assert_eq!(buffer, expected);
393 }
394
395 #[test]
396 fn depth_two() {
397 let mut state = TreeState::default();
398 state.open(vec!["b"]);
399 state.open(vec!["b", "d"]);
400 let buffer = render(15, 9, &mut state);
401 let expected = Buffer::with_lines([
402 " Alfa ",
403 "▼ Bravo ",
404 " Charlie ",
405 " ▼ Delta ",
406 " Echo ",
407 " Foxtrot ",
408 " Golf ",
409 " Hotel ",
410 " ",
411 ]);
412 assert_eq!(buffer, expected);
413 }
414}