rig_milvus/
lib.rs

1use reqwest::StatusCode;
2use rig::{
3    Embed, OneOrMany,
4    embeddings::{Embedding, EmbeddingModel},
5    vector_store::{
6        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
7    },
8};
9use serde::{Deserialize, Serialize};
10
11/// Represents a vector store implementation using Milvus - <https://milvus.io/> as the backend.
12pub struct MilvusVectorStore<M> {
13    /// Model used to generate embeddings for the vector store
14    model: M,
15    base_url: String,
16    client: reqwest::Client,
17    database_name: String,
18    collection_name: String,
19    token: Option<String>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct CreateRecord {
24    document: String,
25    embedded_text: String,
26    embedding: Vec<f64>,
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31struct InsertRequest<'a> {
32    data: Vec<CreateRecord>,
33    collection_name: &'a str,
34    db_name: &'a str,
35}
36
37#[derive(Debug, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39struct SearchRequest<'a> {
40    collection_name: &'a str,
41    db_name: &'a str,
42    data: Vec<f64>,
43    #[serde(skip_serializing_if = "String::is_empty")]
44    filter: String,
45    anns_field: &'a str,
46    limit: usize,
47    output_fields: Vec<&'a str>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52struct SearchResult<T> {
53    code: i64,
54    data: Vec<SearchResultData<T>>,
55}
56
57#[derive(Debug, Serialize, Deserialize)]
58#[serde(rename_all = "camelCase")]
59struct SearchResultData<T> {
60    id: i64,
61    distance: f64,
62    document: T,
63    embedded_text: String,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67#[serde(rename_all = "camelCase")]
68struct SearchResultOnlyId {
69    code: i64,
70    data: Vec<SearchResultDataOnlyId>,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74#[serde(rename_all = "camelCase")]
75struct SearchResultDataOnlyId {
76    id: i64,
77    distance: f64,
78}
79
80impl<M: EmbeddingModel> MilvusVectorStore<M> {
81    /// Creates a new instance of `MilvusVectorStore`.
82    ///
83    /// # Arguments
84    /// * `model` - Embedding model instance
85    /// * `base_url` - The URL of where your Milvus instance is located. Alternatively if you're using the Milvus offering provided by Zilliz, your cluster endpoint.
86    /// * `database_name` - The name of your database
87    /// * `collection_name` - The name of your collection
88    pub fn new(model: M, base_url: String, database_name: String, collection_name: String) -> Self {
89        Self {
90            model,
91            base_url,
92            client: reqwest::Client::new(),
93            database_name,
94            collection_name,
95            token: None,
96        }
97    }
98
99    /// Forms the auth token for Milvus from your username and password. Required if using a Milvus instance that requires authentication.
100    pub fn auth(mut self, username: String, password: String) -> Self {
101        let str = format!("{username}:{password}");
102        self.token = Some(str);
103
104        self
105    }
106
107    /// Creates a Milvus insertion request.
108    fn create_insert_request(&self, data: Vec<CreateRecord>) -> InsertRequest<'_> {
109        InsertRequest {
110            data,
111            collection_name: &self.collection_name,
112            db_name: &self.database_name,
113        }
114    }
115
116    /// Creates a Milvus semantic search request.
117    fn create_search_request(
118        &self,
119        data: Vec<f64>,
120        limit: usize,
121        threshold: Option<f64>,
122    ) -> SearchRequest<'_> {
123        let filter = if let Some(threshold) = threshold {
124            format!("distance >= {threshold}")
125        } else {
126            String::new()
127        };
128        SearchRequest {
129            collection_name: &self.collection_name,
130            db_name: &self.database_name,
131            data,
132            filter,
133            anns_field: "embedding",
134            limit,
135            output_fields: vec!["id", "distance", "document", "embeddedText"],
136        }
137    }
138
139    /// Creates a semantic search request, but only for IDs.
140    fn create_search_request_id_only(
141        &self,
142        data: Vec<f64>,
143        limit: usize,
144        threshold: Option<f64>,
145    ) -> SearchRequest<'_> {
146        let filter = if let Some(threshold) = threshold {
147            format!("distance >= {threshold}")
148        } else {
149            String::new()
150        };
151        SearchRequest {
152            collection_name: &self.collection_name,
153            db_name: &self.database_name,
154            data,
155            filter,
156            anns_field: "embedding",
157            limit,
158            output_fields: vec!["id", "distance"],
159        }
160    }
161}
162
163impl<Model> InsertDocuments for MilvusVectorStore<Model>
164where
165    Model: EmbeddingModel + Send + Sync,
166{
167    async fn insert_documents<Doc: Serialize + Embed + Send>(
168        &self,
169        documents: Vec<(Doc, OneOrMany<Embedding>)>,
170    ) -> Result<(), VectorStoreError> {
171        let url = format!(
172            "{base_url}/v2/vectordb/entities/insert",
173            base_url = self.base_url
174        );
175
176        let data = documents
177            .into_iter()
178            .map(|(document, embeddings)| {
179                let json_document: serde_json::Value = serde_json::to_value(&document)?;
180                let json_document_as_string = serde_json::to_string(&json_document)?;
181
182                let embeddings = embeddings
183                    .into_iter()
184                    .map(|embedding| {
185                        let embedded_text = embedding.document;
186                        let embedding: Vec<f64> = embedding.vec;
187
188                        CreateRecord {
189                            document: json_document_as_string.clone(),
190                            embedded_text,
191                            embedding,
192                        }
193                    })
194                    .collect::<Vec<CreateRecord>>();
195                Ok(embeddings)
196            })
197            .collect::<Result<Vec<Vec<CreateRecord>>, VectorStoreError>>()?
198            .into_iter()
199            .flatten()
200            .collect::<Vec<CreateRecord>>();
201
202        let mut client = self.client.post(url);
203        if let Some(ref token) = self.token {
204            client = client.header("Authentication", format!("Bearer {token}"));
205        }
206
207        let insert_request = self.create_insert_request(data);
208
209        let body = serde_json::to_string(&insert_request).unwrap();
210
211        let res = client.body(body).send().await?;
212
213        if res.status() != StatusCode::OK {
214            let status = res.status();
215            let text = res.text().await?;
216
217            return Err(VectorStoreError::ExternalAPIError(status, text));
218        }
219
220        Ok(())
221    }
222}
223
224impl<M: EmbeddingModel> VectorStoreIndex for MilvusVectorStore<M> {
225    /// Search for the top `n` nearest neighbors to the given query within the Milvus vector store.
226    /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors.
227    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
228        &self,
229        req: VectorSearchRequest,
230    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
231        let embedding = self.model.embed_text(req.query()).await?;
232        let url = format!(
233            "{base_url}/v2/vectordb/entities/search",
234            base_url = self.base_url
235        );
236
237        let body =
238            self.create_search_request(embedding.vec, req.samples() as usize, req.threshold());
239
240        let mut client = self.client.post(url);
241        if let Some(ref token) = self.token {
242            client = client.header("Authentication", format!("Bearer {token}"));
243        }
244
245        let body = serde_json::to_string(&body)?;
246
247        let res = client.body(body).send().await?;
248
249        if res.status() != StatusCode::OK {
250            let status = res.status();
251            let text = res.text().await?;
252
253            return Err(VectorStoreError::ExternalAPIError(status, text));
254        }
255
256        let json: SearchResult<T> = res.json().await?;
257
258        let res = json
259            .data
260            .into_iter()
261            .map(|x| (x.distance, x.id.to_string(), x.document))
262            .collect();
263
264        Ok(res)
265    }
266
267    /// Search for the top `n` nearest neighbors to the given query within the Milvus vector store.
268    /// Returns a vector of tuples containing the score and ID of the nearest neighbors.
269    async fn top_n_ids(
270        &self,
271        req: VectorSearchRequest,
272    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
273        let embedding = self.model.embed_text(req.query()).await?;
274        let url = format!(
275            "{base_url}/v2/vectordb/entities/search",
276            base_url = self.base_url
277        );
278
279        let body = self.create_search_request_id_only(
280            embedding.vec,
281            req.samples() as usize,
282            req.threshold(),
283        );
284
285        let mut client = self.client.post(url);
286        if let Some(ref token) = self.token {
287            client = client.header("Authentication", format!("Bearer {token}"));
288        }
289
290        let body = serde_json::to_string(&body)?;
291
292        let res = client.body(body).send().await?;
293
294        if res.status() != StatusCode::OK {
295            let status = res.status();
296            let text = res.text().await?;
297
298            return Err(VectorStoreError::ExternalAPIError(status, text));
299        }
300
301        let json: SearchResultOnlyId = res.json().await?;
302
303        let res = json
304            .data
305            .into_iter()
306            .map(|x| (x.distance, x.id.to_string()))
307            .collect();
308
309        Ok(res)
310    }
311}