aichat 0.18.0

All-in-one AI CLI Tool
use super::{role::Role, session::Session, GlobalConfig};

use crate::client::{
    init_client, list_models, Client, CompletionData, ImageUrl, Message, MessageContent,
    MessageContentPart, MessageRole, Model,
};
use crate::function::{ToolCallResult, ToolResults};
use crate::utils::{base64_encode, sha256};

use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use lazy_static::lazy_static;
use mime_guess::from_path;
use std::{
    collections::HashMap,
    fs::File,
    io::Read,
    path::{Path, PathBuf},
};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};

const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];

lazy_static! {
    static ref URL_RE: Regex = Regex::new(r"^[A-Za-z0-9_-]{2,}:/").unwrap();
}

#[derive(Debug, Clone)]
pub struct Input {
    config: GlobalConfig,
    text: String,
    medias: Vec<String>,
    data_urls: HashMap<String, String>,
    tool_call: Option<ToolResults>,
    context: InputContext,
}

impl Input {
    pub fn from_str(config: &GlobalConfig, text: &str, context: Option<InputContext>) -> Self {
        Self {
            config: config.clone(),
            text: text.to_string(),
            medias: Default::default(),
            data_urls: Default::default(),
            tool_call: None,
            context: context.unwrap_or_else(|| InputContext::from_config(config)),
        }
    }

    pub fn new(
        config: &GlobalConfig,
        text: &str,
        files: Vec<String>,
        context: Option<InputContext>,
    ) -> Result<Self> {
        let mut texts = vec![text.to_string()];
        let mut medias = vec![];
        let mut data_urls = HashMap::new();
        let files: Vec<_> = files
            .iter()
            .map(|f| (f, is_image_ext(Path::new(f))))
            .collect();
        let include_filepath = files.iter().filter(|(_, is_image)| !*is_image).count() > 1;
        for (file_item, is_image) in files {
            match resolve_local_file(file_item) {
                Some(file_path) => {
                    if is_image {
                        let data_url = read_media_to_data_url(&file_path)
                            .with_context(|| format!("Unable to read media file '{file_item}'"))?;
                        data_urls.insert(sha256(&data_url), file_path.display().to_string());
                        medias.push(data_url)
                    } else {
                        let text = read_file(&file_path)
                            .with_context(|| format!("Unable to read file '{file_item}'"))?;
                        if include_filepath {
                            texts.push(format!("`{file_item}`:\n~~~~~~\n{text}\n~~~~~~"));
                        } else {
                            texts.push(text);
                        }
                    }
                }
                None => {
                    if is_image {
                        medias.push(file_item.to_string())
                    } else {
                        bail!("Unable to use remote file '{file_item}");
                    }
                }
            }
        }

        Ok(Self {
            config: config.clone(),
            text: texts.join("\n"),
            medias,
            data_urls,
            tool_call: Default::default(),
            context: context.unwrap_or_else(|| InputContext::from_config(config)),
        })
    }

    pub fn is_empty(&self) -> bool {
        self.text.is_empty() && self.medias.is_empty()
    }

    pub fn data_urls(&self) -> HashMap<String, String> {
        self.data_urls.clone()
    }

    pub fn text(&self) -> String {
        self.text.clone()
    }

    pub fn set_text(&mut self, text: String) {
        self.text = text;
    }

    pub fn merge_tool_call(
        mut self,
        output: String,
        tool_call_results: Vec<ToolCallResult>,
    ) -> Self {
        match self.tool_call.as_mut() {
            Some(exist_tool_call_results) => {
                exist_tool_call_results.0.extend(tool_call_results);
                exist_tool_call_results.1 = output;
            }
            None => self.tool_call = Some((tool_call_results, output)),
        }
        self
    }

    pub fn model(&self) -> Model {
        let model = self.config.read().model.clone();
        if let Some(model_id) = self.role().and_then(|v| v.model_id.clone()) {
            if model.id() != model_id {
                if let Some(model) = list_models(&self.config.read())
                    .into_iter()
                    .find(|v| v.id() == model_id)
                {
                    return model.clone();
                }
            }
        };
        model
    }

    pub fn create_client(&self) -> Result<Box<dyn Client>> {
        init_client(&self.config, Some(self.model()))
    }

    pub fn prepare_completion_data(&self, model: &Model, stream: bool) -> Result<CompletionData> {
        if !self.medias.is_empty() && !model.supports_vision() {
            bail!("The current model does not support vision.");
        }
        let messages = self.build_messages()?;
        self.config.read().model.max_input_tokens_limit(&messages)?;
        let (temperature, top_p) = if let Some(session) = self.session(&self.config.read().session)
        {
            (session.temperature(), session.top_p())
        } else if let Some(role) = self.role() {
            (role.temperature, role.top_p)
        } else {
            let config = self.config.read();
            (config.temperature, config.top_p)
        };
        let mut functions = None;
        if self.config.read().function_calling && model.supports_function_calling() {
            let config = self.config.read();
            let function_matcher = if let Some(session) = self.session(&config.session) {
                session.function_matcher()
            } else if let Some(role) = self.role() {
                role.function_matcher.as_deref()
            } else {
                None
            };
            functions = config.function.select(function_matcher);
        };
        Ok(CompletionData {
            messages,
            temperature,
            top_p,
            functions,
            stream,
        })
    }

