1mod filter;
2
3use reqwest::StatusCode;
4use rig::{
5 Embed, OneOrMany,
6 embeddings::{Embedding, EmbeddingModel},
7 vector_store::{
8 InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
9 request::{Filter as CoreFilter, SearchFilter, VectorSearchRequest},
10 },
11 wasm_compat::WasmBoxedFuture,
12};
13use serde::{Deserialize, Serialize};
14
15use crate::filter::Filter;
16
17pub struct MilvusVectorStore<M> {
19 model: M,
21 base_url: String,
22 client: reqwest::Client,
23 database_name: String,
24 collection_name: String,
25 token: Option<String>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29pub struct CreateRecord {
30 document: String,
31 embedded_text: String,
32 embedding: Vec<f64>,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37struct InsertRequest<'a> {
38 data: Vec<CreateRecord>,
39 collection_name: &'a str,
40 db_name: &'a str,
41}
42
43#[derive(Debug, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45struct SearchRequest<'a> {
46 collection_name: &'a str,
47 db_name: &'a str,
48 data: Vec<f64>,
49 #[serde(skip_serializing_if = "String::is_empty")]
50 filter: String,
51 anns_field: &'a str,
52 limit: usize,
53 output_fields: Vec<&'a str>,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58struct SearchResult<T> {
59 code: i64,
60 data: Vec<SearchResultData<T>>,
61}
62
63#[derive(Debug, Serialize, Deserialize)]
64#[serde(rename_all = "camelCase")]
65struct SearchResultData<T> {
66 id: i64,
67 distance: f64,
68 document: T,
69 embedded_text: String,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74struct SearchResultOnlyId {
75 code: i64,
76 data: Vec<SearchResultDataOnlyId>,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct SearchResultDataOnlyId {
82 id: i64,
83 distance: f64,
84}
85
86impl<M> MilvusVectorStore<M>
87where
88 M: EmbeddingModel,
89{
90 pub fn new(model: M, base_url: String, database_name: String, collection_name: String) -> Self {
98 Self {
99 model,
100 base_url,
101 client: reqwest::Client::new(),
102 database_name,
103 collection_name,
104 token: None,
105 }
106 }
107
108 pub fn auth(mut self, username: String, password: String) -> Self {
110 let str = format!("{username}:{password}");
111 self.token = Some(str);
112
113 self
114 }
115
116 fn create_insert_request(&self, data: Vec<CreateRecord>) -> InsertRequest<'_> {
118 InsertRequest {
119 data,
120 collection_name: &self.collection_name,
121 db_name: &self.database_name,
122 }
123 }
124
125 fn create_search_request(
127 &self,
128 data: Vec<f64>,
129 req: &VectorSearchRequest<Filter>,
130 id_only: bool,
131 ) -> SearchRequest<'_> {
132 const OUTPUT_FIELDS: [&str; 4] = ["id", "distance", "document", "embeddedText"];
133 const OUTPUT_FIELDS_ID_ONLY: [&str; 2] = ["id", "distance"];
134
135 let output_fields = if id_only {
136 OUTPUT_FIELDS_ID_ONLY.to_vec()
137 } else {
138 OUTPUT_FIELDS.to_vec()
139 };
140
141 let threshold = req
142 .threshold()
143 .map(|thresh| Filter::gte("distance".into(), thresh.into()));
144
145 let filter = match (threshold, req.filter()) {
146 (Some(thresh), Some(filter)) => thresh.and(filter.clone()).into_inner(),
147 (Some(thresh), _) => thresh.into_inner(),
148 (_, Some(filter)) => filter.clone().into_inner(),
149 _ => String::new(),
150 };
151
152 SearchRequest {
153 collection_name: &self.collection_name,
154 db_name: &self.database_name,
155 data,
156 filter,
157 anns_field: "embedding",
158 limit: req.samples() as usize,
159 output_fields,
160 }
161 }
162}
163
164impl<Model> InsertDocuments for MilvusVectorStore<Model>
165where
166 Model: EmbeddingModel + Send + Sync,
167{
168 async fn insert_documents<Doc: Serialize + Embed + Send>(
169 &self,
170 documents: Vec<(Doc, OneOrMany<Embedding>)>,
171 ) -> Result<(), VectorStoreError> {
172 let url = format!(
173 "{base_url}/v2/vectordb/entities/insert",
174 base_url = self.base_url
175 );
176
177 let data = documents
178 .into_iter()
179 .map(|(document, embeddings)| {
180 let json_document: serde_json::Value = serde_json::to_value(&document)?;
181 let json_document_as_string = serde_json::to_string(&json_document)?;
182
183 let embeddings = embeddings
184 .into_iter()
185 .map(|embedding| {
186 let embedded_text = embedding.document;
187 let embedding: Vec<f64> = embedding.vec;
188
189 CreateRecord {
190 document: json_document_as_string.clone(),
191 embedded_text,
192 embedding,
193 }
194 })
195 .collect::<Vec<CreateRecord>>();
196 Ok(embeddings)
197 })
198 .collect::<Result<Vec<Vec<CreateRecord>>, VectorStoreError>>()?
199 .into_iter()
200 .flatten()
201 .collect::<Vec<CreateRecord>>();
202
203 let mut client = self.client.post(url);
204 if let Some(ref token) = self.token {
205 client = client.header("Authentication", format!("Bearer {token}"));
206 }
207
208 let insert_request = self.create_insert_request(data);
209
210 let body = serde_json::to_string(&insert_request).unwrap();
211
212 let res = client.body(body).send().await?;
213
214 if res.status() != StatusCode::OK {
215 let status = res.status();
216 let text = res.text().await?;
217
218 return Err(VectorStoreError::ExternalAPIError(status, text));
219 }
220
221 Ok(())
222 }
223}
224
225impl<M> VectorStoreIndex for MilvusVectorStore<M>
226where
227 M: EmbeddingModel,
228{
229 type Filter = Filter;
230
231 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
234 &self,
235 req: VectorSearchRequest<Filter>,
236 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
237 let embedding = self.model.embed_text(req.query()).await?;
238 let url = format!(
239 "{base_url}/v2/vectordb/entities/search",
240 base_url = self.base_url
241 );
242
243 let body = self.create_search_request(embedding.vec, &req, false);
244
245 let mut client = self.client.post(url);
246 if let Some(ref token) = self.token {
247 client = client.header("Authentication", format!("Bearer {token}"));
248 }
249
250 let body = serde_json::to_string(&body)?;
251
252 let res = client.body(body).send().await?;
253
254 if res.status() != StatusCode::OK {
255 let status = res.status();
256 let text = res.text().await?;
257
258 return Err(VectorStoreError::ExternalAPIError(status, text));
259 }
260
261 let json: SearchResult<T> = res.json().await?;
262
263 let res = json
264 .data
265 .into_iter()
266 .map(|x| (x.distance, x.id.to_string(), x.document))
267 .collect();
268
269 Ok(res)
270 }
271
272 async fn top_n_ids(
275 &self,
276 req: VectorSearchRequest<Filter>,
277 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
278 let embedding = self.model.embed_text(req.query()).await?;
279 let url = format!(
280 "{base_url}/v2/vectordb/entities/search",
281 base_url = self.base_url
282 );
283
284 let body = self.create_search_request(embedding.vec, &req, true);
285
286 let mut client = self.client.post(url);
287 if let Some(ref token) = self.token {
288 client = client.header("Authentication", format!("Bearer {token}"));
289 }
290
291 let body = serde_json::to_string(&body)?;
292
293 let res = client.body(body).send().await?;
294
295 if res.status() != StatusCode::OK {
296 let status = res.status();
297 let text = res.text().await?;
298
299 return Err(VectorStoreError::ExternalAPIError(status, text));
300 }
301
302 let json: SearchResultOnlyId = res.json().await?;
303
304 let res = json
305 .data
306 .into_iter()
307 .map(|x| (x.distance, x.id.to_string()))
308 .collect();
309
310 Ok(res)
311 }
312}
313
314impl<M> VectorStoreIndexDyn for MilvusVectorStore<M>
315where
316 M: EmbeddingModel + Sync + Send,
317{
318 fn top_n<'a>(
319 &'a self,
320 req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
321 ) -> WasmBoxedFuture<'a, TopNResults> {
322 Box::pin(async move {
323 let req = req.try_map_filter(Filter::try_from)?;
324 let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
325
326 Ok(results)
327 })
328 }
329
330 fn top_n_ids<'a>(
332 &'a self,
333 req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
334 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
335 Box::pin(async move {
336 let req = req.try_map_filter(Filter::try_from)?;
337 let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
338
339 Ok(results)
340 })
341 }
342}