pub mod config;
pub mod encryption;
pub mod gpt;
use anyhow::Context;
use colored::*;
use config::CacheValue;
use gpt::{GPTClient, GPTQuery, GPTResponse, GPTRole};
use indicatif::{ProgressBar, ProgressStyle};
use rand::Rng;
use std::collections::VecDeque;
use std::sync::Arc;
use std::thread::JoinHandle;
use termimad::crossterm::style::Color::*;
use termimad::{rgb, MadSkin};
const LOADING_MESSAGES: [&'static str; 10] = [
"Consulting neural network...",
"Bribing data set...",
"Thinking hard...",
"Crunching 0s and 1s...",
"Cooking up a response...",
"AI is thinking...",
"Turning AI gears...",
"Consulting the AI crystal...",
"Praying for your answer...",
"Circuits tingling...",
];
fn create_spinner() -> anyhow::Result<ProgressBar> {
let spinner = ProgressBar::new_spinner();
spinner.set_style(
ProgressStyle::default_spinner()
.tick_chars("🌍🌎🌏")
.template("{msg} {spinner:.green}")
.with_context(|| format!("Failed to set template"))?,
);
let mut rng = rand::thread_rng();
let rand_idx = rng.gen_range(0..LOADING_MESSAGES.len());
let rand_msg = LOADING_MESSAGES[rand_idx];
spinner.set_message(rand_msg);
Ok(spinner)
}
fn create_skin() -> MadSkin {
let mut skin = MadSkin::default();
skin.bold.set_fg(Yellow);
skin.italic.set_fgbg(Yellow, rgb(30, 30, 40));
return skin;
}
pub fn run_query(args: Arc<config::QueryArgs>, config: Arc<config::Config>) -> anyhow::Result<()> {
let args_clone = Arc::clone(&args);
let config_clone = Arc::clone(&config);
let spinner = create_spinner()?;
let skin = create_skin();
let model = match &args.model {
Some(model) => model.clone(),
None => config.model.clone(),
};
let model_clone = Arc::new(model.clone());
let handle: JoinHandle<anyhow::Result<(GPTResponse, Vec<CacheValue>)>> =
std::thread::spawn(move || {
let gpt = GPTClient::new(&config_clone.api_key)?;
let mut query_builder = GPTQuery::builder();
query_builder.model(&model_clone);
let cache = config::utils::load_cache()?;
let context = match &args_clone.context {
Some(ctx) => *ctx,
None => config_clone.context,
};
let start = if context > cache.len() {
0
} else {
cache.len() - context
};
let context_messages = Vec::from(&cache[start..]);
for message in context_messages.iter() {
query_builder.message(GPTRole::User, &message.prompt);
query_builder.message(GPTRole::System, &message.response);
}
query_builder.message(GPTRole::User, &args_clone.query);
let query = query_builder.build()?;
let response = gpt.query(&query)?;
let mut queue_cache: VecDeque<CacheValue> = VecDeque::from(context_messages.clone());
let cache_value = CacheValue {
prompt: args_clone.query.to_string(),
response: response.choices[0].message.content.to_string(),
};
queue_cache.push_back(cache_value);
if queue_cache.len() > config_clone.cache_length {
let diff = queue_cache.len() - config_clone.cache_length;
for _ in 0..diff {
queue_cache.pop_front();
}
}
let updated_cache = Vec::from(queue_cache);
let cache_size = updated_cache.len();
config::utils::save_cache(updated_cache)?;
println!(
"{}",
format!(
"Cache capacity {}/{}",
cache_size, config_clone.cache_length
)
.green()
);
Ok((response, context_messages))
});
while !handle.is_finished() {
spinner.tick();
std::thread::sleep(std::time::Duration::from_millis(200));
}
match handle.join() {
Ok(result) => match result {
Ok((response, context_messages)) => {
if args.show_context {
for message in context_messages {
println!("{}:\n{}", "You said".yellow(), message.prompt);
println!(
"{}:\n{}",
"GPT said".magenta(),
skin.term_text(&message.response)
)
}
println!("{}:\n{}", "You said".yellow(), args.query);
}
println!("");
println!(
"{}",
format!("Response from {}", response.model.magenta()).cyan()
);
println!(
"{}",
skin.term_text(response.choices[0].message.content.as_str())
);
if args.cost {
println!(
"{}: ${:.6}",
"Cost".green(),
response.usage.total_cost(&model)
);
}
}
Err(e) => return Err(anyhow::anyhow!(e)),
},
Err(_) => return Err(anyhow::anyhow!("Thread failed to execute!")),
}
Ok(())
}