burn-lm-cli 0.0.1

Burn Large Models Engine - CLI.
Documentation
use anyhow::Error;
use burn_lm_inference::{InferenceJob, InferenceTask, Message, MessageRole, StdOutListener};
use burn_lm_registry::Registry;
use clap::CommandFactory as _;
use rustyline::{history::DefaultHistory, Editor};
use yansi::Paint;

use super::BurnLMPromptHelper;

const MESSAGE_CMD: &str = "<message>";

/// Message subcommand.
#[derive(clap::Subcommand)]
#[command(name = "message", disable_help_subcommand = true)]
pub enum MessageCommand {
    /// Message (prompt) for inference
    #[command(name = MESSAGE_CMD)]
    Msg { message: String },
    /// Display slash commands help
    Help,
    /// Toggle stats
    Stats,
    /// Clear chat session context
    Clear,
    /// Exit chat session
    Exit,
}

// Dummy wrapper to get CommandFactory implemented
#[derive(clap::Parser)]
#[command(name = "chat", about = "Burn LM Chat", disable_help_subcommand = true)]
struct MessageCli {
    #[command(subcommand)]
    command: MessageCommand,
}

// custom rustyline editor to automatically insert the `MESSAGE_CMD` command
// in front of the message and parse slash commands (for instance /exit).
struct ChatEditor<H: rustyline::Helper> {
    editor: Editor<H, DefaultHistory>,
}

impl ChatEditor<BurnLMPromptHelper> {
    fn new() -> Self {
        let mut editor = Editor::<BurnLMPromptHelper, DefaultHistory>::new().unwrap();
        let helper = BurnLMPromptHelper::new(yansi::Color::Yellow.bold());
        editor.set_helper(Some(helper));
        Self { editor }
    }
}

impl cloop::InputReader for ChatEditor<BurnLMPromptHelper> {
    fn read(&mut self, prompt: &str) -> std::io::Result<cloop::InputResult> {
        match self.editor.read(prompt) {
            Ok(cloop::InputResult::Input(s)) => {
                if let (Some(cmd), rest) = burn_lm_inference::utils::parse_command(&s) {
                    Ok(cloop::InputResult::Input(format!("{cmd} {rest}")))
                } else {
                    // consider any freefrom input a message
                    Ok(cloop::InputResult::Input(format!("{MESSAGE_CMD} \"{s}\"")))
                }
            }
            other => other,
        }
    }
}

#[derive(Default)]
struct ChatContext {
    stats: bool,
}

impl ChatContext {
    pub fn new() -> Self {
        Self { stats: true }
    }
}

pub(crate) fn create() -> clap::Command {
    let mut root = clap::Command::new("chat").about("Start a chat session with the chosen model");
    let registry = Registry::new();
    // Create a a subcommand for each registered model with its associated  flags
    let mut installed: Vec<_> = registry.get().iter().collect();
    installed.sort_by_key(|(key, ..)| *key);
    for (_name, plugin) in installed {
        let subcommand = clap::Command::new(plugin.model_cli_param_name())
            .about(format!("Chat with {} model", plugin.model_name()))
            .args((plugin.create_cli_flags_fn())().get_arguments());
        root = root.subcommand(subcommand);
    }
    root
}

pub(crate) fn handle(
    args: &clap::ArgMatches,
    backend: &str,
    dtype: &str,
) -> super::HandleCommandResult {
    let plugin_name = match args.subcommand_name() {
        Some(cmd) => cmd,
        None => {
            create().print_help().unwrap();
            return Ok(None);
        }
    };

    // retrieve registered plugin
    let registry = Registry::new();
    let plugin = registry
        .get()
        .iter()
        .find(|(_, p)| p.model_cli_param_name() == plugin_name.to_lowercase())
        .map(|(_, plugin)| plugin);
    let plugin = plugin.unwrap_or_else(|| panic!("Plugin should be registered: {plugin_name}"));

    if !plugin.is_downloaded() {
        return Err(Error::msg(format!(
            "Model is not downloaded, run `download {}` first.",
            plugin_name
        )));
    }

    let plugin_args = match args.subcommand_matches(plugin_name) {
        Some(args) => args,
        None => panic!("Model {plugin_name} not available, did you forget to download it first?"),
    };
    plugin.parse_cli_config(plugin_args);

    // load the model
    let mut spin_msg = super::SpinningMessage::new(
        &format!("loading model '{}'...", plugin.model_name()),
        "model loaded!",
    );
    plugin.load()?;
    spin_msg.end(false);

    // create chat shell
    let app_name = format!("({backend}-{dtype}) chat|{}", plugin.model_name());
    let delim = "> ";
    let handler = |args: MessageCommand, ctx: &mut ChatContext| -> cloop::ShellResult {
        match args {
            MessageCommand::Msg { message } => {
                let formatted_msg = Message {
                    role: MessageRole::User,
                    content: message,
                    refusal: None,
                };
                let task = InferenceTask::Message(formatted_msg);
                let (job, handle) = InferenceJob::create(task, StdOutListener::default());
                let result = plugin.run_job(job);
                handle.join();

                match result {
                    Ok(answer) => {
                        println!();

                        if ctx.stats {
                            crate::utils::display_stats(&answer);
                        }
                    }
                    Err(err) => anyhow::bail!("An error occurred: {err}"),
                }
                Ok(cloop::ShellAction::Continue)
            }
            MessageCommand::Help => {
                MessageCli::command()
                    .override_usage("<command>")
                    .print_help()
                    .expect("help output should be printed");
                Ok(cloop::ShellAction::Continue)
            }
            MessageCommand::Stats => {
                ctx.stats = !ctx.stats;
                let msg = format!("Stats toggled {}!", if ctx.stats { "on" } else { "off" });
                println!("{}", msg.bright_black().bold());
                Ok(cloop::ShellAction::Continue)
            }
            MessageCommand::Exit => {
                let stats = plugin.unload();
                if let Ok(Some(stats)) = stats {
                    println!("{}", stats.display_stats());
                }
                Ok(cloop::ShellAction::Exit)
            }
            MessageCommand::Clear => {
                match plugin.clear_state() {
                    Ok(_) => {
                        let msg = "Chat state cleared!".to_string();
                        println!("{}", msg.bright_black().bold());
                    }
                    Err(err) => anyhow::bail!("An error occurred: {err}"),
                }
                Ok(cloop::ShellAction::Continue)
            }
        }
    };

    println!();
    let mut shell = cloop::Shell::new(
        format!("{app_name}{delim}"),
        ChatContext::new(),
        ChatEditor::new(),
        cloop::ClapSubcommandParser::default(),
        handler,
    );

    println!("Chat session started! (press CTRL+D or type /exit to close session)");
    shell.run().unwrap();
    println!("Chat session closed!");
    Ok(None)
}