Skip to main content

rig_milvus/
lib.rs

1mod filter;
2
3use reqwest::StatusCode;
4use rig::{
5    Embed, OneOrMany,
6    embeddings::{Embedding, EmbeddingModel},
7    vector_store::{
8        InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
9        request::{Filter as CoreFilter, SearchFilter, VectorSearchRequest},
10    },
11    wasm_compat::WasmBoxedFuture,
12};
13use serde::{Deserialize, Serialize};
14
15use crate::filter::Filter;
16
17/// Represents a vector store implementation using Milvus - <https://milvus.io/> as the backend.
18pub struct MilvusVectorStore<M> {
19    /// Model used to generate embeddings for the vector store
20    model: M,
21    base_url: String,
22    client: reqwest::Client,
23    database_name: String,
24    collection_name: String,
25    token: Option<String>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29pub struct CreateRecord {
30    document: String,
31    embedded_text: String,
32    embedding: Vec<f64>,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37struct InsertRequest<'a> {
38    data: Vec<CreateRecord>,
39    collection_name: &'a str,
40    db_name: &'a str,
41}
42
43#[derive(Debug, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45struct SearchRequest<'a> {
46    collection_name: &'a str,
47    db_name: &'a str,
48    data: Vec<f64>,
49    #[serde(skip_serializing_if = "String::is_empty")]
50    filter: String,
51    anns_field: &'a str,
52    limit: usize,
53    output_fields: Vec<&'a str>,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58struct SearchResult<T> {
59    code: i64,
60    data: Vec<SearchResultData<T>>,
61}
62
63#[derive(Debug, Serialize, Deserialize)]
64#[serde(rename_all = "camelCase")]
65struct SearchResultData<T> {
66    id: i64,
67    distance: f64,
68    document: T,
69    embedded_text: String,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74struct SearchResultOnlyId {
75    code: i64,
76    data: Vec<SearchResultDataOnlyId>,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct SearchResultDataOnlyId {
82    id: i64,
83    distance: f64,
84}
85
86impl<M> MilvusVectorStore<M>
87where
88    M: EmbeddingModel,
89{
90    /// Creates a new instance of `MilvusVectorStore`.
91    ///
92    /// # Arguments
93    /// * `model` - Embedding model instance
94    /// * `base_url` - The URL of where your Milvus instance is located. Alternatively if you're using the Milvus offering provided by Zilliz, your cluster endpoint.
95    /// * `database_name` - The name of your database
96    /// * `collection_name` - The name of your collection
97    pub fn new(model: M, base_url: String, database_name: String, collection_name: String) -> Self {
98        Self {
99            model,
100            base_url,
101            client: reqwest::Client::new(),
102            database_name,
103            collection_name,
104            token: None,
105        }
106    }
107
108    /// Forms the auth token for Milvus from your username and password. Required if using a Milvus instance that requires authentication.
109    pub fn auth(mut self, username: String, password: String) -> Self {
110        let str = format!("{username}:{password}");
111        self.token = Some(str);
112
113        self
114    }
115
116    /// Creates a Milvus insertion request.
117    fn create_insert_request(&self, data: Vec<CreateRecord>) -> InsertRequest<'_> {
118        InsertRequest {
119            data,
120            collection_name: &self.collection_name,
121            db_name: &self.database_name,
122        }
123    }
124
125    /// Creates a Milvus semantic search request.
126    fn create_search_request(
127        &self,
128        data: Vec<f64>,
129        req: &VectorSearchRequest<Filter>,
130        id_only: bool,
131    ) -> SearchRequest<'_> {
132        const OUTPUT_FIELDS: [&str; 4] = ["id", "distance", "document", "embeddedText"];
133        const OUTPUT_FIELDS_ID_ONLY: [&str; 2] = ["id", "distance"];
134
135        let output_fields = if id_only {
136            OUTPUT_FIELDS_ID_ONLY.to_vec()
137        } else {
138            OUTPUT_FIELDS.to_vec()
139        };
140
141        let threshold = req
142            .threshold()
143            .map(|thresh| Filter::gte("distance".into(), thresh.into()));
144
145        let filter = match (threshold, req.filter()) {
146            (Some(thresh), Some(filter)) => thresh.and(filter.clone()).into_inner(),
147            (Some(thresh), _) => thresh.into_inner(),
148            (_, Some(filter)) => filter.clone().into_inner(),
149            _ => String::new(),
150        };
151
152        SearchRequest {
153            collection_name: &self.collection_name,
154            db_name: &self.database_name,
155            data,
156            filter,
157            anns_field: "embedding",
158            limit: req.samples() as usize,
159            output_fields,
160        }
161    }
162}
163
164impl<Model> InsertDocuments for MilvusVectorStore<Model>
165where
166    Model: EmbeddingModel + Send + Sync,
167{
168    async fn insert_documents<Doc: Serialize + Embed + Send>(
169        &self,
170        documents: Vec<(Doc, OneOrMany<Embedding>)>,
171    ) -> Result<(), VectorStoreError> {
172        let url = format!(
173            "{base_url}/v2/vectordb/entities/insert",
174            base_url = self.base_url
175        );
176
177        let data = documents
178            .into_iter()
179            .map(|(document, embeddings)| {
180                let json_document: serde_json::Value = serde_json::to_value(&document)?;
181                let json_document_as_string = serde_json::to_string(&json_document)?;
182
183                let embeddings = embeddings
184                    .into_iter()
185                    .map(|embedding| {
186                        let embedded_text = embedding.document;
187                        let embedding: Vec<f64> = embedding.vec;
188
189                        CreateRecord {
190                            document: json_document_as_string.clone(),
191                            embedded_text,
192                            embedding,
193                        }
194                    })
195                    .collect::<Vec<CreateRecord>>();
196                Ok(embeddings)
197            })
198            .collect::<Result<Vec<Vec<CreateRecord>>, VectorStoreError>>()?
199            .into_iter()
200            .flatten()
201            .collect::<Vec<CreateRecord>>();
202
203        let mut client = self.client.post(url);
204        if let Some(ref token) = self.token {
205            client = client.header("Authentication", format!("Bearer {token}"));
206        }
207
208        let insert_request = self.create_insert_request(data);
209
210        let body = serde_json::to_string(&insert_request).unwrap();
211
212        let res = client.body(body).send().await?;
213
214        if res.status() != StatusCode::OK {
215            let status = res.status();
216            let text = res.text().await?;
217
218            return Err(VectorStoreError::ExternalAPIError(status, text));
219        }
220
221        Ok(())
222    }
223}
224
225impl<M> VectorStoreIndex for MilvusVectorStore<M>
226where
227    M: EmbeddingModel,
228{
229    type Filter = Filter;
230
231    /// Search for the top `n` nearest neighbors to the given query within the Milvus vector store.
232    /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors.
233    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
234        &self,
235        req: VectorSearchRequest<Filter>,
236    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
237        let embedding = self.model.embed_text(req.query()).await?;
238        let url = format!(
239            "{base_url}/v2/vectordb/entities/search",
240            base_url = self.base_url
241        );
242
243        let body = self.create_search_request(embedding.vec, &req, false);
244
245        let mut client = self.client.post(url);
246        if let Some(ref token) = self.token {
247            client = client.header("Authentication", format!("Bearer {token}"));
248        }
249
250        let body = serde_json::to_string(&body)?;
251
252        let res = client.body(body).send().await?;
253
254        if res.status() != StatusCode::OK {
255            let status = res.status();
256            let text = res.text().await?;
257
258            return Err(VectorStoreError::ExternalAPIError(status, text));
259        }
260
261        let json: SearchResult<T> = res.json().await?;
262
263        let res = json
264            .data
265            .into_iter()
266            .map(|x| (x.distance, x.id.to_string(), x.document))
267            .collect();
268
269        Ok(res)
270    }
271
272    /// Search for the top `n` nearest neighbors to the given query within the Milvus vector store.
273    /// Returns a vector of tuples containing the score and ID of the nearest neighbors.
274    async fn top_n_ids(
275        &self,
276        req: VectorSearchRequest<Filter>,
277    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
278        let embedding = self.model.embed_text(req.query()).await?;
279        let url = format!(
280            "{base_url}/v2/vectordb/entities/search",
281            base_url = self.base_url
282        );
283
284        let body = self.create_search_request(embedding.vec, &req, true);
285
286        let mut client = self.client.post(url);
287        if let Some(ref token) = self.token {
288            client = client.header("Authentication", format!("Bearer {token}"));
289        }
290
291        let body = serde_json::to_string(&body)?;
292
293        let res = client.body(body).send().await?;
294
295        if res.status() != StatusCode::OK {
296            let status = res.status();
297            let text = res.text().await?;
298
299            return Err(VectorStoreError::ExternalAPIError(status, text));
300        }
301
302        let json: SearchResultOnlyId = res.json().await?;
303
304        let res = json
305            .data
306            .into_iter()
307            .map(|x| (x.distance, x.id.to_string()))
308            .collect();
309
310        Ok(res)
311    }
312}
313
314impl<M> VectorStoreIndexDyn for MilvusVectorStore<M>
315where
316    M: EmbeddingModel + Sync + Send,
317{
318    fn top_n<'a>(
319        &'a self,
320        req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
321    ) -> WasmBoxedFuture<'a, TopNResults> {
322        Box::pin(async move {
323            let req = req.try_map_filter(Filter::try_from)?;
324            let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
325
326            Ok(results)
327        })
328    }
329
330    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
331    fn top_n_ids<'a>(
332        &'a self,
333        req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
334    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
335        Box::pin(async move {
336            let req = req.try_map_filter(Filter::try_from)?;
337            let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
338
339            Ok(results)
340        })
341    }
342}