use crossterm::event::KeyCode;
use ratatui::{
layout::{Constraint, Direction, Layout, Rect},
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, Clear, Paragraph},
Frame,
};
use crate::model_config::{ModelConfig, ModelProvider, Preset, RoleConfig};
pub const CLAUDE_MODELS: &[(&str, &str)] = &[
("claude-sonnet-4-6", "Claude Sonnet 4.6"),
("claude-opus-4-6", "Claude Opus 4.6"),
("claude-haiku-4-5-20251001", "Claude Haiku 4.5"),
];
const ROLE_NAMES: [&str; 7] = [
"Architect",
"Tester",
"Coder",
"Security",
"Critique",
"CTO",
"Complexity",
];
#[derive(Debug, Clone)]
pub struct AvailableModel {
pub name: String,
pub size_gb: f64,
pub provider: ModelProvider,
}
#[derive(Debug, Clone)]
pub struct RoleSlot {
pub role_name: String,
pub selected_index: usize,
pub default_model: String,
}
pub struct ModelPickerState {
pub available_models: Vec<AvailableModel>,
pub roles: Vec<RoleSlot>,
pub active_role: usize,
pub cursor: usize,
}
pub enum PickerAction {
Continue,
Confirm(ModelConfig),
Cancel,
}
impl ModelPickerState {
pub fn new(available: Vec<AvailableModel>, current_config: &ModelConfig) -> Self {
let defaults = [
¤t_config.architect.model,
¤t_config.tester.model,
¤t_config.coder.model,
¤t_config.security.model,
¤t_config.critique.model,
¤t_config.cto.model,
¤t_config.complexity.model,
];
let roles: Vec<RoleSlot> = ROLE_NAMES
.iter()
.zip(defaults.iter())
.map(|(name, default)| {
let selected_index = available
.iter()
.position(|m| m.name == **default)
.unwrap_or(0);
RoleSlot {
role_name: name.to_string(),
selected_index,
default_model: default.to_string(),
}
})
.collect();
let cursor = roles.first().map(|r| r.selected_index).unwrap_or(0);
Self {
available_models: available,
roles,
active_role: 0,
cursor,
}
}
fn sync_cursor(&mut self) {
self.cursor = self.roles[self.active_role].selected_index;
}
pub fn build_config(&self) -> ModelConfig {
let mut cfg = ModelConfig::from_preset(Preset::Premium);
let get = |slot: &RoleSlot| -> RoleConfig {
match self.available_models.get(slot.selected_index) {
Some(m) => match m.provider {
ModelProvider::Cloud => RoleConfig::cloud(&m.name),
ModelProvider::Local => RoleConfig::local(&m.name),
},
None => RoleConfig::local(&slot.default_model),
}
};
cfg.architect = get(&self.roles[0]);
cfg.tester = get(&self.roles[1]);
cfg.coder = get(&self.roles[2]);
cfg.security = get(&self.roles[3]);
cfg.critique = get(&self.roles[4]);
cfg.cto = get(&self.roles[5]);
cfg.complexity = get(&self.roles[6]);
cfg
}
pub fn to_toml(&self) -> String {
let cfg = self.build_config();
let mut s = String::from(
"# BattleCommand Forge — Model Configuration (generated by model picker)\n",
);
s.push_str("preset = \"premium\"\n\n");
let sections = [
("architect", &cfg.architect),
("tester", &cfg.tester),
("coder", &cfg.coder),
("security", &cfg.security),
("critique", &cfg.critique),
("cto", &cfg.cto),
("complexity", &cfg.complexity),
];
for (name, role) in §ions {
s.push_str(&format!("[{}]\n", name));
s.push_str(&format!("model = \"{}\"\n", role.model));
s.push_str(&format!("provider = \"{}\"\n\n", role.provider));
}
s
}
}
pub fn handle_picker_input(state: &mut ModelPickerState, key: KeyCode) -> PickerAction {
match key {
KeyCode::Up if state.cursor > 0 => {
state.cursor -= 1;
}
KeyCode::Down if state.cursor + 1 < state.available_models.len() => {
state.cursor += 1;
}
KeyCode::Enter => {
state.roles[state.active_role].selected_index = state.cursor;
if state.active_role + 1 < state.roles.len() {
state.active_role += 1;
state.sync_cursor();
} else {
return PickerAction::Confirm(state.build_config());
}
}
KeyCode::Tab => {
state.roles[state.active_role].selected_index = state.cursor;
state.active_role = (state.active_role + 1) % state.roles.len();
state.sync_cursor();
}
KeyCode::BackTab => {
state.roles[state.active_role].selected_index = state.cursor;
if state.active_role > 0 {
state.active_role -= 1;
} else {
state.active_role = state.roles.len() - 1;
}
state.sync_cursor();
}
KeyCode::Char(' ') => {
state.roles[state.active_role].selected_index = state.cursor;
}
KeyCode::Esc => {
return PickerAction::Cancel;
}
_ => {}
}
PickerAction::Continue
}
pub fn draw_model_picker(f: &mut Frame, state: &ModelPickerState) {
let area = f.area();
let picker_area = centered_rect(80, 85, area);
f.render_widget(Clear, picker_area);
let outer = Block::default()
.borders(Borders::ALL)
.title(" BattleCommand Forge — Model Setup ")
.title_style(Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
.border_style(Style::default().fg(Color::Yellow));
let inner = outer.inner(picker_area);
f.render_widget(outer, picker_area);
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(2), Constraint::Length(2), Constraint::Min(5), Constraint::Length(2), ])
.split(inner);
let instructions = Paragraph::new(Line::from(vec![
Span::raw(" Assign a model to each pipeline role. "),
Span::styled(
"Up/Down",
Style::default()
.fg(Color::Cyan)
.add_modifier(Modifier::BOLD),
),
Span::raw(" navigate "),
Span::styled(
"Enter",
Style::default()
.fg(Color::Cyan)
.add_modifier(Modifier::BOLD),
),
Span::raw(" select & next "),
Span::styled(
"Tab",
Style::default()
.fg(Color::Cyan)
.add_modifier(Modifier::BOLD),
),
Span::raw(" switch role "),
Span::styled(
"Esc",
Style::default()
.fg(Color::Cyan)
.add_modifier(Modifier::BOLD),
),
Span::raw(" cancel"),
]));
f.render_widget(instructions, chunks[0]);
let role_spans: Vec<Span> = state
.roles
.iter()
.enumerate()
.flat_map(|(i, slot)| {
let selected_model = state
.available_models
.get(slot.selected_index)
.map(|m| m.name.as_str())
.unwrap_or("?");
let label = format!(
" {} [{}] ",
slot.role_name,
truncate_model(selected_model, 12)
);
let style = if i == state.active_role {
Style::default()
.fg(Color::Black)
.bg(Color::Yellow)
.add_modifier(Modifier::BOLD)
} else if slot.selected_index
!= state
.available_models
.iter()
.position(|m| m.name == slot.default_model)
.unwrap_or(usize::MAX)
{
Style::default()
.fg(Color::Green)
.add_modifier(Modifier::BOLD)
} else {
Style::default().fg(Color::DarkGray)
};
vec![Span::styled(label, style), Span::raw(" ")]
})
.collect();
let role_line = Paragraph::new(Line::from(role_spans));
f.render_widget(role_line, chunks[1]);
let list_area = chunks[2];
let mut lines: Vec<Line> = Vec::new();
lines.push(Line::from(""));
let current_selection = state.roles[state.active_role].selected_index;
for (i, model) in state.available_models.iter().enumerate() {
let is_cursor = i == state.cursor;
let is_selected = i == current_selection;
let pointer = if is_cursor { " > " } else { " " };
let marker = if is_selected { " *" } else { " " };
let name_style = if is_cursor {
Style::default()
.fg(Color::Yellow)
.add_modifier(Modifier::BOLD)
} else if is_selected {
Style::default().fg(Color::Green)
} else {
Style::default().fg(Color::White)
};
let is_cloud = matches!(model.provider, ModelProvider::Cloud);
let size_str = if is_cloud {
" CLOUD".to_string()
} else {
format!("{:>6.1} GB", model.size_gb)
};
let size_style = if is_cloud {
Style::default()
.fg(Color::Magenta)
.add_modifier(Modifier::BOLD)
} else {
Style::default().fg(Color::DarkGray)
};
lines.push(Line::from(vec![
Span::styled(
pointer,
Style::default()
.fg(Color::Yellow)
.add_modifier(Modifier::BOLD),
),
Span::styled(format!("{:<40}", model.name), name_style),
Span::styled(size_str, size_style),
Span::styled(
marker,
Style::default()
.fg(Color::Green)
.add_modifier(Modifier::BOLD),
),
]));
}
let model_list = Paragraph::new(lines).block(
Block::default()
.borders(Borders::TOP)
.title(format!(
" {} — select model ",
state.roles[state.active_role].role_name
))
.title_style(Style::default().fg(Color::Cyan)),
);
f.render_widget(model_list, list_area);
let mut unique_models: Vec<&str> = state
.roles
.iter()
.filter_map(|slot| {
state
.available_models
.get(slot.selected_index)
.filter(|m| !matches!(m.provider, ModelProvider::Cloud))
.map(|m| m.name.as_str())
})
.collect();
unique_models.sort();
unique_models.dedup();
let unique_vram: f64 = unique_models
.iter()
.map(|name| {
state
.available_models
.iter()
.find(|m| m.name == *name)
.map(|m| m.size_gb)
.unwrap_or(0.0)
})
.sum();
let has_cloud = state.roles.iter().any(|slot| {
state
.available_models
.get(slot.selected_index)
.map(|m| matches!(m.provider, ModelProvider::Cloud))
.unwrap_or(false)
});
let mut footer_spans = vec![
Span::raw(" "),
Span::styled(
format!("VRAM: {:.0} GB (unique local)", unique_vram),
Style::default().fg(Color::Cyan),
),
];
if has_cloud {
footer_spans.push(Span::raw(" "));
footer_spans.push(Span::styled(
"+ CLOUD (API key required)",
Style::default().fg(Color::Magenta),
));
}
footer_spans.push(Span::raw(" | "));
footer_spans.push(Span::styled(
format!("Role {}/{}", state.active_role + 1, state.roles.len()),
Style::default().fg(Color::DarkGray),
));
let footer = Paragraph::new(Line::from(footer_spans));
f.render_widget(footer, chunks[3]);
}
fn centered_rect(percent_x: u16, percent_y: u16, r: Rect) -> Rect {
let popup_layout = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Percentage((100 - percent_y) / 2),
Constraint::Percentage(percent_y),
Constraint::Percentage((100 - percent_y) / 2),
])
.split(r);
Layout::default()
.direction(Direction::Horizontal)
.constraints([
Constraint::Percentage((100 - percent_x) / 2),
Constraint::Percentage(percent_x),
Constraint::Percentage((100 - percent_x) / 2),
])
.split(popup_layout[1])[1]
}
fn truncate_model(name: &str, max: usize) -> String {
if name.len() <= max {
name.to_string()
} else {
format!("{}...", &name[..max.saturating_sub(3)])
}
}