Skip to main content

battlecommand_forge/
model_picker.rs

1//! Interactive model picker overlay for TUI.
2//! Ported from battleclaw-v2 model_picker.rs, adapted for forge's 9-stage pipeline roles.
3
4use crossterm::event::KeyCode;
5use ratatui::{
6    layout::{Constraint, Direction, Layout, Rect},
7    style::{Color, Modifier, Style},
8    text::{Line, Span},
9    widgets::{Block, Borders, Clear, Paragraph},
10    Frame,
11};
12
13use crate::model_config::{ModelConfig, ModelProvider, Preset, RoleConfig};
14
15/// Claude cloud models available in the picker.
16pub const CLAUDE_MODELS: &[(&str, &str)] = &[
17    ("claude-sonnet-4-6", "Claude Sonnet 4.6"),
18    ("claude-opus-4-6", "Claude Opus 4.6"),
19    ("claude-haiku-4-5-20251001", "Claude Haiku 4.5"),
20];
21
22// ─── Role definitions (forge pipeline) ───
23
24const ROLE_NAMES: [&str; 7] = [
25    "Architect",
26    "Tester",
27    "Coder",
28    "Security",
29    "Critique",
30    "CTO",
31    "Complexity",
32];
33
34// ─── Available model info ───
35
36#[derive(Debug, Clone)]
37pub struct AvailableModel {
38    pub name: String,
39    pub size_gb: f64,
40    pub provider: ModelProvider,
41}
42
43// ─── Picker state ───
44
45#[derive(Debug, Clone)]
46pub struct RoleSlot {
47    pub role_name: String,
48    pub selected_index: usize,
49    pub default_model: String,
50}
51
52pub struct ModelPickerState {
53    pub available_models: Vec<AvailableModel>,
54    pub roles: Vec<RoleSlot>,
55    pub active_role: usize,
56    pub cursor: usize,
57}
58
59pub enum PickerAction {
60    Continue,
61    Confirm(ModelConfig),
62    Cancel,
63}
64
65impl ModelPickerState {
66    pub fn new(available: Vec<AvailableModel>, current_config: &ModelConfig) -> Self {
67        let defaults = [
68            &current_config.architect.model,
69            &current_config.tester.model,
70            &current_config.coder.model,
71            &current_config.security.model,
72            &current_config.critique.model,
73            &current_config.cto.model,
74            &current_config.complexity.model,
75        ];
76
77        let roles: Vec<RoleSlot> = ROLE_NAMES
78            .iter()
79            .zip(defaults.iter())
80            .map(|(name, default)| {
81                let selected_index = available
82                    .iter()
83                    .position(|m| m.name == **default)
84                    .unwrap_or(0);
85                RoleSlot {
86                    role_name: name.to_string(),
87                    selected_index,
88                    default_model: default.to_string(),
89                }
90            })
91            .collect();
92
93        let cursor = roles.first().map(|r| r.selected_index).unwrap_or(0);
94
95        Self {
96            available_models: available,
97            roles,
98            active_role: 0,
99            cursor,
100        }
101    }
102
103    fn sync_cursor(&mut self) {
104        self.cursor = self.roles[self.active_role].selected_index;
105    }
106
107    /// Build a ModelConfig from current selections.
108    pub fn build_config(&self) -> ModelConfig {
109        let mut cfg = ModelConfig::from_preset(Preset::Premium);
110
111        let get = |slot: &RoleSlot| -> RoleConfig {
112            match self.available_models.get(slot.selected_index) {
113                Some(m) => match m.provider {
114                    ModelProvider::Cloud => RoleConfig::cloud(&m.name),
115                    ModelProvider::Local => RoleConfig::local(&m.name),
116                },
117                None => RoleConfig::local(&slot.default_model),
118            }
119        };
120
121        cfg.architect = get(&self.roles[0]);
122        cfg.tester = get(&self.roles[1]);
123        cfg.coder = get(&self.roles[2]);
124        cfg.security = get(&self.roles[3]);
125        cfg.critique = get(&self.roles[4]);
126        cfg.cto = get(&self.roles[5]);
127        cfg.complexity = get(&self.roles[6]);
128
129        cfg
130    }
131
132    /// Generate TOML string from current selections.
133    pub fn to_toml(&self) -> String {
134        let cfg = self.build_config();
135        let mut s = String::from(
136            "# BattleCommand Forge — Model Configuration (generated by model picker)\n",
137        );
138        s.push_str("preset = \"premium\"\n\n");
139
140        let sections = [
141            ("architect", &cfg.architect),
142            ("tester", &cfg.tester),
143            ("coder", &cfg.coder),
144            ("security", &cfg.security),
145            ("critique", &cfg.critique),
146            ("cto", &cfg.cto),
147            ("complexity", &cfg.complexity),
148        ];
149
150        for (name, role) in &sections {
151            s.push_str(&format!("[{}]\n", name));
152            s.push_str(&format!("model = \"{}\"\n", role.model));
153            s.push_str(&format!("provider = \"{}\"\n\n", role.provider));
154        }
155
156        s
157    }
158}
159
160/// Handle key input for the model picker.
161pub fn handle_picker_input(state: &mut ModelPickerState, key: KeyCode) -> PickerAction {
162    match key {
163        KeyCode::Up if state.cursor > 0 => {
164            state.cursor -= 1;
165        }
166        KeyCode::Down if state.cursor + 1 < state.available_models.len() => {
167            state.cursor += 1;
168        }
169        KeyCode::Enter => {
170            state.roles[state.active_role].selected_index = state.cursor;
171            if state.active_role + 1 < state.roles.len() {
172                state.active_role += 1;
173                state.sync_cursor();
174            } else {
175                return PickerAction::Confirm(state.build_config());
176            }
177        }
178        KeyCode::Tab => {
179            state.roles[state.active_role].selected_index = state.cursor;
180            state.active_role = (state.active_role + 1) % state.roles.len();
181            state.sync_cursor();
182        }
183        KeyCode::BackTab => {
184            state.roles[state.active_role].selected_index = state.cursor;
185            if state.active_role > 0 {
186                state.active_role -= 1;
187            } else {
188                state.active_role = state.roles.len() - 1;
189            }
190            state.sync_cursor();
191        }
192        KeyCode::Char(' ') => {
193            state.roles[state.active_role].selected_index = state.cursor;
194        }
195        KeyCode::Esc => {
196            return PickerAction::Cancel;
197        }
198        _ => {}
199    }
200    PickerAction::Continue
201}
202
203/// Draw the model picker overlay.
204pub fn draw_model_picker(f: &mut Frame, state: &ModelPickerState) {
205    let area = f.area();
206    let picker_area = centered_rect(80, 85, area);
207
208    f.render_widget(Clear, picker_area);
209
210    let outer = Block::default()
211        .borders(Borders::ALL)
212        .title(" BattleCommand Forge — Model Setup ")
213        .title_style(Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
214        .border_style(Style::default().fg(Color::Yellow));
215    let inner = outer.inner(picker_area);
216    f.render_widget(outer, picker_area);
217
218    let chunks = Layout::default()
219        .direction(Direction::Vertical)
220        .constraints([
221            Constraint::Length(2), // Instructions
222            Constraint::Length(2), // Role tabs
223            Constraint::Min(5),    // Model list
224            Constraint::Length(2), // Footer
225        ])
226        .split(inner);
227
228    // ─── Instructions ───
229    let instructions = Paragraph::new(Line::from(vec![
230        Span::raw("  Assign a model to each pipeline role. "),
231        Span::styled(
232            "Up/Down",
233            Style::default()
234                .fg(Color::Cyan)
235                .add_modifier(Modifier::BOLD),
236        ),
237        Span::raw(" navigate  "),
238        Span::styled(
239            "Enter",
240            Style::default()
241                .fg(Color::Cyan)
242                .add_modifier(Modifier::BOLD),
243        ),
244        Span::raw(" select & next  "),
245        Span::styled(
246            "Tab",
247            Style::default()
248                .fg(Color::Cyan)
249                .add_modifier(Modifier::BOLD),
250        ),
251        Span::raw(" switch role  "),
252        Span::styled(
253            "Esc",
254            Style::default()
255                .fg(Color::Cyan)
256                .add_modifier(Modifier::BOLD),
257        ),
258        Span::raw(" cancel"),
259    ]));
260    f.render_widget(instructions, chunks[0]);
261
262    // ─── Role tabs ───
263    let role_spans: Vec<Span> = state
264        .roles
265        .iter()
266        .enumerate()
267        .flat_map(|(i, slot)| {
268            let selected_model = state
269                .available_models
270                .get(slot.selected_index)
271                .map(|m| m.name.as_str())
272                .unwrap_or("?");
273
274            let label = format!(
275                " {} [{}] ",
276                slot.role_name,
277                truncate_model(selected_model, 12)
278            );
279
280            let style = if i == state.active_role {
281                Style::default()
282                    .fg(Color::Black)
283                    .bg(Color::Yellow)
284                    .add_modifier(Modifier::BOLD)
285            } else if slot.selected_index
286                != state
287                    .available_models
288                    .iter()
289                    .position(|m| m.name == slot.default_model)
290                    .unwrap_or(usize::MAX)
291            {
292                Style::default()
293                    .fg(Color::Green)
294                    .add_modifier(Modifier::BOLD)
295            } else {
296                Style::default().fg(Color::DarkGray)
297            };
298
299            vec![Span::styled(label, style), Span::raw(" ")]
300        })
301        .collect();
302
303    let role_line = Paragraph::new(Line::from(role_spans));
304    f.render_widget(role_line, chunks[1]);
305
306    // ─── Model list ───
307    let list_area = chunks[2];
308    let mut lines: Vec<Line> = Vec::new();
309    lines.push(Line::from(""));
310
311    let current_selection = state.roles[state.active_role].selected_index;
312
313    for (i, model) in state.available_models.iter().enumerate() {
314        let is_cursor = i == state.cursor;
315        let is_selected = i == current_selection;
316
317        let pointer = if is_cursor { " > " } else { "   " };
318        let marker = if is_selected { " *" } else { "  " };
319
320        let name_style = if is_cursor {
321            Style::default()
322                .fg(Color::Yellow)
323                .add_modifier(Modifier::BOLD)
324        } else if is_selected {
325            Style::default().fg(Color::Green)
326        } else {
327            Style::default().fg(Color::White)
328        };
329
330        let is_cloud = matches!(model.provider, ModelProvider::Cloud);
331        let size_str = if is_cloud {
332            " CLOUD".to_string()
333        } else {
334            format!("{:>6.1} GB", model.size_gb)
335        };
336
337        let size_style = if is_cloud {
338            Style::default()
339                .fg(Color::Magenta)
340                .add_modifier(Modifier::BOLD)
341        } else {
342            Style::default().fg(Color::DarkGray)
343        };
344
345        lines.push(Line::from(vec![
346            Span::styled(
347                pointer,
348                Style::default()
349                    .fg(Color::Yellow)
350                    .add_modifier(Modifier::BOLD),
351            ),
352            Span::styled(format!("{:<40}", model.name), name_style),
353            Span::styled(size_str, size_style),
354            Span::styled(
355                marker,
356                Style::default()
357                    .fg(Color::Green)
358                    .add_modifier(Modifier::BOLD),
359            ),
360        ]));
361    }
362
363    let model_list = Paragraph::new(lines).block(
364        Block::default()
365            .borders(Borders::TOP)
366            .title(format!(
367                " {} — select model ",
368                state.roles[state.active_role].role_name
369            ))
370            .title_style(Style::default().fg(Color::Cyan)),
371    );
372    f.render_widget(model_list, list_area);
373
374    // ─── Footer ───
375    let mut unique_models: Vec<&str> = state
376        .roles
377        .iter()
378        .filter_map(|slot| {
379            state
380                .available_models
381                .get(slot.selected_index)
382                .filter(|m| !matches!(m.provider, ModelProvider::Cloud))
383                .map(|m| m.name.as_str())
384        })
385        .collect();
386    unique_models.sort();
387    unique_models.dedup();
388    let unique_vram: f64 = unique_models
389        .iter()
390        .map(|name| {
391            state
392                .available_models
393                .iter()
394                .find(|m| m.name == *name)
395                .map(|m| m.size_gb)
396                .unwrap_or(0.0)
397        })
398        .sum();
399
400    let has_cloud = state.roles.iter().any(|slot| {
401        state
402            .available_models
403            .get(slot.selected_index)
404            .map(|m| matches!(m.provider, ModelProvider::Cloud))
405            .unwrap_or(false)
406    });
407
408    let mut footer_spans = vec![
409        Span::raw("  "),
410        Span::styled(
411            format!("VRAM: {:.0} GB (unique local)", unique_vram),
412            Style::default().fg(Color::Cyan),
413        ),
414    ];
415    if has_cloud {
416        footer_spans.push(Span::raw("  "));
417        footer_spans.push(Span::styled(
418            "+ CLOUD (API key required)",
419            Style::default().fg(Color::Magenta),
420        ));
421    }
422    footer_spans.push(Span::raw("  |  "));
423    footer_spans.push(Span::styled(
424        format!("Role {}/{}", state.active_role + 1, state.roles.len()),
425        Style::default().fg(Color::DarkGray),
426    ));
427
428    let footer = Paragraph::new(Line::from(footer_spans));
429    f.render_widget(footer, chunks[3]);
430}
431
432fn centered_rect(percent_x: u16, percent_y: u16, r: Rect) -> Rect {
433    let popup_layout = Layout::default()
434        .direction(Direction::Vertical)
435        .constraints([
436            Constraint::Percentage((100 - percent_y) / 2),
437            Constraint::Percentage(percent_y),
438            Constraint::Percentage((100 - percent_y) / 2),
439        ])
440        .split(r);
441
442    Layout::default()
443        .direction(Direction::Horizontal)
444        .constraints([
445            Constraint::Percentage((100 - percent_x) / 2),
446            Constraint::Percentage(percent_x),
447            Constraint::Percentage((100 - percent_x) / 2),
448        ])
449        .split(popup_layout[1])[1]
450}
451
452fn truncate_model(name: &str, max: usize) -> String {
453    if name.len() <= max {
454        name.to_string()
455    } else {
456        format!("{}...", &name[..max.saturating_sub(3)])
457    }
458}