use async_trait::async_trait;
use opensearch::http::request::JsonBody;
use opensearch::http::response::Response;
use opensearch::indices::{IndicesCreateParts, IndicesDeleteParts};
use opensearch::{BulkParts, SearchParts};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
pub use opensearch::auth::Credentials;
pub use opensearch::cert::CertificateValidation;
pub use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder};
pub use opensearch::OpenSearch;
use crate::{
embedding::embedder_trait::Embedder,
schemas::Document,
vectorstore::{VecStoreOptions, VectorStore, VectorStoreError},
};
pub struct Store {
pub client: OpenSearch,
pub embedder: Arc<dyn Embedder>,
pub k: i32,
pub index: String,
pub vector_field: String,
pub content_field: String,
}
impl Store {
pub async fn delete_index(&self) -> Result<Response, Box<dyn Error>> {
let response = self
.client
.indices()
.delete(IndicesDeleteParts::Index(&[&self.index]))
.send()
.await?;
let result = response.error_for_status_code().map_err(|e| Box::new(e))?;
Ok(result)
}
pub async fn create_index(&self) -> Result<Response, Box<dyn Error>> {
let body = json!({
"settings": {
"index.knn": true,
"knn.algo_param": {
"ef_search": "512"
},
},
"mappings": {
"properties": {
&self.vector_field: {
"type": "knn_vector",
"dimension": 1536,
"method": {
"engine": "faiss",
"name": "hnsw",
"space_type": "l2",
"parameters": {
"ef_construction": 512,
"m": 16
}
}
},
&self.content_field: {
"type": "text"
},
"metadata": {
"properties": {
"source": {
"type": "text",
}
}
}
}
}
});
let response = self
.client
.indices()
.create(IndicesCreateParts::Index(&self.index))
.body(body)
.send()
.await?;
let result = response.error_for_status_code().map_err(|e| Box::new(e))?;
Ok(result)
}
}
#[async_trait]
impl VectorStore for Store {
type Options = VecStoreOptions<Value>;
async fn add_documents(
&self,
docs: &[Document],
opt: &Self::Options,
) -> Result<Vec<String>, VectorStoreError> {
let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();
let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
let vectors = embedder.embed_documents(&texts).await?;
if vectors.len() != docs.len() {
return Err(VectorStoreError::InternalError(
"Number of vectors and documents do not match".to_string(),
));
}
let mut body: Vec<JsonBody<_>> = Vec::with_capacity(docs.len() * 2);
for (doc, vector) in docs.iter().zip(vectors.iter()) {
let operation = json!({"index": {}});
body.push(operation.into());
let document = json!({
&self.content_field: doc.page_content,
"metadata": doc.metadata,
&self.vector_field: vector,
});
body.push(document.into());
}
let response = self
.client
.bulk(BulkParts::Index(&self.index))
.body(body)
.send()
.await
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?
.error_for_status_code()
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
let response_body = response
.json::<Value>()
.await
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
let ids = response_body["items"]
.as_array()
.unwrap()
.iter()
.map(|item| serde_json::from_value::<String>(item["index"]["_id"].clone()).unwrap())
.collect::<Vec<_>>();
Ok(ids)
}
async fn similarity_search(
&self,
query: &str,
limit: usize,
opt: &Self::Options,
) -> Result<Vec<Document>, VectorStoreError> {
let query_vector = self.embedder.embed_query(query).await?;
let query = build_similarity_search_query(
query_vector,
&self.vector_field,
limit,
self.k,
opt.filters.clone(),
);
let response = self
.client
.search(SearchParts::Index(&[&self.index]))
.from(0)
.size(3)
.body(query)
.send()
.await
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
let response_body = response
.json::<Value>()
.await
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
let aoss_documents = response_body["hits"]["hits"]
.as_array()
.unwrap()
.iter()
.map(|raw_value| {
serde_json::from_value::<HashMap<String, Value>>(raw_value.clone()).unwrap()
})
.collect::<Vec<_>>();
let documents = aoss_documents
.into_iter()
.map(|item| {
let page_content =
serde_json::from_value::<String>(item["_source"][&self.content_field].clone())
.unwrap();
let metadata = serde_json::from_value::<HashMap<String, Value>>(
item["_source"]["metadata"].clone(),
)
.unwrap();
let score = serde_json::from_value::<f64>(item["_score"].clone()).unwrap();
Document {
page_content,
metadata,
score,
}
})
.collect();
Ok(documents)
}
async fn delete(
&self,
ids: &[String],
_opt: &VecStoreOptions<Value>,
) -> Result<(), VectorStoreError> {
if ids.is_empty() {
return Ok(());
}
let body: Vec<JsonBody<_>> = ids
.iter()
.map(|id| json!({"delete": {"_id": id}}).into())
.collect();
self.client
.bulk(BulkParts::Index(&self.index))
.body(body)
.send()
.await
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?
.error_for_status_code()
.map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
Ok(())
}
}
fn build_similarity_search_query(
embedded_query: Vec<f64>,
vector_field: &str,
size: usize,
k: i32,
maybe_filter: Option<Value>,
) -> Value {
match maybe_filter {
Some(filter) => {
json!({
"size": size,
"query": {
"knn": {
vector_field: {
"vector": embedded_query,
"k": k,
"filter": filter,
}
}
}
})
}
None => {
json!({
"size": size,
"query": {
"knn": {
vector_field: {
"vector": embedded_query,
"k": k,
}
}
}
})
}
}
}