1mod filter;
2
3use filter::*;
4use qdrant_client::{
5 Payload, Qdrant,
6 qdrant::{
7 Filter, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder,
8 point_id::PointIdOptions,
9 },
10};
11use rig::{
12 Embed, OneOrMany,
13 embeddings::{Embedding, EmbeddingModel},
14 vector_store::{
15 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
16 },
17};
18use serde::{Deserialize, Serialize};
19use uuid::Uuid;
20
21pub struct QdrantVectorStore<M: EmbeddingModel> {
23 model: M,
25 client: Qdrant,
27 query_params: QueryPoints,
29}
30
31impl<M> QdrantVectorStore<M>
32where
33 M: EmbeddingModel,
34{
35 pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
43 Self {
44 client,
45 model,
46 query_params,
47 }
48 }
49
50 pub fn client(&self) -> &Qdrant {
51 &self.client
52 }
53
54 async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
56 let embedding = self.model.embed_text(query).await?;
57 Ok(embedding.vec.iter().map(|&x| x as f32).collect())
58 }
59
60 fn prepare_query_params(
62 &self,
63 query: Option<Query>,
64 limit: usize,
65 threshold: Option<f64>,
66 filter: Option<Filter>,
67 ) -> QueryPoints {
68 let mut params = self.query_params.clone();
69 params.query = query;
70 params.limit = Some(limit as u64);
71 params.score_threshold = threshold.map(|x| x as f32);
72 params.filter = filter;
73 params
74 }
75}
76
77impl<Model> InsertDocuments for QdrantVectorStore<Model>
78where
79 Model: EmbeddingModel + Send + Sync,
80{
81 async fn insert_documents<Doc: Serialize + Embed + Send>(
82 &self,
83 documents: Vec<(Doc, OneOrMany<Embedding>)>,
84 ) -> Result<(), VectorStoreError> {
85 let collection_name = self.query_params.collection_name.clone();
86
87 for (document, embeddings) in documents {
88 let json_document = serde_json::to_value(&document)?;
89 let doc_as_payload = Payload::try_from(json_document)
90 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
91
92 let embeddings_as_point_structs = embeddings
93 .into_iter()
94 .map(|embedding| {
95 let embedding_as_f32: Vec<f32> =
96 embedding.vec.into_iter().map(|x| x as f32).collect();
97 PointStruct::new(
98 Uuid::new_v4().to_string(),
99 embedding_as_f32,
100 doc_as_payload.clone(),
101 )
102 })
103 .collect::<Vec<PointStruct>>();
104
105 let request =
106 UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs).wait(true);
107 self.client.upsert_points(request).await.map_err(|err| {
108 VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
109 })?;
110 }
111
112 Ok(())
113 }
114}
115
116fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
118 match id.point_id_options {
119 Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
120 Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
121 None => Err(VectorStoreError::DatastoreError(
122 "Invalid point ID format".into(),
123 )),
124 }
125}
126
127impl<M> VectorStoreIndex for QdrantVectorStore<M>
128where
129 M: EmbeddingModel + std::marker::Sync + Send,
130{
131 type Filter = QdrantFilter;
132
133 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
136 &self,
137 req: VectorSearchRequest<Self::Filter>,
138 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
139 let query = match self.query_params.query {
140 Some(ref q) => Some(q.clone()),
141 None => Some(Query::new_nearest(
142 self.generate_query_vector(req.query()).await?,
143 )),
144 };
145
146 let filter = req
147 .filter()
148 .as_ref()
149 .cloned()
150 .map(QdrantFilter::interpret)
151 .transpose()?
152 .flatten();
153
154 let params =
155 self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
156
157 let result = self
158 .client
159 .query(params)
160 .await
161 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
162
163 result
164 .result
165 .into_iter()
166 .map(|item| {
167 let id =
168 stringify_id(item.id.ok_or_else(|| {
169 VectorStoreError::DatastoreError("Missing point ID".into())
170 })?)?;
171 let score = item.score as f64;
172 let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
173 Ok((score, id, payload))
174 })
175 .collect()
176 }
177
178 async fn top_n_ids(
181 &self,
182 req: VectorSearchRequest<Self::Filter>,
183 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
184 let query = match self.query_params.query {
185 Some(ref q) => Some(q.clone()),
186 None => Some(Query::new_nearest(
187 self.generate_query_vector(req.query()).await?,
188 )),
189 };
190
191 let filter = req
192 .filter()
193 .as_ref()
194 .cloned()
195 .map(QdrantFilter::interpret)
196 .transpose()?
197 .flatten();
198
199 let params =
200 self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
201
202 let points = self
203 .client
204 .query(params)
205 .await
206 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
207 .result;
208
209 points
210 .into_iter()
211 .map(|point| {
212 let id =
213 stringify_id(point.id.ok_or_else(|| {
214 VectorStoreError::DatastoreError("Missing point ID".into())
215 })?)?;
216 Ok((point.score as f64, id))
217 })
218 .collect()
219 }
220}