use std::panic;
use async_trait::async_trait;
use futures::future;
use reqwest::{
Response,
header::{AUTHORIZATION, CONTENT_TYPE},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use secretary_derive::Task;
use crate::{
SecretaryError, generate_from_tuples,
message::Message,
utilities::{
cleanup_thinking_blocks, extract_result_content, extract_text_content_from_llm_response,
format_additional_instructions,
},
};
#[async_trait]
pub trait IsLLM {
fn send_message(
&self,
message: Message,
return_json: bool,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: reqwest::blocking::Response = reqwest::blocking::Client::new()
.post(self.get_chat_completion_request_url())
.header(AUTHORIZATION, self.get_authorization_credentials())
.header(CONTENT_TYPE, "application/json")
.json(&self.get_request_body(message, return_json))
.send()?;
Ok(request.text()?)
}
async fn async_send_message(
&self,
message: Message,
return_json: bool,
) -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: Response = reqwest::Client::new()
.post(self.get_chat_completion_request_url())
.header(AUTHORIZATION, self.get_authorization_credentials())
.header(CONTENT_TYPE, "application/json")
.json(&self.get_request_body(message, return_json))
.send()
.await?;
Ok(request.text().await?)
}
fn get_authorization_credentials(&self) -> String;
fn get_request_body(&self, message: Message, return_json: bool) -> Value;
fn get_chat_completion_request_url(&self) -> String;
fn get_model_ref(&self) -> &str;
}
pub trait Task: Serialize + for<'de> Deserialize<'de> + Default {
fn get_system_prompt(&self) -> String;
fn get_system_prompts_for_distributed_generation(&self) -> Vec<(String, String)>;
fn make_prompt(&self, target: &str, additional_instructions: &Vec<String>) -> Message {
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating a json:\n{}",
self.get_system_prompt(),
format_additional_instructions(additional_instructions),
target
),
}
}
fn make_dstributed_generation_prompts(
&self,
target: &str,
additional_instructions: &Vec<String>,
) -> Vec<(String, Message)> {
let mut messages: Vec<(String, Message)> = Vec::new();
for prompt in self.get_system_prompts_for_distributed_generation() {
messages.push((
prompt.0,
Message {
role: "user".to_string(),
content: format!(
"{}{}\nThis is the basis for generating the result:\n{}",
prompt.1,
format_additional_instructions(additional_instructions),
target
),
},
));
}
messages
}
}
pub trait GenerateData
where
Self: IsLLM + Sync,
{
fn generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: String =
self.send_message(task.make_prompt(target, additional_instructions), true)?;
let result: String = extract_text_content_from_llm_response(&request)?;
match serde_json::from_str::<T>(&result) {
Ok(result) => Ok(result),
Err(error) => Err(Box::new(SecretaryError::SerdeJsonError(error))),
}
}
fn force_generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let response: String =
self.send_message(task.make_prompt(target, additional_instructions), false)?;
let result: String = extract_text_content_from_llm_response(&response)?;
match surfing::serde::from_mixed_text(&result) {
Ok(result) => Ok(result),
Err(error) => Err(Box::new(SecretaryError::JsonParsingError(
error.to_string(),
))),
}
}
fn fields_generate_data<T: Task>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let messages: Vec<(String, Message)> =
task.make_dstributed_generation_prompts(target, additional_instructions);
let distributed_tasks_results: Vec<(String, String)> = std::thread::scope(|s| {
let mut distributed_tasks = Vec::new();
for (field_name, message) in messages {
let handler = s.spawn(move || {
let content: String = extract_text_content_from_llm_response(
&self.send_message(message, false)?,
)?;
Ok::<(String, String), Box<dyn std::error::Error + Send + Sync + 'static>>((
field_name,
extract_result_content(&cleanup_thinking_blocks(content)),
))
});
distributed_tasks.push(handler);
}
let mut distributed_tasks_results: Vec<(String, String)> = Vec::new();
for distributed_task in distributed_tasks {
match distributed_task.join() {
Ok(result) => match result {
Ok(result) => distributed_tasks_results.push(result),
Err(error) => return Err(error),
},
Err(error) => panic!(),
}
}
Ok(distributed_tasks_results)
})?;
match panic::catch_unwind(|| generate_from_tuples!(T, distributed_tasks_results)) {
Ok(result) => Ok(result),
Err(panic_payload) => {
if let Some(error_msg) = panic_payload.downcast_ref::<String>() {
if error_msg.contains("Failed to deserialize") {
return Err(Box::new(SecretaryError::JsonParsingError(
error_msg.clone(),
)));
}
}
Err(Box::new(SecretaryError::JsonParsingError(
"Field deserialization failed with unknown error".to_string(),
)))
}
}
}
}
#[async_trait]
pub trait AsyncGenerateData
where
Self: IsLLM,
{
async fn async_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: Result<String, Box<dyn std::error::Error + Send + Sync>> = self
.async_send_message(task.make_prompt(target, additional_instructions), true)
.await;
let result = match request {
Ok(result) => extract_text_content_from_llm_response(&result)?,
Err(error) => return Err(SecretaryError::BuildRequestError(error.to_string()).into()),
};
match serde_json::from_str::<T>(&result) {
Ok(result) => Ok(result),
Err(error) => Err(Box::new(SecretaryError::SerdeJsonError(error))),
}
}
async fn async_force_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let request: Result<String, Box<dyn std::error::Error + Send + Sync>> = self
.async_send_message(task.make_prompt(target, additional_instructions), false)
.await;
let result: String = match request {
Ok(result) => extract_text_content_from_llm_response(&result)?,
Err(error) => return Err(SecretaryError::BuildRequestError(error.to_string()).into()),
};
match surfing::serde::from_mixed_text(&result) {
Ok(result) => Ok(result),
Err(error) => Err(Box::new(SecretaryError::JsonParsingError(
error.to_string(),
))),
}
}
async fn async_fields_generate_data<T: Task + Sync + Send>(
&self,
task: &T,
target: &str,
additional_instructions: &Vec<String>,
) -> Result<T, Box<dyn std::error::Error + Send + Sync + 'static>> {
let messages: Vec<(String, Message)> =
task.make_dstributed_generation_prompts(target, additional_instructions);
let mut distributed_tasks = Vec::new();
for (field_name, message) in messages {
let task_future = async move {
let content: String = extract_text_content_from_llm_response(
&self.async_send_message(message, false).await?,
)?;
Ok::<(String, String), Box<dyn std::error::Error + Send + Sync>>((
field_name,
extract_result_content(&cleanup_thinking_blocks(content)),
))
};
distributed_tasks.push(task_future);
}
let distributed_tasks_results: Result<
Vec<(String, String)>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> = future::try_join_all(distributed_tasks).await;
let distributed_tasks_results: Vec<(String, String)> = distributed_tasks_results?;
match panic::catch_unwind(|| generate_from_tuples!(T, distributed_tasks_results)) {
Ok(result) => Ok(result),
Err(panic_payload) => {
if let Some(error_msg) = panic_payload.downcast_ref::<String>() {
if error_msg.contains("Failed to deserialize") {
return Err(Box::new(SecretaryError::JsonParsingError(
error_msg.clone(),
)));
}
}
Err(Box::new(SecretaryError::JsonParsingError(
"Field deserialization failed with unknown error".to_string(),
)))
}
}
}
}