use std::io;
use std::io::Write;
use anyhow::Result;
use aws_sdk_bedrockruntime::types::ContentBlock;
use aws_sdk_bedrockruntime::types::ConversationRole;
use aws_sdk_bedrockruntime::types::InferenceConfiguration;
use aws_sdk_bedrockruntime::types::Message;
use bedrust::config;
use bedrust::utils;
use colored::*;
use dialoguer::{theme::ColorfulTheme, FuzzySelect};
use bedrust::configure_aws;
use bedrust::utils::prompt_for_model_selection;
use bedrust::captioner::caption_process;
use bedrust::chat::{
list_chat_histories, load_chat_history, print_conversation_history, save_chat_history,
ConversationHistory,
};
use bedrust::utils::{check_for_config, print_warning};
use clap::Parser;
use bedrust::code::code_chat_process;
use bedrust::models::converse_stream::call_converse_stream;
#[tokio::main]
async fn main() -> Result<()> {
let arguments = utils::Args::parse();
if arguments.init {
config::prompt_init_config()?;
}
if !check_for_config()? {
print_warning("****************************************");
print_warning("WARNING:");
println!("Your Bedrust configuration files are not set up correctly.");
println!("To use Bedrust you need the appropriate `bedrust_config.ron and `model_config.ron` in your $HOME/.config/bedrust/ directory.");
println!("You can configure the application by running `bedrust --init`");
print_warning("****************************************");
print_warning("Bedrust will now exit");
std::process::exit(1);
}
let bedrust_config = utils::load_bedrust_config()?;
let config = configure_aws(String::from("us-east-1"), &bedrust_config.aws_profile).await;
let bedrock_runtime_client = aws_sdk_bedrockruntime::Client::new(&config);
let bedrock_client = aws_sdk_bedrock::Client::new(&config);
let model_id = arguments.model_id.or(bedrust_config.default_model);
let model_id = match model_id {
Some(model_id) => model_id,
None => prompt_for_model_selection()?,
}
.to_str();
let inference_parameters = InferenceConfiguration::builder()
.max_tokens(bedrust_config.inference_params.max_tokens)
.top_p(bedrust_config.inference_params.top_p)
.temperature(bedrust_config.inference_params.temperature)
.build();
let system_prompt = bedrust_config.system_prompt.as_deref().unwrap_or("You are a helpful assistant");
if arguments.caption.is_some() {
caption_process(
model_id,
&bedrock_client,
&bedrock_runtime_client,
arguments.caption,
&bedrust_config,
arguments.xml,
)
.await?;
} else {
utils::hello_header("Bedrust")?;
let mut message_count = 0;
let mut conversation_history = ConversationHistory::new(None, None, None, None);
let mut current_file: Option<String> = None;
let code: Option<String> = match arguments.source {
Some(ref source_path) => {
Some(code_chat_process(source_path.to_path_buf(), &bedrock_runtime_client).await?)
}
None => None,
};
loop {
println!("----------------------------------------");
println!("🤖 | What would you like to know today?");
print!("😎 | Human: ");
io::stdout().flush()?;
let mut question = String::new();
io::stdin().read_line(&mut question)?;
message_count += 1;
let question = question.trim();
if question.is_empty() {
println!("Please enter a question.");
continue;
}
if question == "/q" {
println!("Bye!");
break;
} else if question == "/h" {
conversation_history.save_as_html()?;
continue;
} else if question == "/c" {
println!("Clearing current chat history");
conversation_history.clear();
current_file = None;
continue;
} else if question == "/s" {
let filename = if let Some(ref file) = current_file {
save_chat_history(
Some(file),
&bedrock_runtime_client,
&mut conversation_history,
)
.await?
} else {
match save_chat_history(
None,
&bedrock_runtime_client,
&mut conversation_history,
)
.await
{
Ok(name) => {
current_file = Some(name.clone());
name
}
Err(e) => {
eprintln!("Error saving chat history: {}", e);
continue;
}
}
};
println!("Chat history saved to: {}", filename.cyan());
continue;
} else if question == "/r" {
match list_chat_histories() {
Ok(histories) => {
if histories.is_empty() {
println!("No chat histories found.");
continue;
}
let selection = FuzzySelect::with_theme(&ColorfulTheme::default())
.with_prompt("Select a chat history to recall:")
.default(0)
.items(&histories[..])
.interact()
.unwrap();
let selected_history = &histories[selection];
match load_chat_history(selected_history) {
Ok((content, filename, existing_title, summary)) => {
conversation_history.messages = Some(content);
conversation_history.title = Some(existing_title.clone());
conversation_history.summary = Some(summary.clone());
current_file = Some(filename);
utils::print_warning("----------------------------------------");
println!("Loaded chat history from: {}", selected_history.yellow());
println!();
println!("Loaded chat summary: ");
println!("{}", summary);
print_conversation_history(&conversation_history);
println!("You can now continue the conversation.");
}
Err(e) => eprintln!("Error loading chat history: {}", e),
}
}
Err(e) => eprintln!("Error listing chat histories: {}", e),
}
continue;
} else if question.starts_with('/') {
utils::print_warning("Special command detected: /");
utils::print_warning("----------------------------------------");
utils::print_warning("Currently supported chat commands: ");
utils::print_warning("/c\t \t - Clear current chat history");
utils::print_warning("/s\t \t - (BETA) Save chat history");
utils::print_warning("/r\t \t - (BETA) Recall and load a chat history");
utils::print_warning(
"/h\t \t - (BETA) Export history as HTML(saves in current dir)",
);
utils::print_warning("/q\t \t - Quit");
continue;
}
let message = if arguments.source.is_some() && message_count == 1 {
let question_with_code = code
.as_ref()
.map(|src_code| format!("{}\n<question>{}</question>", src_code, question))
.unwrap_or_else(|| question.to_string());
Message::builder()
.set_role(Some(ConversationRole::User))
.set_content(Some(vec![ContentBlock::Text(question_with_code)]))
.build()?
} else {
Message::builder()
.set_role(Some(ConversationRole::User))
.set_content(Some(vec![ContentBlock::Text(question.to_string())]))
.build()?
};
let mut messages = conversation_history.messages.unwrap_or_default().clone();
messages.push(message.into());
conversation_history.messages = Some(messages);
println!("----------------------------------------");
println!("☎️ | Calling Model: {}", &model_id);
println!("----------------------------------------");
let streamresp = call_converse_stream(
&bedrock_runtime_client,
model_id.to_string(),
&conversation_history,
inference_parameters.clone(),
system_prompt
)
.await?;
let message = Message::builder()
.set_role(Some(ConversationRole::Assistant))
.set_content(Some(vec![ContentBlock::Text(streamresp.to_string())]))
.build()?;
let mut messages = conversation_history.messages.unwrap();
messages.push(message.into());
conversation_history.messages = Some(messages);
}
}
Ok(())
}