use serde::{Deserialize, Serialize};
use serde_json::Value;
use async_openai::{
Client,
config::Config,
types::{
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat,
},
};
use tokio::runtime::Runtime;
pub use secretary_derive::Task;
use crate::SecretaryError;
pub trait IsLLM {
fn access_client(&self) -> &Client<impl Config>;
fn access_model(&self) -> &str;
}
pub trait Task: Serialize + for<'de> Deserialize<'de> + Default {
fn get_data_model_instructions() -> Value {
serde_json::to_value(Self::provide_data_model_instructions())
.expect("Failed to convert data model to JSON")
}
fn provide_data_model_instructions() -> Self;
fn get_system_prompt(&self) -> String;
}
pub trait GenerateData
where
Self: IsLLM,
{
fn generate_data<T: Task>(&self, task: &T, target: &str, additional_instructions: &Vec<String>) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String = format_additional_instructions(additional_instructions);
let runtime: Runtime = tokio::runtime::Runtime::new()?;
let result: String = runtime.block_on(async {
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ &formatted_additional_instructions
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build().map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?
.into(),
])
.build().map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?
.into(),
])
.build().map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?;
let response: CreateChatCompletionResponse = self
.access_client()
.chat()
.create(request.clone())
.await
.map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?;
if let Some(content) = response.choices[0].clone().message.content {
return Ok(content);
}
return Err(SecretaryError::NoLLMResponse);
})?;
Ok(serde_json::from_str(&result)?)
}
fn force_generate_data<T: Task>(&self, task: &T, target: &str, additional_instructions: &Vec<String>) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String = format_additional_instructions(additional_instructions);
let runtime: Runtime = tokio::runtime::Runtime::new()?;
let result: String = runtime.block_on(async {
let request = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ &formatted_additional_instructions
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build().map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?
.into(),
])
.build().map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?;
let response: CreateChatCompletionResponse = self
.access_client()
.chat()
.create(request.clone())
.await
.map_err(|e| SecretaryError::BuildRequestError(e.to_string()))?;
if let Some(content) = response.choices[0].clone().message.content {
return Ok::<String, Box<dyn std::error::Error + Send + Sync + 'static>>(content);
}
return Err(SecretaryError::NoLLMResponse.into());
})?;
Ok(surfing::serde::from_mixed_text(&result)?)
}
}
pub trait AsyncGenerateData
where
Self: IsLLM,
{
async fn async_generate_data<T: Task>(&self, task: &T, target: &str, additional_instructions: &Vec<String>) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String = format_additional_instructions(additional_instructions);
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ &formatted_additional_instructions
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build()?
.into(),
])
.build()?;
let response: CreateChatCompletionResponse =
self.access_client()
.chat()
.create(request.clone())
.await
.map_err(|e| format!("Failed to execute function: {}", e))?;
if let Some(content) = response.choices[0].clone().message.content {
return Ok(serde_json::from_str(&content)?);
}
return Err(SecretaryError::NoLLMResponse.into());
}
async fn async_force_generate_data<T: Task>(&self, task: &T, target: &str, additional_instructions: &Vec<String>) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let formatted_additional_instructions: String = format_additional_instructions(additional_instructions);
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ &formatted_additional_instructions
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build()?
.into(),
])
.build()?;
let response: CreateChatCompletionResponse =
self.access_client()
.chat()
.create(request.clone())
.await
.map_err(|e| format!("Failed to execute function: {}", e))?;
if let Some(content) = response.choices[0].clone().message.content {
return Ok(surfing::serde::from_mixed_text(&content)?);
}
return Err(SecretaryError::NoLLMResponse.into());
}
}
pub trait ToJSON
where
Self: serde::Serialize + Sized,
{
fn to_json(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
Ok(serde_json::to_string(self)?)
}
}
pub trait FromJSON {
fn from_json(json: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>>
where
Self: for<'de> serde::Deserialize<'de> + Sized,
{
Ok(serde_json::from_str(json)?)
}
}
fn format_additional_instructions(additional_instructions: &Vec<String>) -> String {
let mut prompt: String = String::new();
if !additional_instructions.is_empty() {
prompt.push_str("\nAdditional instructions:\n");
for instruction in additional_instructions {
prompt.push_str(&format!("- {}\n", instruction));
}
}
prompt
}