Skip to main content

ratatui_tag_picker/
lib.rs

1use std::collections::BTreeSet;
2
3use ratatui_core::{
4    buffer::Buffer,
5    layout::{Constraint, Layout, Rect},
6    style::{Color, Style},
7    text::{Line, Span},
8    widgets::{StatefulWidget, Widget},
9};
10use ratatui_widgets::{
11    block::Block,
12    borders::Borders,
13    paragraph::{Paragraph, Wrap},
14};
15
16const DEFAULT_INPUT_HEIGHT: u16 = 5;
17const MIN_INPUT_HEIGHT: u16 = 2;
18const SELECTED_SEPARATOR: &str = " | ";
19
20/// Tag picker widget.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct TagPicker {
23    available_tags: Vec<String>,
24    input_height: u16,
25    accent_color: Color,
26}
27
28/// Tag picker widget state.
29///
30/// [TagPicker] follows Ratatui List widget's implementation pattern and
31/// implements [StatefulWidget] so it uses external state.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct TagPickerState {
34    selected_indices: Vec<usize>,
35    focus: TagPickerFocus,
36    input: String,
37    match_cursor: usize,
38    selected_cursor: usize,
39    selected_scroll_x: usize,
40}
41
42impl Default for TagPickerState {
43    fn default() -> Self {
44        Self {
45            selected_indices: Vec::new(),
46            focus: TagPickerFocus::Input,
47            input: String::new(),
48            match_cursor: 0,
49            selected_cursor: 0,
50            selected_scroll_x: 0,
51        }
52    }
53}
54
55/// Optional configuration for [TagPicker].
56pub struct TagPickerConfig {
57    /// Set the height of the widget's upper part made of the text input and
58    /// matches.
59    pub input_height: u16,
60    /// Navigation highlight color.
61    pub accent_color: Color,
62}
63
64impl Default for TagPickerConfig {
65    fn default() -> Self {
66        Self {
67            input_height: DEFAULT_INPUT_HEIGHT,
68            accent_color: Color::Yellow,
69        }
70    }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74enum TagPickerFocus {
75    Input,
76    SelectedTags,
77}
78
79impl TagPicker {
80    /// Create a widget with a set of tags to pick from.
81    pub fn new<I, S>(available_tags: I) -> Self
82    where
83        I: IntoIterator<Item = S>,
84        S: Into<String>,
85    {
86        Self::with_config(available_tags, TagPickerConfig::default())
87    }
88
89    /// Create a picker with additional configuration.
90    pub fn with_config<I, S>(available_tags: I, config: TagPickerConfig) -> Self
91    where
92        I: IntoIterator<Item = S>,
93        S: Into<String>,
94    {
95        let available_tags = available_tags
96            .into_iter()
97            .map(Into::into)
98            .collect::<BTreeSet<_>>()
99            .into_iter()
100            .collect();
101        Self {
102            available_tags,
103            input_height: config.input_height.max(MIN_INPUT_HEIGHT),
104            accent_color: config.accent_color,
105        }
106    }
107
108    fn tag(&self, index: usize) -> Option<&str> {
109        self.available_tags.get(index).map(String::as_str)
110    }
111
112    fn matched_tag_indices(&self, state: &TagPickerState) -> Vec<usize> {
113        let selected_indices = state
114            .selected_indices
115            .iter()
116            .copied()
117            .collect::<BTreeSet<_>>();
118        let mut matches = self
119            .available_tags
120            .iter()
121            .enumerate()
122            .filter(|(index, _)| !selected_indices.contains(index))
123            .filter_map(|(index, tag)| fuzzy_score(&state.input, tag).map(|score| (score, index)))
124            .collect::<Vec<_>>();
125
126        matches.sort_unstable_by(|(score_a, index_a), (score_b, index_b)| {
127            let tag_a = &self.available_tags[*index_a];
128            let tag_b = &self.available_tags[*index_b];
129            score_b
130                .cmp(score_a)
131                .then_with(|| tag_a.to_lowercase().cmp(&tag_b.to_lowercase()))
132        });
133
134        matches.into_iter().map(|(_, index)| index).collect()
135    }
136
137    fn retain_selected_indices(&self, state: &mut TagPickerState) {
138        let mut seen = BTreeSet::new();
139        state
140            .selected_indices
141            .retain(|index| *index < self.available_tags.len() && seen.insert(*index));
142    }
143
144    fn sync_render_state(&self, state: &mut TagPickerState) {
145        self.retain_selected_indices(state);
146
147        let match_count = self.matched_tag_indices(state).len();
148        state.match_cursor = if match_count == 0 {
149            0
150        } else {
151            state.match_cursor.min(match_count - 1)
152        };
153
154        state.selected_cursor = if state.selected_indices.is_empty() {
155            state.selected_scroll_x = 0;
156            0
157        } else {
158            state.selected_cursor.min(state.selected_indices.len() - 1)
159        };
160    }
161
162    fn valid_selected_raw_positions(&self, state: &TagPickerState) -> Vec<usize> {
163        state
164            .selected_indices
165            .iter()
166            .enumerate()
167            .filter_map(|(position, &tag_index)| self.tag(tag_index).map(|_| position))
168            .collect()
169    }
170
171    fn render_input_area(&self, state: &TagPickerState, area: Rect, buf: &mut Buffer) {
172        let block = Block::default().borders(Borders::BOTTOM);
173        let inner = block.inner(area);
174        block.render(area, buf);
175
176        let matches = self.matched_tag_indices(state);
177        let mut lines = vec![Line::from(vec![
178            Span::styled("> ", Style::new().fg(self.accent_color)),
179            Span::raw(if state.input.is_empty() {
180                "<type to search>".to_string()
181            } else {
182                state.input.clone()
183            }),
184        ])];
185
186        let preview_limit = inner.height.saturating_sub(1) as usize;
187        if preview_limit > 0 {
188            if matches.is_empty() {
189                lines.push(Line::from(Span::styled(
190                    "No matching tags",
191                    Style::new().fg(Color::DarkGray),
192                )));
193            } else {
194                for row in visible_match_rows(matches.len(), state.match_cursor, preview_limit) {
195                    match row {
196                        MatchRow::EllipsisBelow => {
197                            lines.push(Line::from(Span::styled(
198                                "...",
199                                Style::new().fg(Color::DarkGray),
200                            )));
201                        }
202                        MatchRow::Item(index) => {
203                            let Some(tag) = self.tag(matches[index]) else {
204                                continue;
205                            };
206                            let style = if index == state.match_cursor
207                                && state.focus == TagPickerFocus::Input
208                            {
209                                Style::new().fg(Color::Black).bg(self.accent_color)
210                            } else {
211                                Style::new().fg(Color::DarkGray)
212                            };
213                            lines.push(Line::from(Span::styled(format!("{tag}"), style)));
214                        }
215                    }
216                }
217            }
218        }
219
220        Paragraph::new(lines)
221            .wrap(Wrap { trim: false })
222            .render(inner, buf);
223    }
224
225    fn render_selected_area(&self, state: &mut TagPickerState, area: Rect, buf: &mut Buffer) {
226        let line = if state.selected_indices.is_empty() {
227            Line::from(Span::styled(
228                "No tags selected",
229                Style::new().fg(Color::DarkGray),
230            ))
231        } else {
232            let mut spans = Vec::new();
233            let mut selected_bounds = None;
234            let mut line_width = 0;
235
236            for (index, tag_index) in state.selected_indices.iter().copied().enumerate() {
237                let Some(tag) = self.tag(tag_index) else {
238                    continue;
239                };
240
241                if index > 0 {
242                    spans.push(Span::raw(SELECTED_SEPARATOR));
243                    line_width += SELECTED_SEPARATOR.chars().count();
244                }
245
246                let is_selected = index == state.selected_cursor;
247                if is_selected {
248                    let text = format!("{tag}");
249                    let separator_width = SELECTED_SEPARATOR.chars().count();
250                    let start = if index > 0 {
251                        line_width.saturating_sub(separator_width)
252                    } else {
253                        line_width
254                    };
255                    let mut end = line_width + text.chars().count();
256                    if index + 1 < state.selected_indices.len() {
257                        end += separator_width;
258                    }
259                    selected_bounds = Some((start, end));
260                    let style = if state.focus == TagPickerFocus::SelectedTags {
261                        Style::new().fg(Color::Black).bg(self.accent_color)
262                    } else {
263                        Style::new().fg(Color::default())
264                    };
265                    line_width += text.chars().count();
266                    spans.push(Span::styled(text, style));
267                } else {
268                    line_width += tag.chars().count();
269                    spans.push(Span::raw(tag));
270                }
271            }
272
273            sync_scroll_to_visible(
274                &mut state.selected_scroll_x,
275                area.width as usize,
276                line_width,
277                selected_bounds,
278            );
279            Line::from(spans)
280        };
281
282        Paragraph::new(vec![line])
283            .scroll((0, state.selected_scroll_x.min(u16::MAX as usize) as u16))
284            .render(area, buf);
285    }
286}
287
288impl TagPickerState {
289    /// Create the initial state.
290    pub fn new() -> Self {
291        Self::default()
292    }
293
294    /// Create state with tags preselected.
295    ///
296    /// This requires the widget because it is the only holder of the tag data.
297    pub fn new_with_selected_tags<I, S>(picker: &TagPicker, selected_tags: I) -> Self
298    where
299        I: IntoIterator<Item = S>,
300        S: Into<String>,
301    {
302        let mut state = Self::new();
303        state.set_selected_tags(picker, selected_tags);
304        state
305    }
306
307    /// Get indices of selected tags.
308    pub fn selected_indices(&self) -> &[usize] {
309        &self.selected_indices
310    }
311
312    /// Use the widget's data to convert selected indices to actual tags.
313    pub fn selected_tags<'a>(
314        &'a self,
315        picker: &'a TagPicker,
316    ) -> impl Iterator<Item = &'a str> + 'a {
317        self.selected_indices
318            .iter()
319            .filter_map(|&index| picker.tag(index))
320    }
321
322    /// Cycle focus between the tag search input and selected tags.
323    pub fn cycle_focus(&mut self) {
324        self.focus = match self.focus {
325            TagPickerFocus::Input => TagPickerFocus::SelectedTags,
326            TagPickerFocus::SelectedTags => TagPickerFocus::Input,
327        };
328    }
329
330    /// Focus widget's input area.
331    pub fn focus_input(&mut self) {
332        self.focus = TagPickerFocus::Input;
333    }
334
335    /// Focus widget's selected tags area.
336    pub fn focus_selected_tags(&mut self) {
337        self.focus = TagPickerFocus::SelectedTags;
338    }
339
340    /// Clears query text.
341    pub fn clear_input(&mut self) {
342        self.input.clear();
343        self.match_cursor = 0;
344    }
345
346    /// Insert a character into the tag search input.
347    pub fn insert_char(&mut self, ch: char) {
348        if self.focus != TagPickerFocus::Input || ch.is_control() {
349            return;
350        }
351
352        self.input.push(ch);
353        self.match_cursor = 0;
354    }
355
356    /// Delete a character from the search input.
357    pub fn backspace(&mut self) {
358        if self.focus != TagPickerFocus::Input {
359            return;
360        }
361
362        self.input.pop();
363        self.match_cursor = 0;
364    }
365
366    /// Move to next candidate:
367    /// - Focused input will scroll the matching tags list
368    /// - Focused selected tags area will move tag selection
369    pub fn move_next(&mut self, picker: &TagPicker) {
370        match self.focus {
371            TagPickerFocus::Input => {
372                let match_count = picker.matched_tag_indices(self).len();
373                if match_count > 0 {
374                    self.match_cursor = (self.match_cursor + 1) % match_count;
375                }
376            }
377            TagPickerFocus::SelectedTags => {
378                let selected_count = picker.valid_selected_raw_positions(self).len();
379                if selected_count > 0 {
380                    let cursor = self.selected_cursor.min(selected_count - 1);
381                    self.selected_cursor = (cursor + 1) % selected_count;
382                }
383            }
384        }
385    }
386
387    /// Move to previous candidate:
388    /// - Focused input will scroll the matching tags list
389    /// - Focused selected tags area will move tag selection
390    pub fn move_previous(&mut self, picker: &TagPicker) {
391        match self.focus {
392            TagPickerFocus::Input => {
393                let match_count = picker.matched_tag_indices(self).len();
394                if match_count > 0 {
395                    self.match_cursor = if self.match_cursor == 0 {
396                        match_count - 1
397                    } else {
398                        self.match_cursor - 1
399                    };
400                }
401            }
402            TagPickerFocus::SelectedTags => {
403                let selected_count = picker.valid_selected_raw_positions(self).len();
404                if selected_count > 0 {
405                    let cursor = self.selected_cursor.min(selected_count - 1);
406                    self.selected_cursor = if cursor == 0 {
407                        selected_count - 1
408                    } else {
409                        cursor - 1
410                    };
411                }
412            }
413        }
414    }
415
416    /// Add the highlighted tag to the result set.
417    pub fn confirm(&mut self, picker: &TagPicker) {
418        if self.focus != TagPickerFocus::Input {
419            return;
420        }
421
422        let matches = picker.matched_tag_indices(self);
423        let Some(selected_index) = matches.get(self.match_cursor).copied() else {
424            return;
425        };
426        if self.selected_indices.contains(&selected_index) {
427            return;
428        }
429
430        self.selected_indices.push(selected_index);
431        self.selected_cursor = self.selected_indices.len().saturating_sub(1);
432        self.input.clear();
433        self.match_cursor = 0;
434    }
435
436    /// Remove the highlighted tag from the result set.
437    pub fn remove_selected_tag(&mut self, picker: &TagPicker) {
438        if self.focus != TagPickerFocus::SelectedTags || self.selected_indices.is_empty() {
439            return;
440        }
441
442        let valid_positions = picker.valid_selected_raw_positions(self);
443        let Some(raw_index) = valid_positions
444            .get(
445                self.selected_cursor
446                    .min(valid_positions.len().saturating_sub(1)),
447            )
448            .copied()
449        else {
450            return;
451        };
452
453        self.selected_indices.remove(raw_index);
454        let remaining_count = valid_positions.len().saturating_sub(1);
455        if remaining_count == 0 {
456            self.selected_cursor = 0;
457            self.selected_scroll_x = 0;
458        } else {
459            self.selected_cursor = self.selected_cursor.min(remaining_count - 1);
460        }
461    }
462
463    fn set_selected_tags<I, S>(&mut self, picker: &TagPicker, selected_tags: I)
464    where
465        I: IntoIterator<Item = S>,
466        S: Into<String>,
467    {
468        self.selected_indices.clear();
469        let mut seen = BTreeSet::new();
470
471        for tag in selected_tags {
472            let tag = tag.into();
473            let Some(index) = picker
474                .available_tags
475                .iter()
476                .position(|candidate| candidate == &tag)
477            else {
478                continue;
479            };
480
481            if seen.insert(index) {
482                self.selected_indices.push(index);
483            }
484        }
485
486        self.selected_cursor = 0;
487        self.selected_scroll_x = 0;
488    }
489}
490
491impl StatefulWidget for &TagPicker {
492    type State = TagPickerState;
493
494    fn render(self, area: Rect, buf: &mut Buffer, state: &mut Self::State) {
495        self.sync_render_state(state);
496
497        let outer = Block::default().borders(Borders::ALL).title("Tags");
498        let inner = outer.inner(area);
499        outer.render(area, buf);
500
501        let sections = Layout::vertical([
502            Constraint::Length(self.input_height.saturating_add(2)),
503            Constraint::Length(3),
504        ])
505        .split(inner);
506        self.render_input_area(state, sections[0], buf);
507        self.render_selected_area(state, sections[1], buf);
508    }
509}
510
511fn sync_scroll_to_visible(
512    scroll_x: &mut usize,
513    viewport_width: usize,
514    content_width: usize,
515    selected_bounds: Option<(usize, usize)>,
516) {
517    if viewport_width == 0 || content_width <= viewport_width {
518        *scroll_x = 0;
519        return;
520    }
521
522    let max_scroll = content_width - viewport_width;
523    *scroll_x = (*scroll_x).min(max_scroll);
524
525    let Some((start, end)) = selected_bounds else {
526        return;
527    };
528
529    if end.saturating_sub(start) >= viewport_width {
530        *scroll_x = start.min(max_scroll);
531    } else if end > *scroll_x + viewport_width {
532        *scroll_x = (end - viewport_width).min(max_scroll);
533    } else if start < *scroll_x {
534        *scroll_x = start;
535    }
536}
537
538fn fuzzy_score(query: &str, candidate: &str) -> Option<i64> {
539    if query.is_empty() {
540        return Some(0);
541    }
542
543    let query = query.to_lowercase();
544    let candidate = candidate.to_lowercase();
545
546    let mut score = 0_i64;
547    let mut search_from = 0_usize;
548    let mut previous_match = None;
549
550    for query_char in query.chars() {
551        let rest = candidate.get(search_from..)?;
552        let offset = rest.find(query_char)?;
553        let match_index = search_from + offset;
554
555        score += 10;
556        score -= offset as i64;
557
558        if match_index == 0 {
559            score += 8;
560        }
561
562        if let Some(previous_match) = previous_match {
563            if match_index == previous_match + 1 {
564                score += 6;
565            }
566        }
567
568        previous_match = Some(match_index);
569        search_from = match_index + query_char.len_utf8();
570    }
571
572    Some(score)
573}
574
575#[derive(Debug, Clone, Copy, PartialEq, Eq)]
576enum MatchRow {
577    Item(usize),
578    EllipsisBelow,
579}
580
581fn visible_match_rows(match_count: usize, cursor: usize, max_rows: usize) -> Vec<MatchRow> {
582    if match_count == 0 || max_rows == 0 {
583        return Vec::new();
584    }
585
586    let cursor = cursor.min(match_count - 1);
587
588    if match_count <= max_rows {
589        return (0..match_count).map(MatchRow::Item).collect();
590    }
591
592    if max_rows == 1 {
593        return vec![MatchRow::EllipsisBelow];
594    }
595
596    let visible_items = max_rows - 1;
597    let mut start = cursor.saturating_sub(visible_items.saturating_sub(1));
598    let mut end = start + visible_items;
599
600    if end >= match_count {
601        end = match_count;
602        start = end.saturating_sub(max_rows);
603    }
604
605    let mut rows = (start..end).map(MatchRow::Item).collect::<Vec<_>>();
606    if end < match_count {
607        rows.push(MatchRow::EllipsisBelow);
608    }
609    rows
610}
611
612#[cfg(test)]
613mod tests {
614    use ratatui::{buffer::Buffer, layout::Rect, widgets::StatefulWidget};
615
616    use super::{
617        MatchRow, TagPicker, TagPickerFocus, TagPickerState, fuzzy_score, sync_scroll_to_visible,
618        visible_match_rows,
619    };
620    use crate::TagPickerConfig;
621
622    #[test]
623    fn fuzzy_score_prefers_prefix_and_contiguous_matches() {
624        assert!(fuzzy_score("rs", "rust").unwrap() > fuzzy_score("rs", "crates").unwrap());
625        assert!(fuzzy_score("tag", "tags").unwrap() > fuzzy_score("tag", "meta graph").unwrap());
626    }
627
628    #[test]
629    fn input_focus_filters_and_confirms_a_match() {
630        let picker = TagPicker::new(["rust", "ratatui", "ruby"]);
631        let mut state = TagPickerState::new();
632
633        state.insert_char('r');
634        state.insert_char('a');
635
636        assert_eq!(state.focus, TagPickerFocus::Input);
637        assert_eq!(
638            picker
639                .matched_tag_indices(&state)
640                .into_iter()
641                .filter_map(|index| picker.tag(index))
642                .collect::<Vec<_>>(),
643            vec!["ratatui"]
644        );
645
646        state.confirm(&picker);
647        assert_eq!(state.selected_indices(), &[0]);
648        assert_eq!(
649            state.selected_tags(&picker).collect::<Vec<_>>(),
650            vec!["ratatui"]
651        );
652        assert_eq!(state.input, "");
653    }
654
655    #[test]
656    fn state_with_selected_tags_applies_selection() {
657        let picker = TagPicker::new(["rust", "ratatui", "ruby"]);
658        let state = TagPickerState::new_with_selected_tags(&picker, ["rust", "ratatui"]);
659
660        assert_eq!(
661            state.selected_tags(&picker).collect::<Vec<_>>(),
662            vec!["rust", "ratatui"]
663        );
664    }
665
666    #[test]
667    fn cycling_focus_and_removing_selected_tag_works() {
668        let picker = TagPicker::new(["rust", "ratatui", "ruby"]);
669        let mut state = TagPickerState::new_with_selected_tags(&picker, ["rust", "ratatui"]);
670
671        state.cycle_focus();
672        state.move_next(&picker);
673
674        assert_eq!(state.focus, TagPickerFocus::SelectedTags);
675        state.remove_selected_tag(&picker);
676        assert_eq!(
677            state.selected_tags(&picker).collect::<Vec<_>>(),
678            vec!["rust"]
679        );
680    }
681
682    #[test]
683    fn input_methods_do_nothing_when_selected_tags_are_focused() {
684        let mut state = TagPickerState::new();
685
686        state.cycle_focus();
687        state.insert_char('r');
688        state.backspace();
689
690        assert_eq!(state.focus, TagPickerFocus::SelectedTags);
691        assert_eq!(state.input, "");
692    }
693
694    #[test]
695    fn constructor_clamps_input_height() {
696        let picker = TagPicker::with_config(
697            ["rust"],
698            TagPickerConfig {
699                input_height: 1,
700                ..Default::default()
701            },
702        );
703
704        assert_eq!(picker.input_height, 2);
705    }
706
707    #[test]
708    fn overflowing_matches_show_ellipsis_and_scroll_with_cursor() {
709        let rows = visible_match_rows(8, 4, 4);
710
711        assert_eq!(
712            rows,
713            vec![
714                MatchRow::Item(2),
715                MatchRow::Item(3),
716                MatchRow::Item(4),
717                MatchRow::EllipsisBelow,
718            ]
719        );
720    }
721
722    #[test]
723    fn rendering_overflowing_matches_shows_ellipsis() {
724        let picker = TagPicker::with_config(
725            [
726                "tag-0", "tag-1", "tag-2", "tag-3", "tag-4", "tag-5", "tag-6", "tag-7",
727            ],
728            TagPickerConfig {
729                input_height: 4,
730                ..Default::default()
731            },
732        );
733        let mut state = TagPickerState::new();
734        state.insert_char('t');
735        state.move_next(&picker);
736        state.move_next(&picker);
737        state.move_next(&picker);
738        state.move_next(&picker);
739        let mut buffer = Buffer::empty(Rect::new(0, 0, 40, 12));
740
741        (&picker).render(buffer.area, &mut buffer, &mut state);
742
743        let rendered = buffer
744            .content
745            .iter()
746            .map(|cell| cell.symbol())
747            .collect::<String>();
748
749        assert!(rendered.contains("..."));
750        assert!(rendered.contains("tag-4"));
751    }
752
753    #[test]
754    fn selected_tags_scroll_horizontally_to_keep_selection_visible() {
755        let picker = TagPicker::new(["alpha", "beta", "gamma"]);
756        let mut state = TagPickerState::new_with_selected_tags(&picker, ["alpha", "beta", "gamma"]);
757        state.cycle_focus();
758        state.move_next(&picker);
759        state.move_next(&picker);
760        let mut buffer = Buffer::empty(Rect::new(0, 0, 20, 10));
761
762        (&picker).render(buffer.area, &mut buffer, &mut state);
763
764        assert!(state.selected_scroll_x > 0);
765    }
766
767    #[test]
768    fn confirming_selects_the_newly_added_tag() {
769        let picker = TagPicker::new(["alpha", "beta", "gamma"]);
770        let mut state = TagPickerState::new_with_selected_tags(&picker, ["alpha"]);
771
772        state.insert_char('g');
773        state.confirm(&picker);
774
775        assert_eq!(
776            state.selected_tags(&picker).nth(state.selected_cursor),
777            Some("gamma")
778        );
779    }
780
781    #[test]
782    fn scroll_sync_keeps_selection_in_view() {
783        let mut scroll_x = 0;
784
785        sync_scroll_to_visible(&mut scroll_x, 8, 22, Some((15, 22)));
786
787        assert_eq!(scroll_x, 14);
788    }
789}