use std::panic;
use async_trait::async_trait;
use futures::future;
use reqwest::{
Response,
header::{AUTHORIZATION, CONTENT_TYPE},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use secretary_derive::Task;
use crate::{generate_from_tuples, message::Message, utilities::{cleanup_thinking_blocks, format_additional_instructions}, SecretaryError};
#[async_trait]
pub trait IsLLM {
fn send_message(
&self,
message: Message,
return_json: bool,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: reqwest::blocking::Response = reqwest::blocking::Client::new()
.post(self.get_chat_completion_request_url())
.header(AUTHORIZATION, self.get_authorization_credentials())
.header(CONTENT_TYPE, "application/json")
.json(&self.get_request_body(message, return_json))
.send()?;
Ok(request.text()?)
}
async fn async_send_message(
&self,
message: Message,
return_json: bool,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: Response = reqwest::Client::new()
.post(self.get_chat_completion_request_url())
.header(AUTHORIZATION, self.get_authorization_credentials())
.header(CONTENT_TYPE, "application/json")
.json(&self.get_request_body(message, return_json))
.send()
.await?;
Ok(request.text().await?)
}
fn get_authorization_credentials(&self) -> String;
fn get_request_body(&self, message: Message, return_json: bool) -> Value;
fn get_chat_completion_request_url(&self) -> String;
fn get_model_ref(&self) -> &str;
}
pub trait Task: Serialize + for<'de> Deserialize<'de> + Default {
fn get_system_prompt(&self) -> String;
fn get_system_prompts_for_distributed_generation(&self) -> Vec<(String, String)>;
fn make_prompt(&self, target: &str, additional_instructions: &Vec<String>) -> Message {
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating a json:\n{}",
self.get_system_prompt(),
format_additional_instructions(additional_instructions),
target
),
}
}
fn make_dstributed_generation_prompts(&self, target: &str, additional_instructions: &Vec<String>) -> Vec<(String, Message)> {
let mut messages: Vec<(String, Message)> = Vec::new();
for prompt in self.get_system_prompts_for_distributed_generation() {
messages.push(
(
prompt.0,
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating the result:\n{}",
prompt.1,
format_additional_instructions(additional_instructions),
target
),
}
)
);
}
messages
}
}
pub trait GenerateData
where
Self: IsLLM + Sync,
{
fn generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: String = self.send_message(
task.make_prompt(target, additional_instructions),
true,
)?;
let value: Value = serde_json::from_str(&request).unwrap();
let result = value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string();
Ok(serde_json::from_str::<T>(&result)?)
}
fn force_generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let response: String = self.send_message(
task.make_prompt(target, additional_instructions),
false,
)?;
let value: Value = serde_json::from_str(&response).unwrap();
let result: String = value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string();
Ok(surfing::serde::from_mixed_text(&result)?)
}
fn fields_generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let messages: Vec<(String, Message)> = task.make_dstributed_generation_prompts(target, additional_instructions);
let distributed_tasks_results: Vec<(String, String)> = std::thread::scope(|s|{
let mut distributed_tasks = Vec::new();
for (field_name, message) in messages {
let handler = s.spawn(move || {
let raw_result: String = self.send_message(message, false).unwrap();
let value: Value = serde_json::from_str(&raw_result).unwrap();
let content: String = value["choices"][0]["message"]["content"].as_str().unwrap().to_string();
(field_name, cleanup_thinking_blocks(content))
});
distributed_tasks.push(handler);
}
let mut distributed_tasks_results: Vec<(String, String)> = Vec::new();
for distributed_task in distributed_tasks {
match distributed_task.join() {
Ok(result) => distributed_tasks_results.push(result),
Err(_) => panic!()
}
}
distributed_tasks_results
});
Ok(generate_from_tuples!(T, distributed_tasks_results))
}
}
#[async_trait]
pub trait AsyncGenerateData
where
Self: IsLLM,
{
async fn async_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: Result<String, Box<dyn std::error::Error + Send + Sync>> = self
.async_send_message(
task.make_prompt(target, additional_instructions),
true,
)
.await;
let result = match request {
Ok(result) => {
let value: Value = serde_json::from_str(&result).unwrap();
value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string()
}
Err(error) => return Err(SecretaryError::BuildRequestError(error.to_string()).into()),
};
Ok(serde_json::from_str(&result)?)
}
async fn async_force_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: Result<String, Box<dyn std::error::Error + Send + Sync>> = self
.async_send_message(
task.make_prompt(target, additional_instructions),
false,
)
.await;
let result: String = match request {
Ok(result) => {
let value: Value = serde_json::from_str(&result).unwrap();
value["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.to_string()
}
Err(error) => return Err(SecretaryError::BuildRequestError(error.to_string()).into()),
};
Ok(surfing::serde::from_mixed_text(&result)?)
}
async fn async_fields_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let messages: Vec<(String, Message)> = task.make_dstributed_generation_prompts(target, additional_instructions);
let mut distributed_tasks = Vec::new();
for (field_name, message) in messages {
let task_future = async move {
let raw_result: String = self.async_send_message(message, false).await?;
let value: Value = serde_json::from_str(&raw_result).unwrap();
let content: String = value["choices"][0]["message"]["content"].as_str().unwrap().to_string();
Ok::<(String, String), Box<dyn std::error::Error + Send + Sync>>((field_name, cleanup_thinking_blocks(content)))
};
distributed_tasks.push(task_future);
}
let distributed_tasks_results: Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync + 'static>> =
future::try_join_all(distributed_tasks).await;
let distributed_tasks_results: Vec<(String, String)> = distributed_tasks_results?;
Ok(generate_from_tuples!(T, distributed_tasks_results))
}
}