use super::SlashCommand;
use async_trait::async_trait;
use serenity::all::{
CommandInteraction, Context, CreateActionRow, CreateSelectMenu, CreateSelectMenuKind,
CreateSelectMenuOption, EditInteractionResponse,
};
use std::sync::Arc;
use crate::agent::AiAgent;
use tracing::{error, info};
pub struct ModelCommand;
const MAX_SELECT_OPTIONS: usize = 125;
const SELECT_CHUNK_SIZE: usize = 25;
fn capped_model_count(models_len: usize) -> usize {
models_len.min(MAX_SELECT_OPTIONS)
}
fn build_model_value(provider: &str, model_id: &str) -> String {
format!("{}|{}", provider, model_id)
}
fn parse_model_value(composite: &str) -> Option<(&str, &str)> {
composite.split_once('|')
}
#[async_trait]
impl SlashCommand for ModelCommand {
fn name(&self) -> &'static str {
"model"
}
fn description(&self, i18n: &crate::i18n::I18n) -> String {
i18n.get("cmd_model_desc")
}
fn options(&self, _i18n: &crate::i18n::I18n) -> Vec<serenity::all::CreateCommandOption> {
vec![]
}
async fn execute(
&self,
ctx: &Context,
command: &CommandInteraction,
state: &crate::AppState,
) -> anyhow::Result<()> {
command.defer_ephemeral(&ctx.http).await?;
let channel_id_str = command.channel_id.to_string();
let channel_config = crate::commands::agent::ChannelConfig::load()
.await
.unwrap_or_default();
let agent_type = channel_config.get_agent_type(&channel_id_str);
let (agent, _) = state
.session_manager
.get_or_create_session(command.channel_id.get(), agent_type, &state.backend_manager)
.await?;
let i18n = state.i18n.read().await;
let models = match agent.get_available_models().await {
Ok(m) => {
info!("Fetched {} models for /model command", m.len());
m
}
Err(e) => {
error!("Failed to fetch models: {}", e);
command
.edit_response(
&ctx.http,
EditInteractionResponse::new()
.content(i18n.get_args("model_fetch_failed", &[e.to_string()])),
)
.await?;
return Ok(());
}
};
if models.is_empty() {
command
.edit_response(
&ctx.http,
EditInteractionResponse::new().content(i18n.get("model_no_available")),
)
.await?;
return Ok(());
}
let mut action_rows = Vec::new();
let total_models = capped_model_count(models.len());
let models_slice = &models[..total_models];
for (idx, chunk) in models_slice.chunks(SELECT_CHUNK_SIZE).enumerate() {
let select_options: Vec<CreateSelectMenuOption> = chunk
.iter()
.map(|m| {
let value = build_model_value(&m.provider, &m.id);
CreateSelectMenuOption::new(&m.label, value)
.description(i18n.get_args("model_provider_desc", &[m.provider.clone()]))
})
.collect();
let select_menu = CreateSelectMenu::new(
format!("model_select_{}", idx), CreateSelectMenuKind::String {
options: select_options,
},
)
.placeholder(i18n.get_args("model_placeholder", &[(idx + 1).to_string()]))
.min_values(1)
.max_values(1);
action_rows.push(CreateActionRow::SelectMenu(select_menu));
}
match command
.edit_response(
&ctx.http,
EditInteractionResponse::new()
.content(i18n.get_args("model_fetched", &[total_models.to_string()]))
.components(action_rows),
)
.await
{
Ok(_) => info!("Successfully sent model select menu(s)"),
Err(e) => error!("Failed to send model select menu: {}", e),
}
Ok(())
}
}
pub async fn handle_model_select(
ctx: &Context,
interaction: &serenity::all::ComponentInteraction,
agent: Arc<dyn AiAgent>,
state: &crate::AppState,
) -> anyhow::Result<()> {
interaction.defer_ephemeral(&ctx.http).await?;
let i18n = state.i18n.read().await;
if let serenity::all::ComponentInteractionDataKind::StringSelect { values } =
&interaction.data.kind
{
if let Some(composite_id) = values.first() {
if let Some((provider, model)) = parse_model_value(composite_id) {
match agent.set_model(provider, model).await {
Ok(_) => {
interaction
.edit_response(
&ctx.http,
EditInteractionResponse::new()
.content(
i18n.get_args(
"model_switched",
&[composite_id.to_string()],
),
)
.components(vec![]), )
.await?;
}
Err(e) => {
interaction
.edit_response(
&ctx.http,
EditInteractionResponse::new()
.content(i18n.get_args("model_failed", &[e.to_string()]))
.components(vec![]),
)
.await?;
}
}
} else {
interaction
.edit_response(
&ctx.http,
EditInteractionResponse::new()
.content(i18n.get("model_invalid"))
.components(vec![]),
)
.await?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{build_model_value, capped_model_count, parse_model_value};
#[test]
fn test_capped_model_count_limited_to_125() {
assert_eq!(capped_model_count(0), 0);
assert_eq!(capped_model_count(24), 24);
assert_eq!(capped_model_count(125), 125);
assert_eq!(capped_model_count(200), 125);
}
#[test]
fn test_build_and_parse_model_value_roundtrip() {
let composite = build_model_value("openai", "gpt-4.1");
let (provider, model) = parse_model_value(&composite).expect("must parse");
assert_eq!(provider, "openai");
assert_eq!(model, "gpt-4.1");
}
#[test]
fn test_parse_model_value_rejects_invalid() {
assert!(parse_model_value("no-delimiter").is_none());
}
}