1use async_trait::async_trait;
2use errors::MilvusError;
3use llm_chain::{
4 schema::Document,
5 traits::{Embeddings, VectorStore},
6};
7use milvus::{
8 client::Client as MilvusClient,
9 collection::SearchOption,
10 data::FieldColumn,
11 proto::{milvus::MutationResult, schema::i_ds::IdField},
12 value::ValueVec,
13};
14use serde::{de::DeserializeOwned, Serialize};
15use std::{collections::HashMap, marker::PhantomData, sync::Arc};
16
17pub mod errors;
18const DEFAULT_CONTENT_PAYLOAD_KEY: &str = "page_content";
19const DEFAULT_METADATA_PAYLOAD_KEY: &str = "metadata";
20
21pub struct Milvus<E, M>
22where
23 E: Embeddings,
24 M: Serialize + DeserializeOwned + Send + Sync,
25{
26 client: Arc<MilvusClient>,
27 collection_name: String,
28 vector_field_name: String,
29 payload_field_name: Option<String>,
30 content_payload_key: String,
31 metadata_payload_key: String,
32 embeddings: E,
33 _marker: PhantomData<M>,
34}
35
36impl<E, M> Milvus<E, M>
37where
38 E: Embeddings,
39 M: Serialize + DeserializeOwned + Send + Sync,
40{
41 pub fn new(
42 client: Arc<MilvusClient>,
43 collection_name: String,
44 vector_field_name: String,
45 payload_field_name: Option<String>,
46 content_payload_key: Option<String>,
47 metadata_payload_key: Option<String>,
48 embeddings: E,
49 ) -> Self {
50 Self {
51 client,
52 collection_name,
53 vector_field_name,
54 payload_field_name,
55 embeddings,
56 content_payload_key: content_payload_key
57 .unwrap_or(DEFAULT_CONTENT_PAYLOAD_KEY.to_string()),
58 metadata_payload_key: metadata_payload_key
59 .unwrap_or(DEFAULT_METADATA_PAYLOAD_KEY.to_string()),
60 _marker: Default::default(),
61 }
62 }
63
64 fn ids_from_milvus_results(
65 &self,
66 res: MutationResult,
67 ) -> Result<Vec<String>, MilvusError<E::Error>> {
68 let ids = res.i_ds.ok_or(errors::MilvusError::InsertionError)?;
69 match ids.id_field {
70 Some(IdField::IntId(arr)) => Ok(arr
71 .data
72 .into_iter()
73 .map(|x| x.to_string())
74 .collect::<Vec<String>>()),
75 Some(IdField::StrId(string_arr)) => Ok(string_arr.data),
76 None => Err(errors::MilvusError::InsertionError),
77 }
78 }
79}
80
81#[async_trait]
82impl<E, M> VectorStore<E, M> for Milvus<E, M>
83where
84 E: Embeddings + Send + Sync,
85 M: Send + Sync + Serialize + DeserializeOwned,
86{
87 type Error = errors::MilvusError<E::Error>;
88
89 async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error> {
90 let embedding_vecs = self.embeddings.embed_texts(texts.clone()).await?;
91 let collection = self
92 .client
93 .get_collection(&self.collection_name)
94 .await
95 .map_err(errors::MilvusError::Client)?;
96
97 let embed_column = FieldColumn::new(
98 collection
99 .schema()
100 .get_field(&self.vector_field_name)
101 .unwrap(),
102 embedding_vecs.into_iter().flatten().collect::<Vec<_>>(),
103 );
104
105 let milvus_results = collection.insert(vec![embed_column], None).await.unwrap();
106 collection
107 .flush()
108 .await
109 .map_err(|_| errors::MilvusError::InsertionError)?;
110 self.ids_from_milvus_results(milvus_results)
111 }
112
113 async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error> {
114 let collection = self
115 .client
116 .get_collection(&self.collection_name)
117 .await
118 .map_err(errors::MilvusError::Client)?;
119
120 let texts = documents.iter().map(|d| d.page_content.clone()).collect();
122 let embedding_vecs = self.embeddings.embed_texts(texts).await?;
123
124 let embed_column = FieldColumn::new(
126 collection
127 .schema()
128 .get_field(&self.vector_field_name)
129 .unwrap(),
130 embedding_vecs.into_iter().flatten().collect::<Vec<_>>(),
131 );
132 match &self.payload_field_name {
136 Some(payload_field_name) => {
137 let payload_column_name = collection
138 .schema()
139 .get_field(&payload_field_name)
140 .ok_or(errors::MilvusError::InvalidColumnName)?;
141 let payloads: Vec<String> = documents
142 .into_iter()
143 .map(|document| {
144 let mut payload: HashMap<String, Option<String>> = HashMap::new();
146
147 if let Some(metadata) = document.metadata {
148 let val =
149 serde_json::to_string(&metadata).map_err(Self::Error::Serde)?;
150
151 payload.insert(self.metadata_payload_key.clone(), val.into());
152 } else {
153 payload.insert(self.metadata_payload_key.clone(), None);
154 }
155 payload.insert(
156 self.content_payload_key.clone(),
157 document.page_content.clone().into(),
158 );
159 let payload =
160 serde_json::to_string(&payload).map_err(Self::Error::Serde)?;
161 Ok(payload)
162 })
163 .collect::<Result<Vec<_>, errors::MilvusError<_>>>()?;
164 let payload_column = FieldColumn::new(payload_column_name, payloads);
165 let milvus_results = collection
166 .insert(vec![embed_column, payload_column], None)
167 .await
168 .unwrap();
169
170 collection
171 .flush()
172 .await
173 .map_err(|_| errors::MilvusError::InsertionError)?;
174
175 self.ids_from_milvus_results(milvus_results)
176 }
177 None => {
178 let milvus_results = collection.insert(vec![embed_column], None).await.unwrap();
179 self.ids_from_milvus_results(milvus_results)
180 }
181 }
182 }
183
184 async fn similarity_search(
185 &self,
186 query: String,
187 limit: u32,
188 ) -> Result<Vec<Document<M>>, Self::Error> {
189 let collection = self
190 .client
191 .get_collection(&self.collection_name)
192 .await
193 .map_err(errors::MilvusError::Client)?;
194
195 let embedded_query = self.embeddings.embed_query(query).await?;
196
197 let indexes = collection
198 .describe_index(self.vector_field_name.clone())
199 .await
200 .unwrap();
201
202 let index = indexes
204 .first()
205 .ok_or(errors::MilvusError::EmptyIndexError)?;
206
207 match &self.payload_field_name {
208 Some(out_field) => {
209 let results = collection
210 .search(
211 vec![embedded_query.into()],
212 self.vector_field_name.clone(),
213 limit as i32,
214 index.params().metric_type(),
215 vec![out_field],
216 &SearchOption::default(),
217 )
218 .await
219 .map_err(Self::Error::Client)?;
220
221 let mut docs: Vec<Document<M>> = Vec::new();
223 for res in results {
224 for field in res.field.iter().filter(|f| &f.name == out_field) {
225 match &field.value {
226 ValueVec::String(val) => {
227 let payload: HashMap<String, Option<String>> =
228 serde_json::from_str(&val[0])
229 .map_err(errors::MilvusError::Serde)?;
230
231 let _metadata: Option<String> = payload .get(&self.metadata_payload_key)
234 .unwrap()
235 .clone()
236 .into();
237
238 let page_content = payload
239 .get(&self.content_payload_key)
240 .unwrap()
241 .clone()
242 .unwrap_or("".to_string());
243
244 docs.push(Document {
245 page_content: page_content,
246 metadata: None,
247 });
248 }
249 _ => return Err(errors::MilvusError::QueryError),
250 }
251 }
252 }
253 Ok(docs)
254 }
255 None => return Err(errors::MilvusError::QueryError),
256 }
257 }
258}