Skip to main content

rig_qdrant/
lib.rs

1mod filter;
2
3use filter::*;
4use qdrant_client::{
5    Payload, Qdrant,
6    qdrant::{
7        Filter, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder,
8        point_id::PointIdOptions,
9    },
10};
11use rig::{
12    Embed, OneOrMany,
13    embeddings::{Embedding, EmbeddingModel},
14    vector_store::{
15        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
16    },
17};
18use serde::{Deserialize, Serialize};
19use uuid::Uuid;
20
21/// Represents a vector store implementation using Qdrant - <https://qdrant.tech/> as the backend.
22pub struct QdrantVectorStore<M: EmbeddingModel> {
23    /// Model used to generate embeddings for the vector store
24    model: M,
25    /// Client instance for Qdrant server communication
26    client: Qdrant,
27    /// Default search parameters
28    query_params: QueryPoints,
29}
30
31impl<M> QdrantVectorStore<M>
32where
33    M: EmbeddingModel,
34{
35    /// Creates a new instance of `QdrantVectorStore`.
36    ///
37    /// # Arguments
38    /// * `client` - Qdrant client instance
39    /// * `model` - Embedding model instance
40    /// * `query_params` - Search parameters for vector queries
41    ///   Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
42    pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
43        Self {
44            client,
45            model,
46            query_params,
47        }
48    }
49
50    pub fn client(&self) -> &Qdrant {
51        &self.client
52    }
53
54    /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format.
55    async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
56        let embedding = self.model.embed_text(query).await?;
57        Ok(embedding.vec.iter().map(|&x| x as f32).collect())
58    }
59
60    /// Fill in query parameters with the given query and limit.
61    fn prepare_query_params(
62        &self,
63        query: Option<Query>,
64        limit: usize,
65        threshold: Option<f64>,
66        filter: Option<Filter>,
67    ) -> QueryPoints {
68        let mut params = self.query_params.clone();
69        params.query = query;
70        params.limit = Some(limit as u64);
71        params.score_threshold = threshold.map(|x| x as f32);
72        params.filter = filter;
73        params
74    }
75}
76
77impl<Model> InsertDocuments for QdrantVectorStore<Model>
78where
79    Model: EmbeddingModel + Send + Sync,
80{
81    async fn insert_documents<Doc: Serialize + Embed + Send>(
82        &self,
83        documents: Vec<(Doc, OneOrMany<Embedding>)>,
84    ) -> Result<(), VectorStoreError> {
85        let collection_name = self.query_params.collection_name.clone();
86
87        for (document, embeddings) in documents {
88            let json_document = serde_json::to_value(&document)?;
89            let doc_as_payload = Payload::try_from(json_document)
90                .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
91
92            let embeddings_as_point_structs = embeddings
93                .into_iter()
94                .map(|embedding| {
95                    let embedding_as_f32: Vec<f32> =
96                        embedding.vec.into_iter().map(|x| x as f32).collect();
97                    PointStruct::new(
98                        Uuid::new_v4().to_string(),
99                        embedding_as_f32,
100                        doc_as_payload.clone(),
101                    )
102                })
103                .collect::<Vec<PointStruct>>();
104
105            let request =
106                UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs).wait(true);
107            self.client.upsert_points(request).await.map_err(|err| {
108                VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
109            })?;
110        }
111
112        Ok(())
113    }
114}
115
116/// Converts a `PointId` to its string representation.
117fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
118    match id.point_id_options {
119        Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
120        Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
121        None => Err(VectorStoreError::DatastoreError(
122            "Invalid point ID format".into(),
123        )),
124    }
125}
126
127impl<M> VectorStoreIndex for QdrantVectorStore<M>
128where
129    M: EmbeddingModel + std::marker::Sync + Send,
130{
131    type Filter = QdrantFilter;
132
133    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
134    /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors.
135    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
136        &self,
137        req: VectorSearchRequest<Self::Filter>,
138    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
139        let query = match self.query_params.query {
140            Some(ref q) => Some(q.clone()),
141            None => Some(Query::new_nearest(
142                self.generate_query_vector(req.query()).await?,
143            )),
144        };
145
146        let filter = req
147            .filter()
148            .as_ref()
149            .cloned()
150            .map(QdrantFilter::interpret)
151            .transpose()?
152            .flatten();
153
154        let params =
155            self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
156
157        let result = self
158            .client
159            .query(params)
160            .await
161            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
162
163        result
164            .result
165            .into_iter()
166            .map(|item| {
167                let id =
168                    stringify_id(item.id.ok_or_else(|| {
169                        VectorStoreError::DatastoreError("Missing point ID".into())
170                    })?)?;
171                let score = item.score as f64;
172                let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
173                Ok((score, id, payload))
174            })
175            .collect()
176    }
177
178    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
179    /// Returns a vector of tuples containing the score and ID of the nearest neighbors.
180    async fn top_n_ids(
181        &self,
182        req: VectorSearchRequest<Self::Filter>,
183    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
184        let query = match self.query_params.query {
185            Some(ref q) => Some(q.clone()),
186            None => Some(Query::new_nearest(
187                self.generate_query_vector(req.query()).await?,
188            )),
189        };
190
191        let filter = req
192            .filter()
193            .as_ref()
194            .cloned()
195            .map(QdrantFilter::interpret)
196            .transpose()?
197            .flatten();
198
199        let params =
200            self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
201
202        let points = self
203            .client
204            .query(params)
205            .await
206            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
207            .result;
208
209        points
210            .into_iter()
211            .map(|point| {
212                let id =
213                    stringify_id(point.id.ok_or_else(|| {
214                        VectorStoreError::DatastoreError("Missing point ID".into())
215                    })?)?;
216                Ok((point.score as f64, id))
217            })
218            .collect()
219    }
220}