ai_chain_qdrant/
lib.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3use async_trait::async_trait;
4use qdrant_client::{
5    prelude::QdrantClient,
6    qdrant::{
7        value::Kind, with_payload_selector::SelectorOptions, Filter, PayloadIncludeSelector,
8        PointId, PointStruct, ScoredPoint, SearchPoints, Value, Vectors, WithPayloadSelector,
9    },
10};
11use thiserror::Error;
12use uuid::Uuid;
13
14use ai_chain::{
15    schema::Document,
16    traits::{Embeddings, EmbeddingsError, VectorStore, VectorStoreError},
17};
18
19use serde::{de::DeserializeOwned, Serialize};
20
21const DEFAULT_CONTENT_PAYLOAD_KEY: &str = "page_content";
22const DEFAULT_METADATA_PAYLOAD_KEY: &str = "metadata";
23
24pub struct Qdrant<E, M>
25where
26    E: Embeddings,
27    M: Serialize + DeserializeOwned + Send + Sync,
28{
29    client: Arc<QdrantClient>,
30    collection_name: String,
31    embeddings: E,
32    content_payload_key: String,
33    metadata_payload_key: String,
34    filter: Option<Filter>,
35    _marker: PhantomData<M>,
36}
37
38impl<E, M> Qdrant<E, M>
39where
40    E: Embeddings,
41    M: Send + Sync + Serialize + DeserializeOwned,
42{
43    pub fn new(
44        client: Arc<QdrantClient>,
45        collection_name: String,
46        embeddings: E,
47        content_payload_key: Option<String>,
48        metadata_payload_key: Option<String>,
49        filter: Option<Filter>,
50    ) -> Self {
51        Qdrant {
52            client,
53            collection_name,
54            embeddings,
55            content_payload_key: content_payload_key
56                .unwrap_or(DEFAULT_CONTENT_PAYLOAD_KEY.to_string()),
57            metadata_payload_key: metadata_payload_key
58                .unwrap_or(DEFAULT_METADATA_PAYLOAD_KEY.to_string()),
59            filter,
60            _marker: Default::default(),
61        }
62    }
63
64    fn try_document_from_scored_point(
65        &self,
66        scored_point: ScoredPoint,
67    ) -> Result<Document<M>, QdrantError<E::Error>> {
68        let metadata = scored_point.payload.get(&self.metadata_payload_key);
69        let metadata: Option<M> = match metadata.cloned() {
70            Some(val) => {
71                let j = serde_json::to_value(val).map_err(QdrantError::Serde)?;
72                Some(serde_json::from_value(j).map_err(QdrantError::Serde)?)
73            }
74            None => None,
75        };
76        let page_content = scored_point
77            .payload
78            .get(&self.content_payload_key)
79            .ok_or::<QdrantError<E::Error>>(
80                ConversionError::PayloadKeyNotFound {
81                    payload_key: self.content_payload_key.clone(),
82                    point_id: scored_point.id.clone(),
83                }
84                .into(),
85            )?
86            .kind
87            .clone()
88            .ok_or::<QdrantError<E::Error>>(
89                ConversionError::InvalidPageContent {
90                    point_id: scored_point.id.clone(),
91                }
92                .into(),
93            )?;
94        if let Kind::StringValue(page_content) = page_content {
95            Ok(Document {
96                page_content,
97                metadata,
98            })
99        } else {
100            Err(ConversionError::InvalidPageContent {
101                point_id: scored_point.id,
102            }
103            .into())
104        }
105    }
106}
107
108#[derive(Debug, Error)]
109pub enum ConversionError {
110    #[error("Qdrant: Payload key {payload_key:?} not found in Scored Point with ID: {point_id:?}")]
111    PayloadKeyNotFound {
112        payload_key: String,
113        point_id: Option<PointId>,
114    },
115    #[error("Page content was not a valid string. Point ID: {point_id:?}")]
116    InvalidPageContent { point_id: Option<PointId> },
117    #[error("Could not convert metadata. Point ID: {point_id:?}")]
118    InvalidMetadata { point_id: Option<PointId> },
119}
120
121#[derive(Debug, Error)]
122pub enum QdrantError<E>
123where
124    E: std::fmt::Debug + std::error::Error + EmbeddingsError,
125{
126    #[error(transparent)]
127    Embeddings(#[from] E),
128    #[error("Qdrant Client Error")]
129    Client(anyhow::Error),
130    #[error(transparent)]
131    ConversionError(#[from] ConversionError),
132    #[error("Serde Error")]
133    Serde(serde_json::Error),
134}
135
136impl<E> VectorStoreError for QdrantError<E> where
137    E: std::fmt::Debug + std::error::Error + EmbeddingsError
138{
139}
140
141#[async_trait]
142impl<E, M> VectorStore<E, M> for Qdrant<E, M>
143where
144    E: Embeddings + Send + Sync,
145    M: Send + Sync + Serialize + DeserializeOwned,
146{
147    type Error = QdrantError<E::Error>;
148
149    async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error> {
150        let embedding_vecs = self.embeddings.embed_texts(texts.clone()).await?;
151
152        let ids = (0..embedding_vecs.len())
153            .map(|_| Uuid::new_v4().to_string())
154            .collect::<Vec<String>>();
155        let points = embedding_vecs
156            .into_iter()
157            .zip(texts.into_iter())
158            .zip(ids.iter())
159            .map(|((vec, text), uuid)| {
160                let mut payload = HashMap::new();
161                payload.insert(self.content_payload_key.clone(), text.into());
162                PointStruct {
163                    id: Some(uuid.to_string().into()),
164                    payload,
165                    vectors: Some(Vectors::from(vec)),
166                }
167            })
168            .collect();
169        self.client
170            .upsert_points(&self.collection_name, None, points, None)
171            .await
172            .map_err(QdrantError::Client)?;
173        Ok(ids)
174    }
175
176    async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error> {
177        let texts = documents.iter().map(|d| d.page_content.clone()).collect();
178        let embedding_vecs = self.embeddings.embed_texts(texts).await?;
179
180        let ids = (0..embedding_vecs.len())
181            .map(|_| Uuid::new_v4().to_string())
182            .collect::<Vec<String>>();
183
184        let points: Result<Vec<PointStruct>, Self::Error> = embedding_vecs
185            .into_iter()
186            .zip(documents.into_iter())
187            .zip(ids.iter())
188            .map(|((vec, document), uuid)| {
189                let mut payload: HashMap<String, Value> = HashMap::new();
190
191                if let Some(metadata) = document.metadata {
192                    let val = serde_json::to_value(metadata).map_err(Self::Error::Serde)?;
193                    payload.insert(self.metadata_payload_key.clone(), val.into());
194                } else {
195                    payload.insert(self.metadata_payload_key.clone(), Value { kind: None });
196                }
197                payload.insert(
198                    self.content_payload_key.clone(),
199                    document.page_content.clone().into(),
200                );
201                Ok(PointStruct {
202                    id: Some(uuid.to_string().into()),
203                    payload,
204                    vectors: Some(Vectors::from(vec)),
205                })
206            })
207            .collect();
208
209        let points = points?;
210
211        self.client
212            .upsert_points(self.collection_name.clone(), None, points, None)
213            .await
214            .map_err(QdrantError::Client)?;
215
216        Ok(ids)
217    }
218
219    async fn similarity_search(
220        &self,
221        query: String,
222        limit: u32,
223    ) -> Result<Vec<Document<M>>, Self::Error> {
224        let embedded_query = self.embeddings.embed_query(query).await?;
225        let res = self
226            .client
227            .search_points(&SearchPoints {
228                timeout: None,
229                shard_key_selector: None,
230                sparse_indices: None,
231                collection_name: self.collection_name.clone(),
232                vector: embedded_query,
233                filter: self.filter.clone(),
234                limit: limit.into(),
235                with_payload: Some(WithPayloadSelector {
236                    selector_options: Some(SelectorOptions::Include(PayloadIncludeSelector {
237                        fields: vec![
238                            self.content_payload_key.clone(),
239                            self.metadata_payload_key.clone(),
240                        ],
241                    })),
242                }),
243                params: None,
244                score_threshold: None,
245                offset: None,
246                vector_name: None,
247                with_vectors: None,
248                read_consistency: None,
249            })
250            .await
251            .map_err(QdrantError::Client)?;
252
253        let mut out = vec![];
254        for r in res.result.into_iter() {
255            let val = self.try_document_from_scored_point(r)?;
256            out.push(val);
257        }
258        Ok(out)
259    }
260}