aichat 0.18.0

All-in-one AI CLI Tool
use super::message::{Message, MessageContent};

use crate::utils::{estimate_token_length, format_option_value};

use anyhow::{bail, Result};
use serde::Deserialize;

const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;

#[derive(Debug, Clone)]
pub struct Model {
    client_name: String,
    data: ModelData,
}

impl Default for Model {
    fn default() -> Self {
        Model::new("", "")
    }
}

impl Model {
    pub fn new(client_name: &str, name: &str) -> Self {
        Self {
            client_name: client_name.into(),
            data: ModelData::new(name),
        }
    }

    pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec<Self> {
        models
            .iter()
            .map(|v| Model {
                client_name: client_name.to_string(),
                data: v.clone(),
            })
            .collect()
    }

    pub fn find(models: &[&Self], value: &str) -> Option<Self> {
        let mut model = None;
        let (client_name, model_name) = match value.split_once(':') {
            Some((client_name, model_name)) => {
                if model_name.is_empty() {
                    (client_name, None)
                } else {
                    (client_name, Some(model_name))
                }
            }
            None => (value, None),
        };
        match model_name {
            Some(model_name) => {
                if let Some(found) = models.iter().find(|v| v.id() == value) {
                    model = Some((*found).clone());
                } else if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
                    let mut found = (*found).clone();
                    found.data.name = model_name.to_string();
                    model = Some(found)
                }
            }
            None => {
                if let Some(found) = models.iter().find(|v| v.client_name == client_name) {
                    model = Some((*found).clone());
                }
            }
        }
        model
    }

    pub fn id(&self) -> String {
        format!("{}:{}", self.client_name, self.data.name)
    }

    pub fn client_name(&self) -> &str {
        &self.client_name
    }

    pub fn name(&self) -> &str {
        &self.data.name
    }

    pub fn data(&self) -> &ModelData {
        &self.data
    }

    pub fn data_mut(&mut self) -> &mut ModelData {
        &mut self.data
    }

    pub fn description(&self) -> String {
        let ModelData {
            max_input_tokens,
            max_output_tokens,
            input_price,
            output_price,
            supports_vision,
            supports_function_calling,
            ..
        } = &self.data;
        let max_input_tokens = format_option_value(max_input_tokens);
        let max_output_tokens = format_option_value(max_output_tokens);
        let input_price = format_option_value(input_price);
        let output_price = format_option_value(output_price);
        let mut capabilities = vec![];
        if *supports_vision {
            capabilities.push('👁');
        };
        if *supports_function_calling {
            capabilities.push('');
        };
        let capabilities: String = capabilities
            .into_iter()
            .map(|v| format!("{v} "))
            .collect::<Vec<String>>()
            .join("");
        format!(
            "{:>8} / {:>8}  |  {:>6} / {:>6}  {:>6}",
            max_input_tokens, max_output_tokens, input_price, output_price, capabilities
        )
    }

    pub fn max_input_tokens(&self) -> Option<usize> {
        self.data.max_input_tokens
    }

    pub fn max_output_tokens(&self) -> Option<isize> {
        self.data.max_output_tokens
    }

    pub fn supports_vision(&self) -> bool {
        self.data.supports_vision
    }

    pub fn supports_function_calling(&self) -> bool {
        self.data.supports_function_calling
    }

    pub fn max_tokens_param(&self) -> Option<isize> {
        if self.data.pass_max_tokens {
            self.data.max_output_tokens
        } else {
            None
        }
    }

    pub fn set_max_tokens(
        &mut self,
        max_output_tokens: Option<isize>,
        pass_max_tokens: bool,
    ) -> &mut Self {
        match max_output_tokens {
            None | Some(0) => self.data.max_output_tokens = None,
            _ => self.data.max_output_tokens = max_output_tokens,
        }
        self.data.pass_max_tokens = pass_max_tokens;
        self
    }

    pub fn messages_tokens(&self, messages: &[Message]) -> usize {
        messages
            .iter()
            .map(|v| match &v.content {
                MessageContent::Text(text) => estimate_token_length(text),
                MessageContent::Array(_) => 0,
                MessageContent::ToolResults(_) => 0,
            })
            .sum()
    }

    pub fn total_tokens(&self, messages: &[Message]) -> usize {
        if messages.is_empty() {
            return 0;
        }
        let num_messages = messages.len();
        let message_tokens = self.messages_tokens(messages);
        if messages[num_messages - 1].role.is_user() {
            num_messages * PER_MESSAGES_TOKENS + message_tokens
        } else {
            (num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
        }
    }

    pub fn max_input_tokens_limit(&self, messages: &[Message]) -> Result<()> {
        let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
        if let Some(max_input_tokens) = self.data.max_input_tokens {
            if total_tokens >= max_input_tokens {
                bail!("Exceed max input tokens limit")
            }
        }
        Ok(())
    }
}

#[derive(Debug, Clone, Default, Deserialize)]
pub struct ModelData {
    pub name: String,
    pub max_input_tokens: Option<usize>,
    pub max_output_tokens: Option<isize>,
    #[serde(default)]
    pub pass_max_tokens: bool,
    pub input_price: Option<f64>,
    pub output_price: Option<f64>,
    #[serde(default)]
    pub supports_vision: bool,
    #[serde(default)]
    pub supports_function_calling: bool,
}

impl ModelData {
    pub fn new(name: &str) -> Self {
        Self {
            name: name.to_string(),
            ..Default::default()
        }
    }
}

#[derive(Debug, Clone, Deserialize)]
pub struct BuiltinModels {
    pub platform: String,
    pub models: Vec<ModelData>,
}