use std::collections::HashMap;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, warn};
use super::base::{Filters, OutputData, VectorStoreBase};
use crate::config::MilvusConfig;
use crate::error::{NeomemxError, Result};
pub struct Milvus {
config: MilvusConfig,
client: Client,
base_url: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateCollectionRequest {
db_name: String,
collection_name: String,
dimension: usize,
metric_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
primary_field_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
vector_field_name: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct InsertRequest {
db_name: String,
collection_name: String,
data: Vec<InsertData>,
}
#[derive(Debug, Serialize)]
struct InsertData {
id: String,
vector: Vec<f32>,
#[serde(flatten)]
metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct SearchRequest {
db_name: String,
collection_name: String,
data: Vec<Vec<f32>>,
#[serde(rename = "annsField")]
anns_field: String,
limit: usize,
output_fields: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct QueryRequest {
db_name: String,
collection_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
filter: Option<String>,
output_fields: Vec<String>,
limit: usize,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GetRequest {
db_name: String,
collection_name: String,
id: Vec<String>,
output_fields: Vec<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct DeleteRequest {
db_name: String,
collection_name: String,
filter: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct UpsertRequest {
db_name: String,
collection_name: String,
data: Vec<InsertData>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct DropCollectionRequest {
db_name: String,
collection_name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct DescribeCollectionRequest {
db_name: String,
collection_name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ListCollectionsRequest {
db_name: String,
}
#[derive(Debug, Deserialize)]
struct MilvusResponse<T> {
code: i32,
#[serde(default)]
message: Option<String>,
data: Option<T>,
}
#[derive(Debug, Deserialize)]
struct SearchResultData {
id: serde_json::Value,
distance: f32,
#[serde(flatten)]
fields: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Deserialize)]
struct QueryResultData {
id: serde_json::Value,
#[serde(flatten)]
fields: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)] struct CollectionInfo {
collection_name: String,
#[serde(default)]
description: String,
#[serde(default)]
fields: Vec<serde_json::Value>,
}
impl Milvus {
pub async fn new(config: MilvusConfig) -> Result<Self> {
let base_url = config.get_base_url();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
if let Some(api_key) = config.get_api_key() {
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", api_key).parse().map_err(|_| {
NeomemxError::VectorStoreError("Invalid API key format".to_string())
})?,
);
}
let client = Client::builder()
.default_headers(headers)
.build()
.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to create HTTP client: {}", e))
})?;
let milvus = Self {
config,
client,
base_url,
};
milvus.ensure_collection().await?;
Ok(milvus)
}
async fn ensure_collection(&self) -> Result<()> {
let describe_url = format!("{}/v2/vectordb/collections/describe", self.base_url);
let describe_req = DescribeCollectionRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
};
debug!(
"Checking for existing collection: {}",
self.config.collection_name
);
let response = self
.client
.post(&describe_url)
.json(&describe_req)
.send()
.await;
if let Ok(resp) = response {
if resp.status().is_success() {
if let Ok(result) = resp.json::<MilvusResponse<CollectionInfo>>().await {
if result.code == 0 && result.data.is_some() {
info!("Found existing collection: {}", self.config.collection_name);
return Ok(());
}
}
}
}
self.create_collection_internal().await
}
async fn create_collection_internal(&self) -> Result<()> {
let url = format!("{}/v2/vectordb/collections/create", self.base_url);
let request = CreateCollectionRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
dimension: self.config.dimension,
metric_type: self.config.metric_type.as_str().to_string(),
primary_field_name: Some("id".to_string()),
vector_field_name: Some("vector".to_string()),
};
debug!(
"Creating collection: {} with dimension {}",
self.config.collection_name, self.config.dimension
);
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to connect to Milvus: {}", e))
})?;
let status = response.status();
let body = response.text().await.unwrap_or_default();
if !status.is_success() {
return Err(NeomemxError::VectorStoreError(format!(
"Failed to create collection: {}",
body
)));
}
let result: MilvusResponse<serde_json::Value> =
serde_json::from_str(&body).map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse response: {}", e))
})?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
info!("Created collection: {}", self.config.collection_name);
Ok(())
}
fn generate_filter_expression(filters: &Filters) -> Option<String> {
let mut conditions: Vec<String> = Vec::new();
for (key, value) in filters {
if key.starts_with('$') {
continue;
}
if let serde_json::Value::Object(ops) = value {
for (op, val) in ops {
let expr = match op.as_str() {
"eq" => format_condition(key, "==", val),
"ne" => format_condition(key, "!=", val),
"gt" => format_condition(key, ">", val),
"gte" => format_condition(key, ">=", val),
"lt" => format_condition(key, "<", val),
"lte" => format_condition(key, "<=", val),
"in" => {
if let serde_json::Value::Array(arr) = val {
let values: Vec<String> = arr.iter().map(format_value).collect();
Some(format!("{} in [{}]", key, values.join(", ")))
} else {
None
}
}
_ => None,
};
if let Some(e) = expr {
conditions.push(e);
}
}
} else if value.as_str() != Some("*") {
if let Some(cond) = format_condition(key, "==", value) {
conditions.push(cond);
}
}
}
if conditions.is_empty() {
None
} else if conditions.len() == 1 {
Some(conditions.into_iter().next().unwrap())
} else {
Some(conditions.join(" and "))
}
}
fn parse_search_results(results: Vec<SearchResultData>) -> Vec<OutputData> {
results
.into_iter()
.map(|r| {
let id = match &r.id {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
_ => r.id.to_string(),
};
let payload: HashMap<String, serde_json::Value> = r
.fields
.into_iter()
.filter(|(k, _)| k != "vector" && k != "id")
.collect();
let score = Some(r.distance);
OutputData::new(id, score, payload)
})
.collect()
}
fn parse_query_results(results: Vec<QueryResultData>) -> Vec<OutputData> {
results
.into_iter()
.map(|r| {
let id = match &r.id {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
_ => r.id.to_string(),
};
let payload: HashMap<String, serde_json::Value> = r
.fields
.into_iter()
.filter(|(k, _)| k != "vector" && k != "id")
.collect();
OutputData::new(id, None, payload)
})
.collect()
}
}
fn format_condition(key: &str, op: &str, value: &serde_json::Value) -> Option<String> {
Some(format!("{} {} {}", key, op, format_value(value)))
}
fn format_value(value: &serde_json::Value) -> String {
match value {
serde_json::Value::String(s) => format!("\"{}\"", s.replace('"', "\\\"")),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => value.to_string(),
}
}
#[async_trait]
impl VectorStoreBase for Milvus {
async fn create_collection(&self, name: &str) -> Result<()> {
let url = format!("{}/v2/vectordb/collections/create", self.base_url);
let request = CreateCollectionRequest {
db_name: self.config.database.clone(),
collection_name: name.to_string(),
dimension: self.config.dimension,
metric_type: self.config.metric_type.as_str().to_string(),
primary_field_name: Some("id".to_string()),
vector_field_name: Some("vector".to_string()),
};
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to create collection: {}",
body
)));
}
info!("Created collection: {}", name);
Ok(())
}
async fn insert(
&self,
vectors: Vec<Vec<f32>>,
payloads: Option<Vec<HashMap<String, serde_json::Value>>>,
ids: Option<Vec<String>>,
) -> Result<()> {
let ids = ids.unwrap_or_else(|| {
(0..vectors.len())
.map(|_| uuid::Uuid::new_v4().to_string())
.collect()
});
let payloads = payloads.unwrap_or_else(|| vec![HashMap::new(); vectors.len()]);
info!("Inserting {} vectors into collection", vectors.len());
let data: Vec<InsertData> = ids
.into_iter()
.zip(vectors)
.zip(payloads)
.map(|((id, vector), metadata)| InsertData {
id,
vector,
metadata,
})
.collect();
let request = InsertRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
data,
};
let url = format!("{}/v2/vectordb/entities/insert", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
error!("Failed to insert vectors: {}", body);
return Err(NeomemxError::VectorStoreError(format!(
"Failed to insert vectors: {}",
body
)));
}
let result: MilvusResponse<serde_json::Value> = response.json().await?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus insert error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
debug!("Successfully inserted vectors");
Ok(())
}
async fn search(
&self,
_query: &str,
vectors: &[f32],
limit: usize,
filters: Option<Filters>,
) -> Result<Vec<OutputData>> {
let filter = filters.as_ref().and_then(Self::generate_filter_expression);
let request = SearchRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
data: vec![vectors.to_vec()],
anns_field: "vector".to_string(),
limit,
output_fields: vec!["*".to_string()],
filter,
};
let url = format!("{}/v2/vectordb/entities/search", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to search vectors: {}",
body
)));
}
let result: MilvusResponse<Vec<SearchResultData>> = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse search response: {}", e))
})?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus search error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
Ok(Self::parse_search_results(result.data.unwrap_or_default()))
}
async fn delete(&self, vector_id: &str) -> Result<()> {
self.delete_batch(&[vector_id.to_string()]).await
}
async fn delete_batch(&self, vector_ids: &[String]) -> Result<()> {
if vector_ids.is_empty() {
return Ok(());
}
let ids_list: Vec<String> = vector_ids
.iter()
.map(|id| format!("\"{}\"", id.replace('"', "\\\"")))
.collect();
let filter = format!("id in [{}]", ids_list.join(", "));
let request = DeleteRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
filter,
};
let url = format!("{}/v2/vectordb/entities/delete", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to delete vectors: {}",
body
)));
}
let result: MilvusResponse<serde_json::Value> = response.json().await?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus delete error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
debug!("Deleted {} vectors", vector_ids.len());
Ok(())
}
async fn update(
&self,
vector_id: &str,
vector: Option<Vec<f32>>,
payload: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
let existing = self.get(vector_id).await?;
let (final_vector, final_payload) = match existing {
Some(existing_data) => {
let merged_payload = if let Some(new_payload) = payload {
let mut merged = existing_data.payload;
merged.extend(new_payload);
merged
} else {
existing_data.payload
};
let final_vec = vector.ok_or_else(|| {
NeomemxError::VectorStoreError(
"Vector is required for updates in Milvus".to_string(),
)
})?;
(final_vec, merged_payload)
}
None => {
let vec = vector.ok_or_else(|| {
NeomemxError::VectorStoreError("Vector is required for insert".to_string())
})?;
(vec, payload.unwrap_or_default())
}
};
let data = InsertData {
id: vector_id.to_string(),
vector: final_vector,
metadata: final_payload,
};
let request = UpsertRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
data: vec![data],
};
let url = format!("{}/v2/vectordb/entities/upsert", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to update vector: {}",
body
)));
}
let result: MilvusResponse<serde_json::Value> = response.json().await?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus upsert error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
debug!("Updated vector: {}", vector_id);
Ok(())
}
async fn get(&self, vector_id: &str) -> Result<Option<OutputData>> {
let results = self.get_batch(&[vector_id.to_string()]).await?;
Ok(results.into_iter().next())
}
async fn get_batch(&self, vector_ids: &[String]) -> Result<Vec<OutputData>> {
if vector_ids.is_empty() {
return Ok(Vec::new());
}
let request = GetRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
id: vector_ids.to_vec(),
output_fields: vec!["*".to_string()],
};
let url = format!("{}/v2/vectordb/entities/get", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to get vectors: {}",
body
)));
}
let result: MilvusResponse<Vec<QueryResultData>> = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse get response: {}", e))
})?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus get error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
Ok(Self::parse_query_results(result.data.unwrap_or_default()))
}
async fn list_collections(&self) -> Result<Vec<String>> {
let url = format!("{}/v2/vectordb/collections/list", self.base_url);
let request = ListCollectionsRequest {
db_name: self.config.database.clone(),
};
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to list collections: {}",
body
)));
}
let result: MilvusResponse<Vec<String>> = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse list response: {}", e))
})?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus list error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
Ok(result.data.unwrap_or_default())
}
async fn delete_collection(&self) -> Result<()> {
let url = format!("{}/v2/vectordb/collections/drop", self.base_url);
let request = DropCollectionRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
};
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to delete collection: {}",
body
)));
}
let result: MilvusResponse<serde_json::Value> = response.json().await?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus drop error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
warn!("Deleted collection: {}", self.config.collection_name);
Ok(())
}
async fn collection_info(&self) -> Result<serde_json::Value> {
let url = format!("{}/v2/vectordb/collections/describe", self.base_url);
let request = DescribeCollectionRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
};
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to get collection info: {}",
body
)));
}
let result: MilvusResponse<serde_json::Value> = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse collection info: {}", e))
})?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus describe error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
Ok(result.data.unwrap_or(serde_json::Value::Null))
}
async fn list(&self, filters: Option<Filters>, limit: usize) -> Result<Vec<OutputData>> {
let filter = filters.as_ref().and_then(Self::generate_filter_expression);
let request = QueryRequest {
db_name: self.config.database.clone(),
collection_name: self.config.collection_name.clone(),
filter,
output_fields: vec!["*".to_string()],
limit,
};
let url = format!("{}/v2/vectordb/entities/query", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(NeomemxError::VectorStoreError(format!(
"Failed to list vectors: {}",
body
)));
}
let result: MilvusResponse<Vec<QueryResultData>> = response.json().await.map_err(|e| {
NeomemxError::VectorStoreError(format!("Failed to parse list response: {}", e))
})?;
if result.code != 0 {
return Err(NeomemxError::VectorStoreError(format!(
"Milvus query error: {}",
result
.message
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
Ok(Self::parse_query_results(result.data.unwrap_or_default()))
}
async fn reset(&self) -> Result<()> {
warn!("Resetting collection: {}", self.config.collection_name);
self.delete_collection().await?;
self.create_collection_internal().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_expression() {
let mut filters = HashMap::new();
filters.insert("user_id".to_string(), serde_json::json!("user_123"));
let expr = Milvus::generate_filter_expression(&filters);
assert!(expr.is_some());
assert!(expr.unwrap().contains("user_id == \"user_123\""));
}
#[test]
fn test_filter_with_operators() {
let mut filters = HashMap::new();
filters.insert(
"score".to_string(),
serde_json::json!({"gt": 0.5, "lt": 1.0}),
);
let expr = Milvus::generate_filter_expression(&filters);
assert!(expr.is_some());
let expr_str = expr.unwrap();
assert!(expr_str.contains("score > 0.5"));
assert!(expr_str.contains("score < 1"));
}
#[test]
fn test_format_value() {
assert_eq!(format_value(&serde_json::json!("hello")), "\"hello\"");
assert_eq!(format_value(&serde_json::json!(42)), "42");
assert_eq!(format_value(&serde_json::json!(3.14)), "3.14");
assert_eq!(format_value(&serde_json::json!(true)), "true");
}
}