Skip to main content

imp_tui/views/
model_selector.rs

1use imp_llm::model::ModelMeta;
2use ratatui::buffer::Buffer;
3use ratatui::layout::Rect;
4use ratatui::style::{Modifier, Style};
5use ratatui::text::{Line, Span};
6use ratatui::widgets::{Block, Borders, Clear, Widget};
7
8use crate::theme::Theme;
9
10/// State for the model selector overlay.
11#[derive(Debug, Clone)]
12pub struct ModelSelectorState {
13    pub models: Vec<ModelMeta>,
14    pub filter: String,
15    pub selected: usize,
16    pub current_model: String,
17}
18
19pub enum ModelSelection<'a> {
20    Builtin(&'a ModelMeta),
21    Custom(String),
22}
23
24impl ModelSelectorState {
25    pub fn new(models: Vec<ModelMeta>, current_model: String) -> Self {
26        let selected = models
27            .iter()
28            .position(|model| model.id == current_model)
29            .unwrap_or(0);
30        Self {
31            models,
32            filter: String::new(),
33            selected,
34            current_model,
35        }
36    }
37
38    pub fn filtered(&self) -> Vec<&ModelMeta> {
39        if self.filter.is_empty() {
40            self.models.iter().collect()
41        } else {
42            let lower = self.filter.to_lowercase();
43            self.models
44                .iter()
45                .filter(|m| {
46                    m.name.to_lowercase().contains(&lower)
47                        || m.id.to_lowercase().contains(&lower)
48                        || m.provider.to_lowercase().contains(&lower)
49                })
50                .collect()
51        }
52    }
53
54    pub fn custom_model(&self) -> Option<String> {
55        let trimmed = self.filter.trim();
56        if trimmed.is_empty() {
57            None
58        } else {
59            Some(trimmed.to_string())
60        }
61    }
62
63    fn option_count(&self) -> usize {
64        self.filtered().len() + usize::from(self.custom_model().is_some())
65    }
66
67    pub fn move_up(&mut self) {
68        if self.selected > 0 {
69            self.selected -= 1;
70        }
71    }
72
73    pub fn move_down(&mut self) {
74        let count = self.option_count();
75        if self.selected + 1 < count {
76            self.selected += 1;
77        }
78    }
79
80    pub fn push_filter(&mut self, c: char) {
81        self.filter.push(c);
82        self.selected = 0;
83    }
84
85    pub fn pop_filter(&mut self) {
86        self.filter.pop();
87        self.selected = 0;
88    }
89
90    pub fn selected_choice(&self) -> Option<ModelSelection<'_>> {
91        let filtered = self.filtered();
92
93        if let Some(model) = filtered.get(self.selected).copied() {
94            return Some(ModelSelection::Builtin(model));
95        }
96
97        let custom_index = filtered.len();
98        self.custom_model().and_then(|custom| {
99            (self.selected == custom_index).then_some(ModelSelection::Custom(custom))
100        })
101    }
102}
103
104/// Model selector overlay widget.
105pub struct ModelSelectorView<'a> {
106    state: &'a ModelSelectorState,
107    theme: &'a Theme,
108}
109
110impl<'a> ModelSelectorView<'a> {
111    pub fn new(state: &'a ModelSelectorState, theme: &'a Theme) -> Self {
112        Self { state, theme }
113    }
114}
115
116impl Widget for ModelSelectorView<'_> {
117    fn render(self, area: Rect, buf: &mut Buffer) {
118        if area.height < 5 || area.width < 20 {
119            return;
120        }
121
122        Clear.render(area, buf);
123
124        let title = if self.state.filter.is_empty() {
125            " Select Model ".to_string()
126        } else {
127            format!(" Select Model [{}] ", self.state.filter)
128        };
129
130        let block = Block::default()
131            .title(title)
132            .borders(Borders::ALL)
133            .border_style(self.theme.accent_style());
134        let inner = block.inner(area);
135        block.render(area, buf);
136
137        let filtered = self.state.filtered();
138        let mut row: usize = 0;
139        let custom_model = self.state.custom_model();
140
141        let mut current_provider = String::new();
142
143        for (i, model) in filtered.iter().enumerate() {
144            if row >= inner.height as usize {
145                break;
146            }
147
148            if model.provider != current_provider {
149                current_provider = model.provider.clone();
150                let header_line = Line::from(Span::styled(
151                    format!("  {}", current_provider.to_uppercase()),
152                    Style::default()
153                        .fg(self.theme.muted)
154                        .add_modifier(Modifier::BOLD),
155                ));
156                buf.set_line(inner.x, inner.y + row as u16, &header_line, inner.width);
157                row += 1;
158                if row >= inner.height as usize {
159                    break;
160                }
161            }
162
163            let is_selected = i == self.state.selected;
164            let is_current = model.id == self.state.current_model;
165
166            let marker = if is_current { "✓ " } else { "  " };
167            let style = if is_selected {
168                self.theme.selected_style()
169            } else {
170                Style::default()
171            };
172
173            let context_str = format!("{}k", model.context_window / 1000);
174            let price_str =
175                if model.pricing.input_per_mtok == 0.0 && model.pricing.output_per_mtok == 0.0 {
176                    "n/a".to_string()
177                } else {
178                    format!(
179                        "${:.2}/{:.2}",
180                        model.pricing.input_per_mtok, model.pricing.output_per_mtok
181                    )
182                };
183
184            let line = Line::from(vec![
185                Span::styled(format!("    {marker}"), self.theme.accent_style()),
186                Span::styled(model.name.clone(), style),
187                Span::raw("  "),
188                Span::styled(context_str, self.theme.muted_style()),
189                Span::raw("  "),
190                Span::styled(price_str, self.theme.muted_style()),
191            ]);
192
193            buf.set_line(inner.x, inner.y + row as u16, &line, inner.width);
194            row += 1;
195        }
196
197        if let Some(ref custom_model) = custom_model {
198            if row < inner.height as usize && !filtered.is_empty() {
199                let spacer = Line::from(Span::styled(
200                    "  Custom",
201                    Style::default()
202                        .fg(self.theme.muted)
203                        .add_modifier(Modifier::BOLD),
204                ));
205                buf.set_line(inner.x, inner.y + row as u16, &spacer, inner.width);
206                row += 1;
207            }
208
209            if row < inner.height as usize {
210                let custom_index = filtered.len();
211                let is_selected = self.state.selected == custom_index;
212                let is_current = custom_model == &self.state.current_model;
213                let marker = if is_current { "✓ " } else { "  " };
214                let style = if is_selected {
215                    self.theme.selected_style()
216                } else {
217                    Style::default()
218                };
219
220                let line = Line::from(vec![
221                    Span::styled(format!("  {marker}"), self.theme.accent_style()),
222                    Span::styled("Use custom model: ", self.theme.muted_style()),
223                    Span::styled(custom_model, style),
224                ]);
225                buf.set_line(inner.x, inner.y + row as u16, &line, inner.width);
226            }
227        }
228
229        if filtered.is_empty() && custom_model.is_none() {
230            let line = Line::from(Span::styled(
231                "  No matching models",
232                self.theme.muted_style(),
233            ));
234            buf.set_line(inner.x, inner.y, &line, inner.width);
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use imp_llm::model::{Capabilities, ModelPricing};
243
244    fn test_model(id: &str) -> ModelMeta {
245        ModelMeta {
246            id: id.into(),
247            provider: "openai".into(),
248            name: id.into(),
249            context_window: 128_000,
250            max_output_tokens: 16_384,
251            pricing: ModelPricing::default(),
252            capabilities: Capabilities {
253                reasoning: true,
254                images: true,
255                tool_use: true,
256            },
257        }
258    }
259
260    #[test]
261    fn model_selector_initially_selects_current_model() {
262        let state = ModelSelectorState::new(
263            vec![test_model("gpt-5.4"), test_model("gpt-4o")],
264            "gpt-4o".into(),
265        );
266
267        assert_eq!(state.selected, 1);
268        match state.selected_choice() {
269            Some(ModelSelection::Builtin(model)) => assert_eq!(model.id, "gpt-4o"),
270            _ => panic!("expected current model to be selected"),
271        }
272    }
273
274    #[test]
275    fn custom_model_is_available_after_builtin_matches() {
276        let mut state = ModelSelectorState::new(vec![test_model("gpt-4o")], "gpt-4o".into());
277        state.push_filter('g');
278        state.push_filter('p');
279        state.push_filter('t');
280        state.push_filter('-');
281        state.push_filter('4');
282        state.push_filter('o');
283
284        match state.selected_choice() {
285            Some(ModelSelection::Builtin(model)) => assert_eq!(model.id, "gpt-4o"),
286            _ => panic!("expected builtin model selection"),
287        }
288
289        state.move_down();
290        match state.selected_choice() {
291            Some(ModelSelection::Custom(model)) => assert_eq!(model, "gpt-4o"),
292            _ => panic!("expected custom model selection after builtin matches"),
293        }
294    }
295
296    #[test]
297    fn custom_model_is_selected_when_no_builtin_matches() {
298        let mut state = ModelSelectorState::new(vec![test_model("gpt-5.4")], "gpt-5.4".into());
299        state.push_filter('g');
300        state.push_filter('p');
301        state.push_filter('t');
302        state.push_filter('-');
303        state.push_filter('4');
304        state.push_filter('o');
305
306        state.move_down();
307        match state.selected_choice() {
308            Some(ModelSelection::Custom(model)) => assert_eq!(model, "gpt-4o"),
309            _ => panic!("expected custom model selection"),
310        }
311    }
312}