1use crate::models::ModelEntry;
6use crate::provider_metadata::provider_metadata;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ModelKey {
10 pub provider: String,
11 pub id: String,
12}
13
14impl ModelKey {
15 #[must_use]
16 pub fn full_id(&self) -> String {
17 format!("{}/{}", self.provider, self.id)
18 }
19}
20
21#[derive(Debug)]
22pub struct ModelSelectorOverlay {
23 all: Vec<ModelKey>,
24 filtered: Vec<usize>,
25 selected: usize,
26 query: String,
27 max_visible: usize,
28 source_total: usize,
29 configured_only: bool,
30}
31
32impl ModelSelectorOverlay {
33 #[must_use]
34 pub fn new(models: &[ModelEntry]) -> Self {
35 let keys = models
36 .iter()
37 .map(|entry| ModelKey {
38 provider: entry.model.provider.clone(),
39 id: entry.model.id.clone(),
40 })
41 .collect::<Vec<_>>();
42 Self::new_from_keys(keys)
43 }
44
45 #[must_use]
46 pub fn new_from_keys(mut keys: Vec<ModelKey>) -> Self {
47 keys.sort_by(|a, b| a.provider.cmp(&b.provider).then_with(|| a.id.cmp(&b.id)));
48 let source_total = keys.len();
49 let mut selector = Self {
50 all: keys,
51 filtered: Vec::new(),
52 selected: 0,
53 query: String::new(),
54 max_visible: 10,
55 source_total,
56 configured_only: false,
57 };
58 selector.refresh_filtered();
59 selector
60 }
61
62 #[must_use]
63 pub fn query(&self) -> &str {
64 &self.query
65 }
66
67 #[must_use]
68 pub const fn max_visible(&self) -> usize {
69 self.max_visible
70 }
71
72 pub fn set_max_visible(&mut self, max_visible: usize) {
73 self.max_visible = max_visible.max(1);
74 }
75
76 pub fn clear_query(&mut self) {
77 if self.query.is_empty() {
78 return;
79 }
80 self.query.clear();
81 self.refresh_filtered();
82 }
83
84 pub fn push_chars<I: IntoIterator<Item = char>>(&mut self, chars: I) {
85 let mut changed = false;
86 for ch in chars {
87 if ch.is_control() {
88 continue;
89 }
90 self.query.push(ch);
91 changed = true;
92 }
93 if changed {
94 self.refresh_filtered();
95 }
96 }
97
98 pub fn pop_char(&mut self) {
99 if self.query.pop().is_some() {
100 self.refresh_filtered();
101 }
102 }
103
104 pub fn select_next(&mut self) {
105 if !self.filtered.is_empty() {
106 self.selected = (self.selected + 1) % self.filtered.len();
107 }
108 }
109
110 pub fn select_prev(&mut self) {
111 if !self.filtered.is_empty() {
112 self.selected = self
113 .selected
114 .checked_sub(1)
115 .unwrap_or(self.filtered.len() - 1);
116 }
117 }
118
119 pub fn select_page_down(&mut self) {
120 if self.filtered.is_empty() {
121 return;
122 }
123 let step = self.max_visible.saturating_sub(1).max(1);
124 self.selected = (self.selected + step).min(self.filtered.len() - 1);
125 }
126
127 pub fn select_page_up(&mut self) {
128 if self.filtered.is_empty() {
129 return;
130 }
131 let step = self.max_visible.saturating_sub(1).max(1);
132 self.selected = self.selected.saturating_sub(step);
133 }
134
135 #[must_use]
136 pub fn filtered_len(&self) -> usize {
137 self.filtered.len()
138 }
139
140 #[must_use]
141 pub fn item_at(&self, filtered_index: usize) -> Option<&ModelKey> {
142 self.filtered
143 .get(filtered_index)
144 .and_then(|&idx| self.all.get(idx))
145 }
146
147 #[must_use]
148 pub fn selected_item(&self) -> Option<&ModelKey> {
149 self.item_at(self.selected)
150 }
151
152 #[must_use]
153 pub const fn selected_index(&self) -> usize {
154 self.selected
155 }
156
157 #[must_use]
158 pub const fn source_total(&self) -> usize {
159 self.source_total
160 }
161
162 #[must_use]
163 pub const fn configured_only(&self) -> bool {
164 self.configured_only
165 }
166
167 pub fn set_configured_only_scope(&mut self, source_total: usize) {
168 self.configured_only = true;
169 self.source_total = source_total.max(self.all.len());
170 }
171
172 #[must_use]
173 pub const fn scroll_offset(&self) -> usize {
174 if self.selected < self.max_visible {
175 0
176 } else {
177 self.selected - self.max_visible + 1
178 }
179 }
180
181 fn refresh_filtered(&mut self) {
182 self.filtered = self
183 .all
184 .iter()
185 .enumerate()
186 .filter_map(|(idx, key)| matches_query(&self.query, key).then_some(idx))
187 .collect();
188 self.selected = 0;
189 }
190}
191
192fn matches_query(query: &str, key: &ModelKey) -> bool {
193 let trimmed = query.trim();
194 if trimmed.is_empty() {
195 return true;
196 }
197
198 if fuzzy_match(trimmed, &key.full_id())
199 || fuzzy_match(trimmed, &key.provider)
200 || fuzzy_match(trimmed, &key.id)
201 {
202 return true;
203 }
204
205 if let Some(meta) = provider_metadata(&key.provider) {
208 for alias in meta.aliases {
209 if fuzzy_match(trimmed, alias) {
210 return true;
211 }
212 }
213 }
214
215 false
216}
217
218fn fuzzy_match(pattern: &str, value: &str) -> bool {
219 let needle_str = pattern.to_lowercase();
220 let haystack_str = value.to_lowercase();
221 let mut needle = needle_str.chars().filter(|c| !c.is_whitespace());
222 let mut haystack = haystack_str.chars();
223 for ch in needle.by_ref() {
224 if !haystack.by_ref().any(|h| h == ch) {
225 return false;
226 }
227 }
228 true
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 fn selector(keys: &[(&str, &str)]) -> ModelSelectorOverlay {
236 ModelSelectorOverlay::new_from_keys(
237 keys.iter()
238 .map(|(provider, id)| ModelKey {
239 provider: (*provider).to_string(),
240 id: (*id).to_string(),
241 })
242 .collect(),
243 )
244 }
245
246 #[test]
247 fn filters_with_fuzzy_subsequence() {
248 let mut selector = selector(&[("openai", "gpt-4o"), ("anthropic", "claude-sonnet-4")]);
249 selector.push_chars("og".chars());
250 assert_eq!(selector.filtered_len(), 1);
251 assert_eq!(selector.selected_item().unwrap().full_id(), "openai/gpt-4o");
252 }
253
254 #[test]
255 fn filters_case_insensitive_and_whitespace_insensitive() {
256 let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "gpt-4o-mini")]);
257 selector.push_chars("GPT 4O".chars());
258 assert_eq!(selector.filtered_len(), 2);
259 }
260
261 #[test]
262 fn selection_wraps() {
263 let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "gpt-4o-mini")]);
264 selector.select_prev();
265 assert_eq!(
266 selector.selected_item().unwrap().full_id(),
267 "openai/gpt-4o-mini"
268 );
269 selector.select_next();
270 assert_eq!(selector.selected_item().unwrap().full_id(), "openai/gpt-4o");
271 }
272
273 #[test]
274 fn new_from_keys_sorts_provider_then_id() {
275 let selector = selector(&[
276 ("openai", "gpt-4o-mini"),
277 ("anthropic", "claude-sonnet-4"),
278 ("openai", "gpt-4o"),
279 ]);
280 let ids = (0..selector.filtered_len())
281 .map(|idx| selector.item_at(idx).unwrap().full_id())
282 .collect::<Vec<_>>();
283 assert_eq!(
284 ids,
285 vec![
286 "anthropic/claude-sonnet-4",
287 "openai/gpt-4o",
288 "openai/gpt-4o-mini"
289 ]
290 );
291 }
292
293 #[test]
294 fn page_navigation_respects_window_and_bounds() {
295 let mut selector = selector(&[
296 ("openai", "a"),
297 ("openai", "b"),
298 ("openai", "c"),
299 ("openai", "d"),
300 ("openai", "e"),
301 ]);
302 selector.set_max_visible(3);
303 assert_eq!(selector.max_visible(), 3);
304 assert_eq!(selector.selected_index(), 0);
305 assert_eq!(selector.scroll_offset(), 0);
306
307 selector.select_page_down();
308 assert_eq!(selector.selected_index(), 2);
309 assert_eq!(selector.scroll_offset(), 0);
310
311 selector.select_page_down();
312 assert_eq!(selector.selected_index(), 4);
313 assert_eq!(selector.scroll_offset(), 2);
314
315 selector.select_page_down();
316 assert_eq!(selector.selected_index(), 4);
317
318 selector.select_page_up();
319 assert_eq!(selector.selected_index(), 2);
320 assert_eq!(selector.scroll_offset(), 0);
321
322 selector.select_page_up();
323 assert_eq!(selector.selected_index(), 0);
324 }
325
326 #[test]
327 fn set_max_visible_clamps_to_one() {
328 let mut selector = selector(&[("openai", "a"), ("openai", "b"), ("openai", "c")]);
329 selector.set_max_visible(0);
330 assert_eq!(selector.max_visible(), 1);
331
332 selector.select_page_down();
333 assert_eq!(selector.selected_index(), 1);
334 selector.select_page_down();
335 assert_eq!(selector.selected_index(), 2);
336 }
337
338 #[test]
339 fn query_input_ignores_control_chars_and_pop_refreshes() {
340 let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "o1")]);
341 selector.push_chars("o1\n\t".chars());
342 assert_eq!(selector.query(), "o1");
343 assert_eq!(selector.filtered_len(), 1);
344 assert_eq!(selector.selected_item().unwrap().full_id(), "openai/o1");
345
346 selector.pop_char();
347 assert_eq!(selector.query(), "o");
348 assert_eq!(selector.filtered_len(), 2);
349 }
350
351 #[test]
352 fn clear_query_noop_when_empty_and_reset_when_non_empty() {
353 let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "o1")]);
354
355 selector.select_next();
356 assert_eq!(selector.selected_index(), 1);
357 selector.clear_query();
358 assert_eq!(selector.selected_index(), 1);
359
360 selector.push_chars("1".chars());
361 assert_eq!(selector.filtered_len(), 1);
362 selector.clear_query();
363 assert_eq!(selector.query(), "");
364 assert_eq!(selector.filtered_len(), 2);
365 assert_eq!(selector.selected_index(), 0);
366 }
367
368 #[test]
369 fn no_match_has_no_selected_item_and_navigation_is_stable() {
370 let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "o1")]);
371 selector.push_chars("zzz".chars());
372
373 assert_eq!(selector.filtered_len(), 0);
374 assert!(selector.selected_item().is_none());
375 assert!(selector.item_at(0).is_none());
376
377 selector.select_next();
378 selector.select_prev();
379 selector.select_page_down();
380 selector.select_page_up();
381
382 assert_eq!(selector.selected_index(), 0);
383 assert_eq!(selector.scroll_offset(), 0);
384 }
385
386 #[test]
387 fn empty_selector_stays_stable_for_navigation_and_queries() {
388 let mut selector = selector(&[]);
389 assert_eq!(selector.filtered_len(), 0);
390 assert!(selector.selected_item().is_none());
391
392 selector.select_next();
393 selector.select_prev();
394 selector.select_page_down();
395 selector.select_page_up();
396 selector.push_chars("abc".chars());
397 selector.pop_char();
398 selector.clear_query();
399
400 assert_eq!(selector.filtered_len(), 0);
401 assert_eq!(selector.selected_index(), 0);
402 assert_eq!(selector.scroll_offset(), 0);
403 assert_eq!(selector.query(), "");
404 }
405
406 #[test]
407 fn whitespace_only_query_keeps_all_models_visible() {
408 let mut selector = selector(&[
409 ("openai", "gpt-4o"),
410 ("openai", "gpt-4o-mini"),
411 ("anthropic", "claude-sonnet-4"),
412 ]);
413 selector.push_chars(" ".chars());
414
415 assert_eq!(selector.query(), " ");
416 assert_eq!(selector.filtered_len(), 3);
417 assert_eq!(
418 selector.selected_item().unwrap().full_id(),
419 "anthropic/claude-sonnet-4"
420 );
421 }
422
423 #[test]
424 fn query_refresh_resets_selection_to_first_match() {
425 let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "gpt-4o-mini")]);
426 selector.select_next();
427 assert_eq!(selector.selected_index(), 1);
428
429 selector.push_chars("mini".chars());
430 assert_eq!(selector.filtered_len(), 1);
431 assert_eq!(selector.selected_index(), 0);
432 assert_eq!(
433 selector.selected_item().unwrap().full_id(),
434 "openai/gpt-4o-mini"
435 );
436 }
437
438 #[test]
441 fn model_key_full_id() {
442 let key = ModelKey {
443 provider: "anthropic".to_string(),
444 id: "claude-sonnet-4".to_string(),
445 };
446 assert_eq!(key.full_id(), "anthropic/claude-sonnet-4");
447 }
448
449 #[test]
452 fn fuzzy_match_exact() {
453 assert!(fuzzy_match("hello", "hello"));
454 }
455
456 #[test]
457 fn fuzzy_match_subsequence() {
458 assert!(fuzzy_match("gpt", "gpt-4o-mini"));
459 }
460
461 #[test]
462 fn fuzzy_match_no_match() {
463 assert!(!fuzzy_match("xyz", "abc"));
464 }
465
466 #[test]
467 fn fuzzy_match_case_insensitive() {
468 assert!(fuzzy_match("GPT", "gpt-4o"));
469 }
470
471 #[test]
472 fn fuzzy_match_empty_pattern() {
473 assert!(fuzzy_match("", "anything"));
474 }
475
476 #[test]
479 fn matches_query_by_provider() {
480 let key = ModelKey {
481 provider: "anthropic".to_string(),
482 id: "claude".to_string(),
483 };
484 assert!(matches_query("anth", &key));
485 }
486
487 #[test]
488 fn matches_query_by_id() {
489 let key = ModelKey {
490 provider: "openai".to_string(),
491 id: "gpt-4o".to_string(),
492 };
493 assert!(matches_query("gpt", &key));
494 }
495
496 #[test]
497 fn matches_query_by_full_id() {
498 let key = ModelKey {
499 provider: "openai".to_string(),
500 id: "gpt-4o".to_string(),
501 };
502 assert!(matches_query("oi/g", &key));
503 }
504
505 #[test]
508 fn matches_query_by_provider_alias_grok_finds_xai() {
509 let key = ModelKey {
510 provider: "xai".to_string(),
511 id: "grok-2".to_string(),
512 };
513 assert!(matches_query("grok", &key));
514 }
515
516 #[test]
517 fn matches_query_by_provider_alias_together_finds_togetherai() {
518 let key = ModelKey {
519 provider: "togetherai".to_string(),
520 id: "llama-3".to_string(),
521 };
522 assert!(matches_query("together", &key));
523 }
524
525 #[test]
526 fn matches_query_by_provider_alias_hf_finds_huggingface() {
527 let key = ModelKey {
528 provider: "huggingface".to_string(),
529 id: "meta-llama".to_string(),
530 };
531 assert!(matches_query("hf", &key));
532 }
533
534 #[test]
535 fn matches_query_by_provider_alias_gemini_finds_google() {
536 let key = ModelKey {
537 provider: "google".to_string(),
538 id: "gemini-2.0-flash".to_string(),
539 };
540 assert!(matches_query("gemini", &key));
541 }
542
543 #[test]
544 fn matches_query_alias_no_false_positive_for_unknown_provider() {
545 let key = ModelKey {
546 provider: "unknown-provider".to_string(),
547 id: "model-x".to_string(),
548 };
549 assert!(!matches_query("grok", &key));
550 }
551
552 #[test]
555 fn pop_char_on_empty_is_noop() {
556 let mut s = selector(&[("a", "b")]);
557 s.pop_char();
558 assert_eq!(s.query(), "");
559 assert_eq!(s.filtered_len(), 1);
560 }
561
562 #[test]
565 fn item_at_out_of_bounds_returns_none() {
566 let s = selector(&[("a", "b")]);
567 assert!(s.item_at(100).is_none());
568 }
569
570 #[test]
573 fn duplicate_keys_are_preserved() {
574 let s = selector(&[("a", "m1"), ("a", "m1")]);
575 assert_eq!(s.filtered_len(), 2);
576 }
577
578 #[test]
581 fn scroll_offset_zero_when_within_window() {
582 let s = selector(&[("a", "1"), ("a", "2"), ("a", "3")]);
583 assert_eq!(s.scroll_offset(), 0);
584 }
585
586 #[test]
587 fn scroll_offset_tracks_selection_beyond_window() {
588 let mut s = selector(&[("a", "1"), ("a", "2"), ("a", "3"), ("a", "4"), ("a", "5")]);
589 s.set_max_visible(2);
590 s.select_next(); s.select_next(); assert_eq!(s.scroll_offset(), 1);
594 }
595
596 mod proptest_model_selector {
597 use super::*;
598 use proptest::prelude::*;
599
600 proptest! {
601 #[test]
603 fn full_id_format(
604 provider in "[a-z]{1,15}",
605 id in "[a-z0-9-]{1,20}"
606 ) {
607 let key = ModelKey { provider: provider.clone(), id: id.clone() };
608 let full = key.full_id();
609 assert_eq!(full, format!("{provider}/{id}"));
610 assert!(full.contains('/'));
611 }
612
613 #[test]
615 fn fuzzy_match_never_panics(
616 pattern in ".{0,50}",
617 value in ".{0,50}"
618 ) {
619 let _ = fuzzy_match(&pattern, &value);
620 }
621
622 #[test]
624 fn fuzzy_match_empty_pattern_matches(value in ".{0,50}") {
625 assert!(fuzzy_match("", &value));
626 }
627
628 #[test]
630 fn fuzzy_match_case_insensitive(
631 pattern in "[a-z]{1,10}",
632 value in "[a-z]{1,30}"
633 ) {
634 let lower = fuzzy_match(&pattern, &value);
635 let upper = fuzzy_match(&pattern.to_uppercase(), &value);
636 assert_eq!(lower, upper, "case mismatch for pattern={pattern} value={value}");
637 }
638
639 #[test]
641 fn fuzzy_match_exact_always_matches(s in "[a-zA-Z0-9]{1,20}") {
642 assert!(fuzzy_match(&s, &s));
643 assert!(fuzzy_match(&s.to_lowercase(), &s.to_uppercase()));
644 }
645
646 #[test]
648 fn matches_query_never_panics(
649 query in ".{0,30}",
650 provider in "[a-z]{1,10}",
651 id in "[a-z0-9-]{1,15}"
652 ) {
653 let key = ModelKey { provider, id };
654 let _ = matches_query(&query, &key);
655 }
656
657 #[test]
659 fn empty_query_matches_all(
660 ws in "[ \\t]{0,5}",
661 provider in "[a-z]{1,10}",
662 id in "[a-z0-9]{1,10}"
663 ) {
664 let key = ModelKey { provider, id };
665 assert!(matches_query(&ws, &key));
666 }
667
668 #[test]
670 fn max_visible_clamps_to_one(n in 0..100usize) {
671 let mut s = ModelSelectorOverlay::new_from_keys(vec![]);
672 s.set_max_visible(n);
673 assert!(s.max_visible() >= 1);
674 if n > 0 {
675 assert_eq!(s.max_visible(), n);
676 }
677 }
678
679 #[test]
681 fn scroll_offset_bounded(
682 n_items in 1..20usize,
683 max_vis in 1..10usize,
684 n_next in 0..30usize
685 ) {
686 let keys: Vec<ModelKey> = (0..n_items)
687 .map(|i| ModelKey {
688 provider: "p".to_string(),
689 id: format!("m{i}"),
690 })
691 .collect();
692 let mut s = ModelSelectorOverlay::new_from_keys(keys);
693 s.set_max_visible(max_vis);
694 for _ in 0..n_next {
695 s.select_next();
696 }
697 assert!(s.scroll_offset() <= s.selected_index());
698 }
699
700 #[test]
702 fn select_next_wraps(n_items in 1..10usize) {
703 let keys: Vec<ModelKey> = (0..n_items)
704 .map(|i| ModelKey {
705 provider: "p".to_string(),
706 id: format!("m{i}"),
707 })
708 .collect();
709 let mut s = ModelSelectorOverlay::new_from_keys(keys);
710 for _ in 0..n_items {
711 s.select_next();
712 }
713 assert_eq!(s.selected_index(), 0);
715 }
716
717 #[test]
719 fn select_prev_wraps(n_items in 1..10usize) {
720 let keys: Vec<ModelKey> = (0..n_items)
721 .map(|i| ModelKey {
722 provider: "p".to_string(),
723 id: format!("m{i}"),
724 })
725 .collect();
726 let mut s = ModelSelectorOverlay::new_from_keys(keys);
727 s.select_prev();
729 assert_eq!(s.selected_index(), n_items - 1);
730 }
731
732 #[test]
734 fn next_prev_roundtrip(
735 n_items in 1..10usize,
736 n_next in 0..20usize
737 ) {
738 let keys: Vec<ModelKey> = (0..n_items)
739 .map(|i| ModelKey {
740 provider: "p".to_string(),
741 id: format!("m{i}"),
742 })
743 .collect();
744 let mut s = ModelSelectorOverlay::new_from_keys(keys);
745 for _ in 0..n_next {
746 s.select_next();
747 }
748 let idx_after_next = s.selected_index();
749 s.select_prev();
750 s.select_next();
751 assert_eq!(s.selected_index(), idx_after_next);
752 }
753
754 #[test]
756 fn query_by_provider_filters(
757 p1 in "[a-z]{3,8}",
758 p2 in "[a-z]{3,8}"
759 ) {
760 if p1 != p2 {
762 let mut s = ModelSelectorOverlay::new_from_keys(vec![
763 ModelKey { provider: p1.clone(), id: "m1".to_string() },
764 ModelKey { provider: p2, id: "m2".to_string() },
765 ]);
766 s.push_chars(p1.chars());
767 assert!(s.filtered_len() >= 1);
769 }
770 }
771 }
772 }
773}