1use reqwest::StatusCode;
2use rig::{
3 Embed, OneOrMany,
4 embeddings::{Embedding, EmbeddingModel},
5 vector_store::{
6 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
7 },
8};
9use serde::{Deserialize, Serialize};
10
11pub struct MilvusVectorStore<M> {
13 model: M,
15 base_url: String,
16 client: reqwest::Client,
17 database_name: String,
18 collection_name: String,
19 token: Option<String>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct CreateRecord {
24 document: String,
25 embedded_text: String,
26 embedding: Vec<f64>,
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31struct InsertRequest<'a> {
32 data: Vec<CreateRecord>,
33 collection_name: &'a str,
34 db_name: &'a str,
35}
36
37#[derive(Debug, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39struct SearchRequest<'a> {
40 collection_name: &'a str,
41 db_name: &'a str,
42 data: Vec<f64>,
43 #[serde(skip_serializing_if = "String::is_empty")]
44 filter: String,
45 anns_field: &'a str,
46 limit: usize,
47 output_fields: Vec<&'a str>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52struct SearchResult<T> {
53 code: i64,
54 data: Vec<SearchResultData<T>>,
55}
56
57#[derive(Debug, Serialize, Deserialize)]
58#[serde(rename_all = "camelCase")]
59struct SearchResultData<T> {
60 id: i64,
61 distance: f64,
62 document: T,
63 embedded_text: String,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67#[serde(rename_all = "camelCase")]
68struct SearchResultOnlyId {
69 code: i64,
70 data: Vec<SearchResultDataOnlyId>,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74#[serde(rename_all = "camelCase")]
75struct SearchResultDataOnlyId {
76 id: i64,
77 distance: f64,
78}
79
80impl<M: EmbeddingModel> MilvusVectorStore<M> {
81 pub fn new(model: M, base_url: String, database_name: String, collection_name: String) -> Self {
89 Self {
90 model,
91 base_url,
92 client: reqwest::Client::new(),
93 database_name,
94 collection_name,
95 token: None,
96 }
97 }
98
99 pub fn auth(mut self, username: String, password: String) -> Self {
101 let str = format!("{username}:{password}");
102 self.token = Some(str);
103
104 self
105 }
106
107 fn create_insert_request(&self, data: Vec<CreateRecord>) -> InsertRequest<'_> {
109 InsertRequest {
110 data,
111 collection_name: &self.collection_name,
112 db_name: &self.database_name,
113 }
114 }
115
116 fn create_search_request(
118 &self,
119 data: Vec<f64>,
120 limit: usize,
121 threshold: Option<f64>,
122 ) -> SearchRequest<'_> {
123 let filter = if let Some(threshold) = threshold {
124 format!("distance >= {threshold}")
125 } else {
126 String::new()
127 };
128 SearchRequest {
129 collection_name: &self.collection_name,
130 db_name: &self.database_name,
131 data,
132 filter,
133 anns_field: "embedding",
134 limit,
135 output_fields: vec!["id", "distance", "document", "embeddedText"],
136 }
137 }
138
139 fn create_search_request_id_only(
141 &self,
142 data: Vec<f64>,
143 limit: usize,
144 threshold: Option<f64>,
145 ) -> SearchRequest<'_> {
146 let filter = if let Some(threshold) = threshold {
147 format!("distance >= {threshold}")
148 } else {
149 String::new()
150 };
151 SearchRequest {
152 collection_name: &self.collection_name,
153 db_name: &self.database_name,
154 data,
155 filter,
156 anns_field: "embedding",
157 limit,
158 output_fields: vec!["id", "distance"],
159 }
160 }
161}
162
163impl<Model> InsertDocuments for MilvusVectorStore<Model>
164where
165 Model: EmbeddingModel + Send + Sync,
166{
167 async fn insert_documents<Doc: Serialize + Embed + Send>(
168 &self,
169 documents: Vec<(Doc, OneOrMany<Embedding>)>,
170 ) -> Result<(), VectorStoreError> {
171 let url = format!(
172 "{base_url}/v2/vectordb/entities/insert",
173 base_url = self.base_url
174 );
175
176 let data = documents
177 .into_iter()
178 .map(|(document, embeddings)| {
179 let json_document: serde_json::Value = serde_json::to_value(&document)?;
180 let json_document_as_string = serde_json::to_string(&json_document)?;
181
182 let embeddings = embeddings
183 .into_iter()
184 .map(|embedding| {
185 let embedded_text = embedding.document;
186 let embedding: Vec<f64> = embedding.vec;
187
188 CreateRecord {
189 document: json_document_as_string.clone(),
190 embedded_text,
191 embedding,
192 }
193 })
194 .collect::<Vec<CreateRecord>>();
195 Ok(embeddings)
196 })
197 .collect::<Result<Vec<Vec<CreateRecord>>, VectorStoreError>>()?
198 .into_iter()
199 .flatten()
200 .collect::<Vec<CreateRecord>>();
201
202 let mut client = self.client.post(url);
203 if let Some(ref token) = self.token {
204 client = client.header("Authentication", format!("Bearer {token}"));
205 }
206
207 let insert_request = self.create_insert_request(data);
208
209 let body = serde_json::to_string(&insert_request).unwrap();
210
211 let res = client.body(body).send().await?;
212
213 if res.status() != StatusCode::OK {
214 let status = res.status();
215 let text = res.text().await?;
216
217 return Err(VectorStoreError::ExternalAPIError(status, text));
218 }
219
220 Ok(())
221 }
222}
223
224impl<M: EmbeddingModel> VectorStoreIndex for MilvusVectorStore<M> {
225 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
228 &self,
229 req: VectorSearchRequest,
230 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
231 let embedding = self.model.embed_text(req.query()).await?;
232 let url = format!(
233 "{base_url}/v2/vectordb/entities/search",
234 base_url = self.base_url
235 );
236
237 let body =
238 self.create_search_request(embedding.vec, req.samples() as usize, req.threshold());
239
240 let mut client = self.client.post(url);
241 if let Some(ref token) = self.token {
242 client = client.header("Authentication", format!("Bearer {token}"));
243 }
244
245 let body = serde_json::to_string(&body)?;
246
247 let res = client.body(body).send().await?;
248
249 if res.status() != StatusCode::OK {
250 let status = res.status();
251 let text = res.text().await?;
252
253 return Err(VectorStoreError::ExternalAPIError(status, text));
254 }
255
256 let json: SearchResult<T> = res.json().await?;
257
258 let res = json
259 .data
260 .into_iter()
261 .map(|x| (x.distance, x.id.to_string(), x.document))
262 .collect();
263
264 Ok(res)
265 }
266
267 async fn top_n_ids(
270 &self,
271 req: VectorSearchRequest,
272 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
273 let embedding = self.model.embed_text(req.query()).await?;
274 let url = format!(
275 "{base_url}/v2/vectordb/entities/search",
276 base_url = self.base_url
277 );
278
279 let body = self.create_search_request_id_only(
280 embedding.vec,
281 req.samples() as usize,
282 req.threshold(),
283 );
284
285 let mut client = self.client.post(url);
286 if let Some(ref token) = self.token {
287 client = client.header("Authentication", format!("Bearer {token}"));
288 }
289
290 let body = serde_json::to_string(&body)?;
291
292 let res = client.body(body).send().await?;
293
294 if res.status() != StatusCode::OK {
295 let status = res.status();
296 let text = res.text().await?;
297
298 return Err(VectorStoreError::ExternalAPIError(status, text));
299 }
300
301 let json: SearchResultOnlyId = res.json().await?;
302
303 let res = json
304 .data
305 .into_iter()
306 .map(|x| (x.distance, x.id.to_string()))
307 .collect();
308
309 Ok(res)
310 }
311}