Skip to main content

demand/
multiselect.rs

1use std::collections::HashSet;
2use std::io;
3use std::io::Write;
4
5use console::{Alignment, Key, Term};
6use fuzzy_matcher::FuzzyMatcher;
7use fuzzy_matcher::skim::SkimMatcherV2;
8use itertools::Itertools;
9use termcolor::{Buffer, WriteColor};
10
11use crate::theme::Theme;
12use crate::{DemandOption, ctrlc, theme};
13
14/// Select multiple options from a list
15///
16/// # Example
17/// ```rust
18/// use demand::{DemandOption, MultiSelect};
19///
20/// let multiselect = MultiSelect::new("Toppings")
21///   .description("Select your toppings")
22///   .min(1)
23///   .max(4)
24///   .filterable(true)
25///   .option(DemandOption::new("Lettuce").selected(true))
26///   .option(DemandOption::new("Tomatoes").selected(true))
27///   .option(DemandOption::new("Charm Sauce"))
28///   .option(DemandOption::new("Jalapenos").label("Jalapeños"))
29///   .option(DemandOption::new("Cheese"))
30///   .option(DemandOption::new("Vegan Cheese"))
31///   .option(DemandOption::new("Nutella"));
32/// let toppings = match multiselect.run() {
33///   Ok(toppings) => toppings,
34///   Err(e) => {
35///       if e.kind() == std::io::ErrorKind::Interrupted {
36///           println!("Input cancelled");
37///           return;
38///       } else {
39///           panic!("Error: {}", e);
40///       }
41///   }
42/// };
43/// ```
44pub struct MultiSelect<'a, T> {
45    /// The title of the selector
46    pub title: String,
47    /// The colors/style of the selector
48    pub theme: &'a Theme,
49    /// A description to display after the title
50    pub description: String,
51    /// The options which can be selected
52    pub options: Vec<DemandOption<T>>,
53    /// The minimum number of options which must be selected
54    pub min: usize,
55    /// The maximum number of options which can be selected
56    pub max: usize,
57    /// Whether the selector can be filtered with a query
58    pub filterable: bool,
59    /// Whether the selector is currently being filtered
60    pub filtering: bool,
61    /// A filter query to preset when `filtering` is true
62    pub filter: String,
63
64    err: Option<String>,
65    cursor_x: usize,
66    cursor_y: usize,
67    cursor: usize,
68    height: usize,
69    term: Term,
70    pages: usize,
71    cur_page: usize,
72    capacity: usize,
73    fuzzy_matcher: SkimMatcherV2,
74}
75
76impl<'a, T> MultiSelect<'a, T> {
77    /// Create a new multi select with the given title
78    pub fn new<S: Into<String>>(title: S) -> Self {
79        let mut ms = MultiSelect {
80            title: title.into(),
81            description: String::new(),
82            options: vec![],
83            min: 0,
84            max: usize::MAX,
85            filterable: false,
86            theme: &theme::DEFAULT,
87            cursor_x: 0,
88            cursor_y: 0,
89            err: None,
90            cursor: 0,
91            height: 0,
92            term: Term::stderr(),
93            filter: String::new(),
94            filtering: false,
95            pages: 0,
96            cur_page: 0,
97            capacity: 0,
98            fuzzy_matcher: SkimMatcherV2::default().use_cache(true).smart_case(),
99        };
100        let max_height = ms.term.size().0 as usize;
101        ms.capacity = max_height.max(8) - 6;
102        ms
103    }
104
105    /// Set the description of the selector
106    pub fn description(mut self, description: &str) -> Self {
107        self.description = description.to_string();
108        self
109    }
110
111    /// Add an option to the selector
112    pub fn option(mut self, option: DemandOption<T>) -> Self {
113        self.options.push(option);
114        self.pages = self.get_pages();
115        self
116    }
117
118    /// Add multiple options to the selector
119    pub fn options(mut self, options: Vec<DemandOption<T>>) -> Self {
120        for option in options {
121            self.options.push(option);
122        }
123        self.pages = self.get_pages();
124        self
125    }
126
127    /// Set the minimum number of options which must be selected
128    pub fn min(mut self, min: usize) -> Self {
129        self.min = min;
130        self
131    }
132
133    /// Set the maximum number of options which can be selected
134    pub fn max(mut self, max: usize) -> Self {
135        self.max = max;
136        self
137    }
138
139    /// Set whether the selector can be filtered with a query
140    pub fn filterable(mut self, filterable: bool) -> Self {
141        self.filterable = filterable;
142        self
143    }
144
145    pub fn filtering(mut self, filtering: bool) -> Self {
146        self.filtering = filtering;
147        self
148    }
149
150    pub fn filter(mut self, filter: &str) -> Self {
151        self.filter = filter.to_string();
152        self.cursor_x = self.filter.chars().count();
153        self.pages = self.get_pages();
154        self
155    }
156
157    /// Set the theme of the selector
158    pub fn theme(mut self, theme: &'a Theme) -> Self {
159        self.theme = theme;
160        self
161    }
162
163    /// Displays the selector to the user and returns their selected options
164    ///
165    /// This function will block until the user submits the input. If the user cancels the input,
166    /// an error of type `io::ErrorKind::Interrupted` is returned.
167    pub fn run(mut self) -> io::Result<Vec<T>> {
168        let ctrlc_handle = ctrlc::show_cursor_after_ctrlc(&self.term)?;
169
170        self.max = self.max.min(self.options.len());
171        self.min = self.min.min(self.max);
172
173        loop {
174            self.clear()?;
175            let output = self.render()?;
176            self.term.write_all(output.as_bytes())?;
177            self.term.flush()?;
178            self.height = output.lines().count() - 1;
179            if self.filtering {
180                match self.term.read_key()? {
181                    Key::ArrowLeft => self.handle_left()?,
182                    Key::ArrowRight => self.handle_right()?,
183                    Key::Enter => self.handle_stop_filtering(true)?,
184                    Key::Escape => self.handle_stop_filtering(false)?,
185                    Key::Backspace => self.handle_filter_backspace()?,
186                    Key::Char(c) => self.handle_filter_key(c)?,
187                    _ => {}
188                }
189            } else {
190                self.term.hide_cursor()?;
191                match self.term.read_key()? {
192                    Key::ArrowDown | Key::Char('j') => self.handle_down()?,
193                    Key::ArrowUp | Key::Char('k') => self.handle_up()?,
194                    Key::ArrowLeft | Key::Char('h') => self.handle_left()?,
195                    Key::ArrowRight | Key::Char('l') => self.handle_right()?,
196                    Key::Char('x') | Key::Char(' ') => self.handle_toggle(),
197                    Key::Char('a') => self.handle_toggle_all(),
198                    Key::Char('/') if self.filterable => self.handle_start_filtering(),
199                    Key::Escape => {
200                        if self.filter.is_empty() {
201                            self.term.show_cursor()?;
202                            ctrlc_handle.close();
203                            return Err(io::Error::new(
204                                io::ErrorKind::Interrupted,
205                                "user cancelled",
206                            ));
207                        }
208                        self.handle_stop_filtering(false)?
209                    }
210                    Key::Enter => {
211                        let selected = self
212                            .options
213                            .iter()
214                            .filter(|o| o.selected)
215                            .map(|o| o.label.to_string())
216                            .collect::<Vec<_>>();
217                        if selected.len() < self.min {
218                            if self.min == 1 {
219                                self.err = Some("Please select an option".to_string());
220                            } else {
221                                self.err =
222                                    Some(format!("Please select at least {} options", self.min));
223                            }
224                            continue;
225                        }
226                        if selected.len() > self.max {
227                            if self.max == 1 {
228                                self.err = Some("Please select only one option".to_string());
229                            } else {
230                                self.err =
231                                    Some(format!("Please select at most {} options", self.max));
232                            }
233                            continue;
234                        }
235                        self.clear()?;
236                        self.term.show_cursor()?;
237                        ctrlc_handle.close();
238                        let output = self.render_success(&selected)?;
239                        self.term.write_all(output.as_bytes())?;
240                        let selected = self
241                            .options
242                            .into_iter()
243                            .filter(|o| o.selected)
244                            .map(|o| o.item)
245                            .collect::<Vec<_>>();
246                        self.term.clear_to_end_of_screen()?;
247                        return Ok(selected);
248                    }
249                    _ => {}
250                }
251            }
252        }
253    }
254
255    fn filtered_options(&self) -> Vec<&DemandOption<T>> {
256        self.options
257            .iter()
258            .filter_map(|opt| {
259                if self.filter.is_empty() {
260                    Some((0, opt))
261                } else {
262                    self.fuzzy_matcher
263                        .fuzzy_match(&opt.label.to_lowercase(), &self.filter.to_lowercase())
264                        .map(|score| (score, opt))
265                }
266            })
267            .sorted_by_key(|(score, _opt)| -1 * *score)
268            .map(|(_score, opt)| opt)
269            .collect()
270    }
271
272    fn visible_options(&self) -> Vec<&DemandOption<T>> {
273        let filtered_options = self.filtered_options();
274        let start = self.cur_page * self.capacity;
275        filtered_options
276            .into_iter()
277            .skip(start)
278            .take(self.capacity)
279            .collect()
280    }
281
282    fn handle_down(&mut self) -> Result<(), io::Error> {
283        let visible_options = self.visible_options();
284        if self.cursor < visible_options.len().max(1) - 1 {
285            self.cursor += 1;
286        } else if self.pages > 0 && self.cur_page < self.pages - 1 {
287            self.cur_page += 1;
288            self.cursor = 0;
289            self.term.clear_to_end_of_screen()?;
290        }
291        Ok(())
292    }
293
294    fn handle_up(&mut self) -> Result<(), io::Error> {
295        if self.cursor > 0 {
296            self.cursor -= 1;
297        } else if self.cur_page > 0 {
298            self.cur_page -= 1;
299            self.cursor = self.visible_options().len().max(1) - 1;
300            self.term.clear_to_end_of_screen()?;
301        }
302        Ok(())
303    }
304
305    fn handle_left(&mut self) -> Result<(), io::Error> {
306        if self.filtering {
307            if self.cursor_x > 0 {
308                self.cursor_x -= 1;
309            }
310        } else if self.cur_page > 0 {
311            self.cur_page -= 1;
312            self.term.clear_to_end_of_screen()?;
313        }
314        Ok(())
315    }
316
317    fn handle_right(&mut self) -> Result<(), io::Error> {
318        if self.filtering {
319            if self.cursor_x < self.filter.chars().count() {
320                self.cursor_x += 1;
321            }
322        } else if self.pages > 0 && self.cur_page < self.pages - 1 {
323            self.cur_page += 1;
324            if self.cursor_y > self.visible_options().len() - 1 {
325                self.cursor_y = self.visible_options().len() - 1;
326            }
327            self.term.clear_to_end_of_screen()?;
328        }
329        Ok(())
330    }
331
332    fn handle_toggle(&mut self) {
333        self.err = None;
334        let visible_options = self.visible_options();
335        if visible_options.is_empty() {
336            return;
337        }
338        let id = visible_options[self.cursor].id;
339        let selected = visible_options[self.cursor].selected;
340        self.options
341            .iter_mut()
342            .find(|o| o.id == id)
343            .unwrap()
344            .selected = !selected;
345    }
346
347    fn handle_toggle_all(&mut self) {
348        self.err = None;
349        let filtered_options = self.filtered_options();
350        if filtered_options.is_empty() {
351            return;
352        }
353        let select = !filtered_options.iter().all(|o| o.selected);
354        let ids = filtered_options
355            .into_iter()
356            .map(|o| o.id)
357            .collect::<HashSet<_>>();
358        for opt in &mut self.options {
359            if ids.contains(&opt.id) {
360                opt.selected = select;
361            }
362        }
363    }
364
365    fn handle_start_filtering(&mut self) {
366        self.err = None;
367        self.filtering = true;
368    }
369
370    fn handle_stop_filtering(&mut self, save: bool) -> Result<(), io::Error> {
371        self.filtering = false;
372
373        let visible_options = self.visible_options();
374        if !visible_options.is_empty() {
375            self.cursor = self.cursor.min(self.visible_options().len() - 1);
376        }
377        if !save {
378            self.filter.clear();
379            self.reset_paging();
380        }
381        self.term.clear_to_end_of_screen()
382    }
383
384    fn handle_filter_key(&mut self, c: char) -> Result<(), io::Error> {
385        let idx = self.get_char_idx(&self.filter, self.cursor_x);
386        self.filter.insert(idx, c);
387        self.cursor_x += 1;
388        self.cursor_y = 0;
389        self.err = None;
390        self.reset_paging();
391        self.term.clear_to_end_of_screen()
392    }
393
394    fn handle_filter_backspace(&mut self) -> Result<(), io::Error> {
395        let chars_count = self.filter.chars().count();
396        if chars_count > 0 && self.cursor_x > 0 {
397            let idx = self.get_char_idx(&self.filter, self.cursor_x - 1);
398            self.filter.remove(idx);
399        }
400        if self.cursor_x > 0 {
401            self.cursor_x -= 1;
402        }
403        self.cursor_y = 0;
404        self.err = None;
405        self.reset_paging();
406        self.term.clear_to_end_of_screen()
407    }
408
409    fn reset_paging(&mut self) {
410        self.cur_page = 0;
411        self.pages = self.get_pages();
412    }
413
414    fn get_pages(&self) -> usize {
415        if self.filtering || !self.filter.is_empty() {
416            ((self.filtered_options().len() as f64) / self.capacity as f64).ceil() as usize
417        } else {
418            ((self.options.len() as f64) / self.capacity as f64).ceil() as usize
419        }
420    }
421
422    fn render(&self) -> io::Result<String> {
423        let mut out = Buffer::ansi();
424
425        out.set_color(&self.theme.title)?;
426        write!(out, "{}", self.title)?;
427
428        if self.err.is_some() {
429            out.set_color(&self.theme.error_indicator)?;
430            writeln!(out, " *")?;
431        } else {
432            writeln!(out)?;
433        }
434        if !self.description.is_empty() || self.pages > 1 {
435            out.set_color(&self.theme.description)?;
436            write!(out, "{}", self.description)?;
437            writeln!(out)?;
438        }
439        let max_label_len = self
440            .visible_options()
441            .iter()
442            .map(|o| console::measure_text_width(&o.label))
443            .max()
444            .unwrap_or(0);
445        for (i, option) in self.visible_options().into_iter().enumerate() {
446            if self.cursor == i {
447                out.set_color(&self.theme.cursor)?;
448                write!(out, " >")?;
449            } else {
450                write!(out, "  ")?;
451            }
452            if option.selected {
453                out.set_color(&self.theme.selected_prefix_fg)?;
454                write!(out, "{}", self.theme.selected_prefix)?;
455                out.set_color(&self.theme.selected_option)?;
456                self.print_option_label(&mut out, option, max_label_len)?;
457            } else {
458                out.set_color(&self.theme.unselected_prefix_fg)?;
459                write!(out, "{}", self.theme.unselected_prefix)?;
460                out.set_color(&self.theme.unselected_option)?;
461                self.print_option_label(&mut out, option, max_label_len)?;
462            }
463        }
464        if self.pages > 1 {
465            out.set_color(&self.theme.description)?;
466            writeln!(out, " (page {}/{})", self.cur_page + 1, self.pages)?;
467        }
468
469        if self.filtering {
470            out.set_color(&self.theme.input_cursor)?;
471
472            write!(out, "/")?;
473            out.reset()?;
474
475            let cursor_idx = self.get_char_idx(&self.filter, self.cursor_x);
476            write!(out, "{}", &self.filter[..cursor_idx])?;
477
478            if cursor_idx < self.filter.len() {
479                out.set_color(&self.theme.real_cursor_color(None))?;
480                write!(out, "{}", &self.filter[cursor_idx..cursor_idx + 1])?;
481                out.reset()?;
482            }
483            if cursor_idx + 1 < self.filter.len() {
484                out.reset()?;
485                write!(out, "{}", &self.filter[cursor_idx + 1..])?;
486            }
487            if cursor_idx >= self.filter.len() {
488                out.set_color(&self.theme.real_cursor_color(None))?;
489                write!(out, " ")?;
490                out.reset()?;
491            }
492            writeln!(out)?;
493            out.reset()?;
494        } else if !self.filter.is_empty() {
495            out.set_color(&self.theme.description)?;
496            write!(out, "/{}", self.filter)?;
497        } else if let Some(err) = &self.err {
498            out.set_color(&self.theme.error_indicator)?;
499            write!(out, " {err}")?;
500        }
501
502        self.print_help_keys(&mut out)?;
503
504        writeln!(out)?;
505        out.reset()?;
506
507        Ok(std::str::from_utf8(out.as_slice()).unwrap().to_string())
508    }
509
510    fn print_option_label(
511        &self,
512        out: &mut Buffer,
513        option: &DemandOption<T>,
514        max_label_len: usize,
515    ) -> io::Result<()> {
516        if let Some(desc) = &option.description {
517            let label = console::pad_str(&option.label, max_label_len, Alignment::Left, None);
518            if self.filtering && !self.filter.is_empty() {
519                self.highlight_matches(out, &label)?;
520            } else {
521                write!(out, " {label}")?;
522            }
523            out.set_color(&self.theme.description)?;
524            writeln!(out, "  {desc}")?;
525        } else if self.filtering && !self.filter.is_empty() {
526            self.highlight_matches(out, &option.label)?;
527            writeln!(out)?;
528        } else {
529            writeln!(out, " {}", option.label)?;
530        }
531        Ok(())
532    }
533
534    fn print_help_keys(&self, out: &mut Buffer) -> io::Result<()> {
535        let mut help_keys = vec![("↑/↓/k/j", "up/down")];
536        if self.pages > 1 {
537            help_keys.push(("←/→/h/l", "prev/next page"));
538        }
539        help_keys.push(("x/space", "toggle"));
540        help_keys.push(("a", "toggle all"));
541        if self.filterable {
542            if self.filtering {
543                help_keys = vec![("esc", "clear filter"), ("enter", "save filter")];
544            } else {
545                help_keys.push(("/", "filter"));
546                if !self.filter.is_empty() {
547                    help_keys.push(("esc", "clear filter"));
548                }
549            }
550        }
551        if !self.filtering {
552            help_keys.push(("enter", "confirm"));
553        }
554        for (i, (key, desc)) in help_keys.iter().enumerate() {
555            if i > 0 || (!self.filtering && !self.filter.is_empty()) {
556                out.set_color(&self.theme.help_sep)?;
557                write!(out, " • ")?;
558            }
559            out.set_color(&self.theme.help_key)?;
560            write!(out, "{key}")?;
561            out.set_color(&self.theme.help_desc)?;
562            write!(out, " {desc}")?;
563        }
564        Ok(())
565    }
566
567    fn get_char_idx(&self, input: &str, cursor: usize) -> usize {
568        input
569            .char_indices()
570            .nth(cursor)
571            .map(|(i, _)| i)
572            .unwrap_or(input.len())
573    }
574
575    fn highlight_matches(
576        &self,
577        out: &mut dyn WriteColor,
578        label: &str,
579    ) -> Result<(), std::io::Error> {
580        let matches = self
581            .fuzzy_matcher
582            .fuzzy_indices(&label.to_lowercase(), &self.filter.to_lowercase());
583        if let Some((_, indices)) = matches {
584            for (j, c) in label.chars().enumerate() {
585                if indices.contains(&j) {
586                    out.set_color(&self.theme.selected_option)?;
587                } else {
588                    out.set_color(&self.theme.unselected_option)?;
589                }
590                if j == 0 {
591                    write!(out, " ")?;
592                }
593                write!(out, "{c}")?;
594            }
595        } else {
596            write!(out, " {label}")?;
597        }
598        Ok(())
599    }
600
601    fn render_success(&self, selected: &[String]) -> io::Result<String> {
602        let mut out = Buffer::ansi();
603        out.set_color(&self.theme.title)?;
604        write!(out, "{}", self.title)?;
605        out.set_color(&self.theme.selected_option)?;
606        writeln!(out, " {}", selected.join(", "))?;
607        out.reset()?;
608        Ok(std::str::from_utf8(out.as_slice()).unwrap().to_string())
609    }
610
611    fn clear(&mut self) -> io::Result<()> {
612        self.term.clear_last_lines(self.height)?;
613        self.height = 0;
614        Ok(())
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use crate::test::without_ansi;
621
622    use super::*;
623    use indoc::indoc;
624
625    #[test]
626    fn test_render() {
627        let select = MultiSelect::new("Toppings")
628            .description("Select your toppings")
629            .option(DemandOption::new("Lettuce").selected(true))
630            .option(DemandOption::new("Tomatoes").selected(true))
631            .option(DemandOption::new("Charm Sauce"))
632            .option(DemandOption::new("Jalapenos").label("Jalapeños"))
633            .option(DemandOption::new("Cheese"))
634            .option(DemandOption::new("Vegan Cheese"))
635            .option(DemandOption::new("Nutella"));
636
637        assert_eq!(
638            indoc! {
639              "Toppings
640            Select your toppings
641             >[•] Lettuce
642              [•] Tomatoes
643              [ ] Charm Sauce
644              [ ] Jalapeños
645              [ ] Cheese
646              [ ] Vegan Cheese
647              [ ] Nutella
648            ↑/↓/k/j up/down • x/space toggle • a toggle all • enter confirm
649            "
650            },
651            without_ansi(select.render().unwrap().as_str())
652        );
653    }
654
655    #[test]
656    fn non_display() {
657        struct Thing {
658            num: u32,
659            _thing: Option<()>,
660        }
661        let things = [
662            Thing {
663                num: 1,
664                _thing: Some(()),
665            },
666            Thing {
667                num: 2,
668                _thing: None,
669            },
670            Thing {
671                num: 3,
672                _thing: None,
673            },
674        ];
675        let select = MultiSelect::new("things")
676            .description("pick a thing")
677            .options(
678                things
679                    .iter()
680                    .enumerate()
681                    .map(|(i, t)| {
682                        if i == 0 {
683                            DemandOption::with_label("First", t)
684                        } else {
685                            DemandOption::new(t.num).item(t).selected(true)
686                        }
687                    })
688                    .collect(),
689            );
690        assert_eq!(
691            indoc! {
692              "things
693            pick a thing
694             >[ ] First
695              [•] 2
696              [•] 3
697            ↑/↓/k/j up/down • x/space toggle • a toggle all • enter confirm
698            "
699            },
700            without_ansi(select.render().unwrap().as_str())
701        );
702    }
703}