use crate::models::converse::call_converse;
use anyhow::anyhow;
use aws_sdk_bedrockruntime::types::{ContentBlock, ConversationRole, Message};
use dialoguer::Confirm;
use rand::distr::Alphanumeric;
use rand::{rng, Rng};
use crate::utils::print_warning;
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
fs,
io::{self, Write},
};
use regex::Regex;
use handlebars::{
Handlebars,
};
use convert_case::{Case, Casing};
use colored::*;
use chrono::prelude::*;
use dirs::home_dir;
use crate::constants;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SerializableMessage {
pub role: String,
pub content: Vec<String>,
}
impl From<Message> for SerializableMessage {
fn from(message: Message) -> Self {
SerializableMessage {
role: message.role().as_str().to_string(),
content: vec![message
.content()
.iter()
.find_map(|block| {
if let ContentBlock::Text(text) = block {
Some(text.to_string())
} else {
None
}
})
.unwrap()],
}
}
}
impl From<SerializableMessage> for Message {
fn from(serializable: SerializableMessage) -> Self {
Message::builder()
.role(ConversationRole::from(serializable.role.as_str()))
.set_content(Some(
serializable
.content
.into_iter()
.map(ContentBlock::Text)
.collect(),
))
.build()
.unwrap()
}
}
#[derive(Debug, Deserialize, Serialize)]
pub enum ConversationEntity {
User,
Assistant,
}
impl ConversationEntity {
pub fn to_str(&self) -> &'static str {
match self {
ConversationEntity::User => "user",
ConversationEntity::Assistant => "assistant",
}
}
}
impl Display for ConversationEntity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConversationEntity::User => write!(f, "User"),
ConversationEntity::Assistant => write!(f, "Assistant"),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Conversation {
pub role: ConversationEntity,
pub content: String,
}
impl Conversation {
pub fn new(role: ConversationEntity, content: String) -> Conversation {
Conversation { role, content }
}
}
impl Display for Conversation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.role, self.content)
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Content {
pub text: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ConversationHistory {
pub title: Option<String>,
pub filename: Option<String>,
pub summary: Option<String>,
pub messages: Option<Vec<SerializableMessage>>,
pub timestamp: String,
}
impl ConversationHistory {
pub fn new(
title: Option<String>,
filename: Option<String>,
summary: Option<String>,
messages: Option<Vec<SerializableMessage>>,
) -> ConversationHistory {
let local = Local::now().format("%Y-%m-%d %H:%M"); ConversationHistory {
title,
filename,
summary,
messages,
timestamp: local.to_string(),
}
}
pub fn to_messages_string(&self) -> String {
match &self.messages {
Some(messages) => messages
.iter()
.map(|msg| format!("{}:{}", msg.role, msg.content.join("\n")))
.collect::<Vec<String>>()
.join("\n\n"),
None => String::new(),
}
}
pub fn save_as_html(&self) -> Result<(), anyhow::Error> {
let mut handlebars = Handlebars::new();
handlebars.register_helper(
"nl2br_with_code",
Box::new(
|h: &handlebars::Helper,
_: &handlebars::Handlebars,
_: &handlebars::Context,
_: &mut handlebars::RenderContext,
out: &mut dyn handlebars::Output| {
if let Some(value) = h.param(0) {
let text = if value.value().is_array() {
value
.value()
.as_array()
.unwrap()
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join("\n")
} else {
value.value().as_str().unwrap_or("").to_string()
};
let (p1, p2) = ("<bedrust_be", "gin_source>");
let (p3, p4) = ("</bedrust_en", "d_source>");
let pattern = format!(r"{}{}\s*[\s\S]*?\s*{}{}",
regex::escape(p1),
regex::escape(p2),
regex::escape(p3),
regex::escape(p4)
);
let source_code_regex = Regex::new(&pattern).unwrap();
let text_without_source = source_code_regex.replace_all(&text,
r#"<div class="source-removed"><div class="source-removed-content">ℹ️ <span>The source code has been removed from the export</span></div></div>"#
);
if text_without_source.starts_with("<pre><code") &&
text_without_source.ends_with("</code></pre>") {
out.write(&text_without_source)?;
return Ok(());
}
let mut last_pos = 0;
let mut result = String::new();
let code_block_regex = Regex::new(r"```(\w*)\n([\s\S]*?)\n```").unwrap();
let mut positions = Vec::new();
for cap in code_block_regex.captures_iter(&text_without_source) {
let start = cap.get(0).unwrap().start();
let end = cap.get(0).unwrap().end();
positions.push((
start,
end,
cap.get(1).unwrap().as_str(),
cap.get(2).unwrap().as_str(),
));
}
let inline_code_regex = Regex::new(r"`([^`]+)`").unwrap();
for (start, end, lang, code) in positions {
let before_text = &text_without_source[last_pos..start];
let processed_before =
process_inline_code(before_text, &inline_code_regex);
result.push_str(&processed_before.replace("\n", "<br>"));
result.push_str(&format!(
r#"<pre><code class="language-{}">{}</code></pre>"#,
if lang.is_empty() { "plaintext" } else { lang },
html_escape::encode_text(code)
));
last_pos = end;
}
if last_pos < text_without_source.len() {
let remaining = &text_without_source[last_pos..];
let processed_remaining =
process_inline_code(remaining, &inline_code_regex);
result.push_str(&processed_remaining.replace("\n", "<br>"));
}
out.write(&result)?;
}
Ok(())
},
),
);
handlebars.register_helper(
"format_title",
Box::new(
|h: &handlebars::Helper,
_: &handlebars::Handlebars,
_: &handlebars::Context,
_: &mut handlebars::RenderContext,
out: &mut dyn handlebars::Output| {
if let Some(value) = h.param(0) {
if let Some(text) = value.value().as_str() {
let formatted = text.to_case(Case::Title);
out.write(&formatted)?;
}
}
Ok(())
},
),
);
match handlebars.register_template_string("chat_export", crate::constants::HTML_TW_TEMPLATE)
{
Ok(_) => {
match handlebars.render("chat_export", &self) {
Ok(render) => {
std::fs::write("conversation.html", render)?;
println!("Succesfully saved the conversation to conversation.html");
}
Err(e) => eprintln!(
"Error: Something went wrong with rendering the HTML template: {}",
e
),
};
}
Err(e) => eprintln!(
"Error: Something went wrong with Registering the template: {}",
e
),
};
Ok(())
}
pub fn clear(&self) -> Self {
let local: DateTime<Local> = Local::now(); ConversationHistory {
title: None,
filename: None,
summary: None,
messages: None,
timestamp: local.to_string(),
}
}
async fn generate_title(
&self,
client: &aws_sdk_bedrockruntime::Client,
) -> Result<String, anyhow::Error> {
let messages_str = &self.to_messages_string();
let query = constants::CONVERSATION_TITLE_PROMPT.replace("{}", messages_str);
let model_id = constants::CONVERSATION_HISTORY_MODEL_ID;
let content = ContentBlock::Text(query);
println!("⏳ | Generating a new file name for this conversation... ");
let max_retries = 3;
let mut retry_count = 0;
while retry_count < max_retries {
match call_converse(
client,
model_id.to_string(),
constants::CONVERSATION_HISTORY_TITLE_INF_PARAMS.clone(),
content.clone(),
None,
false,
)
.await
{
Ok(response) => {
println!("✅ | Done ");
return Ok(response);
}
Err(e) => {
println!("🔴 | Error: {}", e);
retry_count += 1;
}
}
if retry_count >= max_retries {
return Err(anyhow!(
"Failed to get a response after {} retries",
max_retries
));
}
tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(retry_count))).await;
}
Err(anyhow!("Unexpected error in generate_title"))
}
async fn generate_summary(
&self,
client: &aws_sdk_bedrockruntime::Client,
) -> Result<String, anyhow::Error> {
let messages_str = &self.to_messages_string();
let query = constants::CONVERSATION_SUMMARY_PROMPT.replace("{}", messages_str);
let model_id = constants::CONVERSATION_HISTORY_MODEL_ID;
let content = ContentBlock::Text(query);
println!("⏳ | Generating a summary for this conversation... ");
println!();
let max_retries = 3;
let mut retry_count = 0;
while retry_count < max_retries {
match call_converse(
client,
model_id.to_string(),
constants::CONVERSATION_HISTORY_INF_PARAMS.clone(),
content.clone(),
None,
false,
)
.await
{
Ok(response) => return Ok(response),
Err(e) => {
println!("🔴 | Error: {}", e);
retry_count += 1;
}
}
if retry_count >= max_retries {
return Err(anyhow!(
"Failed to get a response after {} retries",
max_retries
));
}
tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(retry_count))).await;
}
Err(anyhow!("Unexpected error in generate_summary"))
}
}
pub async fn save_chat_history(
filename: Option<&str>,
client: &aws_sdk_bedrockruntime::Client,
ch: &mut ConversationHistory,
) -> Result<String, anyhow::Error> {
let home_dir = home_dir().expect("Failed to get HOME directory");
let save_dir = home_dir.join(format!(".config/{}/chats", constants::CONFIG_DIR_NAME));
fs::create_dir_all(&save_dir)?;
ch.summary = Some(ch.generate_summary(client).await?);
let (filename, file_path) = if let Some(existing_filename) = filename {
(
existing_filename.to_string(),
save_dir.join(existing_filename),
)
} else {
let title = ch.generate_title(client).await?;
let random_string: String = rng()
.sample_iter(Alphanumeric) .take(5)
.map(char::from) .collect();
let new_filename = format!("{}-{}.json", title, random_string);
ch.title = Some(title.clone());
ch.filename = Some(new_filename.clone());
(new_filename.clone(), save_dir.join(&new_filename))
};
fs::write(&file_path, serde_json::to_string_pretty(&ch)?)?;
Ok(filename)
}
pub fn load_chat_history(
filename: &str,
) -> Result<(Vec<SerializableMessage>, String, String, String), anyhow::Error> {
let home_dir = home_dir().expect("Failed to get HOME directory");
let chat_dir = home_dir.join(format!(".config/{}/chats", constants::CONFIG_DIR_NAME));
let file_path = chat_dir.join(filename);
let content = fs::read_to_string(file_path)?;
let ch = serde_json::from_str::<ConversationHistory>(content.as_str())?;
Ok((
ch.messages.unwrap(), filename.to_string(),
ch.title.expect("NO_TITLE").to_string(),
ch.summary.expect("NO_SUMMARY"),
))
}
pub fn print_conversation_history(history: &ConversationHistory) {
const MAX_CHARACTERS_WITHOUT_PROMPT: usize = 1000;
print_warning("----------------------------------------");
let confirmation = Confirm::new()
.with_prompt("Do you want to print the conversation history?")
.interact()
.unwrap();
if confirmation {
let history = history.to_messages_string();
print_warning("----------------------------------------");
println!("Conversation history: ");
if history.len() > MAX_CHARACTERS_WITHOUT_PROMPT {
println!(
"This conversation history is very long ({} characters).",
history.len()
);
print!("Do you want to display the entire history? (y/n): ");
io::stdout().flush().unwrap();
let mut user_input = String::new();
io::stdin().read_line(&mut user_input).unwrap();
if user_input.trim().to_lowercase() == "y" {
println!("{}", history.yellow());
} else {
println!(
"Displaying first {} characters:",
MAX_CHARACTERS_WITHOUT_PROMPT
);
println!("{}", &history[..MAX_CHARACTERS_WITHOUT_PROMPT].yellow());
println!("... (truncated)");
}
} else {
println!("{}", history.yellow());
}
}
}
pub fn list_chat_histories() -> Result<Vec<String>, anyhow::Error> {
let home_dir = home_dir().expect("Failed to get HOME directory");
let chat_dir = home_dir.join(format!(".config/{}/chats", constants::CONFIG_DIR_NAME));
let mut chat_files = Vec::new();
for entry in fs::read_dir(chat_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Some(filename) = path.file_name().and_then(|s| s.to_str()) {
chat_files.push(filename.to_string());
}
}
}
chat_files.sort_by(|a, b| b.cmp(a)); Ok(chat_files)
}
fn process_inline_code(text: &str, regex: &Regex) -> String {
let mut result = String::new();
let mut last_pos = 0;
for cap in regex.captures_iter(text) {
let full_match = cap.get(0).unwrap();
let code_content = cap.get(1).unwrap();
result.push_str(&text[last_pos..full_match.start()]);
result.push_str(&format!(
r#"<code class="language-plaintext inline-code px-1 py-0.5 rounded bg-gray-100 text-sm font-mono">{}</code>"#,
html_escape::encode_text(code_content.as_str())
));
last_pos = full_match.end();
}
if last_pos < text.len() {
result.push_str(&text[last_pos..]);
}
result
}