1use qdrant_client::{
2 Payload, Qdrant,
3 qdrant::{
4 PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder, point_id::PointIdOptions,
5 },
6};
7use rig::{
8 Embed, OneOrMany,
9 embeddings::{Embedding, EmbeddingModel},
10 vector_store::{
11 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
12 },
13};
14use serde::{Deserialize, Serialize};
15use uuid::Uuid;
16
17pub struct QdrantVectorStore<M: EmbeddingModel> {
19 model: M,
21 client: Qdrant,
23 query_params: QueryPoints,
25}
26
27impl<M> QdrantVectorStore<M>
28where
29 M: EmbeddingModel,
30{
31 pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
39 Self {
40 client,
41 model,
42 query_params,
43 }
44 }
45
46 pub fn client(&self) -> &Qdrant {
47 &self.client
48 }
49
50 async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
52 let embedding = self.model.embed_text(query).await?;
53 Ok(embedding.vec.iter().map(|&x| x as f32).collect())
54 }
55
56 fn prepare_query_params(
58 &self,
59 query: Option<Query>,
60 limit: usize,
61 threshold: Option<f64>,
62 ) -> QueryPoints {
63 let mut params = self.query_params.clone();
64 params.query = query;
65 params.limit = Some(limit as u64);
66 params.score_threshold = threshold.map(|x| x as f32);
67 params
68 }
69}
70
71impl<Model> InsertDocuments for QdrantVectorStore<Model>
72where
73 Model: EmbeddingModel + Send + Sync,
74{
75 async fn insert_documents<Doc: Serialize + Embed + Send>(
76 &self,
77 documents: Vec<(Doc, OneOrMany<Embedding>)>,
78 ) -> Result<(), VectorStoreError> {
79 let collection_name = self.query_params.collection_name.clone();
80
81 for (document, embeddings) in documents {
82 let json_document = serde_json::to_value(&document).unwrap();
83 let doc_as_payload = Payload::try_from(json_document).unwrap();
84
85 let embeddings_as_point_structs = embeddings
86 .into_iter()
87 .map(|embedding| {
88 let embedding_as_f32: Vec<f32> =
89 embedding.vec.into_iter().map(|x| x as f32).collect();
90 PointStruct::new(
91 Uuid::new_v4().to_string(),
92 embedding_as_f32,
93 doc_as_payload.clone(),
94 )
95 })
96 .collect::<Vec<PointStruct>>();
97
98 let request = UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs);
99 self.client.upsert_points(request).await.map_err(|err| {
100 VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
101 })?;
102 }
103
104 Ok(())
105 }
106}
107
108fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
110 match id.point_id_options {
111 Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
112 Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
113 None => Err(VectorStoreError::DatastoreError(
114 "Invalid point ID format".into(),
115 )),
116 }
117}
118
119impl<M> VectorStoreIndex for QdrantVectorStore<M>
120where
121 M: EmbeddingModel + std::marker::Sync + Send,
122{
123 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
126 &self,
127 req: VectorSearchRequest,
128 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
129 let query = match self.query_params.query {
130 Some(ref q) => Some(q.clone()),
131 None => Some(Query::new_nearest(
132 self.generate_query_vector(req.query()).await?,
133 )),
134 };
135
136 let params = self.prepare_query_params(query, req.samples() as usize, req.threshold());
137 let result = self
138 .client
139 .query(params)
140 .await
141 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
142
143 result
144 .result
145 .into_iter()
146 .map(|item| {
147 let id =
148 stringify_id(item.id.ok_or_else(|| {
149 VectorStoreError::DatastoreError("Missing point ID".into())
150 })?)?;
151 let score = item.score as f64;
152 let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
153 Ok((score, id, payload))
154 })
155 .collect()
156 }
157
158 async fn top_n_ids(
161 &self,
162 req: VectorSearchRequest,
163 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
164 let query = match self.query_params.query {
165 Some(ref q) => Some(q.clone()),
166 None => Some(Query::new_nearest(
167 self.generate_query_vector(req.query()).await?,
168 )),
169 };
170
171 let params = self.prepare_query_params(query, req.samples() as usize, req.threshold());
172 let points = self
173 .client
174 .query(params)
175 .await
176 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
177 .result;
178
179 points
180 .into_iter()
181 .map(|point| {
182 let id =
183 stringify_id(point.id.ok_or_else(|| {
184 VectorStoreError::DatastoreError("Missing point ID".into())
185 })?)?;
186 Ok((point.score as f64, id))
187 })
188 .collect()
189 }
190}