1use imp_llm::model::ModelMeta;
2use ratatui::buffer::Buffer;
3use ratatui::layout::Rect;
4use ratatui::style::{Modifier, Style};
5use ratatui::text::{Line, Span};
6use ratatui::widgets::{Block, Borders, Clear, Widget};
7
8use crate::theme::Theme;
9
10#[derive(Debug, Clone)]
12pub struct ModelSelectorState {
13 pub models: Vec<ModelMeta>,
14 pub filter: String,
15 pub selected: usize,
16 pub current_model: String,
17}
18
19pub enum ModelSelection<'a> {
20 Builtin(&'a ModelMeta),
21 Custom(String),
22}
23
24impl ModelSelectorState {
25 pub fn new(models: Vec<ModelMeta>, current_model: String) -> Self {
26 let selected = models
27 .iter()
28 .position(|model| model.id == current_model)
29 .unwrap_or(0);
30 Self {
31 models,
32 filter: String::new(),
33 selected,
34 current_model,
35 }
36 }
37
38 pub fn filtered(&self) -> Vec<&ModelMeta> {
39 if self.filter.is_empty() {
40 self.models.iter().collect()
41 } else {
42 let lower = self.filter.to_lowercase();
43 self.models
44 .iter()
45 .filter(|m| {
46 m.name.to_lowercase().contains(&lower)
47 || m.id.to_lowercase().contains(&lower)
48 || m.provider.to_lowercase().contains(&lower)
49 })
50 .collect()
51 }
52 }
53
54 pub fn custom_model(&self) -> Option<String> {
55 let trimmed = self.filter.trim();
56 if trimmed.is_empty() {
57 None
58 } else {
59 Some(trimmed.to_string())
60 }
61 }
62
63 fn option_count(&self) -> usize {
64 self.filtered().len() + usize::from(self.custom_model().is_some())
65 }
66
67 pub fn move_up(&mut self) {
68 if self.selected > 0 {
69 self.selected -= 1;
70 }
71 }
72
73 pub fn move_down(&mut self) {
74 let count = self.option_count();
75 if self.selected + 1 < count {
76 self.selected += 1;
77 }
78 }
79
80 pub fn push_filter(&mut self, c: char) {
81 self.filter.push(c);
82 self.selected = 0;
83 }
84
85 pub fn pop_filter(&mut self) {
86 self.filter.pop();
87 self.selected = 0;
88 }
89
90 pub fn selected_choice(&self) -> Option<ModelSelection<'_>> {
91 let filtered = self.filtered();
92
93 if let Some(model) = filtered.get(self.selected).copied() {
94 return Some(ModelSelection::Builtin(model));
95 }
96
97 let custom_index = filtered.len();
98 self.custom_model().and_then(|custom| {
99 (self.selected == custom_index).then_some(ModelSelection::Custom(custom))
100 })
101 }
102}
103
104pub struct ModelSelectorView<'a> {
106 state: &'a ModelSelectorState,
107 theme: &'a Theme,
108}
109
110impl<'a> ModelSelectorView<'a> {
111 pub fn new(state: &'a ModelSelectorState, theme: &'a Theme) -> Self {
112 Self { state, theme }
113 }
114}
115
116impl Widget for ModelSelectorView<'_> {
117 fn render(self, area: Rect, buf: &mut Buffer) {
118 if area.height < 5 || area.width < 20 {
119 return;
120 }
121
122 Clear.render(area, buf);
123
124 let title = if self.state.filter.is_empty() {
125 " Select Model ".to_string()
126 } else {
127 format!(" Select Model [{}] ", self.state.filter)
128 };
129
130 let block = Block::default()
131 .title(title)
132 .borders(Borders::ALL)
133 .border_style(self.theme.accent_style());
134 let inner = block.inner(area);
135 block.render(area, buf);
136
137 let filtered = self.state.filtered();
138 let mut row: usize = 0;
139 let custom_model = self.state.custom_model();
140
141 let mut current_provider = String::new();
142
143 for (i, model) in filtered.iter().enumerate() {
144 if row >= inner.height as usize {
145 break;
146 }
147
148 if model.provider != current_provider {
149 current_provider = model.provider.clone();
150 let header_line = Line::from(Span::styled(
151 format!(" {}", current_provider.to_uppercase()),
152 Style::default()
153 .fg(self.theme.muted)
154 .add_modifier(Modifier::BOLD),
155 ));
156 buf.set_line(inner.x, inner.y + row as u16, &header_line, inner.width);
157 row += 1;
158 if row >= inner.height as usize {
159 break;
160 }
161 }
162
163 let is_selected = i == self.state.selected;
164 let is_current = model.id == self.state.current_model;
165
166 let marker = if is_current { "✓ " } else { " " };
167 let style = if is_selected {
168 self.theme.selected_style()
169 } else {
170 Style::default()
171 };
172
173 let context_str = format!("{}k", model.context_window / 1000);
174 let price_str =
175 if model.pricing.input_per_mtok == 0.0 && model.pricing.output_per_mtok == 0.0 {
176 "n/a".to_string()
177 } else {
178 format!(
179 "${:.2}/{:.2}",
180 model.pricing.input_per_mtok, model.pricing.output_per_mtok
181 )
182 };
183
184 let line = Line::from(vec![
185 Span::styled(format!(" {marker}"), self.theme.accent_style()),
186 Span::styled(model.name.clone(), style),
187 Span::raw(" "),
188 Span::styled(context_str, self.theme.muted_style()),
189 Span::raw(" "),
190 Span::styled(price_str, self.theme.muted_style()),
191 ]);
192
193 buf.set_line(inner.x, inner.y + row as u16, &line, inner.width);
194 row += 1;
195 }
196
197 if let Some(ref custom_model) = custom_model {
198 if row < inner.height as usize && !filtered.is_empty() {
199 let spacer = Line::from(Span::styled(
200 " Custom",
201 Style::default()
202 .fg(self.theme.muted)
203 .add_modifier(Modifier::BOLD),
204 ));
205 buf.set_line(inner.x, inner.y + row as u16, &spacer, inner.width);
206 row += 1;
207 }
208
209 if row < inner.height as usize {
210 let custom_index = filtered.len();
211 let is_selected = self.state.selected == custom_index;
212 let is_current = custom_model == &self.state.current_model;
213 let marker = if is_current { "✓ " } else { " " };
214 let style = if is_selected {
215 self.theme.selected_style()
216 } else {
217 Style::default()
218 };
219
220 let line = Line::from(vec![
221 Span::styled(format!(" {marker}"), self.theme.accent_style()),
222 Span::styled("Use custom model: ", self.theme.muted_style()),
223 Span::styled(custom_model, style),
224 ]);
225 buf.set_line(inner.x, inner.y + row as u16, &line, inner.width);
226 }
227 }
228
229 if filtered.is_empty() && custom_model.is_none() {
230 let line = Line::from(Span::styled(
231 " No matching models",
232 self.theme.muted_style(),
233 ));
234 buf.set_line(inner.x, inner.y, &line, inner.width);
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use imp_llm::model::{Capabilities, ModelPricing};
243
244 fn test_model(id: &str) -> ModelMeta {
245 ModelMeta {
246 id: id.into(),
247 provider: "openai".into(),
248 name: id.into(),
249 context_window: 128_000,
250 max_output_tokens: 16_384,
251 pricing: ModelPricing::default(),
252 capabilities: Capabilities {
253 reasoning: true,
254 images: true,
255 tool_use: true,
256 },
257 }
258 }
259
260 #[test]
261 fn model_selector_initially_selects_current_model() {
262 let state = ModelSelectorState::new(
263 vec![test_model("gpt-5.4"), test_model("gpt-4o")],
264 "gpt-4o".into(),
265 );
266
267 assert_eq!(state.selected, 1);
268 match state.selected_choice() {
269 Some(ModelSelection::Builtin(model)) => assert_eq!(model.id, "gpt-4o"),
270 _ => panic!("expected current model to be selected"),
271 }
272 }
273
274 #[test]
275 fn custom_model_is_available_after_builtin_matches() {
276 let mut state = ModelSelectorState::new(vec![test_model("gpt-4o")], "gpt-4o".into());
277 state.push_filter('g');
278 state.push_filter('p');
279 state.push_filter('t');
280 state.push_filter('-');
281 state.push_filter('4');
282 state.push_filter('o');
283
284 match state.selected_choice() {
285 Some(ModelSelection::Builtin(model)) => assert_eq!(model.id, "gpt-4o"),
286 _ => panic!("expected builtin model selection"),
287 }
288
289 state.move_down();
290 match state.selected_choice() {
291 Some(ModelSelection::Custom(model)) => assert_eq!(model, "gpt-4o"),
292 _ => panic!("expected custom model selection after builtin matches"),
293 }
294 }
295
296 #[test]
297 fn custom_model_is_selected_when_no_builtin_matches() {
298 let mut state = ModelSelectorState::new(vec![test_model("gpt-5.4")], "gpt-5.4".into());
299 state.push_filter('g');
300 state.push_filter('p');
301 state.push_filter('t');
302 state.push_filter('-');
303 state.push_filter('4');
304 state.push_filter('o');
305
306 state.move_down();
307 match state.selected_choice() {
308 Some(ModelSelection::Custom(model)) => assert_eq!(model, "gpt-4o"),
309 _ => panic!("expected custom model selection"),
310 }
311 }
312}