1use async_trait::async_trait;
21use serde_json::{json, Value};
22use std::collections::HashMap;
23use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
24
25#[derive(Debug, Clone)]
27pub struct MilvusConfig {
28 pub endpoint: String,
30 pub collection: String,
32 pub api_key: Option<String>,
34 pub dim: usize,
36}
37
38impl MilvusConfig {
39 pub fn new(endpoint: impl Into<String>, collection: impl Into<String>, dim: usize) -> Self {
41 Self {
42 endpoint: endpoint.into(),
43 collection: collection.into(),
44 api_key: None,
45 dim,
46 }
47 }
48
49 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
51 self.api_key = Some(key.into());
52 self
53 }
54}
55
56pub struct MilvusVectorStore {
62 config: MilvusConfig,
63 client: reqwest::Client,
64}
65
66impl MilvusVectorStore {
67 pub fn new(config: MilvusConfig) -> Self {
69 Self {
70 config,
71 client: reqwest::Client::new(),
72 }
73 }
74
75 pub fn config(&self) -> &MilvusConfig {
77 &self.config
78 }
79
80 pub async fn initialize(&self) -> Result<(), SynapticError> {
85 let body = json!({
86 "collectionName": self.config.collection,
87 "dimension": self.config.dim,
88 "metricType": "COSINE",
89 });
90 let resp = self
91 .request("POST", "/v2/vectordb/collections/create", &body)
92 .await?;
93
94 let code = resp["code"].as_i64().unwrap_or(0);
95 if code != 0 {
96 let msg = resp["message"].as_str().unwrap_or("");
97 if !msg.to_lowercase().contains("already exist") {
99 return Err(SynapticError::VectorStore(format!(
100 "Milvus create collection error (code {code}): {msg}"
101 )));
102 }
103 }
104 Ok(())
105 }
106
107 async fn request(
109 &self,
110 method: &str,
111 path: &str,
112 body: &Value,
113 ) -> Result<Value, SynapticError> {
114 let url = format!("{}{}", self.config.endpoint.trim_end_matches('/'), path);
115 let mut req = match method {
116 "POST" => self.client.post(&url),
117 "DELETE" => self.client.delete(&url),
118 _ => self.client.get(&url),
119 };
120 req = req.header("Content-Type", "application/json");
121 if let Some(ref key) = self.config.api_key {
122 req = req.header("Authorization", format!("Bearer {}", key));
123 }
124 let resp = req
125 .json(body)
126 .send()
127 .await
128 .map_err(|e| SynapticError::VectorStore(format!("Milvus request error: {e}")))?;
129 let status = resp.status().as_u16();
130 let json: Value = resp
131 .json()
132 .await
133 .map_err(|e| SynapticError::VectorStore(format!("Milvus response parse error: {e}")))?;
134 if status >= 400 {
135 return Err(SynapticError::VectorStore(format!(
136 "Milvus HTTP error ({status}): {json}"
137 )));
138 }
139 Ok(json)
140 }
141
142 async fn search_by_vector_with_score(
144 &self,
145 vector: &[f32],
146 k: usize,
147 ) -> Result<Vec<(Document, f32)>, SynapticError> {
148 let body = json!({
149 "collectionName": self.config.collection,
150 "data": [vector],
151 "limit": k,
152 "outputFields": ["docId", "content", "metadata"],
153 });
154 let resp = self
155 .request("POST", "/v2/vectordb/entities/search", &body)
156 .await?;
157
158 let results = resp["data"].as_array().cloned().unwrap_or_default();
159 let docs = results
160 .iter()
161 .filter_map(|r| {
162 let doc_id = r["docId"].as_str()?.to_string();
163 let content = r["content"].as_str()?.to_string();
164 let metadata_str = r["metadata"].as_str().unwrap_or("{}");
165 let metadata: HashMap<String, Value> =
166 serde_json::from_str(metadata_str).unwrap_or_default();
167 let score = r["distance"].as_f64().unwrap_or(0.0) as f32;
168 Some((Document::with_metadata(doc_id, content, metadata), score))
169 })
170 .collect();
171 Ok(docs)
172 }
173}
174
175#[async_trait]
176impl VectorStore for MilvusVectorStore {
177 async fn add_documents(
178 &self,
179 docs: Vec<Document>,
180 embeddings: &dyn Embeddings,
181 ) -> Result<Vec<String>, SynapticError> {
182 if docs.is_empty() {
183 return Ok(vec![]);
184 }
185
186 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
187 let vectors = embeddings.embed_documents(&texts).await?;
188
189 let data: Vec<Value> = docs
190 .iter()
191 .zip(vectors.iter())
192 .map(|(doc, vec)| {
193 let metadata_str =
194 serde_json::to_string(&doc.metadata).unwrap_or_else(|_| "{}".to_string());
195 json!({
196 "docId": doc.id,
197 "content": doc.content,
198 "metadata": metadata_str,
199 "vector": vec,
200 })
201 })
202 .collect();
203
204 let body = json!({
205 "collectionName": self.config.collection,
206 "data": data,
207 });
208 self.request("POST", "/v2/vectordb/entities/insert", &body)
209 .await?;
210
211 Ok(docs.into_iter().map(|d| d.id).collect())
212 }
213
214 async fn similarity_search(
215 &self,
216 query: &str,
217 k: usize,
218 embeddings: &dyn Embeddings,
219 ) -> Result<Vec<Document>, SynapticError> {
220 let results = self
221 .similarity_search_with_score(query, k, embeddings)
222 .await?;
223 Ok(results.into_iter().map(|(doc, _)| doc).collect())
224 }
225
226 async fn similarity_search_with_score(
227 &self,
228 query: &str,
229 k: usize,
230 embeddings: &dyn Embeddings,
231 ) -> Result<Vec<(Document, f32)>, SynapticError> {
232 let qvec = embeddings.embed_query(query).await?;
233 self.search_by_vector_with_score(&qvec, k).await
234 }
235
236 async fn similarity_search_by_vector(
237 &self,
238 embedding: &[f32],
239 k: usize,
240 ) -> Result<Vec<Document>, SynapticError> {
241 let results = self.search_by_vector_with_score(embedding, k).await?;
242 Ok(results.into_iter().map(|(doc, _)| doc).collect())
243 }
244
245 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
246 if ids.is_empty() {
247 return Ok(());
248 }
249 let filter = format!(
250 "docId in [{}]",
251 ids.iter()
252 .map(|id| format!("\"{}\"", id))
253 .collect::<Vec<_>>()
254 .join(",")
255 );
256 let body = json!({
257 "collectionName": self.config.collection,
258 "filter": filter,
259 });
260 self.request("POST", "/v2/vectordb/entities/delete", &body)
261 .await?;
262 Ok(())
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn config_new_sets_fields() {
272 let config = MilvusConfig::new("http://localhost:19530", "test_collection", 1536);
273 assert_eq!(config.endpoint, "http://localhost:19530");
274 assert_eq!(config.collection, "test_collection");
275 assert_eq!(config.dim, 1536);
276 assert!(config.api_key.is_none());
277 }
278
279 #[test]
280 fn config_with_api_key() {
281 let config =
282 MilvusConfig::new("http://localhost:19530", "test", 768).with_api_key("my-token");
283 assert_eq!(config.api_key, Some("my-token".to_string()));
284 }
285
286 #[test]
287 fn store_new_creates_instance() {
288 let config = MilvusConfig::new("http://localhost:19530", "coll", 512);
289 let store = MilvusVectorStore::new(config);
290 assert_eq!(store.config().collection, "coll");
291 assert_eq!(store.config().dim, 512);
292 }
293}