Skip to main content

pi/
model_selector.rs

1//! Model selector overlay state.
2//!
3//! This is used by the interactive TUI to present a searchable list of models.
4
5use crate::models::ModelEntry;
6use crate::provider_metadata::provider_metadata;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ModelKey {
10    pub provider: String,
11    pub id: String,
12}
13
14impl ModelKey {
15    #[must_use]
16    pub fn full_id(&self) -> String {
17        format!("{}/{}", self.provider, self.id)
18    }
19}
20
21#[derive(Debug)]
22pub struct ModelSelectorOverlay {
23    all: Vec<ModelKey>,
24    filtered: Vec<usize>,
25    selected: usize,
26    query: String,
27    max_visible: usize,
28    source_total: usize,
29    configured_only: bool,
30}
31
32impl ModelSelectorOverlay {
33    #[must_use]
34    pub fn new(models: &[ModelEntry]) -> Self {
35        let keys = models
36            .iter()
37            .map(|entry| ModelKey {
38                provider: entry.model.provider.clone(),
39                id: entry.model.id.clone(),
40            })
41            .collect::<Vec<_>>();
42        Self::new_from_keys(keys)
43    }
44
45    #[must_use]
46    pub fn new_from_keys(mut keys: Vec<ModelKey>) -> Self {
47        keys.sort_by(|a, b| a.provider.cmp(&b.provider).then_with(|| a.id.cmp(&b.id)));
48        let source_total = keys.len();
49        let mut selector = Self {
50            all: keys,
51            filtered: Vec::new(),
52            selected: 0,
53            query: String::new(),
54            max_visible: 10,
55            source_total,
56            configured_only: false,
57        };
58        selector.refresh_filtered();
59        selector
60    }
61
62    #[must_use]
63    pub fn query(&self) -> &str {
64        &self.query
65    }
66
67    #[must_use]
68    pub const fn max_visible(&self) -> usize {
69        self.max_visible
70    }
71
72    pub fn set_max_visible(&mut self, max_visible: usize) {
73        self.max_visible = max_visible.max(1);
74    }
75
76    pub fn clear_query(&mut self) {
77        if self.query.is_empty() {
78            return;
79        }
80        self.query.clear();
81        self.refresh_filtered();
82    }
83
84    pub fn push_chars<I: IntoIterator<Item = char>>(&mut self, chars: I) {
85        let mut changed = false;
86        for ch in chars {
87            if ch.is_control() {
88                continue;
89            }
90            self.query.push(ch);
91            changed = true;
92        }
93        if changed {
94            self.refresh_filtered();
95        }
96    }
97
98    pub fn pop_char(&mut self) {
99        if self.query.pop().is_some() {
100            self.refresh_filtered();
101        }
102    }
103
104    pub fn select_next(&mut self) {
105        if !self.filtered.is_empty() {
106            self.selected = (self.selected + 1) % self.filtered.len();
107        }
108    }
109
110    pub fn select_prev(&mut self) {
111        if !self.filtered.is_empty() {
112            self.selected = self
113                .selected
114                .checked_sub(1)
115                .unwrap_or(self.filtered.len() - 1);
116        }
117    }
118
119    pub fn select_page_down(&mut self) {
120        if self.filtered.is_empty() {
121            return;
122        }
123        let step = self.max_visible.saturating_sub(1).max(1);
124        self.selected = (self.selected + step).min(self.filtered.len() - 1);
125    }
126
127    pub fn select_page_up(&mut self) {
128        if self.filtered.is_empty() {
129            return;
130        }
131        let step = self.max_visible.saturating_sub(1).max(1);
132        self.selected = self.selected.saturating_sub(step);
133    }
134
135    #[must_use]
136    pub fn filtered_len(&self) -> usize {
137        self.filtered.len()
138    }
139
140    #[must_use]
141    pub fn item_at(&self, filtered_index: usize) -> Option<&ModelKey> {
142        self.filtered
143            .get(filtered_index)
144            .and_then(|&idx| self.all.get(idx))
145    }
146
147    #[must_use]
148    pub fn selected_item(&self) -> Option<&ModelKey> {
149        self.item_at(self.selected)
150    }
151
152    #[must_use]
153    pub const fn selected_index(&self) -> usize {
154        self.selected
155    }
156
157    #[must_use]
158    pub const fn source_total(&self) -> usize {
159        self.source_total
160    }
161
162    #[must_use]
163    pub const fn configured_only(&self) -> bool {
164        self.configured_only
165    }
166
167    pub fn set_configured_only_scope(&mut self, source_total: usize) {
168        self.configured_only = true;
169        self.source_total = source_total.max(self.all.len());
170    }
171
172    #[must_use]
173    pub const fn scroll_offset(&self) -> usize {
174        if self.selected < self.max_visible {
175            0
176        } else {
177            self.selected - self.max_visible + 1
178        }
179    }
180
181    fn refresh_filtered(&mut self) {
182        self.filtered = self
183            .all
184            .iter()
185            .enumerate()
186            .filter_map(|(idx, key)| matches_query(&self.query, key).then_some(idx))
187            .collect();
188        self.selected = 0;
189    }
190}
191
192fn matches_query(query: &str, key: &ModelKey) -> bool {
193    let trimmed = query.trim();
194    if trimmed.is_empty() {
195        return true;
196    }
197
198    if fuzzy_match(trimmed, &key.full_id())
199        || fuzzy_match(trimmed, &key.provider)
200        || fuzzy_match(trimmed, &key.id)
201    {
202        return true;
203    }
204
205    // Also match against provider aliases so users can search by common
206    // names (e.g. "grok" finds xai models, "together" finds togetherai).
207    if let Some(meta) = provider_metadata(&key.provider) {
208        for alias in meta.aliases {
209            if fuzzy_match(trimmed, alias) {
210                return true;
211            }
212        }
213    }
214
215    false
216}
217
218fn fuzzy_match(pattern: &str, value: &str) -> bool {
219    let needle_str = pattern.to_lowercase();
220    let haystack_str = value.to_lowercase();
221    let mut needle = needle_str.chars().filter(|c| !c.is_whitespace());
222    let mut haystack = haystack_str.chars();
223    for ch in needle.by_ref() {
224        if !haystack.by_ref().any(|h| h == ch) {
225            return false;
226        }
227    }
228    true
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    fn selector(keys: &[(&str, &str)]) -> ModelSelectorOverlay {
236        ModelSelectorOverlay::new_from_keys(
237            keys.iter()
238                .map(|(provider, id)| ModelKey {
239                    provider: (*provider).to_string(),
240                    id: (*id).to_string(),
241                })
242                .collect(),
243        )
244    }
245
246    #[test]
247    fn filters_with_fuzzy_subsequence() {
248        let mut selector = selector(&[("openai", "gpt-4o"), ("anthropic", "claude-sonnet-4")]);
249        selector.push_chars("og".chars());
250        assert_eq!(selector.filtered_len(), 1);
251        assert_eq!(selector.selected_item().unwrap().full_id(), "openai/gpt-4o");
252    }
253
254    #[test]
255    fn filters_case_insensitive_and_whitespace_insensitive() {
256        let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "gpt-4o-mini")]);
257        selector.push_chars("GPT 4O".chars());
258        assert_eq!(selector.filtered_len(), 2);
259    }
260
261    #[test]
262    fn selection_wraps() {
263        let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "gpt-4o-mini")]);
264        selector.select_prev();
265        assert_eq!(
266            selector.selected_item().unwrap().full_id(),
267            "openai/gpt-4o-mini"
268        );
269        selector.select_next();
270        assert_eq!(selector.selected_item().unwrap().full_id(), "openai/gpt-4o");
271    }
272
273    #[test]
274    fn new_from_keys_sorts_provider_then_id() {
275        let selector = selector(&[
276            ("openai", "gpt-4o-mini"),
277            ("anthropic", "claude-sonnet-4"),
278            ("openai", "gpt-4o"),
279        ]);
280        let ids = (0..selector.filtered_len())
281            .map(|idx| selector.item_at(idx).unwrap().full_id())
282            .collect::<Vec<_>>();
283        assert_eq!(
284            ids,
285            vec![
286                "anthropic/claude-sonnet-4",
287                "openai/gpt-4o",
288                "openai/gpt-4o-mini"
289            ]
290        );
291    }
292
293    #[test]
294    fn page_navigation_respects_window_and_bounds() {
295        let mut selector = selector(&[
296            ("openai", "a"),
297            ("openai", "b"),
298            ("openai", "c"),
299            ("openai", "d"),
300            ("openai", "e"),
301        ]);
302        selector.set_max_visible(3);
303        assert_eq!(selector.max_visible(), 3);
304        assert_eq!(selector.selected_index(), 0);
305        assert_eq!(selector.scroll_offset(), 0);
306
307        selector.select_page_down();
308        assert_eq!(selector.selected_index(), 2);
309        assert_eq!(selector.scroll_offset(), 0);
310
311        selector.select_page_down();
312        assert_eq!(selector.selected_index(), 4);
313        assert_eq!(selector.scroll_offset(), 2);
314
315        selector.select_page_down();
316        assert_eq!(selector.selected_index(), 4);
317
318        selector.select_page_up();
319        assert_eq!(selector.selected_index(), 2);
320        assert_eq!(selector.scroll_offset(), 0);
321
322        selector.select_page_up();
323        assert_eq!(selector.selected_index(), 0);
324    }
325
326    #[test]
327    fn set_max_visible_clamps_to_one() {
328        let mut selector = selector(&[("openai", "a"), ("openai", "b"), ("openai", "c")]);
329        selector.set_max_visible(0);
330        assert_eq!(selector.max_visible(), 1);
331
332        selector.select_page_down();
333        assert_eq!(selector.selected_index(), 1);
334        selector.select_page_down();
335        assert_eq!(selector.selected_index(), 2);
336    }
337
338    #[test]
339    fn query_input_ignores_control_chars_and_pop_refreshes() {
340        let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "o1")]);
341        selector.push_chars("o1\n\t".chars());
342        assert_eq!(selector.query(), "o1");
343        assert_eq!(selector.filtered_len(), 1);
344        assert_eq!(selector.selected_item().unwrap().full_id(), "openai/o1");
345
346        selector.pop_char();
347        assert_eq!(selector.query(), "o");
348        assert_eq!(selector.filtered_len(), 2);
349    }
350
351    #[test]
352    fn clear_query_noop_when_empty_and_reset_when_non_empty() {
353        let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "o1")]);
354
355        selector.select_next();
356        assert_eq!(selector.selected_index(), 1);
357        selector.clear_query();
358        assert_eq!(selector.selected_index(), 1);
359
360        selector.push_chars("1".chars());
361        assert_eq!(selector.filtered_len(), 1);
362        selector.clear_query();
363        assert_eq!(selector.query(), "");
364        assert_eq!(selector.filtered_len(), 2);
365        assert_eq!(selector.selected_index(), 0);
366    }
367
368    #[test]
369    fn no_match_has_no_selected_item_and_navigation_is_stable() {
370        let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "o1")]);
371        selector.push_chars("zzz".chars());
372
373        assert_eq!(selector.filtered_len(), 0);
374        assert!(selector.selected_item().is_none());
375        assert!(selector.item_at(0).is_none());
376
377        selector.select_next();
378        selector.select_prev();
379        selector.select_page_down();
380        selector.select_page_up();
381
382        assert_eq!(selector.selected_index(), 0);
383        assert_eq!(selector.scroll_offset(), 0);
384    }
385
386    #[test]
387    fn empty_selector_stays_stable_for_navigation_and_queries() {
388        let mut selector = selector(&[]);
389        assert_eq!(selector.filtered_len(), 0);
390        assert!(selector.selected_item().is_none());
391
392        selector.select_next();
393        selector.select_prev();
394        selector.select_page_down();
395        selector.select_page_up();
396        selector.push_chars("abc".chars());
397        selector.pop_char();
398        selector.clear_query();
399
400        assert_eq!(selector.filtered_len(), 0);
401        assert_eq!(selector.selected_index(), 0);
402        assert_eq!(selector.scroll_offset(), 0);
403        assert_eq!(selector.query(), "");
404    }
405
406    #[test]
407    fn whitespace_only_query_keeps_all_models_visible() {
408        let mut selector = selector(&[
409            ("openai", "gpt-4o"),
410            ("openai", "gpt-4o-mini"),
411            ("anthropic", "claude-sonnet-4"),
412        ]);
413        selector.push_chars("   ".chars());
414
415        assert_eq!(selector.query(), "   ");
416        assert_eq!(selector.filtered_len(), 3);
417        assert_eq!(
418            selector.selected_item().unwrap().full_id(),
419            "anthropic/claude-sonnet-4"
420        );
421    }
422
423    #[test]
424    fn query_refresh_resets_selection_to_first_match() {
425        let mut selector = selector(&[("openai", "gpt-4o"), ("openai", "gpt-4o-mini")]);
426        selector.select_next();
427        assert_eq!(selector.selected_index(), 1);
428
429        selector.push_chars("mini".chars());
430        assert_eq!(selector.filtered_len(), 1);
431        assert_eq!(selector.selected_index(), 0);
432        assert_eq!(
433            selector.selected_item().unwrap().full_id(),
434            "openai/gpt-4o-mini"
435        );
436    }
437
438    // ── ModelKey::full_id ────────────────────────────────────────────
439
440    #[test]
441    fn model_key_full_id() {
442        let key = ModelKey {
443            provider: "anthropic".to_string(),
444            id: "claude-sonnet-4".to_string(),
445        };
446        assert_eq!(key.full_id(), "anthropic/claude-sonnet-4");
447    }
448
449    // ── fuzzy_match function ─────────────────────────────────────────
450
451    #[test]
452    fn fuzzy_match_exact() {
453        assert!(fuzzy_match("hello", "hello"));
454    }
455
456    #[test]
457    fn fuzzy_match_subsequence() {
458        assert!(fuzzy_match("gpt", "gpt-4o-mini"));
459    }
460
461    #[test]
462    fn fuzzy_match_no_match() {
463        assert!(!fuzzy_match("xyz", "abc"));
464    }
465
466    #[test]
467    fn fuzzy_match_case_insensitive() {
468        assert!(fuzzy_match("GPT", "gpt-4o"));
469    }
470
471    #[test]
472    fn fuzzy_match_empty_pattern() {
473        assert!(fuzzy_match("", "anything"));
474    }
475
476    // ── matches_query function ───────────────────────────────────────
477
478    #[test]
479    fn matches_query_by_provider() {
480        let key = ModelKey {
481            provider: "anthropic".to_string(),
482            id: "claude".to_string(),
483        };
484        assert!(matches_query("anth", &key));
485    }
486
487    #[test]
488    fn matches_query_by_id() {
489        let key = ModelKey {
490            provider: "openai".to_string(),
491            id: "gpt-4o".to_string(),
492        };
493        assert!(matches_query("gpt", &key));
494    }
495
496    #[test]
497    fn matches_query_by_full_id() {
498        let key = ModelKey {
499            provider: "openai".to_string(),
500            id: "gpt-4o".to_string(),
501        };
502        assert!(matches_query("oi/g", &key));
503    }
504
505    // ── matches_query via provider aliases ─────────────────────────────
506
507    #[test]
508    fn matches_query_by_provider_alias_grok_finds_xai() {
509        let key = ModelKey {
510            provider: "xai".to_string(),
511            id: "grok-2".to_string(),
512        };
513        assert!(matches_query("grok", &key));
514    }
515
516    #[test]
517    fn matches_query_by_provider_alias_together_finds_togetherai() {
518        let key = ModelKey {
519            provider: "togetherai".to_string(),
520            id: "llama-3".to_string(),
521        };
522        assert!(matches_query("together", &key));
523    }
524
525    #[test]
526    fn matches_query_by_provider_alias_hf_finds_huggingface() {
527        let key = ModelKey {
528            provider: "huggingface".to_string(),
529            id: "meta-llama".to_string(),
530        };
531        assert!(matches_query("hf", &key));
532    }
533
534    #[test]
535    fn matches_query_by_provider_alias_gemini_finds_google() {
536        let key = ModelKey {
537            provider: "google".to_string(),
538            id: "gemini-2.0-flash".to_string(),
539        };
540        assert!(matches_query("gemini", &key));
541    }
542
543    #[test]
544    fn matches_query_alias_no_false_positive_for_unknown_provider() {
545        let key = ModelKey {
546            provider: "unknown-provider".to_string(),
547            id: "model-x".to_string(),
548        };
549        assert!(!matches_query("grok", &key));
550    }
551
552    // ── pop_char on empty query ──────────────────────────────────────
553
554    #[test]
555    fn pop_char_on_empty_is_noop() {
556        let mut s = selector(&[("a", "b")]);
557        s.pop_char();
558        assert_eq!(s.query(), "");
559        assert_eq!(s.filtered_len(), 1);
560    }
561
562    // ── item_at out of bounds ────────────────────────────────────────
563
564    #[test]
565    fn item_at_out_of_bounds_returns_none() {
566        let s = selector(&[("a", "b")]);
567        assert!(s.item_at(100).is_none());
568    }
569
570    // ── duplicate keys ──────────────────────────────────────────────
571
572    #[test]
573    fn duplicate_keys_are_preserved() {
574        let s = selector(&[("a", "m1"), ("a", "m1")]);
575        assert_eq!(s.filtered_len(), 2);
576    }
577
578    // ── scroll_offset edge cases ─────────────────────────────────────
579
580    #[test]
581    fn scroll_offset_zero_when_within_window() {
582        let s = selector(&[("a", "1"), ("a", "2"), ("a", "3")]);
583        assert_eq!(s.scroll_offset(), 0);
584    }
585
586    #[test]
587    fn scroll_offset_tracks_selection_beyond_window() {
588        let mut s = selector(&[("a", "1"), ("a", "2"), ("a", "3"), ("a", "4"), ("a", "5")]);
589        s.set_max_visible(2);
590        // Select past the visible window
591        s.select_next(); // index 1
592        s.select_next(); // index 2 → scroll_offset should be 1
593        assert_eq!(s.scroll_offset(), 1);
594    }
595
596    mod proptest_model_selector {
597        use super::*;
598        use proptest::prelude::*;
599
600        proptest! {
601            /// `full_id` always has format `provider/id`.
602            #[test]
603            fn full_id_format(
604                provider in "[a-z]{1,15}",
605                id in "[a-z0-9-]{1,20}"
606            ) {
607                let key = ModelKey { provider: provider.clone(), id: id.clone() };
608                let full = key.full_id();
609                assert_eq!(full, format!("{provider}/{id}"));
610                assert!(full.contains('/'));
611            }
612
613            /// `fuzzy_match` never panics on arbitrary input.
614            #[test]
615            fn fuzzy_match_never_panics(
616                pattern in ".{0,50}",
617                value in ".{0,50}"
618            ) {
619                let _ = fuzzy_match(&pattern, &value);
620            }
621
622            /// Empty pattern always matches any value.
623            #[test]
624            fn fuzzy_match_empty_pattern_matches(value in ".{0,50}") {
625                assert!(fuzzy_match("", &value));
626            }
627
628            /// `fuzzy_match` is case-insensitive.
629            #[test]
630            fn fuzzy_match_case_insensitive(
631                pattern in "[a-z]{1,10}",
632                value in "[a-z]{1,30}"
633            ) {
634                let lower = fuzzy_match(&pattern, &value);
635                let upper = fuzzy_match(&pattern.to_uppercase(), &value);
636                assert_eq!(lower, upper, "case mismatch for pattern={pattern} value={value}");
637            }
638
639            /// Exact match always returns true (case-insensitive).
640            #[test]
641            fn fuzzy_match_exact_always_matches(s in "[a-zA-Z0-9]{1,20}") {
642                assert!(fuzzy_match(&s, &s));
643                assert!(fuzzy_match(&s.to_lowercase(), &s.to_uppercase()));
644            }
645
646            /// `matches_query` never panics.
647            #[test]
648            fn matches_query_never_panics(
649                query in ".{0,30}",
650                provider in "[a-z]{1,10}",
651                id in "[a-z0-9-]{1,15}"
652            ) {
653                let key = ModelKey { provider, id };
654                let _ = matches_query(&query, &key);
655            }
656
657            /// Empty/whitespace query matches everything.
658            #[test]
659            fn empty_query_matches_all(
660                ws in "[ \\t]{0,5}",
661                provider in "[a-z]{1,10}",
662                id in "[a-z0-9]{1,10}"
663            ) {
664                let key = ModelKey { provider, id };
665                assert!(matches_query(&ws, &key));
666            }
667
668            /// `set_max_visible(0)` clamps to 1.
669            #[test]
670            fn max_visible_clamps_to_one(n in 0..100usize) {
671                let mut s = ModelSelectorOverlay::new_from_keys(vec![]);
672                s.set_max_visible(n);
673                assert!(s.max_visible() >= 1);
674                if n > 0 {
675                    assert_eq!(s.max_visible(), n);
676                }
677            }
678
679            /// `scroll_offset` is always <= selected.
680            #[test]
681            fn scroll_offset_bounded(
682                n_items in 1..20usize,
683                max_vis in 1..10usize,
684                n_next in 0..30usize
685            ) {
686                let keys: Vec<ModelKey> = (0..n_items)
687                    .map(|i| ModelKey {
688                        provider: "p".to_string(),
689                        id: format!("m{i}"),
690                    })
691                    .collect();
692                let mut s = ModelSelectorOverlay::new_from_keys(keys);
693                s.set_max_visible(max_vis);
694                for _ in 0..n_next {
695                    s.select_next();
696                }
697                assert!(s.scroll_offset() <= s.selected_index());
698            }
699
700            /// `select_next` wraps around at the end.
701            #[test]
702            fn select_next_wraps(n_items in 1..10usize) {
703                let keys: Vec<ModelKey> = (0..n_items)
704                    .map(|i| ModelKey {
705                        provider: "p".to_string(),
706                        id: format!("m{i}"),
707                    })
708                    .collect();
709                let mut s = ModelSelectorOverlay::new_from_keys(keys);
710                for _ in 0..n_items {
711                    s.select_next();
712                }
713                // After n_items next calls, should wrap back to 0
714                assert_eq!(s.selected_index(), 0);
715            }
716
717            /// `select_prev` wraps around at the beginning.
718            #[test]
719            fn select_prev_wraps(n_items in 1..10usize) {
720                let keys: Vec<ModelKey> = (0..n_items)
721                    .map(|i| ModelKey {
722                        provider: "p".to_string(),
723                        id: format!("m{i}"),
724                    })
725                    .collect();
726                let mut s = ModelSelectorOverlay::new_from_keys(keys);
727                // From 0, prev wraps to last
728                s.select_prev();
729                assert_eq!(s.selected_index(), n_items - 1);
730            }
731
732            /// `select_next` then `select_prev` returns to original index.
733            #[test]
734            fn next_prev_roundtrip(
735                n_items in 1..10usize,
736                n_next in 0..20usize
737            ) {
738                let keys: Vec<ModelKey> = (0..n_items)
739                    .map(|i| ModelKey {
740                        provider: "p".to_string(),
741                        id: format!("m{i}"),
742                    })
743                    .collect();
744                let mut s = ModelSelectorOverlay::new_from_keys(keys);
745                for _ in 0..n_next {
746                    s.select_next();
747                }
748                let idx_after_next = s.selected_index();
749                s.select_prev();
750                s.select_next();
751                assert_eq!(s.selected_index(), idx_after_next);
752            }
753
754            /// Filtering by exact provider name includes that provider's models.
755            #[test]
756            fn query_by_provider_filters(
757                p1 in "[a-z]{3,8}",
758                p2 in "[a-z]{3,8}"
759            ) {
760                // Only test when providers differ
761                if p1 != p2 {
762                    let mut s = ModelSelectorOverlay::new_from_keys(vec![
763                        ModelKey { provider: p1.clone(), id: "m1".to_string() },
764                        ModelKey { provider: p2, id: "m2".to_string() },
765                    ]);
766                    s.push_chars(p1.chars());
767                    // The filtered set should include at least the p1 model
768                    assert!(s.filtered_len() >= 1);
769                }
770            }
771        }
772    }
773}