1use std::collections::HashSet;
2use std::io;
3use std::io::Write;
4
5use console::{Alignment, Key, Term};
6use fuzzy_matcher::FuzzyMatcher;
7use fuzzy_matcher::skim::SkimMatcherV2;
8use itertools::Itertools;
9use termcolor::{Buffer, WriteColor};
10
11use crate::theme::Theme;
12use crate::{DemandOption, ctrlc, theme};
13
14pub struct MultiSelect<'a, T> {
45 pub title: String,
47 pub theme: &'a Theme,
49 pub description: String,
51 pub options: Vec<DemandOption<T>>,
53 pub min: usize,
55 pub max: usize,
57 pub filterable: bool,
59 pub filtering: bool,
61 pub filter: String,
63
64 err: Option<String>,
65 cursor_x: usize,
66 cursor_y: usize,
67 cursor: usize,
68 height: usize,
69 term: Term,
70 pages: usize,
71 cur_page: usize,
72 capacity: usize,
73 fuzzy_matcher: SkimMatcherV2,
74}
75
76impl<'a, T> MultiSelect<'a, T> {
77 pub fn new<S: Into<String>>(title: S) -> Self {
79 let mut ms = MultiSelect {
80 title: title.into(),
81 description: String::new(),
82 options: vec![],
83 min: 0,
84 max: usize::MAX,
85 filterable: false,
86 theme: &theme::DEFAULT,
87 cursor_x: 0,
88 cursor_y: 0,
89 err: None,
90 cursor: 0,
91 height: 0,
92 term: Term::stderr(),
93 filter: String::new(),
94 filtering: false,
95 pages: 0,
96 cur_page: 0,
97 capacity: 0,
98 fuzzy_matcher: SkimMatcherV2::default().use_cache(true).smart_case(),
99 };
100 let max_height = ms.term.size().0 as usize;
101 ms.capacity = max_height.max(8) - 6;
102 ms
103 }
104
105 pub fn description(mut self, description: &str) -> Self {
107 self.description = description.to_string();
108 self
109 }
110
111 pub fn option(mut self, option: DemandOption<T>) -> Self {
113 self.options.push(option);
114 self.pages = self.get_pages();
115 self
116 }
117
118 pub fn options(mut self, options: Vec<DemandOption<T>>) -> Self {
120 for option in options {
121 self.options.push(option);
122 }
123 self.pages = self.get_pages();
124 self
125 }
126
127 pub fn min(mut self, min: usize) -> Self {
129 self.min = min;
130 self
131 }
132
133 pub fn max(mut self, max: usize) -> Self {
135 self.max = max;
136 self
137 }
138
139 pub fn filterable(mut self, filterable: bool) -> Self {
141 self.filterable = filterable;
142 self
143 }
144
145 pub fn filtering(mut self, filtering: bool) -> Self {
146 self.filtering = filtering;
147 self
148 }
149
150 pub fn filter(mut self, filter: &str) -> Self {
151 self.filter = filter.to_string();
152 self.cursor_x = self.filter.chars().count();
153 self.pages = self.get_pages();
154 self
155 }
156
157 pub fn theme(mut self, theme: &'a Theme) -> Self {
159 self.theme = theme;
160 self
161 }
162
163 pub fn run(mut self) -> io::Result<Vec<T>> {
168 let ctrlc_handle = ctrlc::show_cursor_after_ctrlc(&self.term)?;
169
170 self.max = self.max.min(self.options.len());
171 self.min = self.min.min(self.max);
172
173 loop {
174 self.clear()?;
175 let output = self.render()?;
176 self.term.write_all(output.as_bytes())?;
177 self.term.flush()?;
178 self.height = output.lines().count() - 1;
179 if self.filtering {
180 match self.term.read_key()? {
181 Key::ArrowLeft => self.handle_left()?,
182 Key::ArrowRight => self.handle_right()?,
183 Key::Enter => self.handle_stop_filtering(true)?,
184 Key::Escape => self.handle_stop_filtering(false)?,
185 Key::Backspace => self.handle_filter_backspace()?,
186 Key::Char(c) => self.handle_filter_key(c)?,
187 _ => {}
188 }
189 } else {
190 self.term.hide_cursor()?;
191 match self.term.read_key()? {
192 Key::ArrowDown | Key::Char('j') => self.handle_down()?,
193 Key::ArrowUp | Key::Char('k') => self.handle_up()?,
194 Key::ArrowLeft | Key::Char('h') => self.handle_left()?,
195 Key::ArrowRight | Key::Char('l') => self.handle_right()?,
196 Key::Char('x') | Key::Char(' ') => self.handle_toggle(),
197 Key::Char('a') => self.handle_toggle_all(),
198 Key::Char('/') if self.filterable => self.handle_start_filtering(),
199 Key::Escape => {
200 if self.filter.is_empty() {
201 self.term.show_cursor()?;
202 ctrlc_handle.close();
203 return Err(io::Error::new(
204 io::ErrorKind::Interrupted,
205 "user cancelled",
206 ));
207 }
208 self.handle_stop_filtering(false)?
209 }
210 Key::Enter => {
211 let selected = self
212 .options
213 .iter()
214 .filter(|o| o.selected)
215 .map(|o| o.label.to_string())
216 .collect::<Vec<_>>();
217 if selected.len() < self.min {
218 if self.min == 1 {
219 self.err = Some("Please select an option".to_string());
220 } else {
221 self.err =
222 Some(format!("Please select at least {} options", self.min));
223 }
224 continue;
225 }
226 if selected.len() > self.max {
227 if self.max == 1 {
228 self.err = Some("Please select only one option".to_string());
229 } else {
230 self.err =
231 Some(format!("Please select at most {} options", self.max));
232 }
233 continue;
234 }
235 self.clear()?;
236 self.term.show_cursor()?;
237 ctrlc_handle.close();
238 let output = self.render_success(&selected)?;
239 self.term.write_all(output.as_bytes())?;
240 let selected = self
241 .options
242 .into_iter()
243 .filter(|o| o.selected)
244 .map(|o| o.item)
245 .collect::<Vec<_>>();
246 self.term.clear_to_end_of_screen()?;
247 return Ok(selected);
248 }
249 _ => {}
250 }
251 }
252 }
253 }
254
255 fn filtered_options(&self) -> Vec<&DemandOption<T>> {
256 self.options
257 .iter()
258 .filter_map(|opt| {
259 if self.filter.is_empty() {
260 Some((0, opt))
261 } else {
262 self.fuzzy_matcher
263 .fuzzy_match(&opt.label.to_lowercase(), &self.filter.to_lowercase())
264 .map(|score| (score, opt))
265 }
266 })
267 .sorted_by_key(|(score, _opt)| -1 * *score)
268 .map(|(_score, opt)| opt)
269 .collect()
270 }
271
272 fn visible_options(&self) -> Vec<&DemandOption<T>> {
273 let filtered_options = self.filtered_options();
274 let start = self.cur_page * self.capacity;
275 filtered_options
276 .into_iter()
277 .skip(start)
278 .take(self.capacity)
279 .collect()
280 }
281
282 fn handle_down(&mut self) -> Result<(), io::Error> {
283 let visible_options = self.visible_options();
284 if self.cursor < visible_options.len().max(1) - 1 {
285 self.cursor += 1;
286 } else if self.pages > 0 && self.cur_page < self.pages - 1 {
287 self.cur_page += 1;
288 self.cursor = 0;
289 self.term.clear_to_end_of_screen()?;
290 }
291 Ok(())
292 }
293
294 fn handle_up(&mut self) -> Result<(), io::Error> {
295 if self.cursor > 0 {
296 self.cursor -= 1;
297 } else if self.cur_page > 0 {
298 self.cur_page -= 1;
299 self.cursor = self.visible_options().len().max(1) - 1;
300 self.term.clear_to_end_of_screen()?;
301 }
302 Ok(())
303 }
304
305 fn handle_left(&mut self) -> Result<(), io::Error> {
306 if self.filtering {
307 if self.cursor_x > 0 {
308 self.cursor_x -= 1;
309 }
310 } else if self.cur_page > 0 {
311 self.cur_page -= 1;
312 self.term.clear_to_end_of_screen()?;
313 }
314 Ok(())
315 }
316
317 fn handle_right(&mut self) -> Result<(), io::Error> {
318 if self.filtering {
319 if self.cursor_x < self.filter.chars().count() {
320 self.cursor_x += 1;
321 }
322 } else if self.pages > 0 && self.cur_page < self.pages - 1 {
323 self.cur_page += 1;
324 if self.cursor_y > self.visible_options().len() - 1 {
325 self.cursor_y = self.visible_options().len() - 1;
326 }
327 self.term.clear_to_end_of_screen()?;
328 }
329 Ok(())
330 }
331
332 fn handle_toggle(&mut self) {
333 self.err = None;
334 let visible_options = self.visible_options();
335 if visible_options.is_empty() {
336 return;
337 }
338 let id = visible_options[self.cursor].id;
339 let selected = visible_options[self.cursor].selected;
340 self.options
341 .iter_mut()
342 .find(|o| o.id == id)
343 .unwrap()
344 .selected = !selected;
345 }
346
347 fn handle_toggle_all(&mut self) {
348 self.err = None;
349 let filtered_options = self.filtered_options();
350 if filtered_options.is_empty() {
351 return;
352 }
353 let select = !filtered_options.iter().all(|o| o.selected);
354 let ids = filtered_options
355 .into_iter()
356 .map(|o| o.id)
357 .collect::<HashSet<_>>();
358 for opt in &mut self.options {
359 if ids.contains(&opt.id) {
360 opt.selected = select;
361 }
362 }
363 }
364
365 fn handle_start_filtering(&mut self) {
366 self.err = None;
367 self.filtering = true;
368 }
369
370 fn handle_stop_filtering(&mut self, save: bool) -> Result<(), io::Error> {
371 self.filtering = false;
372
373 let visible_options = self.visible_options();
374 if !visible_options.is_empty() {
375 self.cursor = self.cursor.min(self.visible_options().len() - 1);
376 }
377 if !save {
378 self.filter.clear();
379 self.reset_paging();
380 }
381 self.term.clear_to_end_of_screen()
382 }
383
384 fn handle_filter_key(&mut self, c: char) -> Result<(), io::Error> {
385 let idx = self.get_char_idx(&self.filter, self.cursor_x);
386 self.filter.insert(idx, c);
387 self.cursor_x += 1;
388 self.cursor_y = 0;
389 self.err = None;
390 self.reset_paging();
391 self.term.clear_to_end_of_screen()
392 }
393
394 fn handle_filter_backspace(&mut self) -> Result<(), io::Error> {
395 let chars_count = self.filter.chars().count();
396 if chars_count > 0 && self.cursor_x > 0 {
397 let idx = self.get_char_idx(&self.filter, self.cursor_x - 1);
398 self.filter.remove(idx);
399 }
400 if self.cursor_x > 0 {
401 self.cursor_x -= 1;
402 }
403 self.cursor_y = 0;
404 self.err = None;
405 self.reset_paging();
406 self.term.clear_to_end_of_screen()
407 }
408
409 fn reset_paging(&mut self) {
410 self.cur_page = 0;
411 self.pages = self.get_pages();
412 }
413
414 fn get_pages(&self) -> usize {
415 if self.filtering || !self.filter.is_empty() {
416 ((self.filtered_options().len() as f64) / self.capacity as f64).ceil() as usize
417 } else {
418 ((self.options.len() as f64) / self.capacity as f64).ceil() as usize
419 }
420 }
421
422 fn render(&self) -> io::Result<String> {
423 let mut out = Buffer::ansi();
424
425 out.set_color(&self.theme.title)?;
426 write!(out, "{}", self.title)?;
427
428 if self.err.is_some() {
429 out.set_color(&self.theme.error_indicator)?;
430 writeln!(out, " *")?;
431 } else {
432 writeln!(out)?;
433 }
434 if !self.description.is_empty() || self.pages > 1 {
435 out.set_color(&self.theme.description)?;
436 write!(out, "{}", self.description)?;
437 writeln!(out)?;
438 }
439 let max_label_len = self
440 .visible_options()
441 .iter()
442 .map(|o| console::measure_text_width(&o.label))
443 .max()
444 .unwrap_or(0);
445 for (i, option) in self.visible_options().into_iter().enumerate() {
446 if self.cursor == i {
447 out.set_color(&self.theme.cursor)?;
448 write!(out, " >")?;
449 } else {
450 write!(out, " ")?;
451 }
452 if option.selected {
453 out.set_color(&self.theme.selected_prefix_fg)?;
454 write!(out, "{}", self.theme.selected_prefix)?;
455 out.set_color(&self.theme.selected_option)?;
456 self.print_option_label(&mut out, option, max_label_len)?;
457 } else {
458 out.set_color(&self.theme.unselected_prefix_fg)?;
459 write!(out, "{}", self.theme.unselected_prefix)?;
460 out.set_color(&self.theme.unselected_option)?;
461 self.print_option_label(&mut out, option, max_label_len)?;
462 }
463 }
464 if self.pages > 1 {
465 out.set_color(&self.theme.description)?;
466 writeln!(out, " (page {}/{})", self.cur_page + 1, self.pages)?;
467 }
468
469 if self.filtering {
470 out.set_color(&self.theme.input_cursor)?;
471
472 write!(out, "/")?;
473 out.reset()?;
474
475 let cursor_idx = self.get_char_idx(&self.filter, self.cursor_x);
476 write!(out, "{}", &self.filter[..cursor_idx])?;
477
478 if cursor_idx < self.filter.len() {
479 out.set_color(&self.theme.real_cursor_color(None))?;
480 write!(out, "{}", &self.filter[cursor_idx..cursor_idx + 1])?;
481 out.reset()?;
482 }
483 if cursor_idx + 1 < self.filter.len() {
484 out.reset()?;
485 write!(out, "{}", &self.filter[cursor_idx + 1..])?;
486 }
487 if cursor_idx >= self.filter.len() {
488 out.set_color(&self.theme.real_cursor_color(None))?;
489 write!(out, " ")?;
490 out.reset()?;
491 }
492 writeln!(out)?;
493 out.reset()?;
494 } else if !self.filter.is_empty() {
495 out.set_color(&self.theme.description)?;
496 write!(out, "/{}", self.filter)?;
497 } else if let Some(err) = &self.err {
498 out.set_color(&self.theme.error_indicator)?;
499 write!(out, " {err}")?;
500 }
501
502 self.print_help_keys(&mut out)?;
503
504 writeln!(out)?;
505 out.reset()?;
506
507 Ok(std::str::from_utf8(out.as_slice()).unwrap().to_string())
508 }
509
510 fn print_option_label(
511 &self,
512 out: &mut Buffer,
513 option: &DemandOption<T>,
514 max_label_len: usize,
515 ) -> io::Result<()> {
516 if let Some(desc) = &option.description {
517 let label = console::pad_str(&option.label, max_label_len, Alignment::Left, None);
518 if self.filtering && !self.filter.is_empty() {
519 self.highlight_matches(out, &label)?;
520 } else {
521 write!(out, " {label}")?;
522 }
523 out.set_color(&self.theme.description)?;
524 writeln!(out, " {desc}")?;
525 } else if self.filtering && !self.filter.is_empty() {
526 self.highlight_matches(out, &option.label)?;
527 writeln!(out)?;
528 } else {
529 writeln!(out, " {}", option.label)?;
530 }
531 Ok(())
532 }
533
534 fn print_help_keys(&self, out: &mut Buffer) -> io::Result<()> {
535 let mut help_keys = vec![("↑/↓/k/j", "up/down")];
536 if self.pages > 1 {
537 help_keys.push(("←/→/h/l", "prev/next page"));
538 }
539 help_keys.push(("x/space", "toggle"));
540 help_keys.push(("a", "toggle all"));
541 if self.filterable {
542 if self.filtering {
543 help_keys = vec![("esc", "clear filter"), ("enter", "save filter")];
544 } else {
545 help_keys.push(("/", "filter"));
546 if !self.filter.is_empty() {
547 help_keys.push(("esc", "clear filter"));
548 }
549 }
550 }
551 if !self.filtering {
552 help_keys.push(("enter", "confirm"));
553 }
554 for (i, (key, desc)) in help_keys.iter().enumerate() {
555 if i > 0 || (!self.filtering && !self.filter.is_empty()) {
556 out.set_color(&self.theme.help_sep)?;
557 write!(out, " • ")?;
558 }
559 out.set_color(&self.theme.help_key)?;
560 write!(out, "{key}")?;
561 out.set_color(&self.theme.help_desc)?;
562 write!(out, " {desc}")?;
563 }
564 Ok(())
565 }
566
567 fn get_char_idx(&self, input: &str, cursor: usize) -> usize {
568 input
569 .char_indices()
570 .nth(cursor)
571 .map(|(i, _)| i)
572 .unwrap_or(input.len())
573 }
574
575 fn highlight_matches(
576 &self,
577 out: &mut dyn WriteColor,
578 label: &str,
579 ) -> Result<(), std::io::Error> {
580 let matches = self
581 .fuzzy_matcher
582 .fuzzy_indices(&label.to_lowercase(), &self.filter.to_lowercase());
583 if let Some((_, indices)) = matches {
584 for (j, c) in label.chars().enumerate() {
585 if indices.contains(&j) {
586 out.set_color(&self.theme.selected_option)?;
587 } else {
588 out.set_color(&self.theme.unselected_option)?;
589 }
590 if j == 0 {
591 write!(out, " ")?;
592 }
593 write!(out, "{c}")?;
594 }
595 } else {
596 write!(out, " {label}")?;
597 }
598 Ok(())
599 }
600
601 fn render_success(&self, selected: &[String]) -> io::Result<String> {
602 let mut out = Buffer::ansi();
603 out.set_color(&self.theme.title)?;
604 write!(out, "{}", self.title)?;
605 out.set_color(&self.theme.selected_option)?;
606 writeln!(out, " {}", selected.join(", "))?;
607 out.reset()?;
608 Ok(std::str::from_utf8(out.as_slice()).unwrap().to_string())
609 }
610
611 fn clear(&mut self) -> io::Result<()> {
612 self.term.clear_last_lines(self.height)?;
613 self.height = 0;
614 Ok(())
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use crate::test::without_ansi;
621
622 use super::*;
623 use indoc::indoc;
624
625 #[test]
626 fn test_render() {
627 let select = MultiSelect::new("Toppings")
628 .description("Select your toppings")
629 .option(DemandOption::new("Lettuce").selected(true))
630 .option(DemandOption::new("Tomatoes").selected(true))
631 .option(DemandOption::new("Charm Sauce"))
632 .option(DemandOption::new("Jalapenos").label("Jalapeños"))
633 .option(DemandOption::new("Cheese"))
634 .option(DemandOption::new("Vegan Cheese"))
635 .option(DemandOption::new("Nutella"));
636
637 assert_eq!(
638 indoc! {
639 "Toppings
640 Select your toppings
641 >[•] Lettuce
642 [•] Tomatoes
643 [ ] Charm Sauce
644 [ ] Jalapeños
645 [ ] Cheese
646 [ ] Vegan Cheese
647 [ ] Nutella
648 ↑/↓/k/j up/down • x/space toggle • a toggle all • enter confirm
649 "
650 },
651 without_ansi(select.render().unwrap().as_str())
652 );
653 }
654
655 #[test]
656 fn non_display() {
657 struct Thing {
658 num: u32,
659 _thing: Option<()>,
660 }
661 let things = [
662 Thing {
663 num: 1,
664 _thing: Some(()),
665 },
666 Thing {
667 num: 2,
668 _thing: None,
669 },
670 Thing {
671 num: 3,
672 _thing: None,
673 },
674 ];
675 let select = MultiSelect::new("things")
676 .description("pick a thing")
677 .options(
678 things
679 .iter()
680 .enumerate()
681 .map(|(i, t)| {
682 if i == 0 {
683 DemandOption::with_label("First", t)
684 } else {
685 DemandOption::new(t.num).item(t).selected(true)
686 }
687 })
688 .collect(),
689 );
690 assert_eq!(
691 indoc! {
692 "things
693 pick a thing
694 >[ ] First
695 [•] 2
696 [•] 3
697 ↑/↓/k/j up/down • x/space toggle • a toggle all • enter confirm
698 "
699 },
700 without_ansi(select.render().unwrap().as_str())
701 );
702 }
703}