use async_trait::async_trait;
use reqwest::{header::{AUTHORIZATION, CONTENT_TYPE}, Response};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use secretary_derive::Task;
use crate::{SecretaryError, message::Message};
#[async_trait]
pub trait IsLLM {
fn send_message(
&self,
message: Message,
return_json: bool,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
let authorization_credentials: (String, String) = self.get_authorization_credentials();
let request: reqwest::blocking::Response = reqwest::blocking::Client::new()
.post(self.get_chat_completion_request_url())
.header(AUTHORIZATION, authorization_credentials.1)
.header(CONTENT_TYPE, "application/json")
.json(&self.get_reqeust_body(message, return_json))
.send()?;
Ok(request.text()?)
}
async fn async_send_message(
&self,
message: Message,
return_json: bool,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
let authorization_credentials: (String, String) = self.get_authorization_credentials();
let request: Response = reqwest::Client::new()
.post(self.get_chat_completion_request_url())
.header(AUTHORIZATION, authorization_credentials.1)
.header(CONTENT_TYPE, "application/json")
.json(&self.get_reqeust_body(message, return_json))
.send()
.await?;
Ok(request.text().await?)
}
fn get_authorization_credentials(&self) -> (String, String);
fn get_reqeust_body(&self, message: Message, return_json: bool) -> Value;
fn get_chat_completion_request_url(&self) -> String;
fn get_model_ref(&self) -> &str;
}
pub trait Task: Serialize + for<'de> Deserialize<'de> + Default {
fn get_system_prompt(&self) -> String;
}
pub trait GenerateData
where
Self: IsLLM + Sync,
{
fn generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String =
format_additional_instructions(additional_instructions);
let request: String = self.send_message(
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating a json:\n{}",
task.get_system_prompt(),
formatted_additional_instructions,
target
),
},
true,
)?;
let value: Value = serde_json::from_str(&request).unwrap();
let result = value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string();
Ok(serde_json::from_str::<T>(&result)?)
}
fn force_generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String =
format_additional_instructions(additional_instructions);
let response: String = self
.send_message(
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating a json:\n{}",
task.get_system_prompt(),
formatted_additional_instructions,
target
),
},
false,
)?;
let value: Value = serde_json::from_str(&response).unwrap();
let result: String = value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string();
Ok(surfing::serde::from_mixed_text(&result)?)
}
}
#[async_trait]
pub trait AsyncGenerateData
where
Self: IsLLM,
{
async fn async_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String =
format_additional_instructions(additional_instructions);
let request: Result<String, Box<dyn std::error::Error + Send + Sync>> = self
.async_send_message(
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating a json:\n{}",
task.get_system_prompt(),
formatted_additional_instructions,
target
),
},
true,
)
.await;
let result = match request {
Ok(result) => {
dbg!(&result);
let value: Value = serde_json::from_str(&result).unwrap();
value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string()
}
Err(error) => return Err(SecretaryError::BuildRequestError(error.to_string()).into()),
};
Ok(serde_json::from_str(&result)?)
}
async fn async_force_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String =
format_additional_instructions(additional_instructions);
let request: Result<String, Box<dyn std::error::Error + Send + Sync>> = self
.async_send_message(
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating a json:\n{}",
task.get_system_prompt(),
formatted_additional_instructions,
target
),
},
false,
)
.await;
let result: String = match request {
Ok(result) => {
let value: Value = serde_json::from_str(&result).unwrap();
value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string()
}
Err(error) => return Err(SecretaryError::BuildRequestError(error.to_string()).into()),
};
Ok(surfing::serde::from_mixed_text(&result)?)
}
}
fn format_additional_instructions(additional_instructions: &Vec<String>) -> String {
let mut prompt: String = String::new();
if !additional_instructions.is_empty() {
prompt.push_str("\nAdditional instructions:\n");
for instruction in additional_instructions {
prompt.push_str(&format!("- {}\n", instruction));
}
}
prompt
}