use serde::{Deserialize, Serialize};
use serde_json::Value;
use anyhow::{Error, Result, anyhow};
use async_openai::{
Client,
config::Config,
types::{
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat,
},
};
use tokio::runtime::Runtime;
pub use secretary_derive::Task;
pub trait IsLLM {
fn access_client(&self) -> &Client<impl Config>;
fn access_model(&self) -> &str;
}
pub trait Task: Serialize + for<'de> Deserialize<'de> + Default {
fn get_data_model_instructions() -> Value {
serde_json::to_value(Self::provide_data_model_instructions())
.expect("Failed to convert data model to JSON")
}
fn provide_data_model_instructions() -> Self;
fn get_system_prompt(&self) -> String;
}
pub trait GenerateData
where
Self: IsLLM,
{
fn generate_data<T: Task>(&self, task: &T, target: &str, additional_instructions: &Vec<String>) -> Result<T, Error> {
let formatted_additional_instructions: String = format_additional_instructions(additional_instructions);
let runtime: Runtime = tokio::runtime::Runtime::new()?;
let result: String = runtime.block_on(async {
let request = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ &formatted_additional_instructions
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build()?
.into(),
])
.build()?;
let response: CreateChatCompletionResponse =
match self.access_client().chat().create(request.clone()).await {
std::result::Result::Ok(response) => response,
Err(e) => {
anyhow::bail!("Failed to execute function: {}", e);
}
};
if let Some(content) = response.choices[0].clone().message.content {
return Ok(content);
}
return Err(anyhow!("No response is retrieved from the LLM"));
})?;
Ok(serde_json::from_str(&result)?)
}
}
pub trait AsyncGenerateData
where
Self: IsLLM,
{
async fn async_generate_data<T: Task>(&self, task: &T, target: &str, additional_instructions: &Vec<String>) -> Result<T, Error> {
let formatted_additional_instructions: String = format_additional_instructions(additional_instructions);
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ &formatted_additional_instructions
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build()?
.into(),
])
.build()?;
let response: CreateChatCompletionResponse =
match self.access_client().chat().create(request.clone()).await {
std::result::Result::Ok(response) => response,
Err(e) => {
anyhow::bail!("Failed to execute function: {}", e);
}
};
if let Some(content) = response.choices[0].clone().message.content {
return Ok(serde_json::from_str(&content)?);
}
return Err(anyhow!("No response is retrieved from the LLM"));
}
}
pub trait ToJSON
where
Self: serde::Serialize + Sized,
{
fn to_json(&self) -> Result<String, Error> {
Ok(serde_json::to_string(self)?)
}
}
pub trait FromJSON {
fn from_json(json: &str) -> Result<Self, Error>
where
Self: for<'de> serde::Deserialize<'de> + Sized,
{
Ok(serde_json::from_str(json)?)
}
}
fn format_additional_instructions(additional_instructions: &Vec<String>) -> String {
let mut prompt: String = String::new();
if !additional_instructions.is_empty() {
prompt.push_str("\nAdditional instructions:\n");
for instruction in additional_instructions {
prompt.push_str(&format!("- {}\n", instruction));
}
}
prompt
}