use std::collections::HashMap;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, warn};
use super::base::{Filters, OutputData, VectorStoreBase};
use crate::config::ChromaConfig;
use crate::error::{NeomemxError, Result};
const SEARCH_INCLUDE: &[&str] = &["metadatas", "distances"];
const GET_INCLUDE: &[&str] = &["metadatas"];
pub struct ChromaDB {
config: ChromaConfig,
client: Client,
base_url: String,
tenant: String,
database: String,
collection_id: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CollectionResponse {
id: String,
name: String,
}
#[derive(Debug, Serialize)]
struct AddRequest {
ids: Vec<String>,
embeddings: Vec<Vec<f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
metadatas: Option<Vec<HashMap<String, serde_json::Value>>>,
}
#[derive(Debug, Serialize)]
struct QueryRequest {
query_embeddings: Vec<Vec<f32>>,
n_results: usize,
#[serde(skip_serializing_if = "Option::is_none")]
r#where: Option<serde_json::Value>,
include: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct QueryResponse {
ids: Vec<Vec<String>>,
distances: Option<Vec<Vec<f32>>>,
metadatas: Option<Vec<Vec<HashMap<String, serde_json::Value>>>>,
}
#[derive(Debug, Serialize)]
struct GetRequest {
ids: Vec<String>,
include: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct GetResponse {
ids: Vec<String>,
metadatas: Option<Vec<HashMap<String, serde_json::Value>>>,
}
#[derive(Debug, Serialize)]
struct UpdateRequest {
ids: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
embeddings: Option<Vec<Vec<f32>>>,
#[serde(skip_serializing_if = "Option::is_none")]
metadatas: Option<Vec<HashMap<String, serde_json::Value>>>,
}
#[derive(Debug, Serialize)]
struct DeleteRequest {
ids: Vec<String>,
}
impl ChromaDB {
pub async fn new(config: ChromaConfig) -> Result<Self> {
let base_url = config
.get_base_url()
.unwrap_or("http://localhost:8000".to_string());
let client = Client::builder().no_proxy().build().map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to create HTTP client: {}", e))
})?;
let mut chroma = Self {
config,
client,
base_url,
tenant: "default_tenant".to_string(),
database: "default_database".to_string(),
collection_id: None,
};
chroma.ensure_collection().await?;
Ok(chroma)
}
fn collections_base_url(&self) -> String {
format!(
"{}/api/v2/tenants/{}/databases/{}/collections",
self.base_url, self.tenant, self.database
)
}
async fn ensure_collection(&mut self) -> Result<()> {
let get_url = format!(
"{}/{}",
self.collections_base_url(),
self.config.collection_name
);
debug!(
"Checking for existing collection: {}",
self.config.collection_name
);
let get_response = self.client.get(&get_url).send().await;
if let Ok(response) = get_response {
if response.status().is_success() {
if let Ok(collection) = response.json::<CollectionResponse>().await {
info!(
"Found existing collection: {} ({})",
collection.name, collection.id
);
self.collection_id = Some(collection.id);
return Ok(());
}
}
}
let create_url = self.collections_base_url();
debug!("Creating collection: {}", self.config.collection_name);
let request = serde_json::json!({
"name": self.config.collection_name
});
let response = self
.client
.post(&create_url)
.json(&request)
.send()
.await
.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to connect to ChromaDB: {}", e))
})?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to create collection: {}",
body
)));
}
let collection: CollectionResponse = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse collection response: {}", e))
})?;
info!(
"Created collection: {} ({})",
collection.name, collection.id
);
self.collection_id = Some(collection.id);
Ok(())
}
fn collection_url(&self) -> Result<String> {
let collection_id = self.collection_id.as_ref().ok_or_else(|| {
NeomemxError::VectorStoreError("Collection not initialized".to_string())
})?;
Ok(format!("{}/{}", self.collections_base_url(), collection_id))
}
fn generate_where_clause(filters: &Filters) -> serde_json::Value {
let mut processed: Vec<serde_json::Value> = Vec::new();
for (key, value) in filters {
if key == "$or" || key == "$and" || key == "$not" {
processed.push(serde_json::json!({ key: value }));
continue;
}
if let serde_json::Value::Object(ops) = value {
for (op, val) in ops {
let chroma_op = match op.as_str() {
"eq" => "$eq",
"ne" => "$ne",
"gt" => "$gt",
"gte" => "$gte",
"lt" => "$lt",
"lte" => "$lte",
"in" => "$in",
"nin" => "$nin",
_ => "$eq",
};
processed.push(serde_json::json!({ key: { chroma_op: val } }));
}
} else if value.as_str() != Some("*") {
processed.push(serde_json::json!({ key: { "$eq": value } }));
}
}
if processed.is_empty() {
serde_json::Value::Null
} else if processed.len() == 1 {
processed.into_iter().next().unwrap()
} else {
serde_json::json!({ "$and": processed })
}
}
fn parse_query_response(response: QueryResponse) -> Vec<OutputData> {
let ids = response.ids.into_iter().next().unwrap_or_default();
let distances = response
.distances
.and_then(|d| d.into_iter().next())
.unwrap_or_default();
let metadatas = response
.metadatas
.and_then(|m| m.into_iter().next())
.unwrap_or_default();
ids.into_iter()
.enumerate()
.map(|(i, id)| {
OutputData::new(
id,
distances.get(i).copied(),
metadatas.get(i).cloned().unwrap_or_default(),
)
})
.collect()
}
}
#[async_trait]
impl VectorStoreBase for ChromaDB {
async fn create_collection(&self, name: &str) -> Result<()> {
let request = serde_json::json!({
"name": name
});
let url = self.collections_base_url();
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to create collection: {}",
body
)));
}
Ok(())
}
async fn insert(
&self,
vectors: Vec<Vec<f32>>,
payloads: Option<Vec<HashMap<String, serde_json::Value>>>,
ids: Option<Vec<String>>,
) -> Result<()> {
let ids = ids.unwrap_or_else(|| {
(0..vectors.len())
.map(|_| uuid::Uuid::new_v4().to_string())
.collect()
});
info!("Inserting {} vectors into collection", vectors.len());
let request = AddRequest {
ids,
embeddings: vectors,
metadatas: payloads,
};
let url = format!("{}/add", self.collection_url()?);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
error!("Failed to insert vectors: {}", body);
return Err(NeomemxError::VectorStoreError(format!(
"Failed to insert vectors: {}",
body
)));
}
Ok(())
}
async fn search(
&self,
_query: &str,
vectors: &[f32],
limit: usize,
filters: Option<Filters>,
) -> Result<Vec<OutputData>> {
let where_clause = filters.as_ref().map(Self::generate_where_clause);
let request = QueryRequest {
query_embeddings: vec![vectors.to_vec()],
n_results: limit,
r#where: where_clause.filter(|v| !v.is_null()),
include: SEARCH_INCLUDE.iter().map(|s| (*s).to_string()).collect(),
};
let url = format!("{}/query", self.collection_url()?);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to search vectors: {}",
body
)));
}
let query_response: QueryResponse = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse query response: {}", e))
})?;
Ok(Self::parse_query_response(query_response))
}
async fn delete(&self, vector_id: &str) -> Result<()> {
self.delete_batch(&[vector_id.to_string()]).await
}
async fn delete_batch(&self, vector_ids: &[String]) -> Result<()> {
if vector_ids.is_empty() {
return Ok(());
}
let request = DeleteRequest {
ids: vector_ids.to_vec(),
};
let url = format!("{}/delete", self.collection_url()?);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to delete vectors: {}",
body
)));
}
Ok(())
}
async fn update(
&self,
vector_id: &str,
vector: Option<Vec<f32>>,
payload: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
let request = UpdateRequest {
ids: vec![vector_id.to_string()],
embeddings: vector.map(|v| vec![v]),
metadatas: payload.map(|p| vec![p]),
};
let url = format!("{}/update", self.collection_url()?);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to update vector: {}",
body
)));
}
Ok(())
}
async fn get(&self, vector_id: &str) -> Result<Option<OutputData>> {
let results = self.get_batch(&[vector_id.to_string()]).await?;
Ok(results.into_iter().next())
}
async fn get_batch(&self, vector_ids: &[String]) -> Result<Vec<OutputData>> {
if vector_ids.is_empty() {
return Ok(Vec::new());
}
let request = GetRequest {
ids: vector_ids.to_vec(),
include: GET_INCLUDE.iter().map(|s| (*s).to_string()).collect(),
};
let url = format!("{}/get", self.collection_url()?);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to get vectors: {}",
body
)));
}
let get_response: GetResponse = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse get response: {}", e))
})?;
let metadatas = get_response.metadatas.unwrap_or_default();
Ok(get_response
.ids
.into_iter()
.enumerate()
.map(|(i, id)| OutputData::new(id, None, metadatas.get(i).cloned().unwrap_or_default()))
.collect())
}
async fn list_collections(&self) -> Result<Vec<String>> {
let url = self.collections_base_url();
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to list collections: {}",
body
)));
}
let collections: Vec<CollectionResponse> = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse collections response: {}", e))
})?;
Ok(collections.into_iter().map(|c| c.name).collect())
}
async fn delete_collection(&self) -> Result<()> {
let url = self.collection_url()?;
let response = self.client.delete(&url).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to delete collection: {}",
body
)));
}
warn!("Deleted collection: {}", self.config.collection_name);
Ok(())
}
async fn collection_info(&self) -> Result<serde_json::Value> {
let url = self.collection_url()?;
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to get collection info: {}",
body
)));
}
let info: serde_json::Value = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse collection info: {}", e))
})?;
Ok(info)
}
async fn list(&self, filters: Option<Filters>, limit: usize) -> Result<Vec<OutputData>> {
let where_clause = filters.as_ref().map(Self::generate_where_clause);
let mut request_body = serde_json::json!({
"include": ["metadatas"],
"limit": limit,
});
if let Some(where_val) = where_clause.filter(|v| !v.is_null()) {
request_body["where"] = where_val;
}
let url = format!("{}/get", self.collection_url()?);
let response = self.client.post(&url).json(&request_body).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to list vectors: {}",
body
)));
}
let get_response: GetResponse = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse list response: {}", e))
})?;
let metadatas = get_response.metadatas.unwrap_or_default();
Ok(get_response
.ids
.into_iter()
.enumerate()
.map(|(i, id)| OutputData::new(id, None, metadatas.get(i).cloned().unwrap_or_default()))
.collect())
}
async fn reset(&self) -> Result<()> {
warn!("Resetting collection: {}", self.config.collection_name);
self.delete_collection().await?;
self.create_collection(&self.config.collection_name).await?;
Ok(())
}
}