aichat 0.18.0

All-in-one AI CLI Tool
use crate::{
    config::GlobalConfig,
    utils::{
        dimmed_text, get_env_bool, indent_text, run_command, run_command_with_output, warning_text,
    },
};

use anyhow::{anyhow, bail, Context, Result};
use fancy_regex::Regex;
use indexmap::{IndexMap, IndexSet};
use inquire::{validator::Validation, Text};
use is_terminal::IsTerminal;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{
    collections::{HashMap, HashSet},
    fs,
    io::stdout,
    path::Path,
    sync::mpsc::channel,
};
use threadpool::ThreadPool;

const BIN_DIR_NAME: &str = "bin";
const DECLARATIONS_FILE_PATH: &str = "functions.json";

lazy_static! {
    static ref THREAD_POOL: ThreadPool = ThreadPool::new(num_cpus::get());
}

pub type ToolResults = (Vec<ToolCallResult>, String);

pub fn eval_tool_calls(
    config: &GlobalConfig,
    mut calls: Vec<ToolCall>,
) -> Result<Vec<ToolCallResult>> {
    let mut output = vec![];
    if calls.is_empty() {
        return Ok(output);
    }
    calls = ToolCall::dedup(calls);
    let parallel = calls.len() > 1 && calls.iter().all(|v| !v.is_execute());
    if parallel {
        let (tx, rx) = channel();
        let calls_len = calls.len();
        for (index, call) in calls.into_iter().enumerate() {
            let tx = tx.clone();
            let config = config.clone();
            THREAD_POOL.execute(move || {
                let result = call.eval(&config);
                let _ = tx.send((index, call, result));
            });
        }
        let mut list: Vec<(usize, ToolCall, Result<Value>)> = rx.iter().take(calls_len).collect();
        list.sort_by_key(|v| v.0);
        for (_, call, result) in list {
            output.push(ToolCallResult::new(call, result?));
        }
    } else {
        for call in calls {
            let result = call.eval(config)?;
            output.push(ToolCallResult::new(call, result));
        }
    }
    Ok(output)
}

