Skip to main content

rab/agent/ui/
model_selector.rs

1//! ModelSelector component — matching pi's ModelSelectorComponent.
2//!
3//! Full-screen overlay for selecting a model with search.
4//! Supports switching between "all" and "scoped" model views (Tab).
5
6use crate::agent::ui::theme::ThemeKey;
7use crate::agent::ui::theme::current_theme;
8use crate::tui::Component;
9use crate::tui::fuzzy::fuzzy_filter;
10use crate::tui::keybindings::{
11    ACTION_INPUT_TAB, ACTION_SELECT_CANCEL, ACTION_SELECT_CONFIRM, ACTION_SELECT_DOWN,
12    ACTION_SELECT_UP, get_keybindings,
13};
14use crossterm::event::{KeyCode, KeyEvent};
15
16// ── Model item for display ─────────────────────────────────────────
17
18#[derive(Clone)]
19struct ModelItem {
20    provider: String,
21    id: String,
22    name: String,
23    full_id: String,     // "provider/id"
24    search_text: String, // pre-computed search string
25}
26
27impl ModelItem {
28    fn new(provider: String, id: String, name: String) -> Self {
29        let full_id = format!("{}/{}", provider, id);
30        let search_text = format!("{} {} {} {}", provider, id, name, full_id);
31        Self {
32            provider,
33            id,
34            name,
35            full_id,
36            search_text,
37        }
38    }
39
40    fn search_text(&self) -> &str {
41        &self.search_text
42    }
43}
44
45// ── Visibility style ───────────────────────────────────────────────
46
47#[derive(Clone, Copy, PartialEq)]
48enum ModelScope {
49    All,
50    Scoped,
51}
52
53// ── ModelSelector component ────────────────────────────────────────
54
55pub struct ModelSelector {
56    all_models: Vec<ModelItem>,
57    scoped_model_ids: Vec<String>, // "provider/id" strings
58    scope: ModelScope,
59    active_items: Vec<ModelItem>,
60    filtered_indices: Vec<usize>,
61    selected_index: usize,
62    search_query: String,
63    current_model: String,
64    max_visible: usize,
65    callbacks: ModelSelectorCallbacks,
66}
67
68pub struct ModelSelectorCallbacks {
69    /// Called when user selects a model (receives full "provider/id" string).
70    pub on_select: Box<dyn Fn(String)>,
71    /// Called when user cancels.
72    pub on_cancel: Box<dyn Fn()>,
73}
74
75impl ModelSelector {
76    pub fn new(
77        all_models: Vec<(String, String, String)>, // (provider, id, name)
78        scoped_model_ids: Vec<String>,             // "provider/id" strings
79        current_model: String,
80        callbacks: ModelSelectorCallbacks,
81    ) -> Self {
82        let mut items: Vec<ModelItem> = all_models
83            .into_iter()
84            .map(|(p, id, name)| ModelItem::new(p, id, name))
85            .collect();
86
87        // Deduplicate by full_id (provider/id) — the model registry may list
88        // the same model ID under multiple providers, but provider_for_model
89        // can resolve all of them to the same provider, creating true duplicates.
90        let mut seen = std::collections::HashSet::new();
91        items.retain(|item| seen.insert(item.full_id.clone()));
92
93        // Sort: current model first, then by provider (matches pi's sortModels)
94        items.sort_by(|a, b| {
95            let a_is_current = a.full_id == current_model;
96            let b_is_current = b.full_id == current_model;
97            if a_is_current && !b_is_current {
98                return std::cmp::Ordering::Less;
99            }
100            if !a_is_current && b_is_current {
101                return std::cmp::Ordering::Greater;
102            }
103            a.provider.cmp(&b.provider)
104        });
105
106        let has_scoped = !scoped_model_ids.is_empty();
107        let scope = if has_scoped {
108            ModelScope::Scoped
109        } else {
110            ModelScope::All
111        };
112
113        let active = if has_scoped {
114            // Respect scoped model order: iterate scoped_model_ids, find matching item.
115            let mut active: Vec<ModelItem> = Vec::new();
116            for full_id in &scoped_model_ids {
117                if let Some(item) = items.iter().find(|i| &i.full_id == full_id) {
118                    active.push(item.clone());
119                }
120            }
121            active
122        } else {
123            items.clone()
124        };
125
126        let current_idx = active
127            .iter()
128            .position(|m| m.full_id == current_model)
129            .unwrap_or(0);
130        let filtered: Vec<usize> = (0..active.len()).collect();
131
132        Self {
133            all_models: items,
134            scoped_model_ids,
135            scope,
136            active_items: active,
137            filtered_indices: filtered,
138            selected_index: current_idx,
139            search_query: String::new(),
140            current_model,
141            max_visible: 10,
142            callbacks,
143        }
144    }
145
146    fn set_scope(&mut self, scope: ModelScope) {
147        if self.scope == scope {
148            return;
149        }
150        self.scope = scope;
151        self.active_items = match scope {
152            ModelScope::All => self.all_models.clone(),
153            ModelScope::Scoped => {
154                // Respect scoped model order
155                let mut active: Vec<ModelItem> = Vec::new();
156                for full_id in &self.scoped_model_ids {
157                    if let Some(item) = self.all_models.iter().find(|i| &i.full_id == full_id) {
158                        active.push(item.clone());
159                    }
160                }
161                active
162            }
163        };
164        let current_idx = self
165            .active_items
166            .iter()
167            .position(|m| m.full_id == self.current_model)
168            .unwrap_or(0);
169        self.selected_index = current_idx;
170        self.refresh();
171    }
172
173    fn refresh(&mut self) {
174        let query = self.search_query.clone();
175        self.filtered_indices = if query.trim().is_empty() {
176            (0..self.active_items.len()).collect()
177        } else {
178            fuzzy_filter(&self.active_items, &query, |item| item.search_text())
179        };
180        self.selected_index = self
181            .selected_index
182            .min(self.filtered_indices.len().saturating_sub(1));
183    }
184
185    fn get_item(&self, filtered_idx: usize) -> Option<&ModelItem> {
186        self.filtered_indices
187            .get(filtered_idx)
188            .and_then(|&idx| self.active_items.get(idx))
189    }
190}
191
192impl Component for ModelSelector {
193    fn render(&mut self, width: usize) -> Vec<String> {
194        use crate::tui::util::truncate_to_width;
195        let theme = current_theme();
196        let mut lines: Vec<String> = Vec::new();
197
198        // Top border (matches pi's DynamicBorder)
199        lines.push(theme.dim(&"─".repeat(width)));
200        lines.push(String::new());
201
202        // Scope / hint (matches pi's scopeText + scopeHintText layout)
203        let has_scoped = !self.scoped_model_ids.is_empty();
204        if has_scoped {
205            let all_text = match self.scope {
206                ModelScope::All => theme.fg_key(ThemeKey::Accent, "all"),
207                ModelScope::Scoped => theme.dim("all"),
208            };
209            let scoped_text = match self.scope {
210                ModelScope::Scoped => theme.fg_key(ThemeKey::Accent, "scoped"),
211                ModelScope::All => theme.dim("scoped"),
212            };
213            lines.push(format!(
214                " {} {} | {}",
215                theme.dim("Scope:"),
216                all_text,
217                scoped_text,
218            ));
219            lines.push(format!(" {}", theme.dim("Tab scope (all/scoped)")));
220        } else {
221            lines.push(format!(
222                " {}",
223                theme.fg_key(
224                    ThemeKey::Warning,
225                    "Only showing models from configured providers. Use /login to add providers."
226                )
227            ));
228        }
229        lines.push(String::new());
230
231        // Search input line (matches pi's Input widget — displayed as single line)
232        let search_value = if self.search_query.is_empty() {
233            String::new()
234        } else {
235            self.search_query.clone()
236        };
237        lines.push(format!(" {}{}", theme.dim("Search: "), search_value));
238        lines.push(String::new());
239
240        // Model list
241        let count = self.filtered_indices.len();
242        if count == 0 {
243            lines.push(theme.dim("  No matching models"));
244        } else {
245            let start = self
246                .selected_index
247                .saturating_sub(self.max_visible / 2)
248                .min(count.saturating_sub(self.max_visible));
249            let end = (start + self.max_visible).min(count);
250
251            for i in start..end {
252                let item = &self.active_items[self.filtered_indices[i]];
253                let is_selected = i == self.selected_index;
254                let is_current = item.full_id == self.current_model;
255
256                let prefix = if is_selected {
257                    theme.fg_key(ThemeKey::Accent, "→ ")
258                } else {
259                    "  ".to_string()
260                };
261                let model_text = if is_selected {
262                    theme.fg_key(ThemeKey::Accent, &item.id)
263                } else {
264                    item.id.clone()
265                };
266                let provider_badge = theme.dim(&format!(" [{}]", item.provider));
267                let checkmark = if is_current {
268                    theme.fg_key(ThemeKey::Success, " ✓")
269                } else {
270                    String::new()
271                };
272
273                lines.push(truncate_to_width(
274                    &format!("{}{}{}{}", prefix, model_text, provider_badge, checkmark),
275                    width.saturating_sub(4),
276                    "",
277                    false,
278                ));
279            }
280
281            // Scroll indicator
282            if count > self.max_visible {
283                lines.push(theme.dim(&format!("  ({}/{})", self.selected_index + 1, count)));
284            }
285
286            // Show model name for selected item
287            if let Some(item) = self.get_item(self.selected_index) {
288                lines.push(String::new());
289                lines.push(theme.dim(&format!("  Model Name: {}", item.name)));
290            }
291        }
292
293        // Bottom border
294        lines.push(theme.dim(&"─".repeat(width)));
295
296        lines
297    }
298
299    fn handle_input(&mut self, key: &KeyEvent) -> bool {
300        let kb = get_keybindings();
301
302        // Tab toggles scope
303        if kb.matches(key, ACTION_INPUT_TAB) {
304            if !self.scoped_model_ids.is_empty() {
305                let next = match self.scope {
306                    ModelScope::All => ModelScope::Scoped,
307                    ModelScope::Scoped => ModelScope::All,
308                };
309                self.set_scope(next);
310            }
311            return true;
312        }
313
314        // Up/Down navigation with wrapping
315        if kb.matches(key, ACTION_SELECT_UP) {
316            if self.filtered_indices.is_empty() {
317                return true;
318            }
319            self.selected_index = if self.selected_index == 0 {
320                self.filtered_indices.len() - 1
321            } else {
322                self.selected_index - 1
323            };
324            return true;
325        }
326
327        if kb.matches(key, ACTION_SELECT_DOWN) {
328            if self.filtered_indices.is_empty() {
329                return true;
330            }
331            self.selected_index = if self.selected_index >= self.filtered_indices.len() - 1 {
332                0
333            } else {
334                self.selected_index + 1
335            };
336            return true;
337        }
338
339        // Enter selects model
340        if kb.matches(key, ACTION_SELECT_CONFIRM) {
341            if let Some(item) = self.get_item(self.selected_index) {
342                (self.callbacks.on_select)(item.full_id.clone());
343            }
344            return true;
345        }
346
347        // Escape cancels - call callback then pop overlay via app
348        if kb.matches(key, ACTION_SELECT_CANCEL) {
349            (self.callbacks.on_cancel)();
350            return false; // Let App's fallback pop the overlay
351        }
352
353        // Backspace - delete from search
354        if key.code == KeyCode::Backspace {
355            if !self.search_query.is_empty() {
356                self.search_query.pop();
357                self.refresh();
358            }
359            return true;
360        }
361
362        // Typeable characters go to search
363        if let KeyCode::Char(c) = key.code
364            && !c.is_control()
365        {
366            self.search_query.push(c);
367            self.refresh();
368            return true;
369        }
370
371        false
372    }
373}