use anyhow::{anyhow, Context, Result};
use log::{error, info};
use reqwest::{
header::{self, HeaderMap, HeaderValue},
multipart, Client,
};
use schemars::{schema_for, JsonSchema};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::path::Path;
use std::time::Duration;
use tokio::time;
use tokio::time::timeout;
use super::OpenAIModels;
use crate::constants::{OPENAI_API_URL, OPENAI_ASSISTANT_INSTRUCTIONS};
use crate::domain::{
OpenAIAssistantResp, OpenAIMessageListResp, OpenAIMessageResp, OpenAIRunResp, OpenAIThreadResp,
};
use crate::enums::{OpenAIAssistantRole, OpenAIRunStatus};
use crate::utils::remove_json_wrapper;
#[deprecated(
since = "0.6.1",
note = "This struct is deprecated. Please use the `assistants::OpenAIAssistant` struct for latest functionality including Assistants API v2."
)]
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIAssistant {
id: Option<String>,
thread_id: Option<String>,
run_id: Option<String>,
model: OpenAIModels,
instructions: String,
debug: bool,
api_key: String,
version: OpenAIAssistantVersion,
}
impl OpenAIAssistant {
pub async fn new(model: OpenAIModels, open_ai_key: &str, debug: bool) -> Result<Self> {
Ok(OpenAIAssistant {
id: None,
thread_id: None,
run_id: None,
model,
instructions: OPENAI_ASSISTANT_INSTRUCTIONS.to_string(),
debug,
api_key: open_ai_key.to_string(),
version: OpenAIAssistantVersion::V1,
})
}
pub fn version(mut self, version: OpenAIAssistantVersion) -> Self {
self.version = version;
self
}
async fn create_assistant(&mut self) -> Result<()> {
let assistant_url = format!("{}/assistants", self.version.get_endpoint());
let version_headers = self.version.get_headers();
let tools_payload = self.version.get_tools_payload();
let assistant_body = json!({
"instructions": self.instructions.clone(),
"model": self.model.as_str(),
"tools": tools_payload,
});
let client = Client::new();
let response = client
.post(assistant_url)
.headers(version_headers)
.bearer_auth(&self.api_key)
.json(&assistant_body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Assistant API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIAssistantResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Assistant API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
self.id = Some(response_deser.id);
Ok(())
}
pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
mut self,
message: &str,
file_ids: &[String],
) -> Result<T> {
if self.id.is_none() {
self.create_assistant().await?;
self.add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
.await?;
}
let schema = schema_for!(T);
let schema_json: Value = serde_json::to_value(&schema)?;
let schema_string = serde_json::to_string(&schema_json).unwrap_or_default();
let schema_message = format!(
"Response should include only the data portion of a Json formatted as per the following schema: {}.
The response should only include well-formatted data, and not the schema itself.
Do not include any other words or characters, including the word 'json'. Only respond with the data.
You need to validate the Json before returning.",
schema_string
);
self.add_message(&schema_message, &Vec::new()).await?;
self.add_message(message, file_ids).await?;
self.start_run().await?;
let operation_timeout = Duration::from_secs(600); let poll_interval = Duration::from_secs(10);
let _result = timeout(operation_timeout, async {
let mut interval = time::interval(poll_interval);
loop {
interval.tick().await; match self.get_run_status().await {
Ok(resp) => match resp.status {
OpenAIRunStatus::Completed => {
break Ok(());
}
OpenAIRunStatus::RequiresAction
| OpenAIRunStatus::Cancelling
| OpenAIRunStatus::Cancelled
| OpenAIRunStatus::Failed
| OpenAIRunStatus::Expired => {
return Err(anyhow!("Failed to validate status of the run"));
}
_ => continue, },
Err(e) => return Err(e), }
}
})
.await?;
let messages = self.get_message_thread().await?;
messages
.into_iter()
.filter(|message| message.role == OpenAIAssistantRole::Assistant)
.find_map(|message| {
message.content.into_iter().find_map(|content| {
content.text.and_then(|text| {
let sanitized_text = remove_json_wrapper(&text.value);
serde_json::from_str::<T>(&sanitized_text).ok()
})
})
})
.ok_or(anyhow!("No valid response form OpenAI Assistant found."))
}
pub async fn set_context<T: Serialize>(mut self, dataset_name: &str, data: &T) -> Result<Self> {
if self.id.is_none() {
self.create_assistant().await?;
self.add_message(OPENAI_ASSISTANT_INSTRUCTIONS, &Vec::new())
.await?;
}
let serialized_data = if let Ok(json) = serde_json::to_string(&data) {
json
} else {
return Err(anyhow!("Unable serialize provided input data."));
};
let message = format!("'{dataset_name}'= {serialized_data}");
let file_ids = Vec::new();
self.add_message(&message, &file_ids).await?;
Ok(self)
}
async fn add_message(&mut self, message: &str, file_ids: &[String]) -> Result<()> {
let mut message = json!({
"role": "user",
"content": message.to_string(),
});
if !file_ids.is_empty() {
message = self.version.add_message_attachments(&message, file_ids);
}
match self.thread_id {
None => {
let body = json!({
"messages": vec![message],
});
self.create_thread(&body).await
}
Some(_) => self.add_message_thread(&message).await,
}
}
async fn create_thread(&mut self, body: &serde_json::Value) -> Result<()> {
let thread_url = format!("{}/threads", self.version.get_endpoint());
let version_headers = self.version.get_headers();
let client = Client::new();
let response = client
.post(thread_url)
.headers(version_headers)
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Threads API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIThreadResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Thread API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
self.thread_id = Some(response_deser.id);
Ok(())
}
async fn add_message_thread(&self, body: &serde_json::Value) -> Result<()> {
if self.thread_id.is_none() {
return Err(anyhow!("No active thread detected."));
}
let message_url = format!(
"{}/threads/{}/messages",
self.version.get_endpoint(),
self.thread_id.clone().unwrap_or_default(),
);
let version_headers = self.version.get_headers();
let client = Client::new();
let response = client
.post(message_url)
.headers(version_headers)
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Messages API response: [{}] {:#?}",
&response_status, &response_text
);
}
let _response_deser: OpenAIMessageResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Messages API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
Ok(())
}
async fn get_message_thread(&self) -> Result<Vec<OpenAIMessageResp>> {
if self.thread_id.is_none() {
return Err(anyhow!("No active thread detected."));
}
let message_url = format!(
"{}/threads/{}/messages",
self.version.get_endpoint(),
self.thread_id.clone().unwrap_or_default(),
);
let version_headers = self.version.get_headers();
let client = Client::new();
let response = client
.get(message_url)
.headers(version_headers)
.bearer_auth(&self.api_key)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Messages API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIMessageListResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Messages API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
Ok(response_deser.data)
}
async fn start_run(&mut self) -> Result<()> {
let assistant_id = if let Some(id) = self.id.clone() {
id
} else {
return Err(anyhow!("No active assistant detected."));
};
let thread_id = if let Some(id) = self.thread_id.clone() {
id
} else {
return Err(anyhow!("No active thread detected."));
};
let run_url = format!("{}/threads/{}/runs", self.version.get_endpoint(), thread_id,);
let version_headers = self.version.get_headers();
let body = json!({
"assistant_id": assistant_id,
});
let client = Client::new();
let response = client
.post(run_url)
.headers(version_headers)
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Messages API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIRunResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Run API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
self.run_id = Some(response_deser.id);
Ok(())
}
async fn get_run_status(&self) -> Result<OpenAIRunResp> {
let thread_id = if let Some(id) = self.thread_id.clone() {
id
} else {
return Err(anyhow!("No active thread detected."));
};
let run_id = if let Some(id) = self.run_id.clone() {
id
} else {
return Err(anyhow!("No active run detected."));
};
let run_url = format!(
"{}/threads/{}/runs/{}",
self.version.get_endpoint(),
thread_id,
run_id,
);
let version_headers = self.version.get_headers();
let client = Client::new();
let response = client
.get(run_url)
.headers(version_headers)
.bearer_auth(&self.api_key)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Run status API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIRunResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Run API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
Ok(response_deser)
}
}
#[deprecated(
since = "0.6.1",
note = "This struct is deprecated. Please use the `assistants::OpenAIAssistantVersion` struct for latest functionality including Assistants API v2+."
)]
#[derive(Deserialize, Serialize, Debug, Clone)]
pub enum OpenAIAssistantVersion {
V1,
V2,
}
impl OpenAIAssistantVersion {
pub(crate) fn get_endpoint(&self) -> String {
match self {
OpenAIAssistantVersion::V1 | OpenAIAssistantVersion::V2 => {
format!("{OPENAI_API_URL}/v1", OPENAI_API_URL = *OPENAI_API_URL)
}
}
}
pub(crate) fn get_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
match self {
OpenAIAssistantVersion::V1 => {
headers.insert("OpenAI-Beta", HeaderValue::from_static("assistants=v1"))
}
OpenAIAssistantVersion::V2 => {
headers.insert("OpenAI-Beta", HeaderValue::from_static("assistants=v2"))
}
};
headers
}
pub(crate) fn get_tools_payload(&self) -> Vec<Value> {
match self {
OpenAIAssistantVersion::V1 => vec![json!({
"type": "retrieval"
})],
OpenAIAssistantVersion::V2 => vec![json!({
"type": "file_search"
})],
}
}
pub(crate) fn add_message_attachments(
&self,
message_payload: &Value,
file_ids: &[String],
) -> Value {
let mut message_payload = message_payload.clone();
match self {
OpenAIAssistantVersion::V1 => {
message_payload["file_ids"] = json!(file_ids);
}
OpenAIAssistantVersion::V2 => {
let file_search_json = json!({
"type": "file_search"
});
let attachments_vec: Vec<Value> = file_ids
.iter()
.map(|file_id| {
json!({
"file_id": file_id.to_string(),
"tools": [file_search_json.clone()]
})
})
.collect();
message_payload["attachments"] = json!(attachments_vec);
}
}
message_payload
}
}
#[deprecated(
since = "0.6.1",
note = "This struct is deprecated. Please use the `assistants::OpenAIFile` struct for latest functionality."
)]
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIFile {
pub id: String,
debug: bool,
api_key: String,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIFileResp {
id: String,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIDFileDeleteResp {
id: String,
object: String,
deleted: bool,
}
impl OpenAIFile {
pub async fn new(
file_name: &str,
file_bytes: Vec<u8>,
open_ai_key: &str,
debug: bool,
) -> Result<Self> {
let mut new_file = OpenAIFile {
id: "this-will-be-overwritten".to_string(),
debug,
api_key: open_ai_key.to_string(),
};
new_file.upload_file(file_name, file_bytes).await?;
Ok(new_file)
}
async fn upload_file(&mut self, file_name: &str, file_bytes: Vec<u8>) -> Result<()> {
let files_url = "https://api.openai.com/v1/files";
let mime_type = match Path::new(file_name)
.extension()
.and_then(std::ffi::OsStr::to_str)
{
Some("pdf") => "application/pdf",
Some("json") => "application/json",
Some("txt") => "text/plain",
Some("html") => "text/html",
Some("c") => "text/x-c",
Some("cpp") => "text/x-c++",
Some("docx") => {
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
Some("java") => "text/x-java",
Some("md") => "text/markdown",
Some("php") => "text/x-php",
Some("pptx") => {
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
}
Some("py") => "text/x-python",
Some("rb") => "text/x-ruby",
Some("tex") => "text/x-tex",
Some("css") => "text/css",
Some("jpeg") | Some("jpg") => "image/jpeg",
Some("js") => "text/javascript",
Some("gif") => "image/gif",
Some("png") => "image/png",
Some("tar") => "application/x-tar",
Some("ts") => "application/typescript",
Some("xlsx") => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
Some("xml") => "application/xml",
Some("zip") => "application/zip",
_ => anyhow::bail!("Unsupported file type"),
};
let form = multipart::Form::new().text("purpose", "assistants").part(
"file",
multipart::Part::bytes(file_bytes)
.file_name(file_name.to_string())
.mime_str(mime_type)
.context("Failed to set MIME type")?,
);
let client = Client::new();
let response = client
.post(files_url)
.header(header::CONTENT_TYPE, "application/json")
.header("OpenAI-Beta", "assistants=v1")
.bearer_auth(&self.api_key)
.multipart(form)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Files status API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIFileResp =
serde_json::from_str(&response_text).map_err(|error| {
error!(
"[OpenAIAssistant] Files API response serialization error: {}",
&error
);
anyhow!("Error: {}", error)
})?;
self.id = response_deser.id;
Ok(())
}
pub async fn delete_file(&self) -> Result<()> {
let files_url = format!("https://api.openai.com/v1/files/{}", self.id);
let client = Client::new();
let response = client
.delete(files_url)
.bearer_auth(&self.api_key)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[debug] OpenAI Files status API response: [{}] {:#?}",
&response_status, &response_text
);
}
serde_json::from_str::<OpenAIDFileDeleteResp>(&response_text)
.map_err(|error| {
error!(
"[OpenAIAssistant] Files Delete API response serialization error: {}",
&error
);
anyhow!(
"[OpenAIAssistant] Files Delete API response serialization error: {}",
error
)
})
.and_then(|response| match response.deleted {
true => Ok(()),
false => Err(anyhow!("[OpenAIAssistant] Failed to delete the file.")),
})
}
}