use std::sync::Arc;
use chroma_api_types::ForkCollectionPayload;
use chroma_types::{
plan::{ReadLevel, SearchPayload},
AddCollectionRecordsRequest, AddCollectionRecordsResponse, Collection, CollectionUuid,
DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, GetRequest, GetResponse,
IncludeList, IndexStatusResponse, Metadata, QueryRequest, QueryResponse, Schema, SearchRequest,
SearchResponse, UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse,
UpdateMetadata, UpsertCollectionRecordsRequest, UpsertCollectionRecordsResponse, Where,
};
use reqwest::Method;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::{client::ChromaHttpClientError, ChromaHttpClient};
#[derive(Deserialize)]
struct ForkCountResponse {
count: usize,
}
#[derive(Clone)]
pub struct ChromaCollection {
pub(crate) client: ChromaHttpClient,
pub(crate) collection: Arc<Collection>,
}
impl std::fmt::Debug for ChromaCollection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChromaCollection")
.field("database", &self.collection.database)
.field("tenant", &self.collection.tenant)
.field("name", &self.collection.name)
.field("collection_id", &self.collection.collection_id)
.field("version", &self.collection.version)
.finish()
}
}
impl ChromaCollection {
pub fn database(&self) -> &str {
&self.collection.database
}
pub fn metadata(&self) -> &Option<Metadata> {
&self.collection.metadata
}
pub fn schema(&self) -> &Option<Schema> {
&self.collection.schema
}
pub fn tenant(&self) -> &str {
&self.collection.tenant
}
pub fn name(&self) -> &str {
&self.collection.name
}
pub fn id(&self) -> CollectionUuid {
self.collection.collection_id
}
pub fn version(&self) -> i32 {
self.collection.version
}
pub async fn count(&self) -> Result<u32, ChromaHttpClientError> {
self.count_with_options(ReadLevel::IndexAndWal).await
}
pub async fn count_with_options(
&self,
read_level: ReadLevel,
) -> Result<u32, ChromaHttpClientError> {
#[derive(Serialize)]
struct CountQueryParams {
read_level: ReadLevel,
}
self.send_with_query::<(), CountQueryParams, u32>(
"count",
"count",
Method::GET,
None,
Some(CountQueryParams { read_level }),
)
.await
}
pub async fn get_indexing_status(&self) -> Result<IndexStatusResponse, ChromaHttpClientError> {
self.send::<(), IndexStatusResponse>(
"indexing_status",
"indexing_status",
Method::GET,
None,
)
.await
}
pub async fn modify(
&mut self,
new_name: Option<impl AsRef<str>>,
new_metadata: Option<Metadata>,
) -> Result<(), ChromaHttpClientError> {
self.send::<_, serde_json::Value>(
"modify",
"",
Method::PUT,
Some(serde_json::json!({
"new_name": new_name.as_ref().map(|s| s.as_ref()),
"new_metadata": new_metadata,
})),
)
.await?;
let mut updated_collection = (*self.collection).clone();
if let Some(name) = new_name {
updated_collection.name = name.as_ref().to_string();
}
if let Some(metadata) = new_metadata {
updated_collection.metadata = Some(metadata);
}
self.collection = Arc::new(updated_collection);
Ok(())
}
pub async fn get(
&self,
ids: Option<Vec<String>>,
r#where: Option<Where>,
limit: Option<u32>,
offset: Option<u32>,
include: Option<IncludeList>,
) -> Result<GetResponse, ChromaHttpClientError> {
let request = GetRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
ids,
r#where,
limit,
offset.unwrap_or_default(),
include.unwrap_or_else(IncludeList::default_get),
)?;
let request = request.into_payload()?;
self.send("get", "get", Method::POST, Some(request)).await
}
pub async fn query(
&self,
query_embeddings: Vec<Vec<f32>>,
n_results: Option<u32>,
r#where: Option<Where>,
ids: Option<Vec<String>>,
include: Option<IncludeList>,
) -> Result<QueryResponse, ChromaHttpClientError> {
let request = QueryRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
ids,
r#where,
query_embeddings,
n_results.unwrap_or(10),
include.unwrap_or_else(IncludeList::default_query),
)?;
let request = request.into_payload()?;
self.send("query", "query", Method::POST, Some(request))
.await
}
pub async fn search(
&self,
searches: Vec<SearchPayload>,
) -> Result<SearchResponse, ChromaHttpClientError> {
self.search_with_options(searches, ReadLevel::IndexAndWal)
.await
}
pub async fn search_with_options(
&self,
searches: Vec<SearchPayload>,
read_level: ReadLevel,
) -> Result<SearchResponse, ChromaHttpClientError> {
let request = SearchRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
searches,
read_level,
)?;
let request = request.into_payload();
self.send("search", "search", Method::POST, Some(request))
.await
}
pub async fn add(
&self,
ids: Vec<String>,
embeddings: Vec<Vec<f32>>,
documents: Option<Vec<Option<String>>>,
uris: Option<Vec<Option<String>>>,
metadatas: Option<Vec<Option<Metadata>>>,
) -> Result<AddCollectionRecordsResponse, ChromaHttpClientError> {
let request = AddCollectionRecordsRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
ids,
embeddings,
documents,
uris,
metadatas,
)?;
let request = request.into_payload();
self.send("add", "add", Method::POST, Some(request)).await
}
pub async fn update(
&self,
ids: Vec<String>,
embeddings: Option<Vec<Option<Vec<f32>>>>,
documents: Option<Vec<Option<String>>>,
uris: Option<Vec<Option<String>>>,
metadatas: Option<Vec<Option<UpdateMetadata>>>,
) -> Result<UpdateCollectionRecordsResponse, ChromaHttpClientError> {
let request = UpdateCollectionRecordsRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
ids,
embeddings,
documents,
uris,
metadatas,
)?;
let request = request.into_payload();
self.send("update", "update", Method::POST, Some(request))
.await
}
pub async fn upsert(
&self,
ids: Vec<String>,
embeddings: Vec<Vec<f32>>,
documents: Option<Vec<Option<String>>>,
uris: Option<Vec<Option<String>>>,
metadatas: Option<Vec<Option<UpdateMetadata>>>,
) -> Result<UpsertCollectionRecordsResponse, ChromaHttpClientError> {
let request = UpsertCollectionRecordsRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
ids,
embeddings,
documents,
uris,
metadatas,
)?;
let request = request.into_payload();
self.send("upsert", "upsert", Method::POST, Some(request))
.await
}
pub async fn delete(
&self,
ids: Option<Vec<String>>,
r#where: Option<Where>,
limit: Option<u32>,
) -> Result<DeleteCollectionRecordsResponse, ChromaHttpClientError> {
let request = DeleteCollectionRecordsRequest::try_new(
self.collection.tenant.clone(),
self.collection.database.clone(),
self.collection.collection_id,
ids,
r#where,
limit,
)?;
let request = request.into_payload()?;
self.send("delete", "delete", Method::POST, Some(request))
.await
}
pub async fn fork(
&self,
new_name: impl Into<String>,
) -> Result<ChromaCollection, ChromaHttpClientError> {
let request = ForkCollectionPayload {
new_name: new_name.into(),
};
let collection: Collection = self
.send("fork", "fork", Method::POST, Some(request))
.await?;
Ok(ChromaCollection {
client: self.client.clone(),
collection: Arc::new(collection),
})
}
pub async fn fork_count(&self) -> Result<usize, ChromaHttpClientError> {
let response: ForkCountResponse = self
.send::<(), _>("fork_count", "fork_count", Method::GET, None)
.await?;
Ok(response.count)
}
async fn send<Body: Serialize, Response: DeserializeOwned>(
&self,
operation: &str,
path: &str,
method: Method,
body: Option<Body>,
) -> Result<Response, ChromaHttpClientError> {
self.send_with_query::<Body, (), Response>(operation, path, method, body, None::<()>)
.await
}
async fn send_with_query<
Body: Serialize,
QueryParams: Serialize,
Response: DeserializeOwned,
>(
&self,
operation: &str,
path: &str,
method: Method,
body: Option<Body>,
query_params: Option<QueryParams>,
) -> Result<Response, ChromaHttpClientError> {
let operation_name = format!("collection_{operation}");
let path = format!(
"/api/v2/tenants/{}/databases/{}/collections/{}/{}",
self.collection.tenant, self.collection.database, self.collection.collection_id, path
);
let path = path.trim_end_matches("/");
self.client
.send(&operation_name, method, path, body, query_params)
.await
}
}
#[cfg(test)]
mod tests {
use crate::tests::{unique_collection_name, with_client};
use chroma_types::operator::{Key, QueryVector, RankExpr};
use chroma_types::plan::{ReadLevel, SearchPayload};
use chroma_types::{
Include, IncludeList, Metadata, MetadataComparison, MetadataExpression, MetadataValue,
PrimitiveOperator, UpdateMetadata, UpdateMetadataValue, Where,
};
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_accessor_methods() {
with_client(|mut client| async move {
let collection = client.new_collection("test_accessors").await;
assert!(!collection.database().is_empty());
assert_eq!(collection.metadata(), &None);
assert!(collection.schema().is_some());
assert!(!collection.tenant().is_empty());
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_count_empty_collection() {
with_client(|mut client| async move {
let collection = client.new_collection("test_count_empty").await;
let count = collection.count().await.unwrap();
println!("Empty collection count: {}", count);
assert_eq!(count, 0);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_get_indexing_status() {
with_client(|mut client| async move {
let collection = client.new_collection("test_indexing_status").await;
let status = collection.get_indexing_status().await.unwrap();
println!("Indexing status: {:?}", status);
assert_eq!(status.total_ops, 0);
assert_eq!(status.num_indexed_ops, 0);
assert_eq!(status.num_unindexed_ops, 0);
collection
.add(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
None,
None,
None,
)
.await
.unwrap();
let status = collection.get_indexing_status().await.unwrap();
println!("Indexing status after add: {:?}", status);
assert_eq!(status.total_ops, 2);
assert!(status.op_indexing_progress >= 0.0 && status.op_indexing_progress <= 1.0);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_add_single_record() {
with_client(|mut client| async move {
let collection = client.new_collection("test_add_single").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("document1".to_string())]),
None,
None,
)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Collection count after add: {}", count);
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_add_multiple_records() {
with_client(|mut client| async move {
let collection = client.new_collection("test_add_multiple").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
],
Some(vec![
Some("first document".to_string()),
Some("second document".to_string()),
Some("third document".to_string()),
]),
None,
None,
)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Collection count after adding multiple: {}", count);
assert_eq!(count, 3);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_add_with_metadata() {
with_client(|mut client| async move {
let collection = client.new_collection("test_add_metadata").await;
let mut metadata = Metadata::new();
metadata.insert("category".to_string(), "test".into());
metadata.insert("version".to_string(), 1.into());
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("document with metadata".to_string())]),
None,
Some(vec![Some(metadata)]),
)
.await
.unwrap();
let count = collection.count().await.unwrap();
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_add_with_uris() {
with_client(|mut client| async move {
let collection = client.new_collection("test_add_uris").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("document with uri".to_string())]),
Some(vec![Some("https://example.com/doc1".to_string())]),
None,
)
.await
.unwrap();
let count = collection.count().await.unwrap();
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_get_all_records() {
with_client(|mut client| async move {
let collection = client.new_collection("test_get_all").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
Some(vec![Some("first".to_string()), Some("second".to_string())]),
None,
None,
)
.await
.unwrap();
let response = collection.get(None, None, None, None, None).await.unwrap();
println!("Get all response: {:?}", response);
assert_eq!(response.ids.len(), 2);
assert!(response.ids.contains(&"id1".to_string()));
assert!(response.ids.contains(&"id2".to_string()));
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_get_by_ids() {
with_client(|mut client| async move {
let collection = client.new_collection("test_get_by_ids").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
],
None,
None,
None,
)
.await
.unwrap();
let response = collection
.get(
Some(vec!["id1".to_string(), "id3".to_string()]),
None,
None,
None,
None,
)
.await
.unwrap();
println!("Get by ids response: {:?}", response);
assert_eq!(response.ids.len(), 2);
assert!(response.ids.contains(&"id1".to_string()));
assert!(response.ids.contains(&"id3".to_string()));
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_get_with_limit_and_offset() {
with_client(|mut client| async move {
let collection = client.new_collection("test_get_limit_offset").await;
collection
.add(
vec![
"id1".to_string(),
"id2".to_string(),
"id3".to_string(),
"id4".to_string(),
],
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
vec![10.0, 11.0, 12.0],
],
None,
None,
None,
)
.await
.unwrap();
let response = collection
.get(None, None, Some(2), Some(1), None)
.await
.unwrap();
println!("Get with limit and offset response: {:?}", response);
assert_eq!(response.ids.len(), 2);
assert!(!response.ids.is_empty());
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_get_with_where_clause() {
with_client(|mut client| async move {
let collection = client.new_collection("test_get_where").await;
let mut metadata1 = Metadata::new();
metadata1.insert("category".to_string(), "a".into());
let mut metadata2 = Metadata::new();
metadata2.insert("category".to_string(), "b".into());
collection
.add(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
None,
None,
Some(vec![Some(metadata1), Some(metadata2)]),
)
.await
.unwrap();
let where_clause = Where::Metadata(MetadataExpression {
key: "category".to_string(),
comparison: MetadataComparison::Primitive(
PrimitiveOperator::Equal,
MetadataValue::Str("a".to_string()),
),
});
let response = collection
.get(None, Some(where_clause), None, None, None)
.await
.unwrap();
println!("Get with where clause response: {:?}", response);
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0], "id1");
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_get_with_include_list() {
with_client(|mut client| async move {
let collection = client.new_collection("test_get_include").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("test document".to_string())]),
None,
None,
)
.await
.unwrap();
let include = IncludeList(vec![
Include::Document,
Include::Embedding,
Include::Metadata,
]);
let response = collection
.get(None, None, None, None, Some(include))
.await
.unwrap();
println!("Get with include list response: {:?}", response);
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0], "id1");
assert!(response.documents.is_some());
assert_eq!(
response.documents.as_ref().unwrap()[0],
Some("test document".to_string())
);
assert!(response.embeddings.is_some());
assert_eq!(
response.embeddings.as_ref().unwrap()[0],
vec![1.0, 2.0, 3.0]
);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_query_basic() {
with_client(|mut client| async move {
let collection = client.new_collection("test_query_basic").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![1.1, 2.1, 3.1],
vec![10.0, 20.0, 30.0],
],
Some(vec![
Some("first".to_string()),
Some("second".to_string()),
Some("third".to_string()),
]),
None,
None,
)
.await
.unwrap();
let response = collection
.query(vec![vec![1.0, 2.0, 3.0]], None, None, None, None)
.await
.unwrap();
println!("Query basic response: {:?}", response);
assert_eq!(response.ids.len(), 1);
assert!(!response.ids[0].is_empty());
assert!(response.ids[0].contains(&"id1".to_string()));
assert!(response.distances.is_some());
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_query_with_n_results() {
with_client(|mut client| async move {
let collection = client.new_collection("test_query_n_results").await;
collection
.add(
vec![
"id1".to_string(),
"id2".to_string(),
"id3".to_string(),
"id4".to_string(),
"id5".to_string(),
],
vec![
vec![1.0, 2.0, 3.0],
vec![1.1, 2.1, 3.1],
vec![1.2, 2.2, 3.2],
vec![1.3, 2.3, 3.3],
vec![1.4, 2.4, 3.4],
],
None,
None,
None,
)
.await
.unwrap();
let response = collection
.query(vec![vec![1.0, 2.0, 3.0]], Some(3), None, None, None)
.await
.unwrap();
println!("Query with n_results response: {:?}", response);
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0].len(), 3);
assert!(response.distances.is_some());
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_query_with_where_clause() {
with_client(|mut client| async move {
let collection = client.new_collection("test_query_where").await;
let mut metadata1 = Metadata::new();
metadata1.insert("category".to_string(), "a".into());
let mut metadata2 = Metadata::new();
metadata2.insert("category".to_string(), "b".into());
collection
.add(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![1.0, 2.0, 3.0], vec![1.1, 2.1, 3.1]],
None,
None,
Some(vec![Some(metadata1), Some(metadata2)]),
)
.await
.unwrap();
let where_clause = Where::Metadata(MetadataExpression {
key: "category".to_string(),
comparison: MetadataComparison::Primitive(
PrimitiveOperator::Equal,
MetadataValue::Str("a".to_string()),
),
});
let response = collection
.query(
vec![vec![1.0, 2.0, 3.0]],
None,
Some(where_clause),
None,
None,
)
.await
.unwrap();
println!("Query with where clause response: {:?}", response);
assert_eq!(response.ids.len(), 1);
assert_eq!(response.ids[0].len(), 1);
assert_eq!(response.ids[0][0], "id1");
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_query_multiple_embeddings() {
with_client(|mut client| async move {
let collection = client.new_collection("test_query_multiple").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
],
None,
None,
None,
)
.await
.unwrap();
let response = collection
.query(
vec![vec![1.0, 2.0, 3.0], vec![7.0, 8.0, 9.0]],
Some(1),
None,
None,
None,
)
.await
.unwrap();
println!("Query multiple embeddings response: {:?}", response);
assert_eq!(response.ids.len(), 2);
assert_eq!(response.ids[0].len(), 1);
assert_eq!(response.ids[1].len(), 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_search_with_read_levels() {
with_client(|mut client| async move {
let collection = client.new_collection("test_search_read_level").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![1.1, 2.1, 3.1],
vec![0.9, 1.9, 2.9],
],
Some(vec![
Some("first".to_string()),
Some("second".to_string()),
Some("third".to_string()),
]),
None,
None,
)
.await
.unwrap();
let search = SearchPayload::default()
.rank(RankExpr::Knn {
query: QueryVector::Dense(vec![1.0, 2.0, 3.0]),
key: Key::Embedding,
limit: 10,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([Key::Document, Key::Score]);
let index_and_wal = collection
.search_with_options(vec![search.clone()], ReadLevel::IndexAndWal)
.await
.unwrap();
assert_eq!(index_and_wal.ids.len(), 1);
assert!(!index_and_wal.ids[0].is_empty());
assert_eq!(index_and_wal.ids[0].len(), 3);
assert!(index_and_wal.documents[0].is_some());
assert!(index_and_wal.scores[0].is_some());
let index_only = collection
.search_with_options(vec![search.clone()], ReadLevel::IndexOnly)
.await
.unwrap();
assert_eq!(index_only.ids.len(), 1);
assert!(index_only.documents[0].is_some());
assert!(index_only.scores[0].is_some());
let bounded = collection
.search_with_options(vec![search], ReadLevel::IndexAndBoundedWal)
.await
.unwrap();
assert_eq!(bounded.ids.len(), 1);
assert!(bounded.documents[0].is_some());
assert!(bounded.scores[0].is_some());
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_count_with_read_levels() {
with_client(|mut client| async move {
let collection = client.new_collection("test_count_read_level").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![1.1, 2.1, 3.1],
vec![0.9, 1.9, 2.9],
],
None,
None,
None,
)
.await
.unwrap();
let count = collection
.count_with_options(ReadLevel::IndexAndWal)
.await
.unwrap();
assert_eq!(count, 3);
let count = collection
.count_with_options(ReadLevel::IndexOnly)
.await
.unwrap();
assert!(count <= 3);
let count = collection
.count_with_options(ReadLevel::IndexAndBoundedWal)
.await
.unwrap();
assert!(count <= 3);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_update_embeddings() {
with_client(|mut client| async move {
let collection = client.new_collection("test_update_embeddings").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("original".to_string())]),
None,
None,
)
.await
.unwrap();
collection
.update(
vec!["id1".to_string()],
Some(vec![Some(vec![4.0, 5.0, 6.0])]),
None,
None,
None,
)
.await
.unwrap();
let get_response = collection
.get(
Some(vec!["id1".to_string()]),
None,
None,
None,
Some(IncludeList(vec![Include::Embedding])),
)
.await
.unwrap();
println!("Get after update response: {:?}", get_response);
assert!(get_response.embeddings.is_some());
assert_eq!(
get_response.embeddings.as_ref().unwrap()[0],
vec![4.0, 5.0, 6.0]
);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_update_documents() {
with_client(|mut client| async move {
let collection = client.new_collection("test_update_documents").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("original".to_string())]),
None,
None,
)
.await
.unwrap();
collection
.update(
vec!["id1".to_string()],
None,
Some(vec![Some("updated document".to_string())]),
None,
None,
)
.await
.unwrap();
let get_response = collection
.get(
Some(vec!["id1".to_string()]),
None,
None,
None,
Some(IncludeList(vec![Include::Document])),
)
.await
.unwrap();
println!("Get after update response: {:?}", get_response);
assert!(get_response.documents.is_some());
assert_eq!(
get_response.documents.as_ref().unwrap()[0],
Some("updated document".to_string())
);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_update_metadata() {
with_client(|mut client| async move {
let collection = client.new_collection("test_update_metadata").await;
let mut original_metadata = Metadata::new();
original_metadata.insert("version".to_string(), 1.into());
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
None,
None,
Some(vec![Some(original_metadata)]),
)
.await
.unwrap();
let mut updated_metadata = UpdateMetadata::new();
updated_metadata.insert("version".to_string(), UpdateMetadataValue::Int(2));
updated_metadata.insert(
"new_field".to_string(),
UpdateMetadataValue::Str("test".to_string()),
);
collection
.update(
vec!["id1".to_string()],
None,
None,
None,
Some(vec![Some(updated_metadata)]),
)
.await
.unwrap();
let get_response = collection
.get(
Some(vec!["id1".to_string()]),
None,
None,
None,
Some(IncludeList(vec![Include::Metadata])),
)
.await
.unwrap();
println!("Get after update response: {:?}", get_response);
assert!(get_response.metadatas.is_some());
let metadata = get_response.metadatas.as_ref().unwrap()[0]
.as_ref()
.unwrap();
assert_eq!(metadata.get("version"), Some(&MetadataValue::Int(2)));
assert_eq!(
metadata.get("new_field"),
Some(&MetadataValue::Str("test".to_string()))
);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_upsert_insert_new() {
with_client(|mut client| async move {
let collection = client.new_collection("test_upsert_insert").await;
collection
.upsert(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("new document".to_string())]),
None,
None,
)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Count after upsert insert: {}", count);
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_upsert_update_existing() {
with_client(|mut client| async move {
let collection = client.new_collection("test_upsert_update").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("original".to_string())]),
None,
None,
)
.await
.unwrap();
collection
.upsert(
vec!["id1".to_string()],
vec![vec![4.0, 5.0, 6.0]],
Some(vec![Some("updated via upsert".to_string())]),
None,
None,
)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Count after upsert update: {}", count);
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_upsert_mixed() {
with_client(|mut client| async move {
let collection = client.new_collection("test_upsert_mixed").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("existing".to_string())]),
None,
None,
)
.await
.unwrap();
collection
.upsert(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![4.0, 5.0, 6.0], vec![7.0, 8.0, 9.0]],
Some(vec![Some("updated".to_string()), Some("new".to_string())]),
None,
None,
)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Count after upsert mixed: {}", count);
assert_eq!(count, 2);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_delete_by_ids() {
with_client(|mut client| async move {
let collection = client.new_collection("test_delete_by_ids").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string(), "id3".to_string()],
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
],
None,
None,
None,
)
.await
.unwrap();
collection
.delete(Some(vec!["id1".to_string(), "id3".to_string()]), None, None)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Count after delete: {}", count);
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_delete_by_where() {
with_client(|mut client| async move {
let collection = client.new_collection("test_delete_by_where").await;
let mut metadata1 = Metadata::new();
metadata1.insert("category".to_string(), "a".into());
let mut metadata2 = Metadata::new();
metadata2.insert("category".to_string(), "b".into());
collection
.add(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
None,
None,
Some(vec![Some(metadata1), Some(metadata2)]),
)
.await
.unwrap();
let where_clause = Where::Metadata(MetadataExpression {
key: "category".to_string(),
comparison: MetadataComparison::Primitive(
PrimitiveOperator::Equal,
MetadataValue::Str("a".to_string()),
),
});
collection
.delete(None, Some(where_clause), None)
.await
.unwrap();
let count = collection.count().await.unwrap();
println!("Count after delete: {}", count);
assert_eq!(count, 1);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_delete_with_limit() {
with_client(|mut client| async move {
let collection = client.new_collection("test_delete_with_limit").await;
let mut metadata_a = Metadata::new();
metadata_a.insert("category".to_string(), "a".into());
let mut metadata_b = Metadata::new();
metadata_b.insert("category".to_string(), "b".into());
collection
.add(
vec![
"id1".to_string(),
"id2".to_string(),
"id3".to_string(),
"id4".to_string(),
"id5".to_string(),
],
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
vec![10.0, 11.0, 12.0],
vec![13.0, 14.0, 15.0],
],
None,
None,
Some(vec![
Some(metadata_a.clone()),
Some(metadata_a.clone()),
Some(metadata_a),
Some(metadata_b.clone()),
Some(metadata_b),
]),
)
.await
.unwrap();
let where_clause = Where::Metadata(MetadataExpression {
key: "category".to_string(),
comparison: MetadataComparison::Primitive(
PrimitiveOperator::Equal,
MetadataValue::Str("a".to_string()),
),
});
let response = collection
.delete(None, Some(where_clause), Some(2))
.await
.unwrap();
assert_eq!(response.deleted, 2);
let count = collection.count().await.unwrap();
assert_eq!(count, 3);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_fork_basic() {
with_client(|mut client| async move {
let collection = client.new_collection("test_fork_source").await;
collection
.add(
vec!["id1".to_string(), "id2".to_string()],
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
Some(vec![Some("first".to_string()), Some("second".to_string())]),
None,
None,
)
.await
.unwrap();
let target_name = unique_collection_name("test_fork_target");
let forked = collection.fork(target_name.clone()).await.unwrap();
client.track(&forked);
println!("Forked collection: {:?}", forked);
assert_eq!(forked.collection.name, target_name);
assert_ne!(
forked.collection.collection_id,
collection.collection.collection_id
);
let forked_count = forked.count().await.unwrap();
println!("Forked collection count: {}", forked_count);
assert_eq!(forked_count, 2);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_fork_preserves_data() {
with_client(|mut client| async move {
let collection = client.new_collection("test_fork_preserves_source").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
Some(vec![Some("test document".to_string())]),
None,
None,
)
.await
.unwrap();
let target_name = unique_collection_name("test_fork_preserves_target");
let forked = collection.fork(target_name).await.unwrap();
client.track(&forked);
let forked_get_response = forked
.get(
None,
None,
None,
None,
Some(IncludeList(vec![Include::Document])),
)
.await
.unwrap();
println!("Forked collection get response: {:?}", forked_get_response);
assert_eq!(forked_get_response.ids.len(), 1);
assert_eq!(forked_get_response.ids[0], "id1");
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_fork_independence() {
with_client(|mut client| async move {
let collection = client.new_collection("test_fork_independence_source").await;
collection
.add(
vec!["id1".to_string()],
vec![vec![1.0, 2.0, 3.0]],
None,
None,
None,
)
.await
.unwrap();
let target_name = unique_collection_name("test_fork_independence_target");
let forked = collection.fork(target_name).await.unwrap();
client.track(&forked);
forked
.add(
vec!["id2".to_string()],
vec![vec![4.0, 5.0, 6.0]],
None,
None,
None,
)
.await
.unwrap();
let original_count = collection.count().await.unwrap();
let forked_count = forked.count().await.unwrap();
println!(
"Original count: {}, Forked count: {}",
original_count, forked_count
);
assert_eq!(original_count, 1);
assert_eq!(forked_count, 2);
})
.await;
}
#[tokio::test]
#[test_log::test]
async fn test_k8s_integration_modify() {
with_client(|mut client| async move {
let mut collection = client.new_collection("test_modify").await;
let mut new_metadata = Metadata::new();
new_metadata.insert("foo".into(), "bar".into());
collection
.modify(None::<String>, Some(new_metadata))
.await
.unwrap();
assert_eq!(
collection.metadata().as_ref().unwrap().get("foo"),
Some(&"bar".into())
);
let collection = client.get_collection(collection.name()).await.unwrap();
assert_eq!(
collection.metadata().as_ref().unwrap().get("foo"),
Some(&"bar".into())
);
})
.await;
}
}