use chroma::client::ChromaHttpClientOptions;
use chroma::embed::bm25::BM25SparseEmbeddingFunction;
use chroma::types::{
rrf, Aggregate, GroupBy, Key, Metadata, MetadataValue, QueryVector, RankExpr, Schema,
SearchPayload, SparseVectorIndexConfig,
};
use chroma::ChromaHttpClient;
const TOTAL_DOCS: usize = 256;
const BATCH_SIZE: usize = 128;
const KNN_LIMIT: u32 = 64;
const COLLECTION_NAME: &str = "comprehensive_search_example";
const CATEGORIES: [&str; 3] = ["machine_learning", "quantum_computing", "bioinformatics"];
const AUTHORS: [&str; 5] = [
"Alice Chen",
"Bob Smith",
"Carol Johnson",
"David Lee",
"Emma Wilson",
];
fn generate_abstract(i: usize) -> String {
let topics = [
"deep neural networks for image classification",
"quantum entanglement in superconducting qubits",
"protein folding prediction using transformers",
"reinforcement learning in autonomous systems",
"quantum error correction codes",
"genomic sequence alignment algorithms",
"attention mechanisms in language models",
"topological quantum computing",
"single-cell RNA sequencing analysis",
];
let methods = [
"novel optimization techniques",
"experimental validation",
"theoretical framework",
"large-scale benchmarks",
"ablation studies",
];
let results = [
"significant improvements over baselines",
"state-of-the-art performance",
"promising preliminary results",
"robust generalization",
];
format!(
"This paper investigates {} using {}. Our experiments demonstrate {}, \
opening new directions for future research in this domain.",
topics[i % topics.len()],
methods[i % methods.len()],
results[i % results.len()]
)
}
#[allow(clippy::type_complexity)]
fn generate_test_data() -> (
Vec<String>,
Vec<Vec<f32>>,
Vec<Option<String>>,
Vec<Option<Metadata>>,
) {
let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs();
let mut ids = Vec::with_capacity(TOTAL_DOCS);
let mut embeddings = Vec::with_capacity(TOTAL_DOCS);
let mut documents = Vec::with_capacity(TOTAL_DOCS);
let mut metadatas = Vec::with_capacity(TOTAL_DOCS);
for i in 0..TOTAL_DOCS {
ids.push(format!("paper_{:03}", i + 1));
let t = i as f32 / (TOTAL_DOCS - 1) as f32;
embeddings.push(vec![1.0 - t, 0.0, t]);
let doc_text = generate_abstract(i);
documents.push(Some(doc_text.clone()));
let sparse_vector = bm25
.encode(&doc_text)
.expect("BM25 encoding should not fail");
let mut metadata = Metadata::new();
metadata.insert("category".into(), CATEGORIES[i % 3].into());
metadata.insert("year".into(), MetadataValue::Int(2020 + (i % 5) as i64));
metadata.insert("author".into(), AUTHORS[i % 5].into());
metadata.insert("citations".into(), MetadataValue::Int((i * 7 % 500) as i64));
metadata.insert(
"quality_score".into(),
MetadataValue::Float(0.5 + (i % 50) as f64 / 100.0),
);
metadata.insert("peer_reviewed".into(), MetadataValue::Bool(i % 2 == 0));
metadata.insert("sparse_embedding".into(), sparse_vector.into());
metadatas.push(Some(metadata));
}
(ids, embeddings, documents, metadatas)
}
fn print_results(
title: &str,
response: &chroma::types::SearchResponse,
fields: &[&str],
max_display: usize,
) {
println!("\n{}", title);
println!("{}", "=".repeat(title.len()));
let total = response.ids[0].len();
let display_count = total.min(max_display);
for (i, id) in response.ids[0].iter().take(display_count).enumerate() {
let score = response.scores[0]
.as_ref()
.and_then(|s| s.get(i))
.and_then(|s| *s)
.map(|s| format!("{:.4}", s))
.unwrap_or_else(|| "N/A".to_string());
let metadata = response.metadatas[0]
.as_ref()
.and_then(|m| m.get(i))
.and_then(|m| m.as_ref());
let field_values: Vec<String> = fields
.iter()
.filter_map(|f| {
metadata
.and_then(|m| m.get(*f))
.map(|v| format!("{}={:?}", f, v))
})
.collect();
let fields_str = if field_values.is_empty() {
String::new()
} else {
format!(", {}", field_values.join(", "))
};
println!(" {}. {} (score={}{})", i + 1, id, score, fields_str);
}
if total > display_count {
println!(" ... and {} more results", total - display_count);
}
println!(" Total: {} results", total);
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Comprehensive Chroma Search API Example ===\n");
let client = ChromaHttpClient::new(ChromaHttpClientOptions::cloud(
"<chroma-api-key>",
"<chroma-database-name>",
)?);
println!("Connected to Chroma");
println!("Deleting existing collection if present...");
let _ = client.delete_collection(COLLECTION_NAME).await;
let schema = Schema::default().create_index(
Some("sparse_embedding"),
SparseVectorIndexConfig {
embedding_function: None,
source_key: None,
bm25: Some(true), }
.into(),
)?;
let collection = client
.get_or_create_collection(COLLECTION_NAME, Some(schema), None)
.await?;
let initial_version = collection.version();
println!(
"Created collection '{}' (version: {})",
collection.name(),
initial_version
);
println!("\nGenerating {} test documents...", TOTAL_DOCS);
let (ids, embeddings, documents, metadatas) = generate_test_data();
println!("Inserting documents in batches of {}...", BATCH_SIZE);
let num_batches = TOTAL_DOCS / BATCH_SIZE;
for batch_idx in 0..num_batches {
let start = batch_idx * BATCH_SIZE;
let end = start + BATCH_SIZE;
collection
.add(
ids[start..end].to_vec(),
embeddings[start..end].to_vec(),
Some(documents[start..end].to_vec()),
None, Some(metadatas[start..end].to_vec()),
)
.await?;
println!(" Batch {}/{} inserted", batch_idx + 1, num_batches);
}
let count = collection.count().await?;
println!("Collection ready with {} documents\n", count);
println!("\n>>> Starting Search Demonstrations <<<\n");
let dense_query = vec![1.0, 0.0, 0.0]; let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs();
let sparse_query = bm25
.encode("deep learning neural networks image classification")
.expect("BM25 encoding should not fail");
let search = SearchPayload::default()
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([Key::Score, Key::field("category"), Key::field("year")]);
let response = collection.search(vec![search]).await?;
print_results(
"1. Basic KNN Search (Dense Vectors)",
&response,
&["category", "year"],
5,
);
let search = SearchPayload::default()
.rank(RankExpr::Knn {
query: QueryVector::Sparse(sparse_query.clone()),
key: Key::field("sparse_embedding"),
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([Key::Score, Key::field("category"), Key::Document]);
let response = collection.search(vec![search]).await?;
print_results(
"2. Sparse Vector Search (Keyword-based with BM25)",
&response,
&["category"],
5,
);
let dense_knn = RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: Some(10.0), return_rank: false,
};
let sparse_knn = RankExpr::Knn {
query: QueryVector::Sparse(sparse_query.clone()),
key: Key::field("sparse_embedding"),
limit: KNN_LIMIT,
default: Some(10.0), return_rank: false,
};
let hybrid_rank = dense_knn * 0.7 + sparse_knn * 0.3;
let search = SearchPayload::default()
.rank(hybrid_rank)
.limit(Some(5), 0)
.select([Key::Score, Key::field("category"), Key::field("author")]);
let response = collection.search(vec![search]).await?;
print_results(
"3. Hybrid Search (Linear Combination: 70% dense + 30% sparse)",
&response,
&["category", "author"],
5,
);
let dense_knn_rank = RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: true, };
let sparse_knn_rank = RankExpr::Knn {
query: QueryVector::Sparse(sparse_query.clone()),
key: Key::field("sparse_embedding"),
limit: KNN_LIMIT,
default: None,
return_rank: true, };
let rrf_rank = rrf(
vec![dense_knn_rank, sparse_knn_rank],
Some(60), Some(vec![0.7, 0.3]), false, )?;
let search = SearchPayload::default()
.rank(rrf_rank)
.limit(Some(5), 0)
.select([Key::Score, Key::field("category"), Key::field("citations")]);
let response = collection.search(vec![search]).await?;
print_results(
"4. Hybrid Search (RRF Fusion)",
&response,
&["category", "citations"],
5,
);
let search = SearchPayload::default()
.r#where(Key::field("category").eq("machine_learning"))
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([Key::Score, Key::field("category"), Key::field("year")]);
let response = collection.search(vec![search]).await?;
print_results(
"5. Metadata Filter (category = 'machine_learning')",
&response,
&["category", "year"],
5,
);
let search = SearchPayload::default()
.r#where(Key::Document.contains("neural networks"))
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([Key::Score, Key::Document, Key::field("category")]);
let response = collection.search(vec![search]).await?;
print_results(
"6. Full-Text Search (document contains 'neural networks')",
&response,
&["category"],
5,
);
let search = SearchPayload::default()
.r#where(Key::Document.regex(r"quantum\s+\w+"))
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([Key::Score, Key::Document, Key::field("category")]);
let response = collection.search(vec![search]).await?;
print_results(
r"7. Regex Filter (document matches 'quantum\s+\w+')",
&response,
&["category"],
5,
);
let search = SearchPayload::default()
.r#where(
(Key::field("year").gte(2022))
& (Key::field("peer_reviewed").eq(true))
& (Key::Document.contains("learning")),
)
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 0)
.select([
Key::Score,
Key::Document,
Key::field("year"),
Key::field("peer_reviewed"),
]);
let response = collection.search(vec![search]).await?;
print_results(
"8. Complex Filter (year >= 2022 AND peer_reviewed AND document contains 'learning')",
&response,
&["year", "peer_reviewed"],
5,
);
let search = SearchPayload::default()
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.group_by(GroupBy {
keys: vec![Key::field("category")],
aggregate: Some(Aggregate::MinK {
keys: vec![Key::Score],
k: 2, }),
})
.limit(Some(10), 0)
.select([Key::Score, Key::field("category"), Key::field("author")]);
let response = collection.search(vec![search]).await?;
print_results(
"9. Group By Category (Top 2 per category by score)",
&response,
&["category", "author"],
10,
);
let search = SearchPayload::default()
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(3), 0)
.select([
Key::Document,
Key::Score,
Key::field("author"),
Key::field("quality_score"),
]);
let response = collection.search(vec![search]).await?;
println!("\n10. Field Selection (Document, Score, author, quality_score)");
println!("=============================================================");
for (i, id) in response.ids[0].iter().take(3).enumerate() {
let doc = response.documents[0]
.as_ref()
.and_then(|d| d.get(i))
.and_then(|d| d.as_ref())
.map(|d| {
if d.len() > 80 {
format!("{}...", &d[..80])
} else {
d.clone()
}
})
.unwrap_or_default();
let score = response.scores[0]
.as_ref()
.and_then(|s| s.get(i))
.and_then(|s| *s)
.map(|s| format!("{:.4}", s))
.unwrap_or_else(|| "N/A".to_string());
let metadata = response.metadatas[0]
.as_ref()
.and_then(|m| m.get(i))
.and_then(|m| m.as_ref());
let author = metadata
.and_then(|m| m.get("author"))
.map(|v| format!("{:?}", v))
.unwrap_or_default();
let quality = metadata
.and_then(|m| m.get("quality_score"))
.map(|v| format!("{:?}", v))
.unwrap_or_default();
println!(" {}. {}", i + 1, id);
println!(" Score: {}", score);
println!(" Author: {}", author);
println!(" Quality: {}", quality);
println!(" Doc: {}", doc);
}
println!("\n11. Pagination (limit=5, offset=10)");
println!("====================================");
let search = SearchPayload::default()
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(5), 10) .select([Key::Score, Key::field("category")]);
let response = collection.search(vec![search]).await?;
print_results("Results (offset=10)", &response, &["category"], 5);
println!("\n12. Batch Search (3 queries in one request)");
println!("============================================");
let searches = vec![
SearchPayload::default()
.r#where(Key::field("category").eq("machine_learning"))
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(3), 0)
.select([Key::Score, Key::field("category")]),
SearchPayload::default()
.r#where(Key::field("category").eq("quantum_computing"))
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(3), 0)
.select([Key::Score, Key::field("category")]),
SearchPayload::default()
.r#where(Key::field("category").eq("bioinformatics"))
.rank(RankExpr::Knn {
query: QueryVector::Dense(dense_query.clone()),
key: Key::Embedding,
limit: KNN_LIMIT,
default: None,
return_rank: false,
})
.limit(Some(3), 0)
.select([Key::Score, Key::field("category")]),
];
let response = collection.search(searches).await?;
for (query_idx, category) in ["machine_learning", "quantum_computing", "bioinformatics"]
.iter()
.enumerate()
{
println!("\n Query {}: {} papers", query_idx + 1, category);
for (i, id) in response.ids[query_idx].iter().enumerate() {
let score = response.scores[query_idx]
.as_ref()
.and_then(|s| s.get(i))
.and_then(|s| *s)
.map(|s| format!("{:.4}", s))
.unwrap_or_else(|| "N/A".to_string());
println!(" {}. {} (score={})", i + 1, id, score);
}
}
println!("\n\n>>> Cleanup <<<");
client.delete_collection(COLLECTION_NAME).await?;
println!("Deleted collection '{}'", COLLECTION_NAME);
println!("\n=== Example Complete ===");
Ok(())
}