use std::sync::Arc;
pub use super::api::{ChromaAuthMethod, ChromaTokenHeader};
use super::{
api::APIClientAsync,
commons::{Metadata, Result},
ChromaCollection,
};
use serde::Deserialize;
use serde_json::json;
const DEFAULT_ENDPOINT: &str = "http://localhost:8000";
pub struct ChromaClient {
api: Arc<APIClientAsync>,
}
#[derive(Debug)]
pub struct ChromaClientOptions {
pub url: Option<String>,
pub auth: ChromaAuthMethod,
pub database: String,
}
impl Default for ChromaClientOptions {
fn default() -> Self {
Self {
url: None,
auth: ChromaAuthMethod::None,
database: "default_database".to_string(),
}
}
}
impl ChromaClient {
pub async fn new(
ChromaClientOptions {
url,
auth,
database,
}: ChromaClientOptions,
) -> Result<ChromaClient> {
let endpoint = if let Some(url) = url {
url
} else {
std::env::var("CHROMA_HOST")
.unwrap_or(std::env::var("CHROMA_URL").unwrap_or(DEFAULT_ENDPOINT.to_string()))
};
let user_identity = APIClientAsync::get_auth(&endpoint, &auth).await?;
Ok(ChromaClient {
api: Arc::new(APIClientAsync::new(
endpoint,
auth,
user_identity.tenant,
database,
)),
})
}
pub async fn create_collection(
&self,
name: &str,
metadata: Option<Metadata>,
get_or_create: bool,
) -> Result<ChromaCollection> {
let request_body = json!({
"name": name,
"metadata": metadata,
"get_or_create": get_or_create,
});
let response = self
.api
.post_database("/collections", Some(request_body))
.await?;
let mut collection = response.json::<ChromaCollection>().await?;
collection.api = self.api.clone();
Ok(collection)
}
pub async fn get_or_create_collection(
&self,
name: &str,
metadata: Option<Metadata>,
) -> Result<ChromaCollection> {
self.create_collection(name, metadata, true).await
}
pub async fn list_collections(&self) -> Result<Vec<ChromaCollection>> {
let response = self.api.get_database("/collections").await?;
let collections = response.json::<Vec<ChromaCollection>>().await?;
let collections = collections
.into_iter()
.map(|mut collection| {
collection.api = self.api.clone();
collection
})
.collect();
Ok(collections)
}
pub async fn get_collection(&self, name: &str) -> Result<ChromaCollection> {
let response = self
.api
.get_database(&format!("/collections/{}", name))
.await?;
let mut collection = response.json::<ChromaCollection>().await?;
collection.api = self.api.clone();
Ok(collection)
}
pub async fn delete_collection(&self, name: &str) -> Result<()> {
self.api
.delete_database(&format!("/collections/{}", name))
.await?;
Ok(())
}
pub async fn update_collection(&self, collection_id: &str, new_name: Option<&str>, metadata: Option<Metadata>) -> Result<()> {
self.api.put_database(
&format!("/collections/{}", collection_id),
Some(json!({ "new_name": new_name,"new_metadata": metadata })),
).await?;
Ok(())
}
pub async fn version(&self) -> Result<String> {
let response = self.api.get_v1("/version").await?;
let version = response.json::<String>().await?;
Ok(version)
}
pub async fn heartbeat(&self) -> Result<u64> {
let response = self.api.get_v1("/heartbeat").await?;
let json = response.json::<HeartbeatResponse>().await?;
Ok(json.heartbeat)
}
}
#[derive(Deserialize)]
struct HeartbeatResponse {
#[serde(rename = "nanosecond heartbeat")]
pub heartbeat: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
const TEST_COLLECTION: &str = "8-recipies-for-octopus";
#[tokio::test]
async fn test_heartbeat() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
let heartbeat = client.heartbeat().await.unwrap();
assert!(heartbeat > 0);
}
#[tokio::test]
async fn test_version() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
let version = client.version().await.unwrap();
assert_eq!(version.split('.').count(), 3);
}
#[tokio::test]
async fn test_create_collection() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
let result = client
.create_collection(TEST_COLLECTION, None, true)
.await
.unwrap();
assert_eq!(result.name(), TEST_COLLECTION);
}
#[tokio::test]
async fn test_get_collection() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
const GET_TEST_COLLECTION: &str = "100-recipes-for-octopus";
client
.create_collection(GET_TEST_COLLECTION, None, true)
.await
.unwrap();
let collection = client.get_collection(GET_TEST_COLLECTION).await.unwrap();
assert_eq!(collection.name(), GET_TEST_COLLECTION);
assert!(collection.configuration_json.is_some());
}
#[tokio::test]
async fn test_list_collection() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
let result = client.list_collections().await.unwrap();
assert!(!result.is_empty());
}
#[tokio::test]
async fn test_delete_collection() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
const DELETE_TEST_COLLECTION: &str = "6-recipies-for-octopus";
client
.get_or_create_collection(DELETE_TEST_COLLECTION, None)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let collection = client.delete_collection(DELETE_TEST_COLLECTION).await;
assert!(collection.is_ok());
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let collection = client.delete_collection(DELETE_TEST_COLLECTION).await;
assert!(collection.is_err());
}
#[tokio::test]
async fn test_update_collection() {
let client: ChromaClient = ChromaClient::new(Default::default()).await.unwrap();
let collection = client.get_or_create_collection(TEST_COLLECTION, None).await.unwrap();
let collection_id = collection.id();
let result = client.update_collection(collection_id, None, None).await;
assert!(result.is_ok());
let new_name = "new_name";
let result = client.update_collection(collection_id, Some(new_name), None).await;
assert!(result.is_ok());
let updated_collection = client.get_collection(new_name).await.unwrap();
assert_eq!(collection_id, updated_collection.id());
let new_metadata = Some(json!({"foo": "bar"}).as_object().unwrap().clone());
let result = client.update_collection(collection_id, None, new_metadata.clone()).await;
assert!(result.is_ok());
let updated_collection = client.get_collection(new_name).await.unwrap();
assert_eq!(updated_collection.metadata(), new_metadata.as_ref());
}
}