llm_chain_milvus/
lib.rs

1use async_trait::async_trait;
2use errors::MilvusError;
3use llm_chain::{
4    schema::Document,
5    traits::{Embeddings, VectorStore},
6};
7use milvus::{
8    client::Client as MilvusClient,
9    collection::SearchOption,
10    data::FieldColumn,
11    proto::{milvus::MutationResult, schema::i_ds::IdField},
12    value::ValueVec,
13};
14use serde::{de::DeserializeOwned, Serialize};
15use std::{collections::HashMap, marker::PhantomData, sync::Arc};
16
17pub mod errors;
18const DEFAULT_CONTENT_PAYLOAD_KEY: &str = "page_content";
19const DEFAULT_METADATA_PAYLOAD_KEY: &str = "metadata";
20
21pub struct Milvus<E, M>
22where
23    E: Embeddings,
24    M: Serialize + DeserializeOwned + Send + Sync,
25{
26    client: Arc<MilvusClient>,
27    collection_name: String,
28    vector_field_name: String,
29    payload_field_name: Option<String>,
30    content_payload_key: String,
31    metadata_payload_key: String,
32    embeddings: E,
33    _marker: PhantomData<M>,
34}
35
36impl<E, M> Milvus<E, M>
37where
38    E: Embeddings,
39    M: Serialize + DeserializeOwned + Send + Sync,
40{
41    pub fn new(
42        client: Arc<MilvusClient>,
43        collection_name: String,
44        vector_field_name: String,
45        payload_field_name: Option<String>,
46        content_payload_key: Option<String>,
47        metadata_payload_key: Option<String>,
48        embeddings: E,
49    ) -> Self {
50        Self {
51            client,
52            collection_name,
53            vector_field_name,
54            payload_field_name,
55            embeddings,
56            content_payload_key: content_payload_key
57                .unwrap_or(DEFAULT_CONTENT_PAYLOAD_KEY.to_string()),
58            metadata_payload_key: metadata_payload_key
59                .unwrap_or(DEFAULT_METADATA_PAYLOAD_KEY.to_string()),
60            _marker: Default::default(),
61        }
62    }
63
64    fn ids_from_milvus_results(
65        &self,
66        res: MutationResult,
67    ) -> Result<Vec<String>, MilvusError<E::Error>> {
68        let ids = res.i_ds.ok_or(errors::MilvusError::InsertionError)?;
69        match ids.id_field {
70            Some(IdField::IntId(arr)) => Ok(arr
71                .data
72                .into_iter()
73                .map(|x| x.to_string())
74                .collect::<Vec<String>>()),
75            Some(IdField::StrId(string_arr)) => Ok(string_arr.data),
76            None => Err(errors::MilvusError::InsertionError),
77        }
78    }
79}
80
81#[async_trait]
82impl<E, M> VectorStore<E, M> for Milvus<E, M>
83where
84    E: Embeddings + Send + Sync,
85    M: Send + Sync + Serialize + DeserializeOwned,
86{
87    type Error = errors::MilvusError<E::Error>;
88
89    async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error> {
90        let embedding_vecs = self.embeddings.embed_texts(texts.clone()).await?;
91        let collection = self
92            .client
93            .get_collection(&self.collection_name)
94            .await
95            .map_err(errors::MilvusError::Client)?;
96
97        let embed_column = FieldColumn::new(
98            collection
99                .schema()
100                .get_field(&self.vector_field_name)
101                .unwrap(),
102            embedding_vecs.into_iter().flatten().collect::<Vec<_>>(),
103        );
104
105        let milvus_results = collection.insert(vec![embed_column], None).await.unwrap();
106        collection
107            .flush()
108            .await
109            .map_err(|_| errors::MilvusError::InsertionError)?;
110        self.ids_from_milvus_results(milvus_results)
111    }
112
113    async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error> {
114        let collection = self
115            .client
116            .get_collection(&self.collection_name)
117            .await
118            .map_err(errors::MilvusError::Client)?;
119
120        // Embedding documents' text
121        let texts = documents.iter().map(|d| d.page_content.clone()).collect();
122        let embedding_vecs = self.embeddings.embed_texts(texts).await?;
123
124        // Construct Milvus vector column
125        let embed_column = FieldColumn::new(
126            collection
127                .schema()
128                .get_field(&self.vector_field_name)
129                .unwrap(),
130            embedding_vecs.into_iter().flatten().collect::<Vec<_>>(),
131        );
132        // Inserting document in Milvus collection
133        // Note: To insert document metadata we need to be sure that
134        // the collection has a column `Datatype.JSON`
135        match &self.payload_field_name {
136            Some(payload_field_name) => {
137                let payload_column_name = collection
138                    .schema()
139                    .get_field(&payload_field_name)
140                    .ok_or(errors::MilvusError::InvalidColumnName)?;
141                let payloads: Vec<String> = documents
142                    .into_iter()
143                    .map(|document| {
144                        // Construct the
145                        let mut payload: HashMap<String, Option<String>> = HashMap::new();
146
147                        if let Some(metadata) = document.metadata {
148                            let val =
149                                serde_json::to_string(&metadata).map_err(Self::Error::Serde)?;
150
151                            payload.insert(self.metadata_payload_key.clone(), val.into());
152                        } else {
153                            payload.insert(self.metadata_payload_key.clone(), None);
154                        }
155                        payload.insert(
156                            self.content_payload_key.clone(),
157                            document.page_content.clone().into(),
158                        );
159                        let payload =
160                            serde_json::to_string(&payload).map_err(Self::Error::Serde)?;
161                        Ok(payload)
162                    })
163                    .collect::<Result<Vec<_>, errors::MilvusError<_>>>()?;
164                let payload_column = FieldColumn::new(payload_column_name, payloads);
165                let milvus_results = collection
166                    .insert(vec![embed_column, payload_column], None)
167                    .await
168                    .unwrap();
169
170                collection
171                    .flush()
172                    .await
173                    .map_err(|_| errors::MilvusError::InsertionError)?;
174
175                self.ids_from_milvus_results(milvus_results)
176            }
177            None => {
178                let milvus_results = collection.insert(vec![embed_column], None).await.unwrap();
179                self.ids_from_milvus_results(milvus_results)
180            }
181        }
182    }
183
184    async fn similarity_search(
185        &self,
186        query: String,
187        limit: u32,
188    ) -> Result<Vec<Document<M>>, Self::Error> {
189        let collection = self
190            .client
191            .get_collection(&self.collection_name)
192            .await
193            .map_err(errors::MilvusError::Client)?;
194
195        let embedded_query = self.embeddings.embed_query(query).await?;
196
197        let indexes = collection
198            .describe_index(self.vector_field_name.clone())
199            .await
200            .unwrap();
201
202        // Take the first index for now
203        let index = indexes
204            .first()
205            .ok_or(errors::MilvusError::EmptyIndexError)?;
206
207        match &self.payload_field_name {
208            Some(out_field) => {
209                let results = collection
210                    .search(
211                        vec![embedded_query.into()],
212                        self.vector_field_name.clone(),
213                        limit as i32,
214                        index.params().metric_type(),
215                        vec![out_field],
216                        &SearchOption::default(),
217                    )
218                    .await
219                    .map_err(Self::Error::Client)?;
220
221                // Convert Results to docs
222                let mut docs: Vec<Document<M>> = Vec::new();
223                for res in results {
224                    for field in res.field.iter().filter(|f| &f.name == out_field) {
225                        match &field.value {
226                            ValueVec::String(val) => {
227                                let payload: HashMap<String, Option<String>> =
228                                    serde_json::from_str(&val[0])
229                                        .map_err(errors::MilvusError::Serde)?;
230
231                                let _metadata: Option<String> = payload // XXX: temp fix since the
232                                                                       // var is not used rn
233                                    .get(&self.metadata_payload_key)
234                                    .unwrap()
235                                    .clone()
236                                    .into();
237
238                                let page_content = payload
239                                    .get(&self.content_payload_key)
240                                    .unwrap()
241                                    .clone()
242                                    .unwrap_or("".to_string());
243
244                                docs.push(Document {
245                                    page_content: page_content,
246                                    metadata: None,
247                                });
248                            }
249                            _ => return Err(errors::MilvusError::QueryError),
250                        }
251                    }
252                }
253                Ok(docs)
254            }
255            None => return Err(errors::MilvusError::QueryError),
256        }
257    }
258}