1mod filter;
2
3use filter::*;
4use qdrant_client::{
5 Payload, Qdrant,
6 qdrant::{
7 Filter, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder,
8 point_id::PointIdOptions,
9 },
10};
11use rig::{
12 Embed, OneOrMany,
13 embeddings::{Embedding, EmbeddingModel},
14 vector_store::{
15 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
16 },
17};
18use serde::{Deserialize, Serialize};
19use uuid::Uuid;
20
21pub struct QdrantVectorStore<M: EmbeddingModel> {
23 model: M,
25 client: Qdrant,
27 query_params: QueryPoints,
29}
30
31impl<M> QdrantVectorStore<M>
32where
33 M: EmbeddingModel,
34{
35 pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
43 Self {
44 client,
45 model,
46 query_params,
47 }
48 }
49
50 pub fn client(&self) -> &Qdrant {
51 &self.client
52 }
53
54 async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
56 let embedding = self.model.embed_text(query).await?;
57 Ok(embedding.vec.iter().map(|&x| x as f32).collect())
58 }
59
60 fn prepare_query_params(
62 &self,
63 query: Option<Query>,
64 limit: usize,
65 threshold: Option<f64>,
66 filter: Option<Filter>,
67 ) -> QueryPoints {
68 let mut params = self.query_params.clone();
69 params.query = query;
70 params.limit = Some(limit as u64);
71 params.score_threshold = threshold.map(|x| x as f32);
72 params.filter = filter;
73 params
74 }
75}
76
77impl<Model> InsertDocuments for QdrantVectorStore<Model>
78where
79 Model: EmbeddingModel + Send + Sync,
80{
81 async fn insert_documents<Doc: Serialize + Embed + Send>(
82 &self,
83 documents: Vec<(Doc, OneOrMany<Embedding>)>,
84 ) -> Result<(), VectorStoreError> {
85 let collection_name = self.query_params.collection_name.clone();
86
87 for (document, embeddings) in documents {
88 let json_document = serde_json::to_value(&document).unwrap();
89 let doc_as_payload = Payload::try_from(json_document).unwrap();
90
91 let embeddings_as_point_structs = embeddings
92 .into_iter()
93 .map(|embedding| {
94 let embedding_as_f32: Vec<f32> =
95 embedding.vec.into_iter().map(|x| x as f32).collect();
96 PointStruct::new(
97 Uuid::new_v4().to_string(),
98 embedding_as_f32,
99 doc_as_payload.clone(),
100 )
101 })
102 .collect::<Vec<PointStruct>>();
103
104 let request = UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs);
105 self.client.upsert_points(request).await.map_err(|err| {
106 VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
107 })?;
108 }
109
110 Ok(())
111 }
112}
113
114fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
116 match id.point_id_options {
117 Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
118 Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
119 None => Err(VectorStoreError::DatastoreError(
120 "Invalid point ID format".into(),
121 )),
122 }
123}
124
125impl<M> VectorStoreIndex for QdrantVectorStore<M>
126where
127 M: EmbeddingModel + std::marker::Sync + Send,
128{
129 type Filter = QdrantFilter;
130
131 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
134 &self,
135 req: VectorSearchRequest<Self::Filter>,
136 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
137 let query = match self.query_params.query {
138 Some(ref q) => Some(q.clone()),
139 None => Some(Query::new_nearest(
140 self.generate_query_vector(req.query()).await?,
141 )),
142 };
143
144 let filter = req
145 .filter()
146 .as_ref()
147 .cloned()
148 .map(QdrantFilter::interpret)
149 .transpose()?
150 .flatten();
151
152 let params =
153 self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
154
155 let result = self
156 .client
157 .query(params)
158 .await
159 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
160
161 result
162 .result
163 .into_iter()
164 .map(|item| {
165 let id =
166 stringify_id(item.id.ok_or_else(|| {
167 VectorStoreError::DatastoreError("Missing point ID".into())
168 })?)?;
169 let score = item.score as f64;
170 let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
171 Ok((score, id, payload))
172 })
173 .collect()
174 }
175
176 async fn top_n_ids(
179 &self,
180 req: VectorSearchRequest<Self::Filter>,
181 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
182 let query = match self.query_params.query {
183 Some(ref q) => Some(q.clone()),
184 None => Some(Query::new_nearest(
185 self.generate_query_vector(req.query()).await?,
186 )),
187 };
188
189 let filter = req
190 .filter()
191 .as_ref()
192 .cloned()
193 .map(QdrantFilter::interpret)
194 .transpose()?
195 .flatten();
196
197 let params =
198 self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
199
200 let points = self
201 .client
202 .query(params)
203 .await
204 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
205 .result;
206
207 points
208 .into_iter()
209 .map(|point| {
210 let id =
211 stringify_id(point.id.ok_or_else(|| {
212 VectorStoreError::DatastoreError("Missing point ID".into())
213 })?)?;
214 Ok((point.score as f64, id))
215 })
216 .collect()
217 }
218}