pub fn need_send_call_results(arr: &[ToolCallResult]) -> bool {
    arr.iter().any(|v| !v.output.is_null())
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallResult {
    pub call: ToolCall,
    pub output: Value,
}

impl ToolCallResult {
    pub fn new(call: ToolCall, output: Value) -> Self {
        Self { call, output }
    }
}

#[derive(Debug, Clone, Default)]
pub struct Function {
    names: IndexSet<String>,
    declarations: Vec<FunctionDeclaration>,
    #[cfg(windows)]
    bin_dir: std::path::PathBuf,
    env_path: Option<String>,
}

impl Function {
    pub fn init(functions_dir: &Path) -> Result<Self> {
        let bin_dir = functions_dir.join(BIN_DIR_NAME);
        let env_path = if bin_dir.exists() {
            prepend_env_path(&bin_dir).ok()
        } else {
            None
        };

        let declarations_file = functions_dir.join(DECLARATIONS_FILE_PATH);

        let declarations: Vec<FunctionDeclaration> = if declarations_file.exists() {
            let ctx = || {
                format!(
                    "Failed to load function declarations at {}",
                    declarations_file.display()
                )
            };
            let content = fs::read_to_string(&declarations_file).with_context(ctx)?;
            serde_json::from_str(&content).with_context(ctx)?
        } else {
            vec![]
        };

        let func_names = declarations.iter().map(|v| v.name.clone()).collect();

        Ok(Self {
            names: func_names,
            declarations,
            #[cfg(windows)]
            bin_dir,
            env_path,
        })
    }

    pub fn select(&self, matcher: Option<&str>) -> Option<Vec<FunctionDeclaration>> {
        let matcher = matcher?;
        let regex = Regex::new(&format!("^({matcher})$")).ok()?;
        let output: Vec<FunctionDeclaration> = self
            .declarations
            .iter()
            .filter(|v| regex.is_match(&v.name).unwrap_or_default())
            .cloned()
            .collect();
        if output.is_empty() {
            None
        } else {
            Some(output)
        }
    }
}

#[derive(Debug, Clone, Deserialize)]
pub struct FunctionConfig {
    pub enable: bool,
    pub declarations_file: String,
    pub functions_dir: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDeclaration {
    pub name: String,
    pub description: String,
    pub parameters: JsonSchema,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonSchema {
    #[serde(rename = "type")]
    pub type_value: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub description: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub properties: Option<IndexMap<String, JsonSchema>>,
    #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
    pub enum_value: Option<Vec<String>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub required: Option<Vec<String>>,
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct ToolCall {
    pub name: String,
    pub arguments: Value,
    pub id: Option<String>,
}

impl ToolCall {
    pub fn dedup(calls: Vec<Self>) -> Vec<Self> {
        let mut new_calls = vec![];
        let mut seen_ids = HashSet::new();

        for call in calls.into_iter().rev() {
            if let Some(id) = &call.id {
                if !seen_ids.contains(id) {
                    seen_ids.insert(id.clone());
                    new_calls.push(call);
                }
            } else {
                new_calls.push(call);
            }
        }

        new_calls.reverse();
        new_calls
    }

    pub fn new(name: String, arguments: Value, id: Option<String>) -> Self {
        Self {
            name,
            arguments,
            id,
        }
    }

    pub fn eval(&self, config: &GlobalConfig) -> Result<Value> {
        let name = self.name.clone();
        if !config.read().function.names.contains(&name) {
            bail!("Unexpected call: {name} {}", self.arguments);
        }
        let arguments = if self.arguments.is_object() {
            self.arguments.clone()
        } else if let Some(arguments) = self.arguments.as_str() {
            let args: Value = serde_json::from_str(arguments)
                .map_err(|_| anyhow!("The {name} call has invalid arguments: {arguments}"))?;
            args
        } else {
            bail!("The {name} call has invalid arguments: {}", self.arguments);
        };

        let arguments = arguments.to_string();
        let prompt = format!("Call {name} '{arguments}'",);

        let mut envs = HashMap::new();
        if let Some(env_path) = config.read().function.env_path.clone() {
            envs.insert("PATH".into(), env_path);
        };
        #[cfg(windows)]
        let name = polyfill_cmd_name(&name, &config.read().function.bin_dir);

        let output = if self.is_execute() {
            if stdout().is_terminal() {
                println!("{prompt}");
                let answer = Text::new("[1] Run, [2] Run & Retrieve, [3] Skip:")
                    .with_default("1")
                    .with_validator(|input: &str| match matches!(input, "1" | "2" | "3") {
                        true => Ok(Validation::Valid),
                        false => Ok(Validation::Invalid(
                            "Invalid input, please select 1, 2 or 3".into(),
                        )),
                    })
                    .prompt()?;
                match answer.as_str() {
                    "1" => {
                        let exit_code = run_command(&name, &[arguments], Some(envs))?;
                        if exit_code != 0 {
                            bail!("Exit {exit_code}");
                        }
                        Value::Null
                    }
                    "2" => run_and_retrieve(&name, &arguments, envs, &prompt)?,
                    _ => Value::Null,
                }
            } else {
                println!("Skipped {prompt}");
                Value::Null
            }
        } else {
            println!("{}", dimmed_text(&prompt));
            run_and_retrieve(&name, &arguments, envs, &prompt)?
        };

        Ok(output)
    }

    pub fn is_execute(&self) -> bool {
        if get_env_bool("function_auto_execute") {
            false
        } else {
            self.name.starts_with("may_") || self.name.contains("__may_")
        }
    }
}

fn run_and_retrieve(
    name: &str,
    arguments: &str,
    envs: HashMap<String, String>,
    prompt: &str,
) -> Result<Value> {
    let (success, stdout, stderr) = run_command_with_output(name, &[arguments], Some(envs))?;

    if success {
        if !stderr.is_empty() {
            eprintln!(
                "{}",
                warning_text(&format!("{prompt}:\n{}", indent_text(&stderr, 4)))
            );
        }
        let value = if !stdout.is_empty() {
            serde_json::from_str(&stdout)
                .ok()
                .unwrap_or_else(|| json!({"output": stdout}))
        } else {
            Value::Null
        };
        Ok(value)
    } else {
        let err = if stderr.is_empty() {
            if stdout.is_empty() {
                "Something wrong"
            } else {
                &stdout
            }
        } else {
            &stderr
        };
        bail!("{}", &format!("{prompt}:\n{}", indent_text(err, 4)));
    }
}

fn prepend_env_path(bin_dir: &Path) -> Result<String> {
    let current_path = std::env::var("PATH").context("No PATH environment variable")?;

    let new_path = if cfg!(target_os = "windows") {
        format!("{};{}", bin_dir.display(), current_path)
    } else {
        format!("{}:{}", bin_dir.display(), current_path)
    };
    Ok(new_path)
}

#[cfg(windows)]
fn polyfill_cmd_name(name: &str, bin_dir: &std::path::Path) -> String {
    let mut name = name.to_string();
    if let Ok(exts) = std::env::var("PATHEXT") {
        if let Some(cmd_path) = exts
            .split(';')
            .map(|ext| bin_dir.join(format!("{}{}", name, ext)))
            .find(|path| path.exists())
        {
            name = cmd_path.display().to_string();
        }
    }
    name
}