ai-cli 0.2.1

A CLI tool for all things AI (generating images or audio, chatting with LLM's, you name it)
Documentation
use async_recursion::async_recursion;
use clap::{Args,ValueEnum};
use serde::{Serialize,Deserialize};
use reqwest::Client;
use derive_more::From;
use tiktoken_rs::p50k_base;
use crate::openai::chat::OpenAIChatCommand;
use crate::openai::OpenAIError;
use crate::completion::{CompletionOptions,CompletionFile,ClashingArgumentsError};
use crate::Config;
use log::debug;

#[derive(Args, Clone, Debug, Default, Serialize, Deserialize)]
pub struct ChatCommand {
    #[command(flatten)]
    #[serde(flatten)]
    pub completion: CompletionOptions,

    #[arg(long, short)]
    pub system: Option<String>,

    #[arg(long, short)]
    pub direction: Option<String>,

    #[arg(long)]
    pub provider: Option<ChatProvider>,
}

impl ChatCommand {
    #[async_recursion]
    pub async fn run(&self, client: &Client, config: &Config) -> ChatResult {
        let mut options = ChatOptions::try_from((self, config))?;
        let print_output = !options.completion.quiet.unwrap_or(false);

        if print_output && options.file.transcript.len() > 0 {
            print!("{}", options.file.transcript);
        }

        if !options.ai_responds_first {
            let append = options.completion.append.as_ref().map(|a| &**a);
            let prefix_user = Some(&*options.prefix_user);

            if let None = options.file.read(append, prefix_user, options.no_context) {
                return Ok(vec![]);
            }
        }

        let mut command = OpenAIChatCommand::try_from(options)?;
        command.run(client, config).await
    }
}

#[derive(Default, Debug)]
pub(crate) struct ChatOptions {
    pub ai_responds_first: bool,
    pub completion: CompletionOptions,
    pub direction: Option<ChatMessage>,
    pub system: String,
    pub file: CompletionFile<ChatCommand>,
    pub no_context: bool,
    pub provider: ChatProvider,
    pub prefix_ai: String,
    pub prefix_user: String,
    pub stream: bool,
    pub temperature: f32,
    pub tokens_max: Option<usize>,
    pub tokens_balance: f32,
    pub stop: Vec<String>
}

impl TryFrom<(&ChatCommand, &Config)> for ChatOptions {
    type Error = ChatError;

    fn try_from((command, config): (&ChatCommand, &Config)) -> Result<Self, Self::Error> {
        let file = command.completion.load_session_file::<ChatCommand>(config, command.clone());
        let completion = if file.file.is_some() {
            command.completion.merge(&file.overrides.completion)
        } else {
            command.completion.clone()
        };

        let stream = completion.parse_stream_option()?;
        let system = command.system
            .clone()
            .or_else(|| file.overrides.system.clone())
            .clone()
            .unwrap_or_else(|| String::from("A friendly and helpful AI assistant."));

        let provider = command.provider.unwrap_or_default();

        Ok(ChatOptions {
            ai_responds_first: completion.ai_responds_first.unwrap_or(false),
            direction: command.direction.clone()
                .map(|direction| ChatMessage::new(ChatRole::System, direction)),
            temperature: completion.temperature.unwrap_or(0.8),
            no_context: completion.no_context.unwrap_or(false),
            provider,
            prefix_ai: completion.prefix_ai.clone().unwrap_or_else(|| String::from("AI")),
            prefix_user: completion.prefix_user.clone().unwrap_or_else(|| String::from("USER")),
            system,
            tokens_balance: completion.tokens_balance.unwrap_or(0.5),
            stop: completion.parse_stops(),
            tokens_max: completion.tokens_max,
            completion,
            stream,
            file,
        })
    }
}

#[derive(Debug, From)]
pub enum ChatError {
    ClashingArguments(ClashingArgumentsError),
    ChatTranscriptionError(ChatTranscriptionError),
    TranscriptDeserializationError(serde_json::Error),
    OpenAIError(OpenAIError),
    NetworkError(reqwest::Error),
    IOError(std::io::Error),
    EventSource(reqwest_eventsource::Error),
    Unauthorized
}

#[derive(Debug)]
pub struct ChatTranscriptionError(pub String);

pub type ChatResult = Result<Vec<ChatMessage>, ChatError>;

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ChatMessage {
    pub role: ChatRole,
    pub content: String,
    #[serde(skip)]
    pub tokens: usize
}

impl ChatMessage {
    pub fn new(role: ChatRole, content: impl AsRef<str>) -> Self {
        let tokens = p50k_base().unwrap()
            .encode_with_special_tokens(&format!("{}{}", role, content.as_ref()))
            .len();

        ChatMessage {
            role,
            content: content.as_ref().to_string(),
            tokens
        }
    }
}

pub type ChatMessages = Vec<ChatMessage>;

impl TryFrom<&ChatOptions> for ChatMessages {
    type Error = ChatError;

    fn try_from(options: &ChatOptions) -> Result<Self, Self::Error> {
        let ChatOptions { file, system, .. } = options;

        let mut messages = vec![];
        let mut message: Option<ChatMessage> = None;

        messages.push(ChatMessage::new(ChatRole::System, system));

        let handle_continuing_line = |line, message: &mut Option<ChatMessage>| match message {
            Some(m) => {
                *message = Some(ChatMessage::new(m.role, {
                    let mut content = m.content.clone();
                    content += "\n";
                    content += line;
                    content
                }));
                Ok(())
            },
            None => {
                return Err(ChatError::ChatTranscriptionError(ChatTranscriptionError(
                    "Missing opening chat role".into()
                )));
            }
        };

        for line in file.transcript.lines() {
            match line.split_once(':') {
                Some((role, dialog)) => match ChatRole::try_from((role, options)) {
                    Ok(normalized_role) => {
                        if let Some(message) = message {
                            messages.push(message);
                        }

                        let mut dialog = dialog.trim_start().to_string();
                        if role != "ai" && role != "assitant" && role != "user" && role != "system"
                            && !dialog.to_lowercase().starts_with(role) {
                            dialog = format!("{role}: {dialog}");
                        }

                        message = Some(ChatMessage::new(normalized_role, dialog));
                    },
                    Err(_) => {
                        println!("lkjlkj");
                        handle_continuing_line(line, &mut message)?
                    }
                },
                None => handle_continuing_line(line, &mut message)?
            }
        }

        if let Some(message) = message {
            messages.push(message);
        }

        if options.no_context {
            messages.push(ChatMessage::new(ChatRole::User, file.last_read_input.clone()));
        }

        if let Some(direction) = &options.direction {
            messages.push(direction.clone());
        }

        if options.no_context {
            messages.push(ChatMessage::new(ChatRole::Ai, file.last_written_input.clone()))
        }

        let lab = messages.labotomize(&options)?;
        return Ok(lab);
    }
}

pub(crate) trait ChatMessagesInternalExt {
    fn labotomize(&self, options: &ChatOptions) -> Result<Self, ChatError> where Self: Sized;
}

impl ChatMessagesInternalExt for ChatMessages {
    fn labotomize(&self, options: &ChatOptions) -> Result<Self, ChatError> {
        let tokens_max = options.tokens_max
            .expect("The max tokens should have been assigned when creating the command");

        let tokens_balance = options.tokens_balance;
        let upper_bound = (tokens_max as f32 * tokens_balance).floor() as usize;
        let current_token_length: usize = self.iter().map(|m| m.tokens).sum();

        if current_token_length > upper_bound {
            let system = ChatMessage::new(ChatRole::System, options.system.clone());
            let mut messages = vec![];
            let mut remaining = upper_bound.checked_sub(system.tokens)
                .ok_or_else(|| ChatTranscriptionError(format!(
                    "Cannot fit your system message into the chat messages list. This means \
                    that your tokens_max value is either too small or your system message is \
                    too long. You're upper bound on transcript tokens is {upper_bound} and \
                    your system message has {} tokens", system.tokens)))?;

            for message in self.iter().skip(1).rev() {
                match remaining.checked_sub(message.tokens) {
                    Some(subtracted) => {
                        remaining = subtracted;
                        messages.push(message);
                    },
                    None => {
                        debug!(target: "AI chat", "Lobotomized message {:?}", message);
                    },
                }
            }

            messages.push(&system);
            Ok(messages.iter().rev().map(|i| i.clone()).cloned().collect())
        } else {
            Ok(self.clone())
        }
    }
}

#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
#[allow(non_camel_case_types)]
pub enum ChatProvider {
    #[default]
    OpenAiGPT3Turbo,
    OpenAiGPT3Turbo_0301,
    OpenAiGPT4,
    OpenAiGPT4_0314,
    OpenAiGPT4_32K,
    OpenAiGPT4_32K_0314
}

impl std::fmt::Display for ChatProvider {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
        write!(f, "{}", match self {
            Self::OpenAiGPT3Turbo => "gpt-3.5-turbo",
            Self::OpenAiGPT3Turbo_0301 => "gpt-3.5-turbo-0301",
            Self::OpenAiGPT4 => "gpt-4",
            Self::OpenAiGPT4_0314 => "gpt-4-0314",
            Self::OpenAiGPT4_32K => "gpt-4-32k",
            Self::OpenAiGPT4_32K_0314 => "gpt-4-32k-0314"
        })
    }
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
pub enum ChatRole {
    #[serde(rename = "assistant")]
    Ai,
    #[serde(rename = "user")]
    User,
    #[serde(rename = "system")]
    System
}

impl std::fmt::Display for ChatRole {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
        write!(f, "{}", match self {
            Self::Ai => "AI: ",
            Self::User => "USER: ",
            Self::System => "SYSTEM: "
        })
    }
}

impl TryFrom<(&str, &ChatOptions)> for ChatRole {
    type Error = ChatError;

    fn try_from((role, options): (&str, &ChatOptions)) -> Result<Self, Self::Error> {
        let role = role.to_lowercase();
        let role = role.trim();

        if role == options.prefix_ai.to_lowercase() {
            return Ok(ChatRole::Ai)
        }

        if role == options.prefix_user.to_lowercase() {
            return Ok(ChatRole::User)
        }

        match &*role {
            "ai" |
            "assistant" => Ok(ChatRole::Ai),
            "system" => Ok(ChatRole::System),
            _ => Ok(ChatRole::User),
        }
    }
}