use anyhow::Error;
use burn_lm_inference::{InferenceJob, InferenceTask, Message, MessageRole, StdOutListener};
use burn_lm_registry::Registry;
use clap::CommandFactory as _;
use rustyline::{history::DefaultHistory, Editor};
use yansi::Paint;
use super::BurnLMPromptHelper;
const MESSAGE_CMD: &str = "<message>";
#[derive(clap::Subcommand)]
#[command(name = "message", disable_help_subcommand = true)]
pub enum MessageCommand {
#[command(name = MESSAGE_CMD)]
Msg { message: String },
Help,
Stats,
Clear,
Exit,
}
#[derive(clap::Parser)]
#[command(name = "chat", about = "Burn LM Chat", disable_help_subcommand = true)]
struct MessageCli {
#[command(subcommand)]
command: MessageCommand,
}
struct ChatEditor<H: rustyline::Helper> {
editor: Editor<H, DefaultHistory>,
}
impl ChatEditor<BurnLMPromptHelper> {
fn new() -> Self {
let mut editor = Editor::<BurnLMPromptHelper, DefaultHistory>::new().unwrap();
let helper = BurnLMPromptHelper::new(yansi::Color::Yellow.bold());
editor.set_helper(Some(helper));
Self { editor }
}
}
impl cloop::InputReader for ChatEditor<BurnLMPromptHelper> {
fn read(&mut self, prompt: &str) -> std::io::Result<cloop::InputResult> {
match self.editor.read(prompt) {
Ok(cloop::InputResult::Input(s)) => {
if let (Some(cmd), rest) = burn_lm_inference::utils::parse_command(&s) {
Ok(cloop::InputResult::Input(format!("{cmd} {rest}")))
} else {
Ok(cloop::InputResult::Input(format!("{MESSAGE_CMD} \"{s}\"")))
}
}
other => other,
}
}
}
#[derive(Default)]
struct ChatContext {
stats: bool,
}
impl ChatContext {
pub fn new() -> Self {
Self { stats: true }
}
}
pub(crate) fn create() -> clap::Command {
let mut root = clap::Command::new("chat").about("Start a chat session with the chosen model");
let registry = Registry::new();
let mut installed: Vec<_> = registry.get().iter().collect();
installed.sort_by_key(|(key, ..)| *key);
for (_name, plugin) in installed {
let subcommand = clap::Command::new(plugin.model_cli_param_name())
.about(format!("Chat with {} model", plugin.model_name()))
.args((plugin.create_cli_flags_fn())().get_arguments());
root = root.subcommand(subcommand);
}
root
}
pub(crate) fn handle(
args: &clap::ArgMatches,
backend: &str,
dtype: &str,
) -> super::HandleCommandResult {
let plugin_name = match args.subcommand_name() {
Some(cmd) => cmd,
None => {
create().print_help().unwrap();
return Ok(None);
}
};
let registry = Registry::new();
let plugin = registry
.get()
.iter()
.find(|(_, p)| p.model_cli_param_name() == plugin_name.to_lowercase())
.map(|(_, plugin)| plugin);
let plugin = plugin.unwrap_or_else(|| panic!("Plugin should be registered: {plugin_name}"));
if !plugin.is_downloaded() {
return Err(Error::msg(format!(
"Model is not downloaded, run `download {}` first.",
plugin_name
)));
}
let plugin_args = match args.subcommand_matches(plugin_name) {
Some(args) => args,
None => panic!("Model {plugin_name} not available, did you forget to download it first?"),
};
plugin.parse_cli_config(plugin_args);
let mut spin_msg = super::SpinningMessage::new(
&format!("loading model '{}'...", plugin.model_name()),
"model loaded!",
);
plugin.load()?;
spin_msg.end(false);
let app_name = format!("({backend}-{dtype}) chat|{}", plugin.model_name());
let delim = "> ";
let handler = |args: MessageCommand, ctx: &mut ChatContext| -> cloop::ShellResult {
match args {
MessageCommand::Msg { message } => {
let formatted_msg = Message {
role: MessageRole::User,
content: message,
refusal: None,
};
let task = InferenceTask::Message(formatted_msg);
let (job, handle) = InferenceJob::create(task, StdOutListener::default());
let result = plugin.run_job(job);
handle.join();
match result {
Ok(answer) => {
println!();
if ctx.stats {
crate::utils::display_stats(&answer);
}
}
Err(err) => anyhow::bail!("An error occurred: {err}"),
}
Ok(cloop::ShellAction::Continue)
}
MessageCommand::Help => {
MessageCli::command()
.override_usage("<command>")
.print_help()
.expect("help output should be printed");
Ok(cloop::ShellAction::Continue)
}
MessageCommand::Stats => {
ctx.stats = !ctx.stats;
let msg = format!("Stats toggled {}!", if ctx.stats { "on" } else { "off" });
println!("{}", msg.bright_black().bold());
Ok(cloop::ShellAction::Continue)
}
MessageCommand::Exit => {
let stats = plugin.unload();
if let Ok(Some(stats)) = stats {
println!("{}", stats.display_stats());
}
Ok(cloop::ShellAction::Exit)
}
MessageCommand::Clear => {
match plugin.clear_state() {
Ok(_) => {
let msg = "Chat state cleared!".to_string();
println!("{}", msg.bright_black().bold());
}
Err(err) => anyhow::bail!("An error occurred: {err}"),
}
Ok(cloop::ShellAction::Continue)
}
}
};
println!();
let mut shell = cloop::Shell::new(
format!("{app_name}{delim}"),
ChatContext::new(),
ChatEditor::new(),
cloop::ClapSubcommandParser::default(),
handler,
);
println!("Chat session started! (press CTRL+D or type /exit to close session)");
shell.run().unwrap();
println!("Chat session closed!");
Ok(None)
}