use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::database::Database;
use crate::document::{Document, SearchParams};
use crate::error::{Result, VectorDBError};
use crate::filter::Filter;
#[derive(Debug, Clone)]
pub struct Collection {
database: Database,
name: String,
}
#[derive(Debug, Serialize)]
struct UpsertRequest {
database: String,
collection: String,
documents: Vec<Document>,
#[serde(rename = "buildIndex", skip_serializing_if = "Option::is_none")]
build_index: Option<bool>,
}
#[derive(Debug, Serialize)]
struct QueryRequest {
database: String,
collection: String,
#[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
document_ids: Option<Vec<String>>,
#[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
retrieve_vector: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
offset: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
#[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
output_fields: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
sort: Option<Value>,
}
#[derive(Debug, Serialize)]
struct SearchRequest {
database: String,
collection: String,
vectors: Vec<Vec<f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<SearchParams>,
#[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
retrieve_vector: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u32>,
#[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
output_fields: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
radius: Option<f64>,
}
#[derive(Debug, Serialize)]
struct SearchByIdRequest {
database: String,
collection: String,
#[serde(rename = "documentIds")]
document_ids: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<SearchParams>,
#[serde(rename = "retrieveVector", skip_serializing_if = "Option::is_none")]
retrieve_vector: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u32>,
#[serde(rename = "outputFields", skip_serializing_if = "Option::is_none")]
output_fields: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
radius: Option<f64>,
}
#[derive(Debug, Serialize)]
struct UpdateRequest {
database: String,
collection: String,
data: Document,
#[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
document_ids: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
}
#[derive(Debug, Serialize)]
struct DeleteRequest {
database: String,
collection: String,
#[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
document_ids: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u32>,
}
#[derive(Debug, Serialize)]
struct CountRequest {
database: String,
collection: String,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiResponse<T> {
code: i32,
msg: String,
#[serde(default)]
data: Option<T>,
}
impl Collection {
pub fn new(database: Database, name: String) -> Self {
Self { database, name }
}
pub fn name(&self) -> &str {
&self.name
}
pub fn database(&self) -> &Database {
&self.database
}
pub async fn upsert(
&self,
documents: Vec<Document>,
_timeout: Option<u64>,
build_index: bool,
) -> Result<Value> {
let request = UpsertRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
documents,
build_index: Some(build_index),
};
let response = self
.database
.client()
.post("/document/upsert")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Value> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
Ok(api_response.data.unwrap_or(Value::Null))
}
pub async fn query(
&self,
document_ids: Option<Vec<String>>,
retrieve_vector: bool,
limit: Option<u32>,
offset: Option<u32>,
filter: Option<Filter>,
output_fields: Option<Vec<String>>,
sort: Option<Value>,
) -> Result<Vec<Document>> {
let request = QueryRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
document_ids,
retrieve_vector: Some(retrieve_vector),
limit,
offset,
filter: filter.map(|f| f.condition().to_string()),
output_fields,
sort,
};
let response = self
.database
.client()
.post("/document/query")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Vec<Value>> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
let documents_data = api_response.data.unwrap_or_default();
let documents: Result<Vec<Document>> = documents_data
.into_iter()
.map(|v| serde_json::from_value(v).map_err(Into::into))
.collect();
documents
}
pub async fn search(
&self,
vectors: Vec<Vec<f64>>,
filter: Option<Filter>,
params: Option<SearchParams>,
retrieve_vector: bool,
limit: u32,
output_fields: Option<Vec<String>>,
_timeout: Option<u64>,
radius: Option<f64>,
) -> Result<Vec<Vec<Document>>> {
let request = SearchRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
vectors,
filter: filter.map(|f| f.condition().to_string()),
params,
retrieve_vector: Some(retrieve_vector),
limit: Some(limit),
output_fields,
radius,
};
let response = self
.database
.client()
.post("/document/search")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Vec<Vec<Value>>> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
let results_data = api_response.data.unwrap_or_default();
let results: Result<Vec<Vec<Document>>> = results_data
.into_iter()
.map(|batch| {
batch
.into_iter()
.map(|v| serde_json::from_value(v).map_err(Into::into))
.collect()
})
.collect();
results
}
pub async fn search_by_id(
&self,
document_ids: Vec<String>,
filter: Option<Filter>,
params: Option<SearchParams>,
retrieve_vector: bool,
limit: u32,
output_fields: Option<Vec<String>>,
_timeout: Option<u64>,
radius: Option<f64>,
) -> Result<Vec<Vec<Document>>> {
let request = SearchByIdRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
document_ids,
filter: filter.map(|f| f.condition().to_string()),
params,
retrieve_vector: Some(retrieve_vector),
limit: Some(limit),
output_fields,
radius,
};
let response = self
.database
.client()
.post("/document/searchById")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Vec<Vec<Value>>> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
let results_data = api_response.data.unwrap_or_default();
let results: Result<Vec<Vec<Document>>> = results_data
.into_iter()
.map(|batch| {
batch
.into_iter()
.map(|v| serde_json::from_value(v).map_err(Into::into))
.collect()
})
.collect();
results
}
pub async fn update(
&self,
data: Document,
document_ids: Option<Vec<String>>,
filter: Option<Filter>,
) -> Result<Value> {
let request = UpdateRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
data,
document_ids,
filter: filter.map(|f| f.condition().to_string()),
};
let response = self
.database
.client()
.post("/document/update")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Value> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
Ok(api_response.data.unwrap_or(Value::Null))
}
pub async fn delete(
&self,
document_ids: Option<Vec<String>>,
filter: Option<Filter>,
limit: Option<u32>,
) -> Result<Value> {
let request = DeleteRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
document_ids,
filter: filter.map(|f| f.condition().to_string()),
limit,
};
let response = self
.database
.client()
.post("/document/delete")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Value> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
Ok(api_response.data.unwrap_or(Value::Null))
}
pub async fn count(&self, filter: Option<Filter>) -> Result<u64> {
let request = CountRequest {
database: self.database.name().to_string(),
collection: self.name.clone(),
filter: filter.map(|f| f.condition().to_string()),
};
let response = self
.database
.client()
.post("/document/count")
.await?
.json(&request)
.send()
.await?;
let api_response: ApiResponse<Value> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
let count_data = api_response.data.unwrap_or(Value::Number(0.into()));
let count: u64 = serde_json::from_value(count_data)?;
Ok(count)
}
pub async fn rebuild_index(&self) -> Result<Value> {
let path = format!(
"/index/rebuild?database={}&collection={}",
self.database.name(),
self.name
);
let response = self.database.client().post(&path).await?.send().await?;
let api_response: ApiResponse<Value> =
self.database.client().handle_response(response).await?;
if api_response.code != 0 {
return Err(VectorDBError::server_error(
api_response.code,
api_response.msg,
));
}
Ok(api_response.data.unwrap_or(Value::Null))
}
pub async fn hybrid_search(
&self,
ann_search: Vec<crate::document::AnnSearch>,
keyword_search: Option<crate::document::KeywordSearch>,
rerank: Option<crate::document::Rerank>,
limit: u32,
output_fields: Option<Vec<String>>,
_timeout: Option<u64>,
) -> Result<Vec<Document>> {
let mut request = serde_json::json!({
"database": self.database.name(),
"collection": self.name,
"search": {
"ann": ann_search,
"limit": limit
}
});
if let Some(keyword) = keyword_search {
request["search"]["match"] = serde_json::to_value(vec![keyword])?;
}
if let Some(rerank_config) = rerank {
request["search"]["rerank"] = serde_json::to_value(rerank_config)?;
}
if let Some(fields) = output_fields {
request["search"]["outputFields"] = serde_json::to_value(fields)?;
}
let response = self
.database
.client()
.post("/document/hybridSearch")
.await?
.json(&request)
.send()
.await?;
let response_text = response.text().await?;
let response_json: serde_json::Value = serde_json::from_str(&response_text)?;
if let Some(code) = response_json.get("code").and_then(|v| v.as_i64()) {
if code != 0 {
let msg = response_json
.get("msg")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error");
return Err(VectorDBError::server_error(code as i32, msg));
}
}
let default_array = Value::Array(vec![]);
let documents_data = response_json.get("data").unwrap_or(&default_array);
let documents: Result<Vec<Document>> =
serde_json::from_value(documents_data.clone()).map_err(Into::into);
documents
}
pub async fn search_by_text(
&self,
embedding_items: Vec<String>,
filter: Option<Filter>,
params: Option<SearchParams>,
retrieve_vector: bool,
limit: u32,
output_fields: Option<Vec<String>>,
_timeout: Option<u64>,
radius: Option<f64>,
) -> Result<Vec<Vec<Document>>> {
let mut request = serde_json::json!({
"database": self.database.name(),
"collection": self.name,
"search": {
"embeddingItems": embedding_items,
"retrieveVector": retrieve_vector,
"limit": limit
}
});
if let Some(f) = filter {
request["search"]["filter"] = serde_json::Value::String(f.condition().to_string());
}
if let Some(p) = params {
request["search"]["params"] = serde_json::to_value(p)?;
}
if let Some(fields) = output_fields {
request["search"]["outputFields"] = serde_json::to_value(fields)?;
}
if let Some(r) = radius {
request["search"]["radius"] =
serde_json::Value::Number(serde_json::Number::from_f64(r).unwrap());
}
let response = self
.database
.client()
.post("/document/search")
.await?
.json(&request)
.send()
.await?;
let response_text = response.text().await?;
let response_json: serde_json::Value = serde_json::from_str(&response_text)?;
if let Some(code) = response_json.get("code").and_then(|v| v.as_i64()) {
if code != 0 {
let msg = response_json
.get("msg")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error");
return Err(VectorDBError::server_error(code as i32, msg));
}
}
let default_array = Value::Array(vec![]);
let results_data = response_json.get("data").unwrap_or(&default_array);
let results: Result<Vec<Vec<Document>>> =
serde_json::from_value(results_data.clone()).map_err(Into::into);
results
}
}