Skip to main content

rab/agent/ui/components/
scoped_models_selector.rs

1//! ScopedModelsSelector component — matching pi's ScopedModelsSelectorComponent.
2//!
3//! Full-screen overlay for enabling/disabling models for Ctrl+P cycling.
4//! Changes are session-only until explicitly persisted with Ctrl+S.
5//!
6//! Uses shared `Rc<RefCell<bool>>` for close signalling: the component sets
7//! it to true when the user cancels or persists, and the main loop polls it.
8
9use crate::agent::ui::theme::ThemeKey;
10use crate::agent::ui::theme::current_theme;
11use crate::tui::Component;
12use crate::tui::fuzzy::fuzzy_filter;
13use crate::tui::keybindings::{
14    ACTION_SELECT_CANCEL, ACTION_SELECT_CONFIRM, ACTION_SELECT_DOWN, ACTION_SELECT_UP,
15    get_keybindings,
16};
17use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
18
19// ── Types ──────────────────────────────────────────────────────────
20
21/// EnabledIds: null = all enabled (no filter), Some(vec) = explicit ordered subset.
22pub type EnabledIds = Option<Vec<String>>;
23
24fn is_enabled(enabled_ids: &EnabledIds, id: &str) -> bool {
25    match enabled_ids {
26        None => true,
27        Some(ids) => ids.contains(&id.to_string()),
28    }
29}
30
31fn toggle(enabled_ids: &EnabledIds, id: &str) -> EnabledIds {
32    match enabled_ids {
33        None => Some(vec![id.to_string()]),
34        Some(ids) => {
35            let id_s = id.to_string();
36            if ids.contains(&id_s) {
37                let result: Vec<String> = ids.iter().filter(|i| *i != &id_s).cloned().collect();
38                Some(result)
39            } else {
40                let mut result = ids.clone();
41                result.push(id_s);
42                Some(result)
43            }
44        }
45    }
46}
47
48fn enable_all(
49    enabled_ids: &EnabledIds,
50    all_ids: &[String],
51    target_ids: Option<&[String]>,
52) -> EnabledIds {
53    match enabled_ids {
54        None => None, // Already all enabled
55        Some(ids) => {
56            let targets = target_ids.unwrap_or(all_ids);
57            let mut result = ids.clone();
58            for id in targets {
59                if !result.contains(id) {
60                    result.push(id.clone());
61                }
62            }
63            if result.len() == all_ids.len() {
64                None
65            } else {
66                Some(result)
67            }
68        }
69    }
70}
71
72fn clear_all(
73    enabled_ids: &EnabledIds,
74    all_ids: &[String],
75    target_ids: Option<&[String]>,
76) -> EnabledIds {
77    match enabled_ids {
78        None => match target_ids {
79            Some(targets) => {
80                let result: Vec<String> = all_ids
81                    .iter()
82                    .filter(|id| !targets.contains(id))
83                    .cloned()
84                    .collect();
85                Some(result)
86            }
87            None => Some(vec![]),
88        },
89        Some(ids) => {
90            let targets_set: std::collections::HashSet<&str> = target_ids
91                .unwrap_or(ids)
92                .iter()
93                .map(|s| s.as_str())
94                .collect();
95            let result: Vec<String> = ids
96                .iter()
97                .filter(|id| !targets_set.contains(id.as_str()))
98                .cloned()
99                .collect();
100            Some(result)
101        }
102    }
103}
104
105fn move_item(enabled_ids: &EnabledIds, id: &str, delta: isize) -> EnabledIds {
106    match enabled_ids {
107        None => None,
108        Some(ids) => {
109            let mut list = ids.clone();
110            let pos = list.iter().position(|i| i == id);
111            match pos {
112                Some(idx) => {
113                    let new_idx = idx as isize + delta;
114                    if new_idx < 0 || new_idx >= list.len() as isize {
115                        return Some(list);
116                    }
117                    list.swap(idx, new_idx as usize);
118                    Some(list)
119                }
120                None => Some(list),
121            }
122        }
123    }
124}
125
126fn get_sorted_ids(enabled_ids: &EnabledIds, all_ids: &[String]) -> Vec<String> {
127    match enabled_ids {
128        None => all_ids.to_vec(),
129        Some(ids) => {
130            let enabled_set: std::collections::HashSet<&str> =
131                ids.iter().map(|s| s.as_str()).collect();
132            let mut result = ids.clone();
133            for id in all_ids {
134                if !enabled_set.contains(id.as_str()) {
135                    result.push(id.clone());
136                }
137            }
138            result
139        }
140    }
141}
142
143// ── Model item for display ─────────────────────────────────────────
144
145#[derive(Clone)]
146struct ModelItem {
147    full_id: String,
148    provider: String,
149    model_id: String,
150    model_name: String,
151    enabled: bool,
152}
153
154// ── Config and callbacks ───────────────────────────────────────────
155
156pub struct ModelsConfig {
157    pub all_models: Vec<(String, String, String)>, // (provider, id, name)
158    pub enabled_model_ids: Option<Vec<String>>,    // null = all enabled
159}
160
161pub struct ModelsCallbacks {
162    /// Called whenever the enabled model set or order changes (session-only, no persist).
163    pub on_change: Box<dyn Fn(Option<Vec<String>>)>,
164    /// Called when user wants to persist current selection to settings.
165    pub on_persist: Box<dyn Fn(Option<Vec<String>>)>,
166    /// Called when user cancels.
167    pub on_cancel: Box<dyn Fn()>,
168}
169
170// ── ScopedModelsSelector component ──────────────────────────────────
171
172pub struct ScopedModelsSelector {
173    items: Vec<ModelItem>,
174    all_ids: Vec<String>,
175    enabled_ids: EnabledIds,
176    all_items_sorted: Vec<ModelItem>,
177    filtered_indices: Vec<usize>,
178    selected_index: usize,
179    search_query: String,
180    max_visible: usize,
181    is_dirty: bool,
182    callbacks: ModelsCallbacks,
183}
184
185impl ScopedModelsSelector {
186    pub fn new(config: ModelsConfig, callbacks: ModelsCallbacks) -> Self {
187        let all_ids: Vec<String> = config
188            .all_models
189            .iter()
190            .map(|(p, id, _)| format!("{}/{}", p, id))
191            .collect();
192
193        let items: Vec<ModelItem> = config
194            .all_models
195            .iter()
196            .map(|(provider, model_id, name)| ModelItem {
197                full_id: format!("{}/{}", provider, model_id),
198                provider: provider.clone(),
199                model_id: model_id.clone(),
200                model_name: name.clone(),
201                enabled: is_enabled(
202                    &config.enabled_model_ids,
203                    &format!("{}/{}", provider, model_id),
204                ),
205            })
206            .collect();
207
208        let enabled_ids = config.enabled_model_ids;
209
210        let sorted = get_sorted_ids(&enabled_ids, &all_ids);
211        let all_items_sorted: Vec<ModelItem> = sorted
212            .iter()
213            .filter_map(|full_id| {
214                items
215                    .iter()
216                    .find(|item| item.full_id == *full_id)
217                    .cloned()
218                    .map(|mut item| {
219                        item.enabled = is_enabled(&enabled_ids, &item.full_id);
220                        item
221                    })
222            })
223            .collect();
224
225        let filtered_indices: Vec<usize> = (0..all_items_sorted.len()).collect();
226
227        Self {
228            items,
229            all_ids,
230            enabled_ids,
231            all_items_sorted,
232            filtered_indices,
233            selected_index: 0,
234            search_query: String::new(),
235            max_visible: 10,
236            is_dirty: false,
237            callbacks,
238        }
239    }
240
241    fn rebuild_sorted(&mut self) {
242        let sorted = get_sorted_ids(&self.enabled_ids, &self.all_ids);
243        self.all_items_sorted = sorted
244            .iter()
245            .filter_map(|full_id| {
246                self.items
247                    .iter()
248                    .find(|item| item.full_id == *full_id)
249                    .cloned()
250                    .map(|mut item| {
251                        item.enabled = is_enabled(&self.enabled_ids, &item.full_id);
252                        item
253                    })
254            })
255            .collect();
256    }
257
258    fn refresh(&mut self) {
259        self.rebuild_sorted();
260        let query = self.search_query.clone();
261        self.filtered_indices = if query.trim().is_empty() {
262            (0..self.all_items_sorted.len()).collect()
263        } else {
264            fuzzy_filter(&self.all_items_sorted, &query, |item| &item.full_id)
265        };
266        self.selected_index = self
267            .selected_index
268            .min(self.filtered_indices.len().saturating_sub(1));
269    }
270
271    fn get_item(&self, filtered_idx: usize) -> Option<&ModelItem> {
272        self.filtered_indices
273            .get(filtered_idx)
274            .and_then(|&idx| self.all_items_sorted.get(idx))
275    }
276
277    fn notify_change(&self) {
278        (self.callbacks.on_change)(self.enabled_ids.clone());
279    }
280}
281
282impl Component for ScopedModelsSelector {
283    fn render(&mut self, width: usize) -> Vec<String> {
284        use crate::tui::util::truncate_to_width;
285        let theme = current_theme();
286        let mut lines: Vec<String> = Vec::new();
287
288        // Top border (matches pi's DynamicBorder)
289        lines.push(theme.dim(&"─".repeat(width.saturating_sub(2))));
290        lines.push(String::new());
291
292        // Title (matches pi's theme.fg("accent", theme.bold("Model Configuration")))
293        lines.push(format!(
294            "  {}",
295            theme.bold(&theme.fg_key(ThemeKey::Accent, "Model Configuration"))
296        ));
297
298        // Session-only hint (matches pi's "Session-only. <key> to save to settings.")
299        lines.push(format!(
300            "  {}",
301            theme.dim("Session-only. Ctrl+S to save to settings.")
302        ));
303        lines.push(String::new());
304
305        // Search input line (matches pi's Input — displayed as a single line)
306        let search_value = if self.search_query.is_empty() {
307            String::new()
308        } else {
309            self.search_query.clone()
310        };
311        lines.push(format!(" {}{}", theme.dim("Search: "), search_value));
312        lines.push(String::new());
313
314        // Model list
315        let count = self.filtered_indices.len();
316        let start = self
317            .selected_index
318            .saturating_sub(self.max_visible / 2)
319            .min(count.saturating_sub(self.max_visible));
320        let end = (start + self.max_visible).min(count);
321
322        if count == 0 {
323            lines.push(theme.dim("  No matching models"));
324        } else {
325            for i in start..end {
326                let item = &self.all_items_sorted[self.filtered_indices[i]];
327                let is_selected = i == self.selected_index;
328                let prefix = if is_selected {
329                    theme.fg_key(ThemeKey::Accent, "→ ")
330                } else {
331                    "  ".to_string()
332                };
333                let model_text = if is_selected {
334                    theme.fg_key(ThemeKey::Accent, &item.model_id)
335                } else {
336                    item.model_id.clone()
337                };
338                let provider_badge = theme.dim(&format!(" [{}]", item.provider));
339                let all_enabled = self.enabled_ids.is_none();
340                let status = if all_enabled {
341                    // All enabled: no ✓/✗ needed
342                    String::new()
343                } else if item.enabled {
344                    theme.fg_key(ThemeKey::Success, " ✓")
345                } else {
346                    theme.dim(" ✗")
347                };
348                lines.push(truncate_to_width(
349                    &format!("{}{}{}{}", prefix, model_text, provider_badge, status),
350                    width.saturating_sub(4),
351                    "",
352                    false,
353                ));
354            }
355
356            // Scroll indicator
357            if count > self.max_visible {
358                lines.push(theme.dim(&format!("  ({}/{})", self.selected_index + 1, count)));
359            }
360
361            // Show model name for selected item
362            if let Some(item) = self.get_item(self.selected_index) {
363                lines.push(String::new());
364                lines.push(theme.dim(&format!("  Model Name: {}", item.model_name)));
365            }
366        }
367
368        // Footer hint (matches pi's footerText with count + dirty indicator)
369        let enabled_count = match &self.enabled_ids {
370            None => self.all_ids.len(),
371            Some(ids) => ids.len(),
372        };
373        let all_enabled = self.enabled_ids.is_none();
374        let count_text = if all_enabled {
375            "all enabled".to_string()
376        } else {
377            format!("{}/{} enabled", enabled_count, self.all_ids.len())
378        };
379        let hints = [
380            "Enter: toggle",
381            "Ctrl+A: all",
382            "Ctrl+D: clear",
383            "Ctrl+P: provider",
384            "Ctrl+\u{2191}/\u{2193}: reorder",
385            "Ctrl+S: save",
386        ];
387        let footer = if self.is_dirty {
388            format!(
389                "{} {} {}",
390                theme.dim(&format!("  {}", hints.join(" · "))),
391                count_text,
392                theme.fg_key(ThemeKey::Warning, "(unsaved)"),
393            )
394        } else {
395            format!(
396                "{} {}",
397                theme.dim(&format!("  {}", hints.join(" · "))),
398                count_text,
399            )
400        };
401        lines.push(String::new());
402        lines.push(footer);
403
404        // Bottom border
405        lines.push(theme.dim(&"─".repeat(width.saturating_sub(2))));
406
407        lines
408    }
409
410    fn handle_input(&mut self, key: &KeyEvent) -> bool {
411        let kb = get_keybindings();
412
413        // Up/Down navigation
414        if kb.matches(key, ACTION_SELECT_UP) {
415            if self.filtered_indices.is_empty() {
416                return true;
417            }
418            self.selected_index = if self.selected_index == 0 {
419                self.filtered_indices.len() - 1
420            } else {
421                self.selected_index - 1
422            };
423            return true;
424        }
425
426        if kb.matches(key, ACTION_SELECT_DOWN) {
427            if self.filtered_indices.is_empty() {
428                return true;
429            }
430            self.selected_index = if self.selected_index >= self.filtered_indices.len() - 1 {
431                0
432            } else {
433                self.selected_index + 1
434            };
435            return true;
436        }
437
438        // Toggle on Enter
439        if kb.matches(key, ACTION_SELECT_CONFIRM) {
440            if let Some(item) = self.get_item(self.selected_index) {
441                self.enabled_ids = toggle(&self.enabled_ids, &item.full_id);
442                self.is_dirty = true;
443                self.refresh();
444                self.notify_change();
445            }
446            return true;
447        }
448
449        // Cancel on Escape
450        if kb.matches(key, ACTION_SELECT_CANCEL) {
451            (self.callbacks.on_cancel)();
452            return false; // Let App's fallback pop the overlay
453        }
454
455        // Ctrl+A - Enable all (filtered if search active)
456        if key.code == KeyCode::Char('a') && key.modifiers == KeyModifiers::CONTROL {
457            let target_ids = if self.search_query.trim().is_empty() {
458                None
459            } else {
460                let ids: Vec<String> = self
461                    .filtered_indices
462                    .iter()
463                    .filter_map(|&idx| self.all_items_sorted.get(idx))
464                    .map(|item| item.full_id.clone())
465                    .collect();
466                Some(ids)
467            };
468            self.enabled_ids = enable_all(&self.enabled_ids, &self.all_ids, target_ids.as_deref());
469            self.is_dirty = true;
470            self.refresh();
471            self.notify_change();
472            return true;
473        }
474
475        // Ctrl+D - Clear all (filtered if search active)
476        if key.code == KeyCode::Char('d') && key.modifiers == KeyModifiers::CONTROL {
477            let target_ids = if self.search_query.trim().is_empty() {
478                None
479            } else {
480                let ids: Vec<String> = self
481                    .filtered_indices
482                    .iter()
483                    .filter_map(|&idx| self.all_items_sorted.get(idx))
484                    .map(|item| item.full_id.clone())
485                    .collect();
486                Some(ids)
487            };
488            self.enabled_ids = clear_all(&self.enabled_ids, &self.all_ids, target_ids.as_deref());
489            self.is_dirty = true;
490            self.refresh();
491            self.notify_change();
492            return true;
493        }
494
495        // Ctrl+P - Toggle provider of current item
496        if key.code == KeyCode::Char('p') && key.modifiers == KeyModifiers::CONTROL {
497            if let Some(item) = self.get_item(self.selected_index) {
498                let provider = &item.provider;
499                let provider_ids: Vec<String> = self
500                    .all_ids
501                    .iter()
502                    .filter(|id| id.starts_with(&format!("{}/", provider)))
503                    .cloned()
504                    .collect();
505                let all_enabled = provider_ids
506                    .iter()
507                    .all(|id| is_enabled(&self.enabled_ids, id));
508                self.enabled_ids = if all_enabled {
509                    clear_all(&self.enabled_ids, &self.all_ids, Some(&provider_ids))
510                } else {
511                    enable_all(&self.enabled_ids, &self.all_ids, Some(&provider_ids))
512                };
513                self.is_dirty = true;
514                self.refresh();
515                self.notify_change();
516            }
517            return true;
518        }
519
520        // Ctrl+Up - Reorder up
521        if key.code == KeyCode::Up && key.modifiers == KeyModifiers::CONTROL {
522            if let Some(item) = self.get_item(self.selected_index) {
523                let full_id = item.full_id.clone();
524                let new_ids = move_item(&self.enabled_ids, &full_id, -1);
525                if new_ids != self.enabled_ids {
526                    self.enabled_ids = new_ids;
527                    self.is_dirty = true;
528                    self.refresh();
529                    // Re-find the item after refresh to track selection
530                    if let Some(new_idx) = self
531                        .all_items_sorted
532                        .iter()
533                        .position(|i| i.full_id == full_id)
534                    {
535                        // Find position in filtered_indices
536                        if let Some(pos) = self.filtered_indices.iter().position(|&i| i == new_idx)
537                        {
538                            self.selected_index = pos;
539                        }
540                    }
541                    self.notify_change();
542                }
543            }
544            return true;
545        }
546
547        // Ctrl+Down - Reorder down
548        if key.code == KeyCode::Down && key.modifiers == KeyModifiers::CONTROL {
549            if let Some(item) = self.get_item(self.selected_index) {
550                let full_id = item.full_id.clone();
551                let new_ids = move_item(&self.enabled_ids, &full_id, 1);
552                if new_ids != self.enabled_ids {
553                    self.enabled_ids = new_ids;
554                    self.is_dirty = true;
555                    self.refresh();
556                    // Re-find the item after refresh to track selection
557                    if let Some(new_idx) = self
558                        .all_items_sorted
559                        .iter()
560                        .position(|i| i.full_id == full_id)
561                        && let Some(pos) = self.filtered_indices.iter().position(|&i| i == new_idx)
562                    {
563                        self.selected_index = pos;
564                    }
565                    self.notify_change();
566                }
567            }
568            return true;
569        }
570
571        // Ctrl+S - Save/persist to settings
572        if key.code == KeyCode::Char('s') && key.modifiers == KeyModifiers::CONTROL {
573            (self.callbacks.on_persist)(self.enabled_ids.clone());
574            self.is_dirty = false;
575            return true;
576        }
577
578        // Ctrl+C - Clear search or cancel if empty
579        if key.code == KeyCode::Char('c') && key.modifiers == KeyModifiers::CONTROL {
580            if !self.search_query.is_empty() {
581                self.search_query.clear();
582                self.refresh();
583                return true;
584            }
585            return false;
586        }
587
588        // Backspace - delete from search
589        if key.code == KeyCode::Backspace {
590            if !self.search_query.is_empty() {
591                self.search_query.pop();
592                self.refresh();
593            }
594            return true;
595        }
596
597        // Typeable characters go to search
598        if let KeyCode::Char(c) = key.code
599            && !c.is_control()
600        {
601            self.search_query.push(c);
602            self.refresh();
603            return true;
604        }
605
606        false
607    }
608}