1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3use async_trait::async_trait;
4use qdrant_client::{
5 prelude::QdrantClient,
6 qdrant::{
7 value::Kind, with_payload_selector::SelectorOptions, Filter, PayloadIncludeSelector,
8 PointId, PointStruct, ScoredPoint, SearchPoints, Value, Vectors, WithPayloadSelector,
9 },
10};
11use thiserror::Error;
12use uuid::Uuid;
13
14use ai_chain::{
15 schema::Document,
16 traits::{Embeddings, EmbeddingsError, VectorStore, VectorStoreError},
17};
18
19use serde::{de::DeserializeOwned, Serialize};
20
21const DEFAULT_CONTENT_PAYLOAD_KEY: &str = "page_content";
22const DEFAULT_METADATA_PAYLOAD_KEY: &str = "metadata";
23
24pub struct Qdrant<E, M>
25where
26 E: Embeddings,
27 M: Serialize + DeserializeOwned + Send + Sync,
28{
29 client: Arc<QdrantClient>,
30 collection_name: String,
31 embeddings: E,
32 content_payload_key: String,
33 metadata_payload_key: String,
34 filter: Option<Filter>,
35 _marker: PhantomData<M>,
36}
37
38impl<E, M> Qdrant<E, M>
39where
40 E: Embeddings,
41 M: Send + Sync + Serialize + DeserializeOwned,
42{
43 pub fn new(
44 client: Arc<QdrantClient>,
45 collection_name: String,
46 embeddings: E,
47 content_payload_key: Option<String>,
48 metadata_payload_key: Option<String>,
49 filter: Option<Filter>,
50 ) -> Self {
51 Qdrant {
52 client,
53 collection_name,
54 embeddings,
55 content_payload_key: content_payload_key
56 .unwrap_or(DEFAULT_CONTENT_PAYLOAD_KEY.to_string()),
57 metadata_payload_key: metadata_payload_key
58 .unwrap_or(DEFAULT_METADATA_PAYLOAD_KEY.to_string()),
59 filter,
60 _marker: Default::default(),
61 }
62 }
63
64 fn try_document_from_scored_point(
65 &self,
66 scored_point: ScoredPoint,
67 ) -> Result<Document<M>, QdrantError<E::Error>> {
68 let metadata = scored_point.payload.get(&self.metadata_payload_key);
69 let metadata: Option<M> = match metadata.cloned() {
70 Some(val) => {
71 let j = serde_json::to_value(val).map_err(QdrantError::Serde)?;
72 Some(serde_json::from_value(j).map_err(QdrantError::Serde)?)
73 }
74 None => None,
75 };
76 let page_content = scored_point
77 .payload
78 .get(&self.content_payload_key)
79 .ok_or::<QdrantError<E::Error>>(
80 ConversionError::PayloadKeyNotFound {
81 payload_key: self.content_payload_key.clone(),
82 point_id: scored_point.id.clone(),
83 }
84 .into(),
85 )?
86 .kind
87 .clone()
88 .ok_or::<QdrantError<E::Error>>(
89 ConversionError::InvalidPageContent {
90 point_id: scored_point.id.clone(),
91 }
92 .into(),
93 )?;
94 if let Kind::StringValue(page_content) = page_content {
95 Ok(Document {
96 page_content,
97 metadata,
98 })
99 } else {
100 Err(ConversionError::InvalidPageContent {
101 point_id: scored_point.id,
102 }
103 .into())
104 }
105 }
106}
107
108#[derive(Debug, Error)]
109pub enum ConversionError {
110 #[error("Qdrant: Payload key {payload_key:?} not found in Scored Point with ID: {point_id:?}")]
111 PayloadKeyNotFound {
112 payload_key: String,
113 point_id: Option<PointId>,
114 },
115 #[error("Page content was not a valid string. Point ID: {point_id:?}")]
116 InvalidPageContent { point_id: Option<PointId> },
117 #[error("Could not convert metadata. Point ID: {point_id:?}")]
118 InvalidMetadata { point_id: Option<PointId> },
119}
120
121#[derive(Debug, Error)]
122pub enum QdrantError<E>
123where
124 E: std::fmt::Debug + std::error::Error + EmbeddingsError,
125{
126 #[error(transparent)]
127 Embeddings(#[from] E),
128 #[error("Qdrant Client Error")]
129 Client(anyhow::Error),
130 #[error(transparent)]
131 ConversionError(#[from] ConversionError),
132 #[error("Serde Error")]
133 Serde(serde_json::Error),
134}
135
136impl<E> VectorStoreError for QdrantError<E> where
137 E: std::fmt::Debug + std::error::Error + EmbeddingsError
138{
139}
140
141#[async_trait]
142impl<E, M> VectorStore<E, M> for Qdrant<E, M>
143where
144 E: Embeddings + Send + Sync,
145 M: Send + Sync + Serialize + DeserializeOwned,
146{
147 type Error = QdrantError<E::Error>;
148
149 async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error> {
150 let embedding_vecs = self.embeddings.embed_texts(texts.clone()).await?;
151
152 let ids = (0..embedding_vecs.len())
153 .map(|_| Uuid::new_v4().to_string())
154 .collect::<Vec<String>>();
155 let points = embedding_vecs
156 .into_iter()
157 .zip(texts.into_iter())
158 .zip(ids.iter())
159 .map(|((vec, text), uuid)| {
160 let mut payload = HashMap::new();
161 payload.insert(self.content_payload_key.clone(), text.into());
162 PointStruct {
163 id: Some(uuid.to_string().into()),
164 payload,
165 vectors: Some(Vectors::from(vec)),
166 }
167 })
168 .collect();
169 self.client
170 .upsert_points(&self.collection_name, None, points, None)
171 .await
172 .map_err(QdrantError::Client)?;
173 Ok(ids)
174 }
175
176 async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error> {
177 let texts = documents.iter().map(|d| d.page_content.clone()).collect();
178 let embedding_vecs = self.embeddings.embed_texts(texts).await?;
179
180 let ids = (0..embedding_vecs.len())
181 .map(|_| Uuid::new_v4().to_string())
182 .collect::<Vec<String>>();
183
184 let points: Result<Vec<PointStruct>, Self::Error> = embedding_vecs
185 .into_iter()
186 .zip(documents.into_iter())
187 .zip(ids.iter())
188 .map(|((vec, document), uuid)| {
189 let mut payload: HashMap<String, Value> = HashMap::new();
190
191 if let Some(metadata) = document.metadata {
192 let val = serde_json::to_value(metadata).map_err(Self::Error::Serde)?;
193 payload.insert(self.metadata_payload_key.clone(), val.into());
194 } else {
195 payload.insert(self.metadata_payload_key.clone(), Value { kind: None });
196 }
197 payload.insert(
198 self.content_payload_key.clone(),
199 document.page_content.clone().into(),
200 );
201 Ok(PointStruct {
202 id: Some(uuid.to_string().into()),
203 payload,
204 vectors: Some(Vectors::from(vec)),
205 })
206 })
207 .collect();
208
209 let points = points?;
210
211 self.client
212 .upsert_points(self.collection_name.clone(), None, points, None)
213 .await
214 .map_err(QdrantError::Client)?;
215
216 Ok(ids)
217 }
218
219 async fn similarity_search(
220 &self,
221 query: String,
222 limit: u32,
223 ) -> Result<Vec<Document<M>>, Self::Error> {
224 let embedded_query = self.embeddings.embed_query(query).await?;
225 let res = self
226 .client
227 .search_points(&SearchPoints {
228 timeout: None,
229 shard_key_selector: None,
230 sparse_indices: None,
231 collection_name: self.collection_name.clone(),
232 vector: embedded_query,
233 filter: self.filter.clone(),
234 limit: limit.into(),
235 with_payload: Some(WithPayloadSelector {
236 selector_options: Some(SelectorOptions::Include(PayloadIncludeSelector {
237 fields: vec![
238 self.content_payload_key.clone(),
239 self.metadata_payload_key.clone(),
240 ],
241 })),
242 }),
243 params: None,
244 score_threshold: None,
245 offset: None,
246 vector_name: None,
247 with_vectors: None,
248 read_consistency: None,
249 })
250 .await
251 .map_err(QdrantError::Client)?;
252
253 let mut out = vec![];
254 for r in res.result.into_iter() {
255 let val = self.try_document_from_scored_point(r)?;
256 out.push(val);
257 }
258 Ok(out)
259 }
260}