use anyhow::{anyhow, Result};
use backoff::{future::retry, ExponentialBackoff};
use log::{error, info};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::time::Duration;
use crate::assistants::{OpenAIAssistantResource, OpenAIAssistantVersion};
use crate::domain::AllmsError;
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIVectorStore {
pub id: Option<String>,
pub name: String,
api_key: String,
status: OpenAIVectorStoreStatus,
debug: bool,
version: OpenAIAssistantVersion,
}
impl OpenAIVectorStore {
pub fn new(id: Option<String>, name: &str, api_key: &str) -> Self {
OpenAIVectorStore {
id,
name: name.to_string(),
api_key: api_key.to_string(),
status: OpenAIVectorStoreStatus::InProgress,
debug: false,
version: OpenAIAssistantVersion::V2,
}
}
pub fn debug(mut self) -> Self {
self.debug = !self.debug;
self
}
pub fn version(mut self, version: OpenAIAssistantVersion) -> Self {
let version = match version {
OpenAIAssistantVersion::V1 => OpenAIAssistantVersion::V2,
_ => version,
};
self.version = version;
self
}
async fn create(&mut self, file_ids: Option<Vec<String>>) -> Result<()> {
let vector_store_url = self
.version
.get_endpoint(&OpenAIAssistantResource::VectorStores);
if self.debug {
info!(
"[debug] OpenAI Vector Store Create API URL: {:#?}",
vector_store_url
);
}
let client = Client::new();
let version_headers = self.version.get_headers(&self.api_key);
let mut body = json!({
"name": self.name.clone(),
});
if let Some(ids) = file_ids {
body["file_ids"] = json!(ids.to_vec());
}
let response = client
.post(vector_store_url)
.headers(version_headers)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[allms][OpenAI][VectorStore][debug] VectorStore Create API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIVectorStoreResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_vector_store".to_string(),
error_message: format!(
"VectorStore Create API response serialization error: {}",
error
),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
self.id = Some(response_deser.id);
self.status = response_deser.status;
Ok(())
}
pub async fn upload(&mut self, file_ids: &[String]) -> Result<Self> {
if self.id.is_none() {
self.create(Some(file_ids.to_vec())).await?;
} else {
self.assign_to_store(file_ids).await?;
}
let backoff_config = ExponentialBackoff {
initial_interval: Duration::from_secs(1),
max_interval: Duration::from_secs(1),
max_elapsed_time: Some(Duration::from_secs(90)),
..Default::default()
};
retry(backoff_config.clone(), || async {
let status = self.status().await?;
if status == OpenAIVectorStoreStatus::Completed {
Ok(())
} else {
Err(backoff::Error::transient(anyhow!(
"[allms][OpenAI][Vector Store] Vector store failed to initialize in time."
)))
}
})
.await?;
retry(backoff_config, || async {
let file_count = self.file_count().await?;
if file_count.in_progress == 0 {
Ok(())
} else {
Err(backoff::Error::transient(anyhow!(
"[allms][OpenAI][Vector Store] Files processing did not complete in time."
)))
}
})
.await?;
Ok(self.clone())
}
async fn assign_to_store(&self, file_ids: &[String]) -> Result<()> {
let vs_id = if let Some(id) = &self.id {
id
} else {
return Err(anyhow!(
"[allms][OpenAI][VectorStore][debug] Unable to assign files. No ID provided."
));
};
let vector_store_resource = OpenAIAssistantResource::VectorStoreFileBatches {
vector_store_id: vs_id.to_string(),
};
let url = self.version.get_endpoint(&vector_store_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let body = json!({
"file_ids": file_ids.to_vec(),
});
let response = client
.post(&url)
.headers(version_headers)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[allms][OpenAI][VectorStore][debug] VectorStore Batch Upload API response: [{}] {:#?}",
&response_status, &response_text
);
}
serde_json::from_str::<OpenAIVectorStoreFileBatchResp>(&response_text)
.map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_vector_store".to_string(),
error_message: format!(
"VectorStore Batch Upload API response serialization error: {}",
error
),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})
.map(|_| Ok(()))?
}
pub async fn status(&self) -> Result<OpenAIVectorStoreStatus> {
let vs_id = if let Some(id) = &self.id {
id
} else {
return Err(anyhow!(
"[allms][OpenAI][VectorStore][debug] Unable to check status. No ID provided."
));
};
let vector_store_resource = OpenAIAssistantResource::VectorStore {
vector_store_id: vs_id.to_string(),
};
let url = self.version.get_endpoint(&vector_store_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client.get(&url).headers(version_headers).send().await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[allms][OpenAI][VectorStore][debug] VectorStore Status API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIVectorStoreResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_vector_store".to_string(),
error_message: format!(
"VectorStore Status API response serialization error: {}",
error
),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
Ok(response_deser.status)
}
pub async fn file_count(&self) -> Result<OpenAIVectorStoreFileCounts> {
let vs_id = if let Some(id) = &self.id {
id
} else {
return Err(anyhow!(
"[allms][OpenAI][VectorStore][debug] Unable to check status. No ID provided."
));
};
let vector_store_resource = OpenAIAssistantResource::VectorStore {
vector_store_id: vs_id.to_string(),
};
let url = self.version.get_endpoint(&vector_store_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client.get(&url).headers(version_headers).send().await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[allms][OpenAI][VectorStore][debug] VectorStore Status API response: [{}] {:#?}",
&response_status, &response_text
);
}
let response_deser: OpenAIVectorStoreResp =
serde_json::from_str(&response_text).map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_vector_store".to_string(),
error_message: format!(
"VectorStore Status API response serialization error: {}",
error
),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})?;
Ok(response_deser.file_counts)
}
pub async fn delete(&self) -> Result<()> {
let vs_id = if let Some(id) = &self.id {
id
} else {
return Err(anyhow!(
"[allms][OpenAI][VectorStore][debug] Unable to delete. No ID provided."
));
};
let vector_store_resource = OpenAIAssistantResource::VectorStore {
vector_store_id: vs_id.to_string(),
};
let url = self.version.get_endpoint(&vector_store_resource);
let version_headers = self.version.get_headers(&self.api_key);
let client = Client::new();
let response = client.delete(&url).headers(version_headers).send().await?;
let response_status = response.status();
let response_text = response.text().await?;
if self.debug {
info!(
"[allms][OpenAI][VectorStore][debug] VectorStore Delete API response: [{}] {:#?}",
&response_status, &response_text
);
}
serde_json::from_str::<OpenAIVectorStoreDeleteResp>(&response_text)
.map_err(|error| {
let error = AllmsError {
crate_name: "allms".to_string(),
module: "assistants::openai_vector_store".to_string(),
error_message: format!(
"VectorStore Delete API response serialization error: {}",
error
),
error_detail: response_text,
};
error!("{:?}", error);
anyhow!("{:?}", error)
})
.and_then(|response| match response.deleted {
true => Ok(()),
false => Err(anyhow!(
"[OpenAIAssistant] VectorStore Delete API failed to delete the store."
)),
})
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
struct OpenAIVectorStoreResp {
id: String,
name: String,
status: OpenAIVectorStoreStatus,
created_at: i64,
expires_at: Option<i64>,
last_active_at: Option<i64>,
file_counts: OpenAIVectorStoreFileCounts,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OpenAIVectorStoreFileCounts {
pub in_progress: i32,
pub completed: i32,
pub failed: i32,
pub cancelled: i32,
pub total: i32,
}
#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
pub enum OpenAIVectorStoreStatus {
#[serde(rename(deserialize = "expired", serialize = "expired"))]
Expired,
#[serde(rename(deserialize = "in_progress", serialize = "in_progress"))]
InProgress,
#[serde(rename(deserialize = "completed", serialize = "completed"))]
Completed,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
struct OpenAIVectorStoreFileBatchResp {
id: String,
vector_store_id: String,
status: OpenAIVectorStoreFileBatchStatus,
created_at: i64,
file_counts: OpenAIVectorStoreFileCounts,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub enum OpenAIVectorStoreFileBatchStatus {
#[serde(rename(deserialize = "in_progress", serialize = "in_progress"))]
InProgress,
#[serde(rename(deserialize = "completed", serialize = "completed"))]
Completed,
#[serde(rename(deserialize = "cancelled", serialize = "cancelled"))]
Cancelled,
#[serde(rename(deserialize = "failed", serialize = "failed"))]
Failed,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
struct OpenAIVectorStoreDeleteResp {
id: String,
deleted: bool,
}