use anyhow::{anyhow, Result};
use jsonschema::JSONSchema;
use log::error;
use log::info;
use reqwest::Client;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::time::Duration;
use tokio::time;
use tokio::time::timeout;
use crate::assistants::{OpenAIAssistantResource, OpenAIAssistantVersion, OpenAIVectorStore};
use crate::constants::{OPENAI_ASSISTANT_INSTRUCTIONS, OPENAI_ASSISTANT_POLL_FREQ};
use crate::domain::{
AllmsError, OpenAIAssistantResp, OpenAIMessageListResp, OpenAIMessageResp, OpenAIRunResp,
OpenAIThreadResp,
};
use crate::enums::{OpenAIAssistantRole, OpenAIRunStatus};
use crate::llm_models::{LLMModel, OpenAIModels};
use crate::utils::get_type_schema;
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIAssistant {
id: Option<String>,
thread_id: Option<String>,
run_id: Option<String>,
model: OpenAIModels,
instructions: String,
debug: bool,
api_key: String,
version: OpenAIAssistantVersion,
vector_store: Option<OpenAIVectorStore>,
temperature: f32,
poll_interval: usize,
}
impl OpenAIAssistant {
pub fn new(model: OpenAIModels, open_ai_key: &str) -> Self {
let model = match model.assistants_support() {
false => OpenAIModels::Gpt4o,
_ => model,
};
OpenAIAssistant {
id: None,
thread_id: None,
run_id: None,
temperature: model.get_default_temperature(),
model,
instructions: OPENAI_ASSISTANT_INSTRUCTIONS.to_string(),
debug: false,
api_key: open_ai_key.to_string(),
version: OpenAIAssistantVersion::V1,
vector_store: None,
poll_interval: OPENAI_ASSISTANT_POLL_FREQ,
}
}
pub fn debug(mut self) -> Self {
self.debug = true;
self
}
pub fn version(mut self, version: OpenAIAssistantVersion) -> Self {
self.version = version;
self
}
pub fn temperature(mut self, temp_target: u32) -> Self {
self.temperature = self.model.get_normalized_temperature(temp_target);
self
}
pub fn temperature_unchecked(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
pub fn poll_interval(mut self, poll_interval: usize) -> Self {
self.poll_interval = poll_interval;
self
}
async fn create_assistant(&mut self) -> Result<()> {
let assistant_url = self
.version
.get_endpoint(&OpenAIAssistantResource::Assistants);
if self.debug {
info!(
"[debug] OpenAI Assistant Create API URL: {:#?}",
assistant_url
);
}
let version_headers = self.version.get_headers(&self.api_key);
let mut assistant_body = json!({
"instructions": self.instructions.clone(),
"model": self.model.as_str(),
"temperature": self.temperature,
});
if self.model.tools_support() {
if let Some(assistant_body_object) = assistant_body.as_object_mut() {
let tools_payload = self.version.get_tools_payload();
assistant_body_object.insert("tools".to_string(), tools_payload);
}
}
let client = Client::new();
let response = client
.post(assistant_url)
.headers(version_headers)
.json(&assistant_body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Assistant API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIAssistantResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Assistant API response serialization error: {}", error),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
self.id = Some(response_deser.id);
Ok(())
}
pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
&mut self,
message: &str,
file_ids: &[String],
) -> Result<T> {
let schema_string = get_type_schema::<T>()?;
let assistant_response = self
.call_assistant(&schema_string, message, file_ids)
.await?;
serde_json::from_str::<T>(&assistant_response).map_err(|e| {
let error = AllmsError {
crate_name: "alms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Deserialization error: {:?}", e),
error_detail: assistant_response,
};
anyhow!("{:?}", error)
})
}
pub async fn get_json_answer(
&mut self,
message: &str,
json_schema: &str,
file_ids: &[String],
) -> Result<Value> {
let assistant_response = self.call_assistant(json_schema, message, file_ids).await?;
self.get_valid_json(json_schema, &assistant_response)
}
async fn call_assistant(
&mut self,
json_schema: &str,
message: &str,
file_ids: &[String],
) -> Result<String> {
if self.id.is_none() {
self.create_assistant().await?;
self.add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
.await?;
}
let schema_message = format!(
"Response should include only the data portion of a Json formatted as per the following schema: {}.
The response should only include well-formatted data, and not the schema itself.
Do not include any other words or characters, including the word 'json'. Only respond with the data.
You need to validate the Json before returning.",
json_schema
);
self.add_message(&schema_message, &Vec::new()).await?;
self.add_message(message, file_ids).await?;
self.start_run().await?;
let operation_timeout = Duration::from_secs(600); let poll_interval = Duration::from_secs(self.poll_interval as u64);
let _result = timeout(operation_timeout, async {
let mut interval = time::interval(poll_interval);
loop {
interval.tick().await; match self.get_run_status().await {
Ok(resp) => match resp.status {
OpenAIRunStatus::Completed => {
break Ok(());
}
OpenAIRunStatus::RequiresAction
| OpenAIRunStatus::Cancelling
| OpenAIRunStatus::Cancelled
| OpenAIRunStatus::Failed
| OpenAIRunStatus::Expired => {
return Err(anyhow!("Failed to validate status of the run"));
}
_ => continue, },
Err(e) => return Err(e), }
}
})
.await?;
let messages = self.get_message_thread().await?;
messages
.iter()
.filter(|message| message.role == OpenAIAssistantRole::Assistant)
.find_map(|message| {
message.content.iter().find_map(|content| {
content
.text
.as_ref()
.map(|text| self.model.sanitize_json_response(&text.value))
})
})
.ok_or_else(|| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: "No valid response from OpenAI Assistant found.".to_string(),
error_detail: format!("{:?}", &messages),
};
error!("{:?}", error);
anyhow!("{:?}", error)
})
}
fn get_valid_json(&self, schema: &str, value: &str) -> Result<Value> {
let schema_value = serde_json::from_str(schema).map_err(|e| {
let error = AllmsError {
crate_name: "alms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Json Schema parsing error: {:?}", e),
error_detail: format!("Schema: {:?}", schema),
};
anyhow!("{:?}", error)
})?;
let compiled_schema = JSONSchema::compile(&schema_value).map_err(|e| {
let error = AllmsError {
crate_name: "alms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Json Schema compilation error: {:?}", e),
error_detail: format!("Schema: {:?}", schema_value),
};
anyhow!("{:?}", error)
})?;
let data_value = serde_json::from_str(value).map_err(|e| {
let error = AllmsError {
crate_name: "alms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Json data parsing error: {:?}", e),
error_detail: format!("Data: {:?}", value),
};
anyhow!("{:?}", error)
})?;
compiled_schema.validate(&data_value).map_err(|_| {
let error = AllmsError {
crate_name: "alms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: "Json Schema validation error".to_string(),
error_detail: format!("Data: {:?}\nSchema: {:?}", &data_value, &schema_value),
};
anyhow!("{:?}", error)
})?;
Ok(data_value)
}
pub async fn set_context<T: Serialize>(mut self, dataset_name: &str, data: &T) -> Result<Self> {
if self.id.is_none() {
self.create_assistant().await?;
self.add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
.await?;
}
let serialized_data = if let Ok(json) = serde_json::to_string(&data) {
json
} else {
return Err(anyhow!("Unable serialize provided input data."));
};
let message = format!("'{dataset_name}'= {serialized_data}");
let file_ids = Vec::new();
self.add_message(&message, &file_ids).await?;
Ok(self)
}
async fn add_message(&mut self, message: &str, file_ids: &[String]) -> Result<()> {
let mut message = json!({
"role": "user",
"content": message.to_string(),
});
if !file_ids.is_empty() {
message = self.version.add_message_attachments(&message, file_ids);
}
match self.thread_id {
None => {
let body = json!({
"messages": vec![message],
});
self.create_thread(&body).await
}
Some(_) => self.add_message_thread(&message).await,
}
}
async fn create_thread(&mut self, body: &serde_json::Value) -> Result<()> {
let thread_url = self.version.get_endpoint(&OpenAIAssistantResource::Threads);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client
.post(thread_url)
.headers(version_headers)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Threads API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIThreadResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Thread API response serialization error: {}", error),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
self.thread_id = Some(response_deser.id);
Ok(())
}
async fn add_message_thread(&self, body: &serde_json::Value) -> Result<()> {
if self.thread_id.is_none() {
return Err(anyhow!("No active thread detected."));
}
let messages_resource = OpenAIAssistantResource::Messages {
thread_id: self.thread_id.clone().unwrap_or_default(),
};
let message_url = self.version.get_endpoint(&messages_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client
.post(message_url)
.headers(version_headers)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Messages API response: [{}] {:#?}",
&response_status, &response_text
);
}
let _response_deser: OpenAIMessageResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Messages API response serialization error: {}", error),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
Ok(())
}
async fn get_message_thread(&self) -> Result<Vec<OpenAIMessageResp>> {
if self.thread_id.is_none() {
return Err(anyhow!("No active thread detected."));
}
let message_resource = OpenAIAssistantResource::Messages {
thread_id: self.thread_id.clone().unwrap_or_default(),
};
let message_url = self.version.get_endpoint(&message_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client
.get(message_url)
.headers(version_headers)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Messages API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIMessageListResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Messages API response serialization error: {}", error),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
Ok(response_deser.data)
}
async fn start_run(&mut self) -> Result<()> {
let assistant_id = if let Some(id) = self.id.clone() {
id
} else {
return Err(anyhow!("No active assistant detected."));
};
let thread_id = if let Some(id) = self.thread_id.clone() {
id
} else {
return Err(anyhow!("No active thread detected."));
};
let run_resource = OpenAIAssistantResource::Runs { thread_id };
let run_url = self.version.get_endpoint(&run_resource);
let version_headers = self.version.get_headers(&self.api_key);
let body = json!({
"assistant_id": assistant_id,
});
let client = Client::new();
let response = client
.post(run_url)
.headers(version_headers)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Messages API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIRunResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Run API response serialization error: {}", error),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
self.run_id = Some(response_deser.id);
Ok(())
}
async fn get_run_status(&self) -> Result<OpenAIRunResp> {
let thread_id = if let Some(id) = self.thread_id.clone() {
id
} else {
return Err(anyhow!("No active thread detected."));
};
let run_id = if let Some(id) = self.run_id.clone() {
id
} else {
return Err(anyhow!("No active run detected."));
};
let run_resource = OpenAIAssistantResource::Run { thread_id, run_id };
let run_url = self.version.get_endpoint(&run_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client.get(run_url).headers(version_headers).send().await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Run status API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIRunResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!("Run API response serialization error: {}", error),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
Ok(response_deser)
}
pub async fn vector_store(&mut self, vector_store: OpenAIVectorStore) -> Result<Self> {
if self.version == OpenAIAssistantVersion::V1 {
return Err(anyhow!(
"[OpenAI][Assistants] OpenAI Assistants API v1 does not support Vector Store."
));
}
if vector_store.id.is_none() {
return Err(anyhow!(
"[OpenAI][Assistants] Unable to attach Vector Store. No valid ID found."
));
}
self.attach_vector_store(&vector_store).await?;
self.vector_store = Some(vector_store);
Ok(self.clone())
}
async fn attach_vector_store(&mut self, vector_store: &OpenAIVectorStore) -> Result<()> {
if self.id.is_none() {
self.create_assistant().await?;
self.add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
.await?;
}
let vector_store_id = if let Some(id) = &vector_store.id {
id.to_string()
} else {
return Err(anyhow!(
"[OpenAI][Assistants] Unable to attach Vector Store. No valid ID found."
));
};
let assistant_resource = OpenAIAssistantResource::Assistant {
assistant_id: self.id.clone().unwrap_or_default(),
};
let assistant_url = self.version.get_endpoint(&assistant_resource);
let version_headers = self.version.get_headers(&self.api_key);
let body = json!({
"tool_resources": {
"file_search": {
"vector_store_ids": vec![vector_store_id]
}
},
});
let client = Client::new();
let response = client
.post(assistant_url)
.headers(version_headers)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Vector Store Attach API response: [{}] {:#?}",
&response_status, &response_text
);
}
serde_json::from_str::<OpenAIAssistantResp>(&response_text)
.map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_assistant".to_string(),
error_message: format!(
"Vector Store Attach API response serialization error: {}",
error
),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})
.map(|_| Ok(()))?
}
}