use serde::{Deserialize, Serialize};
use serde_json::Value;
use anyhow::{Error, Result, anyhow};
use async_openai::{
Client,
config::Config,
types::{
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat,
},
};
use tokio::runtime::Runtime;
use crate::message_list::{Message, MessageList, Role};
pub trait IsLLM {
fn access_client(&self) -> &Client<impl Config>;
fn access_model(&self) -> &str;
}
pub trait DataModel: Serialize + Deserialize<'static> {
fn get_data_model_instructions() -> Value {
serde_json::to_value(Self::provide_data_model_instructions()).expect("Failed to convert data model to JSON")
}
fn provide_data_model_instructions() -> Self;
}
pub trait SystemPrompt {
fn get_system_prompt(&self) -> String;
}
pub trait Context {
fn push(&mut self, role: Role, content: &str) -> Result<(), Error> {
match role {
Role::User => self
.get_context_mut()
.push(Message::new(Role::User, content.to_string())),
Role::Assistant => {
self.get_context_mut()
.push(Message::new(Role::Assistant, content.to_string()));
}
Role::System => {
self.get_context_mut()
.push(Message::new(Role::System, content.to_string()));
}
_ => return Err(anyhow!("Unsupported role")),
}
Ok(())
}
fn get_context_mut(&mut self) -> &mut MessageList;
fn get_context(&self) -> MessageList;
}
pub trait GenerateJSON
where
Self: IsLLM,
{
fn generate_json(&self, task: &impl SystemPrompt, target: &str) -> Result<String, Error> {
let runtime = tokio::runtime::Runtime::new()?;
let result: String = runtime.block_on(async {
let request = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build()?
.into(),
])
.build()?;
let response: CreateChatCompletionResponse =
match self.access_client().chat().create(request.clone()).await {
std::result::Result::Ok(response) => response,
Err(e) => {
anyhow::bail!("Failed to execute function: {}", e);
}
};
if let Some(content) = response.choices[0].clone().message.content {
return Ok(content);
}
return Err(anyhow!("No response is retrieved from the LLM"));
})?;
Ok(result)
}
fn generate_json_with_context<T>(&self, task: &T) -> Result<String, Error>
where
T: SystemPrompt + Context,
{
let runtime: Runtime = tokio::runtime::Runtime::new()?;
let result: String = runtime.block_on(async {
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(task.get_context())
.build()?;
let response: CreateChatCompletionResponse =
match self.access_client().chat().create(request.clone()).await {
std::result::Result::Ok(response) => response,
Err(e) => {
anyhow::bail!("Failed to execute function: {}", e);
}
};
if let Some(content) = response.choices[0].clone().message.content {
return Ok(content);
}
return Err(anyhow!("No response is retrieved from the LLM"));
})?;
Ok(result)
}
}
pub trait AsyncGenerateJSON
where
Self: IsLLM,
{
async fn async_generate_json(
&self,
task: &impl SystemPrompt,
target: &str,
) -> Result<String, Error> {
let request = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(vec![
ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(
task.get_system_prompt()
+ "\nThis is the basis for generating a json:\n"
+ target,
)
.build()?
.into(),
])
.build()?
.into(),
])
.build()?;
let response: CreateChatCompletionResponse =
match self.access_client().chat().create(request.clone()).await {
std::result::Result::Ok(response) => response,
Err(e) => {
anyhow::bail!("Failed to execute function: {}", e);
}
};
if let Some(content) = response.choices[0].clone().message.content {
return Ok(content);
}
return Err(anyhow!("No response is retrieved from the LLM"));
}
async fn async_generate_json_with_context<T>(&self, task: &T) -> Result<String, Error>
where
T: SystemPrompt + Context,
{
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.model(&self.access_model().to_string())
.response_format(ResponseFormat::JsonObject)
.messages(task.get_context())
.build()?;
let response: CreateChatCompletionResponse =
match self.access_client().chat().create(request.clone()).await {
std::result::Result::Ok(response) => response,
Err(e) => {
anyhow::bail!("Failed to execute function: {}", e);
}
};
if let Some(content) = response.choices[0].clone().message.content {
return Ok(content);
}
return Err(anyhow!("No response is retrieved from the LLM"));
}
}
pub trait ToJSON
where
Self: serde::Serialize + Sized,
{
fn to_json(&self) -> Result<String, Error> {
Ok(serde_json::to_string(self)?)
}
}
pub trait FromJSON {
fn from_json(json: &str) -> Result<Self, Error>
where
Self: for<'de> serde::Deserialize<'de> + Sized,
{
Ok(serde_json::from_str(json)?)
}
}