    pub fn build_messages(&self) -> Result<Vec<Message>> {
        let mut messages = if let Some(session) = self.session(&self.config.read().session) {
            session.build_messages(self)
        } else if let Some(role) = self.role() {
            role.build_messages(self)
        } else {
            vec![Message::new(MessageRole::User, self.message_content())]
        };
        if let Some(tool_results) = &self.tool_call {
            messages.push(Message::new(
                MessageRole::Assistant,
                MessageContent::ToolResults(tool_results.clone()),
            ))
        }
        Ok(messages)
    }

    pub fn echo_messages(&self) -> String {
        if let Some(session) = self.session(&self.config.read().session) {
            session.echo_messages(self)
        } else if let Some(role) = self.role() {
            role.echo_messages(self)
        } else {
            self.render()
        }
    }

    pub fn role(&self) -> Option<&Role> {
        self.context.role.as_ref()
    }

    pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
        if self.context.session {
            session.as_ref()
        } else {
            None
        }
    }

    pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
        if self.context.session {
            session.as_mut()
        } else {
            None
        }
    }

    pub fn summary(&self) -> String {
        let text: String = self
            .text
            .trim()
            .chars()
            .map(|c| if c.is_control() { ' ' } else { c })
            .collect();
        if text.width_cjk() > 70 {
            let mut sum_width = 0;
            let mut chars = vec![];
            for c in text.chars() {
                sum_width += c.width_cjk().unwrap_or(1);
                if sum_width > 67 {
                    chars.extend(['.', '.', '.']);
                    break;
                }
                chars.push(c);
            }
            chars.into_iter().collect()
        } else {
            text
        }
    }

    pub fn render(&self) -> String {
        if self.medias.is_empty() {
            return self.text.clone();
        }
        let text = if self.text.is_empty() {
            self.text.to_string()
        } else {
            format!(" -- {}", self.text)
        };
        let files: Vec<String> = self
            .medias
            .iter()
            .cloned()
            .map(|url| resolve_data_url(&self.data_urls, url))
            .collect();
        format!(".file {}{}", files.join(" "), text)
    }

    pub fn message_content(&self) -> MessageContent {
        if self.medias.is_empty() {
            MessageContent::Text(self.text.clone())
        } else {
            let mut list: Vec<MessageContentPart> = self
                .medias
                .iter()
                .cloned()
                .map(|url| MessageContentPart::ImageUrl {
                    image_url: ImageUrl { url },
                })
                .collect();
            if !self.text.is_empty() {
                list.insert(
                    0,
                    MessageContentPart::Text {
                        text: self.text.clone(),
                    },
                );
            }
            MessageContent::Array(list)
        }
    }
}

#[derive(Debug, Clone, Default)]
pub struct InputContext {
    role: Option<Role>,
    session: bool,
}

impl InputContext {
    pub fn new(role: Option<Role>, session: bool) -> Self {
        Self { role, session }
    }

    pub fn from_config(config: &GlobalConfig) -> Self {
        let config = config.read();
        InputContext::new(config.role.clone(), config.session.is_some())
    }

    pub fn role(role: Role) -> Self {
        Self {
            role: Some(role),
            session: false,
        }
    }
}

pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
    if data_url.starts_with("data:") {
        let hash = sha256(&data_url);
        if let Some(path) = data_urls.get(&hash) {
            return path.to_string();
        }
        data_url
    } else {
        data_url
    }
}

fn resolve_local_file(file: &str) -> Option<PathBuf> {
    if let Ok(true) = URL_RE.is_match(file) {
        return None;
    }
    let path = if let (Some(file), Some(home)) = (file.strip_prefix("~/"), dirs::home_dir()) {
        home.join(file)
    } else {
        std::env::current_dir().ok()?.join(file)
    };
    Some(path)
}

fn is_image_ext(path: &Path) -> bool {
    path.extension()
        .map(|v| {
            IMAGE_EXTS
                .iter()
                .any(|ext| *ext == v.to_string_lossy().to_lowercase())
        })
        .unwrap_or_default()
}

fn read_media_to_data_url<P: AsRef<Path>>(image_path: P) -> Result<String> {
    let image_path = image_path.as_ref();

    let mime_type = from_path(image_path).first_or_octet_stream().to_string();
    let mut file = File::open(image_path)?;
    let mut buffer = Vec::new();
    file.read_to_end(&mut buffer)?;

    let encoded_image = base64_encode(buffer);
    let data_url = format!("data:{};base64,{}", mime_type, encoded_image);

    Ok(data_url)
}

fn read_file<P: AsRef<Path>>(file_path: P) -> Result<String> {
    let file_path = file_path.as_ref();

    let mut text = String::new();
    let mut file = File::open(file_path)?;
    file.read_to_string(&mut text)?;
    Ok(text)
}