1mod filter;
11
12pub use filter::QdrantFilter;
13use qdrant_client::{
14 Payload, Qdrant,
15 qdrant::{
16 Filter, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder,
17 point_id::PointIdOptions,
18 },
19};
20use rig_core::{
21 Embed, OneOrMany,
22 embeddings::{Embedding, EmbeddingModel},
23 vector_store::{
24 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
25 },
26};
27use serde::{Deserialize, Serialize};
28use uuid::Uuid;
29
30pub struct QdrantVectorStore<M: EmbeddingModel> {
32 model: M,
34 client: Qdrant,
36 query_params: QueryPoints,
38}
39
40impl<M> QdrantVectorStore<M>
41where
42 M: EmbeddingModel,
43{
44 pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
52 Self {
53 client,
54 model,
55 query_params,
56 }
57 }
58
59 pub fn client(&self) -> &Qdrant {
60 &self.client
61 }
62
63 async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
65 let embedding = self.model.embed_text(query).await?;
66 Ok(embedding.vec.iter().map(|&x| x as f32).collect())
67 }
68
69 fn prepare_query_params(
71 &self,
72 query: Option<Query>,
73 limit: usize,
74 threshold: Option<f64>,
75 filter: Option<Filter>,
76 ) -> QueryPoints {
77 let mut params = self.query_params.clone();
78 params.query = query;
79 params.limit = Some(limit as u64);
80 params.score_threshold = threshold.map(|x| x as f32);
81 params.filter = filter;
82 params
83 }
84}
85
86impl<Model> InsertDocuments for QdrantVectorStore<Model>
87where
88 Model: EmbeddingModel + Send + Sync,
89{
90 async fn insert_documents<Doc: Serialize + Embed + Send>(
91 &self,
92 documents: Vec<(Doc, OneOrMany<Embedding>)>,
93 ) -> Result<(), VectorStoreError> {
94 let collection_name = self.query_params.collection_name.clone();
95
96 for (document, embeddings) in documents {
97 let json_document = serde_json::to_value(&document)?;
98 let doc_as_payload = Payload::try_from(json_document)
99 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
100
101 let embeddings_as_point_structs = embeddings
102 .into_iter()
103 .map(|embedding| {
104 let embedding_as_f32: Vec<f32> =
105 embedding.vec.into_iter().map(|x| x as f32).collect();
106 PointStruct::new(
107 Uuid::new_v4().to_string(),
108 embedding_as_f32,
109 doc_as_payload.clone(),
110 )
111 })
112 .collect::<Vec<PointStruct>>();
113
114 let request =
115 UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs).wait(true);
116 self.client.upsert_points(request).await.map_err(|err| {
117 VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
118 })?;
119 }
120
121 Ok(())
122 }
123}
124
125fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
127 match id.point_id_options {
128 Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
129 Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
130 None => Err(VectorStoreError::DatastoreError(
131 "Invalid point ID format".into(),
132 )),
133 }
134}
135
136impl<M> VectorStoreIndex for QdrantVectorStore<M>
137where
138 M: EmbeddingModel + std::marker::Sync + Send,
139{
140 type Filter = QdrantFilter;
141
142 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
145 &self,
146 req: VectorSearchRequest<Self::Filter>,
147 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
148 let query = match self.query_params.query {
149 Some(ref q) => Some(q.clone()),
150 None => Some(Query::new_nearest(
151 self.generate_query_vector(req.query()).await?,
152 )),
153 };
154
155 let filter = req
156 .filter()
157 .as_ref()
158 .cloned()
159 .map(QdrantFilter::interpret)
160 .transpose()?
161 .flatten();
162
163 let params =
164 self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
165
166 let result = self
167 .client
168 .query(params)
169 .await
170 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
171
172 result
173 .result
174 .into_iter()
175 .map(|item| {
176 let id =
177 stringify_id(item.id.ok_or_else(|| {
178 VectorStoreError::DatastoreError("Missing point ID".into())
179 })?)?;
180 let score = item.score as f64;
181 let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
182 Ok((score, id, payload))
183 })
184 .collect()
185 }
186
187 async fn top_n_ids(
190 &self,
191 req: VectorSearchRequest<Self::Filter>,
192 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
193 let query = match self.query_params.query {
194 Some(ref q) => Some(q.clone()),
195 None => Some(Query::new_nearest(
196 self.generate_query_vector(req.query()).await?,
197 )),
198 };
199
200 let filter = req
201 .filter()
202 .as_ref()
203 .cloned()
204 .map(QdrantFilter::interpret)
205 .transpose()?
206 .flatten();
207
208 let params =
209 self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
210
211 let points = self
212 .client
213 .query(params)
214 .await
215 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
216 .result;
217
218 points
219 .into_iter()
220 .map(|point| {
221 let id =
222 stringify_id(point.id.ok_or_else(|| {
223 VectorStoreError::DatastoreError("Missing point ID".into())
224 })?)?;
225 Ok((point.score as f64, id))
226 })
227 .collect()
228 }
229}